985 lines
38 KiB
Python
985 lines
38 KiB
Python
"""回测系统入口。
|
||
|
||
功能:
|
||
- 读取全局配置;
|
||
- 加载股票列表与行情数据;
|
||
- 预处理并计算均线指标;
|
||
- 实例化策略并运行回测;
|
||
- 输出资金曲线 CSV、绩效指标与资金曲线图;
|
||
- 支持参数优化模式(--optimize)
|
||
- 捕获主流程异常并记录日志。
|
||
|
||
使用:
|
||
# 单策略回测
|
||
python main.py --strategy ma_cross
|
||
|
||
# 参数优化(4核并行,保存前20组)
|
||
python main.py --optimize --strategy ma_cross --jobs 4 --top 20
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import gc
|
||
import importlib
|
||
import os
|
||
import sys
|
||
from pathlib import Path
|
||
from typing import Dict, List
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
import tushare as ts
|
||
from tqdm import tqdm
|
||
|
||
from config.settings import (
|
||
DATA_DAY_DIR,
|
||
END_DATE,
|
||
INITIAL_CASH,
|
||
MIN_LISTING_DAYS,
|
||
OPTIMIZATION_METRIC,
|
||
OPTIMIZATION_N_JOBS,
|
||
OPTIMIZATION_TOP_N,
|
||
POSITION_METHOD,
|
||
RESULTS_DIR,
|
||
START_DATE,
|
||
STOCK_CODE_FILE,
|
||
STOP_LOSS_METHOD,
|
||
STOP_LOSS_PCT,
|
||
STRATEGY,
|
||
TAKE_PROFIT_PCT,
|
||
TRADING_DAYS_PER_YEAR,
|
||
TUSHARE_CALENDAR_EXCHANGE,
|
||
TUSHARE_TOKEN,
|
||
ATR_PERIOD,
|
||
ATR_MULTIPLIER,
|
||
TRAILING_PCT,
|
||
KELLY_RISK_FREE,
|
||
KELLY_MAX_FRACTION,
|
||
VOLATILITY_TARGET,
|
||
VOLATILITY_WINDOW,
|
||
BENCHMARK_FILE,
|
||
BENCHMARK_NAME,
|
||
)
|
||
from benchmark.benchmark_loader import load_benchmark, calc_benchmark_return
|
||
from utils.data_loader import load_single_stock
|
||
from utils.logger import setup_logger
|
||
from utils.performance import calc_performance
|
||
from utils.plotter import plot_equity_curve
|
||
|
||
logger = setup_logger(__name__)
|
||
|
||
|
||
def _is_valid_stock(ts_code: str) -> bool:
|
||
"""判断股票代码是否符合回测要求。
|
||
|
||
过滤规则:
|
||
- 科创板:688xxx.SH
|
||
- 北交所:8xxxxx.BJ / 4xxxxx.BJ
|
||
- ST 股票:代码中含 ST 标识(需要名称判断,这里暂时仅通过代码过滤)
|
||
"""
|
||
if not ts_code:
|
||
return False
|
||
|
||
# 科创板:688xxx.SH
|
||
if ts_code.startswith("688") and ts_code.endswith(".SH"):
|
||
logger.debug(f"过滤科创板: {ts_code}")
|
||
return False
|
||
|
||
# 北交所:8xxxxx.BJ 或 4xxxxx.BJ
|
||
if ts_code.endswith(".BJ"):
|
||
logger.debug(f"过滤北交所: {ts_code}")
|
||
return False
|
||
|
||
return True
|
||
|
||
|
||
def _filter_new_stocks(
|
||
stock_universe: List[str],
|
||
data_dir: str,
|
||
end_date: str,
|
||
min_days: int = 60,
|
||
) -> List[str]:
|
||
"""过滤新股(上市时间 < min_days 天)。
|
||
|
||
参数:
|
||
- stock_universe: 股票代码列表
|
||
- data_dir: 行情数据目录
|
||
- end_date: 回测结束日期(YYYYMMDD)
|
||
- min_days: 最小上市天数
|
||
|
||
返回:
|
||
- 过滤后的股票列表
|
||
"""
|
||
valid_stocks = []
|
||
filtered_count = 0
|
||
|
||
# 使用进度条显示过滤进度
|
||
for ts_code in tqdm(stock_universe, desc="过滤新股", unit="股"):
|
||
# 读取该股票的行情数据,获取最早交易日
|
||
df = load_single_stock(data_dir, ts_code, start_date=None, end_date=end_date)
|
||
if df.empty:
|
||
continue
|
||
|
||
# 获取最早交易日(作为上市日)
|
||
first_date_str = df["trade_date"].iloc[0]
|
||
last_date_str = df["trade_date"].iloc[-1]
|
||
|
||
# 计算交易天数
|
||
trading_days = len(df)
|
||
|
||
if trading_days < min_days:
|
||
logger.debug(f"过滤新股: {ts_code}, 上市日={first_date_str}, 交易天数={trading_days}")
|
||
filtered_count += 1
|
||
continue
|
||
|
||
valid_stocks.append(ts_code)
|
||
|
||
if filtered_count > 0:
|
||
logger.info(f"过滤新股(上市<{min_days}天): {filtered_count} 只")
|
||
|
||
return valid_stocks
|
||
|
||
|
||
def _load_stock_universe() -> List[str]:
|
||
"""加载股票代码列表。
|
||
|
||
优先从 data/code/all_stock_codes.txt 读取;
|
||
若文件不存在,则从 DATA_DAY_DIR 扫描 *_daily_data.txt 文件推断。
|
||
同时过滤:科创板、北交所、ST、新股。
|
||
"""
|
||
codes: List[str] = []
|
||
try:
|
||
if STOCK_CODE_FILE.exists():
|
||
with STOCK_CODE_FILE.open("r", encoding="utf-8") as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
# 文件格式:股票代码\t股票名称,只提取第一列
|
||
parts = line.split("\t")
|
||
code = parts[0].strip()
|
||
name = parts[1].strip() if len(parts) > 1 else ""
|
||
|
||
# 过滤:科创板、北交所
|
||
if not _is_valid_stock(code):
|
||
continue
|
||
|
||
# 过滤:ST 股票(通过名称判断)
|
||
if "ST" in name.upper():
|
||
logger.debug(f"过滤ST股票: {code} {name}")
|
||
continue
|
||
|
||
codes.append(code)
|
||
logger.info(f"从 {STOCK_CODE_FILE} 加载股票代码 {len(codes)} 只(已过滤科创板/北交所/ST)")
|
||
else:
|
||
# 扫描 data/day 目录的文件名
|
||
data_dir = str(DATA_DAY_DIR)
|
||
if not os.path.isdir(data_dir):
|
||
logger.error(f"行情目录不存在: {data_dir}")
|
||
return []
|
||
for name in os.listdir(data_dir):
|
||
if not name.endswith("_daily_data.txt"):
|
||
continue
|
||
ts_code = name.replace("_daily_data.txt", "")
|
||
# 过滤:科创板、北交所
|
||
if not _is_valid_stock(ts_code):
|
||
continue
|
||
codes.append(ts_code)
|
||
logger.info(f"从目录 {data_dir} 扫描股票代码 {len(codes)} 只(已过滤科创板/北交所)")
|
||
except Exception as e: # noqa: BLE001
|
||
logger.error(f"加载股票代码列表失败: {e}")
|
||
return []
|
||
|
||
codes = sorted(set(codes))
|
||
|
||
# 过滤新股(上市天数 < MIN_LISTING_DAYS 天)
|
||
logger.info(f"开始过滤新股...")
|
||
codes = _filter_new_stocks(
|
||
stock_universe=codes,
|
||
data_dir=str(DATA_DAY_DIR),
|
||
end_date=END_DATE,
|
||
min_days=MIN_LISTING_DAYS,
|
||
)
|
||
logger.info(f"最终股票池: {len(codes)} 只")
|
||
|
||
return codes
|
||
|
||
|
||
def _load_all_data(stock_universe: List[str]) -> Dict[str, pd.DataFrame]:
|
||
"""加载股票池内所有股票的行情数据。"""
|
||
data_dict: Dict[str, pd.DataFrame] = {}
|
||
data_dir_str = str(DATA_DAY_DIR)
|
||
failed_count = 0 # 统计加载失败的股票数量
|
||
|
||
# 使用进度条显示加载进度
|
||
for ts_code in tqdm(stock_universe, desc="加载行情数据", unit="股"):
|
||
df = load_single_stock(data_dir_str, ts_code, START_DATE, END_DATE)
|
||
if df.empty:
|
||
failed_count += 1
|
||
continue
|
||
data_dict[ts_code] = df
|
||
|
||
if not data_dict:
|
||
logger.error("无可用行情数据,退出")
|
||
else:
|
||
logger.info(f"共加载 {len(data_dict)} 只股票数据")
|
||
if failed_count > 0:
|
||
logger.warning(f"加载失败: {failed_count} 只股票(文件不存在或格式错误)")
|
||
|
||
return data_dict
|
||
|
||
|
||
def _prepare_calendar(data_dict: Dict[str, pd.DataFrame]) -> List[str]:
|
||
"""构建统一交易日历(升序)。
|
||
|
||
优先通过 Tushare trade_cal 接口获取,若失败则退回本地数据构建。
|
||
"""
|
||
# 优先使用 Tushare 交易日历
|
||
try:
|
||
token = TUSHARE_TOKEN or os.getenv("TUSHARE_TOKEN", "")
|
||
if token:
|
||
pro = ts.pro_api(token)
|
||
df_cal = pro.trade_cal(
|
||
exchange=TUSHARE_CALENDAR_EXCHANGE,
|
||
start_date=START_DATE,
|
||
end_date=END_DATE,
|
||
is_open="1",
|
||
)
|
||
if not df_cal.empty:
|
||
calendar = sorted(df_cal["cal_date"].astype(str).tolist())
|
||
logger.info(f"通过 Tushare 获取交易日历共 {len(calendar)} 个交易日")
|
||
return calendar
|
||
else:
|
||
logger.warning("TUSHARE_TOKEN 未配置,将使用本地行情日期构建交易日历")
|
||
except Exception as e: # noqa: BLE001
|
||
logger.error(f"通过 Tushare 获取交易日历失败,将使用本地行情日期构建: {e}")
|
||
|
||
# 退回:根据本地行情数据构建
|
||
all_dates: List[str] = []
|
||
for df in data_dict.values():
|
||
all_dates.extend(df["trade_date"].tolist())
|
||
if not all_dates:
|
||
return []
|
||
unique_dates = np.unique(np.array(all_dates))
|
||
calendar = sorted(unique_dates.tolist())
|
||
logger.info(f"统一交易日历共 {len(calendar)} 个交易日(本地构建)")
|
||
return calendar
|
||
|
||
|
||
def _precompute_ma(data_dict: Dict[str, pd.DataFrame], ma_short: int, ma_long: int) -> None:
|
||
"""为每只股票预先计算 MA 指标和成交量变化,避免在 on_bar 中重复计算。"""
|
||
for ts_code, df in data_dict.items():
|
||
for window in {ma_short, ma_long}:
|
||
col = f"ma_{window}"
|
||
if col not in df.columns:
|
||
df[col] = df["close"].rolling(window=window, min_periods=1).mean()
|
||
|
||
# 预计算成交量变化(避免后续重复计算)
|
||
if "vol" in df.columns and "vol_pct_change" not in df.columns:
|
||
df["vol_pct_change"] = df["vol"].pct_change()
|
||
|
||
data_dict[ts_code] = df
|
||
logger.info("已为所有股票预先计算均线指标和成交量变化")
|
||
|
||
|
||
def _precompute_date_index(data_dict: Dict[str, pd.DataFrame]) -> Dict[str, Dict[str, int]]:
|
||
"""为每只股票预先建立日期索引映射,避免在 on_bar 中重复过滤 DataFrame。
|
||
|
||
性能优化:将 df[df['trade_date'] == date] 替换为 df.loc[date_index[date]]
|
||
可将查询时间从 O(n) 降低到 O(1)。
|
||
|
||
返回:{ts_code: {date: idx}}
|
||
"""
|
||
date_index_dict = {}
|
||
for ts_code, df in data_dict.items():
|
||
# 为每只股票建立日期到索引的映射字典
|
||
date_to_idx = {}
|
||
for idx, trade_date in enumerate(df["trade_date"]):
|
||
date_to_idx[trade_date] = idx
|
||
date_index_dict[ts_code] = date_to_idx
|
||
|
||
logger.info("已为所有股票预先建立日期索引映射")
|
||
return date_index_dict
|
||
|
||
|
||
def _precompute_ocz_signals(data_dict: Dict[str, pd.DataFrame], N: int, B: float, V1: float, TOL: float, R: float, volatility_min: float, volatility_max: float) -> Dict[str, List[str]]:
|
||
"""预计算OCZ策略的回踩信号。
|
||
|
||
策略逻辑:
|
||
1. 计算阻力位(最近N日最高价)
|
||
2. 识别大阳线突破(实体占比>B%,涨幅>R%,放量>V1倍)
|
||
3. 识别首次回踩信号(回踩到阻力位±TOL%,且缩量)
|
||
|
||
返回:{date: [ts_code1, ts_code2, ...]} 每天有回踩信号的股票列表
|
||
"""
|
||
logger.info(f"开始预计算OCZ策略信号(N={N}, B={B}%, V1={V1}, TOL={TOL}%, R={R}%)...")
|
||
|
||
buy_signal_index = {}
|
||
|
||
for ts_code, df in data_dict.items():
|
||
# 检查必需列
|
||
required_cols = ['open', 'high', 'low', 'close', 'vol', 'trade_date']
|
||
if not all(col in df.columns for col in required_cols):
|
||
df["pullback_signal"] = False
|
||
continue
|
||
|
||
if len(df) < N + 2:
|
||
df["pullback_signal"] = False
|
||
continue
|
||
|
||
# 1. 计算阻力位(最近N日最高价)
|
||
df["resistance"] = df["high"].rolling(window=N, min_periods=N).max()
|
||
|
||
# 2. 计算实体和总幅
|
||
df["body"] = (df["close"] - df["open"]).abs()
|
||
df["range"] = df["high"] - df["low"]
|
||
df["body_pct"] = df["body"] / df["range"] * 100
|
||
|
||
# 3. 计算涨幅
|
||
df["return_pct"] = (df["close"] / df["close"].shift(1) - 1) * 100
|
||
|
||
# 4. 计算N日平均成交量
|
||
df["vol_ma_n"] = df["vol"].rolling(window=N, min_periods=N).mean()
|
||
|
||
# 5. 计算30日波动率
|
||
df["volatility"] = (df["high"] - df["low"]).rolling(window=30, min_periods=30).mean() / df["close"] * 100
|
||
|
||
# 6. 识别突破信号
|
||
# 突破条件:收盘价 > 前一日阻力位 AND 实体占比>B% AND 涨幅>=R% AND 放量>V1倍
|
||
breakthrough = (
|
||
(df["close"] > df["resistance"].shift(1)) &
|
||
(df["body_pct"] > B) &
|
||
(df["return_pct"] >= R) &
|
||
(df["vol"] > df["vol_ma_n"] * V1) &
|
||
(df["volatility"] >= volatility_min) &
|
||
(df["volatility"] <= volatility_max)
|
||
)
|
||
df["breakthrough"] = breakthrough
|
||
|
||
# 7. 计算距离上次突破的天数(向量化优化)
|
||
breakthrough_indices = df.index[df["breakthrough"]].tolist()
|
||
if not breakthrough_indices:
|
||
# 没有突破信号,所有天数都设为很大的值
|
||
df["bars_since_breakthrough"] = 999999
|
||
else:
|
||
# 使用 pd.Series.searchsorted 实现快速查找
|
||
breakthrough_indices_series = pd.Series(breakthrough_indices)
|
||
df["bars_since_breakthrough"] = df.index.to_series().apply(
|
||
lambda idx: idx - breakthrough_indices_series[breakthrough_indices_series <= idx].max()
|
||
if breakthrough_indices_series[breakthrough_indices_series <= idx].any()
|
||
else 999999
|
||
).values
|
||
|
||
# 8. 识别回踩信号
|
||
# 回踩条件:
|
||
# - 距离突破正好1天(bars_since_breakthrough == 1)
|
||
# - 最低价在阻力位±TOL%范围内
|
||
# - 收盘价回到阻力位之上
|
||
# - 成交量小于突破日成交量(缩量)
|
||
df["resistance_2days_ago"] = df["resistance"].shift(2)
|
||
df["vol_breakthrough"] = df["vol"].shift(1) # 突破日成交量
|
||
|
||
pullback_signal = (
|
||
(df["bars_since_breakthrough"] == 1) &
|
||
(df["low"] >= df["resistance_2days_ago"] * (1 - TOL / 100)) &
|
||
(df["low"] <= df["resistance_2days_ago"] * (1 + TOL / 100)) &
|
||
(df["close"] > df["resistance_2days_ago"]) &
|
||
(df["vol"] < df["vol_breakthrough"])
|
||
)
|
||
df["pullback_signal"] = pullback_signal
|
||
|
||
# 9. 建立买入信号索引
|
||
signal_dates = df[df["pullback_signal"] == True]["trade_date"].tolist()
|
||
for date in signal_dates:
|
||
if date not in buy_signal_index:
|
||
buy_signal_index[date] = []
|
||
buy_signal_index[date].append(ts_code)
|
||
|
||
logger.info(f"OCZ信号预计算完成,共 {len(buy_signal_index)} 个交易日有回踩信号")
|
||
return buy_signal_index
|
||
|
||
|
||
def _precompute_signals(data_dict: Dict[str, pd.DataFrame], ma_short: int, ma_long: int) -> Dict[str, List[str]]:
|
||
"""预计算交易信号(金叉+放量),存储为布尔列,避免回测时重复计算。
|
||
|
||
性能优化核心:
|
||
- 金叉信号:短期均线上穿长期均线(昨日 <= 今日 >)
|
||
- 放量信号:今日成交量 > 昨日成交量
|
||
- 买入信号:golden_cross & volume_surge
|
||
- 卖出信号:死叉或持有到期(由策略逻辑处理)
|
||
|
||
预计算后回测只需读取 'buy_signal' 列,无需每天重复计算。
|
||
|
||
返回:{date: [ts_code1, ts_code2, ...]} 每天有买入信号的股票列表
|
||
"""
|
||
logger.info("开始预计算交易信号(金叉+放量)...")
|
||
|
||
# 建立日期 -> 股票列表 的索引
|
||
buy_signal_index = {}
|
||
|
||
for ts_code, df in data_dict.items():
|
||
ma_s_col = f"ma_{ma_short}"
|
||
ma_l_col = f"ma_{ma_long}"
|
||
|
||
# 检查必需列是否存在
|
||
if ma_s_col not in df.columns or ma_l_col not in df.columns:
|
||
df["buy_signal"] = False
|
||
df["golden_cross"] = False
|
||
df["volume_surge"] = False
|
||
continue
|
||
|
||
if "vol" not in df.columns:
|
||
df["buy_signal"] = False
|
||
df["golden_cross"] = False
|
||
df["volume_surge"] = False
|
||
continue
|
||
|
||
# 1. 预计算金叉信号
|
||
# 金叉:昨日短期均线 <= 长期均线,今日短期 > 长期
|
||
prev_ma_s = df[ma_s_col].shift(1)
|
||
prev_ma_l = df[ma_l_col].shift(1)
|
||
curr_ma_s = df[ma_s_col]
|
||
curr_ma_l = df[ma_l_col]
|
||
|
||
golden_cross = (prev_ma_s <= prev_ma_l) & (curr_ma_s > curr_ma_l)
|
||
df["golden_cross"] = golden_cross
|
||
|
||
# 2. 预计算放量信号
|
||
# 放量:今日成交量增幅 > 20%(与参数优化保持一致)
|
||
# 注意:vol_pct_change 应该已经在 _precompute_ma 中计算过了
|
||
if "vol_pct_change" not in df.columns:
|
||
df["vol_pct_change"] = df["vol"].pct_change()
|
||
volume_surge = df["vol_pct_change"] > 0.2
|
||
df["volume_surge"] = volume_surge
|
||
|
||
# 3. 预计算买入信号:金叉 & 放量
|
||
df["buy_signal"] = golden_cross & volume_surge
|
||
|
||
# 4. 预计算死叉信号(用于卖出)
|
||
# 死叉:昨日短期均线 >= 长期均线,今日短期 < 长期(与参数优化保持一致)
|
||
death_cross = (prev_ma_s >= prev_ma_l) & (curr_ma_s < curr_ma_l)
|
||
df["death_cross"] = death_cross
|
||
|
||
# 5. 建立买入信号索引:只记录有信号的日期
|
||
buy_signal_dates = df[df["buy_signal"] == True]["trade_date"].tolist()
|
||
for date in buy_signal_dates:
|
||
if date not in buy_signal_index:
|
||
buy_signal_index[date] = []
|
||
buy_signal_index[date].append(ts_code)
|
||
|
||
logger.info(f"交易信号预计算完成,共 {len(buy_signal_index)} 个交易日有买入信号")
|
||
return buy_signal_index
|
||
|
||
|
||
def _create_risk_modules(risk_config: dict = None):
|
||
"""创建风险管理模块(止损、止盈、仓位管理)。
|
||
|
||
参数:
|
||
risk_config: 风控配置字典,包含 stop_loss 和 take_profit 配置
|
||
如果为 None,则使用全局默认配置
|
||
|
||
返回:
|
||
(stop_loss, take_profit, position_sizer)
|
||
"""
|
||
from risk.stop_loss import StopLoss
|
||
from risk.position_sizing import PositionSizing
|
||
|
||
# 如果没有传入风控配置,使用全局默认值
|
||
if risk_config is None:
|
||
risk_config = {
|
||
"stop_loss": {
|
||
"method": STOP_LOSS_METHOD,
|
||
"stop_pct": STOP_LOSS_PCT,
|
||
"atr_period": ATR_PERIOD,
|
||
"atr_multiplier": ATR_MULTIPLIER,
|
||
"trailing_pct": TRAILING_PCT,
|
||
},
|
||
"take_profit": {
|
||
"method": STOP_LOSS_METHOD, # 止盈默认使用与止损相同的方法
|
||
"stop_pct": TAKE_PROFIT_PCT,
|
||
"atr_period": ATR_PERIOD,
|
||
"atr_multiplier": ATR_MULTIPLIER,
|
||
"trailing_pct": TRAILING_PCT,
|
||
},
|
||
}
|
||
|
||
# 获取止损配置
|
||
stop_loss_cfg = risk_config.get("stop_loss", {})
|
||
# 创建止损管理器
|
||
stop_loss = StopLoss(
|
||
method=stop_loss_cfg.get("method", "fixed_pct"),
|
||
stop_pct=stop_loss_cfg.get("stop_pct", 0.05),
|
||
atr_multiplier=stop_loss_cfg.get("atr_multiplier", 2.0),
|
||
atr_period=stop_loss_cfg.get("atr_period", 14),
|
||
trailing_pct=stop_loss_cfg.get("trailing_pct", 0.10),
|
||
)
|
||
|
||
# 获取止盈配置
|
||
take_profit_cfg = risk_config.get("take_profit", {})
|
||
# 创建止盈管理器(使用相同的 StopLoss 类,但参数不同)
|
||
take_profit = StopLoss(
|
||
method=take_profit_cfg.get("method", "fixed_pct"),
|
||
stop_pct=take_profit_cfg.get("stop_pct", 0.10),
|
||
atr_multiplier=take_profit_cfg.get("atr_multiplier", 3.0),
|
||
atr_period=take_profit_cfg.get("atr_period", 14),
|
||
trailing_pct=take_profit_cfg.get("trailing_pct", 0.15),
|
||
)
|
||
|
||
# 创建仓位管理器(使用全局配置)
|
||
position_sizer = PositionSizing(
|
||
method=POSITION_METHOD,
|
||
max_positions=STRATEGY["params"].get("max_positions", 2),
|
||
kelly_risk_free=KELLY_RISK_FREE,
|
||
kelly_max_fraction=KELLY_MAX_FRACTION,
|
||
volatility_target=VOLATILITY_TARGET,
|
||
volatility_window=VOLATILITY_WINDOW,
|
||
)
|
||
|
||
return stop_loss, take_profit, position_sizer
|
||
|
||
|
||
def run_single_strategy_backtest(strategy_config: dict, data_dict: Dict[str, pd.DataFrame], calendar: List[str], date_index_dict: Dict[str, Dict[str, int]]):
|
||
"""运行单个策略的回测。
|
||
|
||
参数:
|
||
strategy_config: 策略配置字典 {"name": ..., "module": ..., "params": ..., "risk_control": ...}
|
||
data_dict: 股票数据字典
|
||
calendar: 交易日列表
|
||
date_index_dict: 日期索引字典
|
||
"""
|
||
try:
|
||
class_name = strategy_config["name"]
|
||
module_name = strategy_config["module"]
|
||
params = strategy_config["params"]
|
||
risk_control = strategy_config.get("risk_control", None) # 获取风控配置
|
||
|
||
logger.info("\n" + "=" * 60)
|
||
logger.info(f"开始回测策略: {class_name}")
|
||
logger.info("=" * 60)
|
||
|
||
# 动态加载策略类
|
||
module = importlib.import_module(module_name)
|
||
strategy_cls = getattr(module, class_name)
|
||
|
||
# 根据策略类型预计算信号
|
||
if class_name == "MaCrossStrategy":
|
||
# 均线交叉策略:预计算均线和信号
|
||
ma_short = params.get("ma_short", 5)
|
||
ma_long = params.get("ma_long", 20)
|
||
_precompute_ma(data_dict, ma_short=ma_short, ma_long=ma_long)
|
||
buy_signal_index = _precompute_signals(data_dict, ma_short=ma_short, ma_long=ma_long)
|
||
elif class_name == "OczStrategy":
|
||
# OCZ策略:预计算回踩信号
|
||
N = params.get("N", 30)
|
||
B = params.get("B", 60.0)
|
||
V1 = params.get("V1", 1.5)
|
||
TOL = params.get("TOL", 1.5)
|
||
R = params.get("R", 4.0)
|
||
volatility_min = params.get("volatility_min", 2.5)
|
||
volatility_max = params.get("volatility_max", 8.0)
|
||
buy_signal_index = _precompute_ocz_signals(
|
||
data_dict, N=N, B=B, V1=V1, TOL=TOL, R=R,
|
||
volatility_min=volatility_min, volatility_max=volatility_max
|
||
)
|
||
else:
|
||
logger.warning(f"未知策略类型 {class_name},不预计算信号")
|
||
buy_signal_index = {}
|
||
|
||
# 创建风险管理模块(使用策略的风控配置)
|
||
# 注意:如果策略使用自定义风控逻辑,则不创建通用风控模块
|
||
if risk_control and risk_control.get("stop_loss", {}).get("method") == "custom":
|
||
# 策略使用自定义风控逻辑,不使用通用模块
|
||
stop_loss = None
|
||
take_profit = None
|
||
logger.info("该策略使用内部自定义止损止盈逻辑")
|
||
else:
|
||
# 使用通用风控模块
|
||
stop_loss, take_profit, position_sizer = _create_risk_modules(risk_config=risk_control)
|
||
position_sizer = None # 禁用仓位管理器
|
||
|
||
# 输出当前策略参数配置
|
||
logger.info("当前策略参数配置:")
|
||
for param_name, param_value in params.items():
|
||
logger.info(f" {param_name}: {param_value}")
|
||
|
||
# 输出风控配置
|
||
if risk_control:
|
||
logger.info("当前策略风控配置:")
|
||
logger.info(f" 止损方法: {risk_control.get('stop_loss', {}).get('method', 'N/A')}")
|
||
logger.info(f" 止损比例: {risk_control.get('stop_loss', {}).get('stop_pct', 'N/A')*100:.1f}%")
|
||
logger.info(f" 止盈方法: {risk_control.get('take_profit', {}).get('method', 'N/A')}")
|
||
logger.info(f" 止盈比例: {risk_control.get('take_profit', {}).get('stop_pct', 'N/A')*100:.1f}%")
|
||
|
||
logger.info(f"初始资金: {INITIAL_CASH:,.0f}")
|
||
logger.info(f"回测区间: {START_DATE} ~ {END_DATE}")
|
||
logger.info("=" * 60)
|
||
|
||
# 实例化策略
|
||
strategy = strategy_cls(
|
||
initial_cash=INITIAL_CASH,
|
||
stop_loss=stop_loss,
|
||
take_profit=take_profit,
|
||
position_sizer=position_sizer,
|
||
date_index_dict=date_index_dict,
|
||
buy_signal_index=buy_signal_index,
|
||
**params
|
||
)
|
||
|
||
# 运行回测
|
||
equity_df = strategy.run_backtest(data_dict, calendar)
|
||
if equity_df.empty:
|
||
logger.error(f"{class_name} 资金曲线为空,回测可能失败")
|
||
return
|
||
|
||
# 保存资金曲线
|
||
results_dir = Path(RESULTS_DIR)
|
||
equity_dir = results_dir / "equity"
|
||
equity_dir.mkdir(parents=True, exist_ok=True)
|
||
csv_path = equity_dir / f"{class_name}_equity.csv"
|
||
equity_df.to_csv(csv_path, index=False, encoding="utf-8")
|
||
logger.info(f"资金曲线 CSV 保存到 {csv_path}")
|
||
|
||
# 保存交易记录
|
||
if strategy.trade_history:
|
||
trades_dir = results_dir / "trades"
|
||
trades_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
import datetime
|
||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
|
||
# 提取策略参数构造文件名
|
||
param_keys = ["ma_short", "ma_long", "N", "B", "hold_days", "position_pct_per_stock"]
|
||
param_str = "_".join([
|
||
f"{k}{v}" for k, v in params.items()
|
||
if k in param_keys
|
||
])
|
||
trades_filename = f"{timestamp}_{class_name}_{param_str}.csv"
|
||
trades_path = trades_dir / trades_filename
|
||
|
||
trades_df = pd.DataFrame(strategy.trade_history)
|
||
trades_df.to_csv(trades_path, index=False, encoding="utf-8")
|
||
logger.info(f"交易记录 CSV 保存到 {trades_path}")
|
||
logger.info(f"总交易次数: {len(strategy.trade_history)} 笔")
|
||
|
||
# 计算绩效
|
||
calc_performance(
|
||
equity_df,
|
||
trade_count=strategy.trade_count,
|
||
trade_history=strategy.trade_history,
|
||
trading_days_per_year=TRADING_DAYS_PER_YEAR,
|
||
)
|
||
|
||
# 加载和计算基准指数
|
||
benchmark_df = pd.DataFrame()
|
||
if BENCHMARK_FILE.exists():
|
||
logger.info(f"加载基准指数: {BENCHMARK_NAME}")
|
||
benchmark_data = load_benchmark(BENCHMARK_FILE, START_DATE, END_DATE)
|
||
if not benchmark_data.empty:
|
||
benchmark_df, benchmark_stats = calc_benchmark_return(benchmark_data, calendar)
|
||
|
||
logger.info("=" * 60)
|
||
logger.info(f"基准指数:{BENCHMARK_NAME}")
|
||
logger.info("=" * 60)
|
||
logger.info(f"累计收益率: {benchmark_stats['cum_return']*100:+.2f}%")
|
||
logger.info(f"年化收益率: {benchmark_stats['ann_return']*100:+.2f}%")
|
||
logger.info(f"最大回撤: {benchmark_stats['max_drawdown']*100:.2f}%")
|
||
logger.info(f"夏普比率: {benchmark_stats['sharpe']:.4f}")
|
||
logger.info("=" * 60)
|
||
|
||
# 绘制资金曲线
|
||
plots_dir = results_dir / "plots"
|
||
plots_dir.mkdir(parents=True, exist_ok=True)
|
||
plot_path = plots_dir / f"{class_name}_curve.png"
|
||
plot_equity_curve(
|
||
str(csv_path),
|
||
str(plot_path),
|
||
benchmark_df=benchmark_df if not benchmark_df.empty else None,
|
||
benchmark_name=BENCHMARK_NAME,
|
||
)
|
||
|
||
logger.info(f"{class_name} 回测完成\n")
|
||
|
||
# 清理策略对象
|
||
del strategy
|
||
del equity_df
|
||
gc.collect()
|
||
|
||
except Exception as e:
|
||
logger.error(f"{strategy_config.get('name', 'Unknown')} 回测异常: {e}", exc_info=True)
|
||
|
||
|
||
def run_backtest_mode(args):
|
||
"""单策略或多策略回测模式。"""
|
||
try:
|
||
# 1. 加载股票池
|
||
stock_universe = _load_stock_universe()
|
||
if not stock_universe:
|
||
logger.error("股票池为空,无法回测")
|
||
return
|
||
|
||
# 2. 加载行情数据
|
||
data_dict = _load_all_data(stock_universe)
|
||
if not data_dict:
|
||
return
|
||
|
||
# 3. 构建统一交易日历
|
||
calendar = _prepare_calendar(data_dict)
|
||
if not calendar:
|
||
logger.error("无法构建交易日历,退出")
|
||
return
|
||
|
||
# 4. 预计算日期索引映射(性能优化)
|
||
date_index_dict = _precompute_date_index(data_dict)
|
||
|
||
# 5. 获取需要回测的策略列表
|
||
from config.settings import STRATEGIES, STRATEGY_SWITCHES
|
||
|
||
strategies_to_run = []
|
||
for strategy_name, switch in STRATEGY_SWITCHES.items():
|
||
if switch == 1 and strategy_name in STRATEGIES:
|
||
strategy_config = {
|
||
"name": strategy_name,
|
||
"module": STRATEGIES[strategy_name]["module"],
|
||
"params": STRATEGIES[strategy_name]["params"],
|
||
"risk_control": STRATEGIES[strategy_name].get("risk_control", None), # 添加风控配置
|
||
}
|
||
strategies_to_run.append(strategy_config)
|
||
|
||
if not strategies_to_run:
|
||
logger.error("没有任何策略开启,请检查 STRATEGY_SWITCHES 配置")
|
||
return
|
||
|
||
logger.info("=" * 60)
|
||
logger.info(f"共有 {len(strategies_to_run)} 个策略需要回测")
|
||
for idx, config in enumerate(strategies_to_run, 1):
|
||
logger.info(f" {idx}. {config['name']}")
|
||
logger.info("=" * 60)
|
||
|
||
# 6. 依次运行每个策略的回测
|
||
for strategy_config in strategies_to_run:
|
||
run_single_strategy_backtest(strategy_config, data_dict, calendar, date_index_dict)
|
||
|
||
# 7. 释放内存
|
||
logger.info("所有策略回测完成,清理内存...")
|
||
del data_dict
|
||
gc.collect()
|
||
logger.info("内存清理完成")
|
||
|
||
except Exception as e:
|
||
logger.error(f"回测模式异常: {e}", exc_info=True)
|
||
|
||
|
||
def run_optimize_mode(args):
|
||
"""参数优化模式。"""
|
||
try:
|
||
from optimization.grid_search import grid_search
|
||
|
||
# 1. 加载股票池
|
||
stock_universe = _load_stock_universe()
|
||
if not stock_universe:
|
||
logger.error("股票池为空,无法优化")
|
||
return
|
||
|
||
# 2. 加载行情数据
|
||
data_dict = _load_all_data(stock_universe)
|
||
if not data_dict:
|
||
return
|
||
|
||
# 3. 构建统一交易日历
|
||
calendar = _prepare_calendar(data_dict)
|
||
if not calendar:
|
||
logger.error("无法构建交易日历,退出")
|
||
return
|
||
|
||
# 4. 预计算策略所需的指标(根据策略类型)
|
||
# 动态加载策略类
|
||
module_name = STRATEGY["module"]
|
||
class_name = STRATEGY["name"]
|
||
module = importlib.import_module(module_name)
|
||
strategy_cls = getattr(module, class_name)
|
||
|
||
# 获取策略的风控配置
|
||
from config.settings import STRATEGIES
|
||
risk_control = STRATEGIES.get(class_name, {}).get("risk_control", None)
|
||
|
||
# 根据策略类型预计算指标
|
||
if class_name == "MaCrossStrategy":
|
||
# 均线交叉策略:需要预计算所有可能的均线周期
|
||
logger.info("预计算所有可能的均线周期...")
|
||
all_ma_periods = set()
|
||
|
||
# 从参数空间提取所有均线周期
|
||
param_space = get_param_space(args.strategy)
|
||
for key in ["ma_short", "short_window"]:
|
||
if key in param_space:
|
||
values = param_space[key]
|
||
if isinstance(values, range):
|
||
all_ma_periods.update(list(values))
|
||
elif isinstance(values, (list, tuple)):
|
||
all_ma_periods.update(values)
|
||
|
||
for key in ["ma_long", "long_window"]:
|
||
if key in param_space:
|
||
values = param_space[key]
|
||
if isinstance(values, range):
|
||
all_ma_periods.update(list(values))
|
||
elif isinstance(values, (list, tuple)):
|
||
all_ma_periods.update(values)
|
||
|
||
# 预计算所有均线和成交量变化
|
||
for ts_code, df in data_dict.items():
|
||
for window in all_ma_periods:
|
||
col = f"ma_{window}"
|
||
if col not in df.columns:
|
||
df[col] = df["close"].rolling(window=window, min_periods=1).mean()
|
||
|
||
# 预计算成交量变化(避免后续重复计算)
|
||
if "vol" in df.columns and "vol_pct_change" not in df.columns:
|
||
df["vol_pct_change"] = df["vol"].pct_change()
|
||
|
||
data_dict[ts_code] = df
|
||
logger.info(f"已预计算 {len(all_ma_periods)} 个均线周期和成交量变化")
|
||
elif class_name == "OczStrategy":
|
||
# OCZ策略:不需要预计算均线,在grid_search中动态计算
|
||
logger.info("OCZ策略不需要预计算均线,将在优化过程中动态计算信号")
|
||
param_space = get_param_space(args.strategy)
|
||
else:
|
||
logger.warning(f"未知策略类型 {class_name},跳过指标预计算")
|
||
param_space = get_param_space(args.strategy)
|
||
|
||
# 5. 预计算日期索引映射(性能优化)
|
||
logger.info("预计算日期索引字典...")
|
||
date_index_dict = _precompute_date_index(data_dict)
|
||
|
||
# 7. 创建风险管理模块(使用策略的风控配置)
|
||
stop_loss, take_profit, position_sizer = _create_risk_modules(risk_config=risk_control)
|
||
# 注意:禁用 position_sizer,使用策略自带的仓位管理逻辑
|
||
position_sizer = None
|
||
|
||
# 8. 定义参数空间
|
||
|
||
# 9. 运行网格搜索
|
||
n_jobs = args.jobs if args.jobs else OPTIMIZATION_N_JOBS
|
||
top_n = args.top if args.top else OPTIMIZATION_TOP_N
|
||
metric = args.metric if args.metric else OPTIMIZATION_METRIC
|
||
|
||
# 获取约束条件
|
||
from config.settings import PARAM_CONSTRAINTS
|
||
constraint_func = PARAM_CONSTRAINTS.get(args.strategy, None)
|
||
if constraint_func is not None:
|
||
logger.info(f"使用策略 {args.strategy} 的约束条件")
|
||
|
||
results = grid_search(
|
||
strategy_class=strategy_cls,
|
||
data_dict=data_dict,
|
||
calendar=calendar,
|
||
param_space=param_space,
|
||
initial_capital=INITIAL_CASH,
|
||
n_jobs=n_jobs,
|
||
metric=metric,
|
||
top_n=top_n,
|
||
constraint_func=constraint_func,
|
||
date_index_dict=date_index_dict, # 传入日期索引
|
||
stop_loss=stop_loss, # 传入止损管理器
|
||
take_profit=take_profit, # 传入止盈管理器
|
||
)
|
||
|
||
# 8. 释放内存
|
||
logger.info("参数优化完成,清理内存...")
|
||
del data_dict
|
||
gc.collect()
|
||
logger.info("内存清理完成")
|
||
|
||
except Exception as e: # noqa: BLE001
|
||
logger.error(f"优化模式异常: {e}", exc_info=True)
|
||
|
||
|
||
def get_param_space(strategy_name: str) -> dict:
|
||
"""获取策略的参数空间定义。
|
||
|
||
参数:
|
||
strategy_name: 策略名称
|
||
|
||
返回:
|
||
参数空间字典
|
||
"""
|
||
from config.settings import PARAM_SPACES
|
||
|
||
if strategy_name in PARAM_SPACES:
|
||
param_space = PARAM_SPACES[strategy_name]
|
||
logger.info(f"使用配置文件中的参数空间: {strategy_name}")
|
||
return param_space
|
||
else:
|
||
logger.warning(f"未在配置文件中找到策略 {strategy_name} 的参数空间,使用默认值")
|
||
# 默认参数空间(兼容旧版本)
|
||
if strategy_name == "ma_cross":
|
||
return {
|
||
"ma_short": [3, 20, 2],
|
||
"ma_long": [20, 60, 5],
|
||
"hold_days": [3, 10, 1],
|
||
}
|
||
else:
|
||
return {}
|
||
|
||
|
||
def parse_args():
|
||
"""解析命令行参数。"""
|
||
parser = argparse.ArgumentParser(description="A股回测系统")
|
||
|
||
parser.add_argument(
|
||
"--optimize",
|
||
action="store_true",
|
||
help="启用参数优化模式"
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--strategy",
|
||
type=str,
|
||
default="ma_cross",
|
||
help="策略名称(仅优化模式使用,回测模式从配置文件读取)"
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--jobs",
|
||
type=int,
|
||
default=None,
|
||
help="并行进程数(优化模式)"
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--top",
|
||
type=int,
|
||
default=None,
|
||
help="保存前N个最优结果(优化模式)"
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--metric",
|
||
type=str,
|
||
default=None,
|
||
choices=["sharpe", "total_return", "max_drawdown", "annual_return"],
|
||
help="排序指标(优化模式)"
|
||
)
|
||
|
||
return parser.parse_args()
|
||
|
||
|
||
def main() -> None:
|
||
"""回测主流程入口。"""
|
||
args = parse_args()
|
||
|
||
if args.optimize:
|
||
logger.info("=" * 60)
|
||
logger.info("启动参数优化模式")
|
||
logger.info("=" * 60)
|
||
run_optimize_mode(args)
|
||
else:
|
||
logger.info("=" * 60)
|
||
logger.info("启动批量回测模式(从配置文件读取 STRATEGY_SWITCHES)")
|
||
logger.info("=" * 60)
|
||
run_backtest_mode(args)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|