Files
strategy_backtest/main.py

985 lines
38 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""回测系统入口。
功能:
- 读取全局配置;
- 加载股票列表与行情数据;
- 预处理并计算均线指标;
- 实例化策略并运行回测;
- 输出资金曲线 CSV、绩效指标与资金曲线图
- 支持参数优化模式(--optimize
- 捕获主流程异常并记录日志。
使用:
# 单策略回测
python main.py --strategy ma_cross
# 参数优化4核并行保存前20组
python main.py --optimize --strategy ma_cross --jobs 4 --top 20
"""
from __future__ import annotations
import argparse
import gc
import importlib
import os
import sys
from pathlib import Path
from typing import Dict, List
import numpy as np
import pandas as pd
import tushare as ts
from tqdm import tqdm
from config.settings import (
DATA_DAY_DIR,
END_DATE,
INITIAL_CASH,
MIN_LISTING_DAYS,
OPTIMIZATION_METRIC,
OPTIMIZATION_N_JOBS,
OPTIMIZATION_TOP_N,
POSITION_METHOD,
RESULTS_DIR,
START_DATE,
STOCK_CODE_FILE,
STOP_LOSS_METHOD,
STOP_LOSS_PCT,
STRATEGY,
TAKE_PROFIT_PCT,
TRADING_DAYS_PER_YEAR,
TUSHARE_CALENDAR_EXCHANGE,
TUSHARE_TOKEN,
ATR_PERIOD,
ATR_MULTIPLIER,
TRAILING_PCT,
KELLY_RISK_FREE,
KELLY_MAX_FRACTION,
VOLATILITY_TARGET,
VOLATILITY_WINDOW,
BENCHMARK_FILE,
BENCHMARK_NAME,
)
from benchmark.benchmark_loader import load_benchmark, calc_benchmark_return
from utils.data_loader import load_single_stock
from utils.logger import setup_logger
from utils.performance import calc_performance
from utils.plotter import plot_equity_curve
logger = setup_logger(__name__)
def _is_valid_stock(ts_code: str) -> bool:
"""判断股票代码是否符合回测要求。
过滤规则:
- 科创板688xxx.SH
- 北交所8xxxxx.BJ / 4xxxxx.BJ
- ST 股票:代码中含 ST 标识(需要名称判断,这里暂时仅通过代码过滤)
"""
if not ts_code:
return False
# 科创板688xxx.SH
if ts_code.startswith("688") and ts_code.endswith(".SH"):
logger.debug(f"过滤科创板: {ts_code}")
return False
# 北交所8xxxxx.BJ 或 4xxxxx.BJ
if ts_code.endswith(".BJ"):
logger.debug(f"过滤北交所: {ts_code}")
return False
return True
def _filter_new_stocks(
stock_universe: List[str],
data_dir: str,
end_date: str,
min_days: int = 60,
) -> List[str]:
"""过滤新股(上市时间 < min_days 天)。
参数:
- stock_universe: 股票代码列表
- data_dir: 行情数据目录
- end_date: 回测结束日期YYYYMMDD
- min_days: 最小上市天数
返回:
- 过滤后的股票列表
"""
valid_stocks = []
filtered_count = 0
# 使用进度条显示过滤进度
for ts_code in tqdm(stock_universe, desc="过滤新股", unit=""):
# 读取该股票的行情数据,获取最早交易日
df = load_single_stock(data_dir, ts_code, start_date=None, end_date=end_date)
if df.empty:
continue
# 获取最早交易日(作为上市日)
first_date_str = df["trade_date"].iloc[0]
last_date_str = df["trade_date"].iloc[-1]
# 计算交易天数
trading_days = len(df)
if trading_days < min_days:
logger.debug(f"过滤新股: {ts_code}, 上市日={first_date_str}, 交易天数={trading_days}")
filtered_count += 1
continue
valid_stocks.append(ts_code)
if filtered_count > 0:
logger.info(f"过滤新股(上市<{min_days}天): {filtered_count}")
return valid_stocks
def _load_stock_universe() -> List[str]:
"""加载股票代码列表。
优先从 data/code/all_stock_codes.txt 读取;
若文件不存在,则从 DATA_DAY_DIR 扫描 *_daily_data.txt 文件推断。
同时过滤科创板、北交所、ST、新股。
"""
codes: List[str] = []
try:
if STOCK_CODE_FILE.exists():
with STOCK_CODE_FILE.open("r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
# 文件格式:股票代码\t股票名称只提取第一列
parts = line.split("\t")
code = parts[0].strip()
name = parts[1].strip() if len(parts) > 1 else ""
# 过滤:科创板、北交所
if not _is_valid_stock(code):
continue
# 过滤ST 股票(通过名称判断)
if "ST" in name.upper():
logger.debug(f"过滤ST股票: {code} {name}")
continue
codes.append(code)
logger.info(f"{STOCK_CODE_FILE} 加载股票代码 {len(codes)} 只(已过滤科创板/北交所/ST")
else:
# 扫描 data/day 目录的文件名
data_dir = str(DATA_DAY_DIR)
if not os.path.isdir(data_dir):
logger.error(f"行情目录不存在: {data_dir}")
return []
for name in os.listdir(data_dir):
if not name.endswith("_daily_data.txt"):
continue
ts_code = name.replace("_daily_data.txt", "")
# 过滤:科创板、北交所
if not _is_valid_stock(ts_code):
continue
codes.append(ts_code)
logger.info(f"从目录 {data_dir} 扫描股票代码 {len(codes)} 只(已过滤科创板/北交所)")
except Exception as e: # noqa: BLE001
logger.error(f"加载股票代码列表失败: {e}")
return []
codes = sorted(set(codes))
# 过滤新股(上市天数 < MIN_LISTING_DAYS 天)
logger.info(f"开始过滤新股...")
codes = _filter_new_stocks(
stock_universe=codes,
data_dir=str(DATA_DAY_DIR),
end_date=END_DATE,
min_days=MIN_LISTING_DAYS,
)
logger.info(f"最终股票池: {len(codes)}")
return codes
def _load_all_data(stock_universe: List[str]) -> Dict[str, pd.DataFrame]:
"""加载股票池内所有股票的行情数据。"""
data_dict: Dict[str, pd.DataFrame] = {}
data_dir_str = str(DATA_DAY_DIR)
failed_count = 0 # 统计加载失败的股票数量
# 使用进度条显示加载进度
for ts_code in tqdm(stock_universe, desc="加载行情数据", unit=""):
df = load_single_stock(data_dir_str, ts_code, START_DATE, END_DATE)
if df.empty:
failed_count += 1
continue
data_dict[ts_code] = df
if not data_dict:
logger.error("无可用行情数据,退出")
else:
logger.info(f"共加载 {len(data_dict)} 只股票数据")
if failed_count > 0:
logger.warning(f"加载失败: {failed_count} 只股票(文件不存在或格式错误)")
return data_dict
def _prepare_calendar(data_dict: Dict[str, pd.DataFrame]) -> List[str]:
"""构建统一交易日历(升序)。
优先通过 Tushare trade_cal 接口获取,若失败则退回本地数据构建。
"""
# 优先使用 Tushare 交易日历
try:
token = TUSHARE_TOKEN or os.getenv("TUSHARE_TOKEN", "")
if token:
pro = ts.pro_api(token)
df_cal = pro.trade_cal(
exchange=TUSHARE_CALENDAR_EXCHANGE,
start_date=START_DATE,
end_date=END_DATE,
is_open="1",
)
if not df_cal.empty:
calendar = sorted(df_cal["cal_date"].astype(str).tolist())
logger.info(f"通过 Tushare 获取交易日历共 {len(calendar)} 个交易日")
return calendar
else:
logger.warning("TUSHARE_TOKEN 未配置,将使用本地行情日期构建交易日历")
except Exception as e: # noqa: BLE001
logger.error(f"通过 Tushare 获取交易日历失败,将使用本地行情日期构建: {e}")
# 退回:根据本地行情数据构建
all_dates: List[str] = []
for df in data_dict.values():
all_dates.extend(df["trade_date"].tolist())
if not all_dates:
return []
unique_dates = np.unique(np.array(all_dates))
calendar = sorted(unique_dates.tolist())
logger.info(f"统一交易日历共 {len(calendar)} 个交易日(本地构建)")
return calendar
def _precompute_ma(data_dict: Dict[str, pd.DataFrame], ma_short: int, ma_long: int) -> None:
"""为每只股票预先计算 MA 指标和成交量变化,避免在 on_bar 中重复计算。"""
for ts_code, df in data_dict.items():
for window in {ma_short, ma_long}:
col = f"ma_{window}"
if col not in df.columns:
df[col] = df["close"].rolling(window=window, min_periods=1).mean()
# 预计算成交量变化(避免后续重复计算)
if "vol" in df.columns and "vol_pct_change" not in df.columns:
df["vol_pct_change"] = df["vol"].pct_change()
data_dict[ts_code] = df
logger.info("已为所有股票预先计算均线指标和成交量变化")
def _precompute_date_index(data_dict: Dict[str, pd.DataFrame]) -> Dict[str, Dict[str, int]]:
"""为每只股票预先建立日期索引映射,避免在 on_bar 中重复过滤 DataFrame。
性能优化:将 df[df['trade_date'] == date] 替换为 df.loc[date_index[date]]
可将查询时间从 O(n) 降低到 O(1)。
返回:{ts_code: {date: idx}}
"""
date_index_dict = {}
for ts_code, df in data_dict.items():
# 为每只股票建立日期到索引的映射字典
date_to_idx = {}
for idx, trade_date in enumerate(df["trade_date"]):
date_to_idx[trade_date] = idx
date_index_dict[ts_code] = date_to_idx
logger.info("已为所有股票预先建立日期索引映射")
return date_index_dict
def _precompute_ocz_signals(data_dict: Dict[str, pd.DataFrame], N: int, B: float, V1: float, TOL: float, R: float, volatility_min: float, volatility_max: float) -> Dict[str, List[str]]:
"""预计算OCZ策略的回踩信号。
策略逻辑:
1. 计算阻力位最近N日最高价
2. 识别大阳线突破(实体占比>B%,涨幅>R%,放量>V1倍
3. 识别首次回踩信号回踩到阻力位±TOL%,且缩量)
返回:{date: [ts_code1, ts_code2, ...]} 每天有回踩信号的股票列表
"""
logger.info(f"开始预计算OCZ策略信号N={N}, B={B}%, V1={V1}, TOL={TOL}%, R={R}%...")
buy_signal_index = {}
for ts_code, df in data_dict.items():
# 检查必需列
required_cols = ['open', 'high', 'low', 'close', 'vol', 'trade_date']
if not all(col in df.columns for col in required_cols):
df["pullback_signal"] = False
continue
if len(df) < N + 2:
df["pullback_signal"] = False
continue
# 1. 计算阻力位最近N日最高价
df["resistance"] = df["high"].rolling(window=N, min_periods=N).max()
# 2. 计算实体和总幅
df["body"] = (df["close"] - df["open"]).abs()
df["range"] = df["high"] - df["low"]
df["body_pct"] = df["body"] / df["range"] * 100
# 3. 计算涨幅
df["return_pct"] = (df["close"] / df["close"].shift(1) - 1) * 100
# 4. 计算N日平均成交量
df["vol_ma_n"] = df["vol"].rolling(window=N, min_periods=N).mean()
# 5. 计算30日波动率
df["volatility"] = (df["high"] - df["low"]).rolling(window=30, min_periods=30).mean() / df["close"] * 100
# 6. 识别突破信号
# 突破条件:收盘价 > 前一日阻力位 AND 实体占比>B% AND 涨幅>=R% AND 放量>V1倍
breakthrough = (
(df["close"] > df["resistance"].shift(1)) &
(df["body_pct"] > B) &
(df["return_pct"] >= R) &
(df["vol"] > df["vol_ma_n"] * V1) &
(df["volatility"] >= volatility_min) &
(df["volatility"] <= volatility_max)
)
df["breakthrough"] = breakthrough
# 7. 计算距离上次突破的天数(向量化优化)
breakthrough_indices = df.index[df["breakthrough"]].tolist()
if not breakthrough_indices:
# 没有突破信号,所有天数都设为很大的值
df["bars_since_breakthrough"] = 999999
else:
# 使用 pd.Series.searchsorted 实现快速查找
breakthrough_indices_series = pd.Series(breakthrough_indices)
df["bars_since_breakthrough"] = df.index.to_series().apply(
lambda idx: idx - breakthrough_indices_series[breakthrough_indices_series <= idx].max()
if breakthrough_indices_series[breakthrough_indices_series <= idx].any()
else 999999
).values
# 8. 识别回踩信号
# 回踩条件:
# - 距离突破正好1天bars_since_breakthrough == 1
# - 最低价在阻力位±TOL%范围内
# - 收盘价回到阻力位之上
# - 成交量小于突破日成交量(缩量)
df["resistance_2days_ago"] = df["resistance"].shift(2)
df["vol_breakthrough"] = df["vol"].shift(1) # 突破日成交量
pullback_signal = (
(df["bars_since_breakthrough"] == 1) &
(df["low"] >= df["resistance_2days_ago"] * (1 - TOL / 100)) &
(df["low"] <= df["resistance_2days_ago"] * (1 + TOL / 100)) &
(df["close"] > df["resistance_2days_ago"]) &
(df["vol"] < df["vol_breakthrough"])
)
df["pullback_signal"] = pullback_signal
# 9. 建立买入信号索引
signal_dates = df[df["pullback_signal"] == True]["trade_date"].tolist()
for date in signal_dates:
if date not in buy_signal_index:
buy_signal_index[date] = []
buy_signal_index[date].append(ts_code)
logger.info(f"OCZ信号预计算完成{len(buy_signal_index)} 个交易日有回踩信号")
return buy_signal_index
def _precompute_signals(data_dict: Dict[str, pd.DataFrame], ma_short: int, ma_long: int) -> Dict[str, List[str]]:
"""预计算交易信号(金叉+放量),存储为布尔列,避免回测时重复计算。
性能优化核心:
- 金叉信号:短期均线上穿长期均线(昨日 <= 今日 >
- 放量信号:今日成交量 > 昨日成交量
- 买入信号golden_cross & volume_surge
- 卖出信号:死叉或持有到期(由策略逻辑处理)
预计算后回测只需读取 'buy_signal' 列,无需每天重复计算。
返回:{date: [ts_code1, ts_code2, ...]} 每天有买入信号的股票列表
"""
logger.info("开始预计算交易信号(金叉+放量)...")
# 建立日期 -> 股票列表 的索引
buy_signal_index = {}
for ts_code, df in data_dict.items():
ma_s_col = f"ma_{ma_short}"
ma_l_col = f"ma_{ma_long}"
# 检查必需列是否存在
if ma_s_col not in df.columns or ma_l_col not in df.columns:
df["buy_signal"] = False
df["golden_cross"] = False
df["volume_surge"] = False
continue
if "vol" not in df.columns:
df["buy_signal"] = False
df["golden_cross"] = False
df["volume_surge"] = False
continue
# 1. 预计算金叉信号
# 金叉:昨日短期均线 <= 长期均线,今日短期 > 长期
prev_ma_s = df[ma_s_col].shift(1)
prev_ma_l = df[ma_l_col].shift(1)
curr_ma_s = df[ma_s_col]
curr_ma_l = df[ma_l_col]
golden_cross = (prev_ma_s <= prev_ma_l) & (curr_ma_s > curr_ma_l)
df["golden_cross"] = golden_cross
# 2. 预计算放量信号
# 放量:今日成交量增幅 > 20%(与参数优化保持一致)
# 注意vol_pct_change 应该已经在 _precompute_ma 中计算过了
if "vol_pct_change" not in df.columns:
df["vol_pct_change"] = df["vol"].pct_change()
volume_surge = df["vol_pct_change"] > 0.2
df["volume_surge"] = volume_surge
# 3. 预计算买入信号:金叉 & 放量
df["buy_signal"] = golden_cross & volume_surge
# 4. 预计算死叉信号(用于卖出)
# 死叉:昨日短期均线 >= 长期均线,今日短期 < 长期(与参数优化保持一致)
death_cross = (prev_ma_s >= prev_ma_l) & (curr_ma_s < curr_ma_l)
df["death_cross"] = death_cross
# 5. 建立买入信号索引:只记录有信号的日期
buy_signal_dates = df[df["buy_signal"] == True]["trade_date"].tolist()
for date in buy_signal_dates:
if date not in buy_signal_index:
buy_signal_index[date] = []
buy_signal_index[date].append(ts_code)
logger.info(f"交易信号预计算完成,共 {len(buy_signal_index)} 个交易日有买入信号")
return buy_signal_index
def _create_risk_modules(risk_config: dict = None):
"""创建风险管理模块(止损、止盈、仓位管理)。
参数:
risk_config: 风控配置字典,包含 stop_loss 和 take_profit 配置
如果为 None则使用全局默认配置
返回:
(stop_loss, take_profit, position_sizer)
"""
from risk.stop_loss import StopLoss
from risk.position_sizing import PositionSizing
# 如果没有传入风控配置,使用全局默认值
if risk_config is None:
risk_config = {
"stop_loss": {
"method": STOP_LOSS_METHOD,
"stop_pct": STOP_LOSS_PCT,
"atr_period": ATR_PERIOD,
"atr_multiplier": ATR_MULTIPLIER,
"trailing_pct": TRAILING_PCT,
},
"take_profit": {
"method": STOP_LOSS_METHOD, # 止盈默认使用与止损相同的方法
"stop_pct": TAKE_PROFIT_PCT,
"atr_period": ATR_PERIOD,
"atr_multiplier": ATR_MULTIPLIER,
"trailing_pct": TRAILING_PCT,
},
}
# 获取止损配置
stop_loss_cfg = risk_config.get("stop_loss", {})
# 创建止损管理器
stop_loss = StopLoss(
method=stop_loss_cfg.get("method", "fixed_pct"),
stop_pct=stop_loss_cfg.get("stop_pct", 0.05),
atr_multiplier=stop_loss_cfg.get("atr_multiplier", 2.0),
atr_period=stop_loss_cfg.get("atr_period", 14),
trailing_pct=stop_loss_cfg.get("trailing_pct", 0.10),
)
# 获取止盈配置
take_profit_cfg = risk_config.get("take_profit", {})
# 创建止盈管理器(使用相同的 StopLoss 类,但参数不同)
take_profit = StopLoss(
method=take_profit_cfg.get("method", "fixed_pct"),
stop_pct=take_profit_cfg.get("stop_pct", 0.10),
atr_multiplier=take_profit_cfg.get("atr_multiplier", 3.0),
atr_period=take_profit_cfg.get("atr_period", 14),
trailing_pct=take_profit_cfg.get("trailing_pct", 0.15),
)
# 创建仓位管理器(使用全局配置)
position_sizer = PositionSizing(
method=POSITION_METHOD,
max_positions=STRATEGY["params"].get("max_positions", 2),
kelly_risk_free=KELLY_RISK_FREE,
kelly_max_fraction=KELLY_MAX_FRACTION,
volatility_target=VOLATILITY_TARGET,
volatility_window=VOLATILITY_WINDOW,
)
return stop_loss, take_profit, position_sizer
def run_single_strategy_backtest(strategy_config: dict, data_dict: Dict[str, pd.DataFrame], calendar: List[str], date_index_dict: Dict[str, Dict[str, int]]):
"""运行单个策略的回测。
参数:
strategy_config: 策略配置字典 {"name": ..., "module": ..., "params": ..., "risk_control": ...}
data_dict: 股票数据字典
calendar: 交易日列表
date_index_dict: 日期索引字典
"""
try:
class_name = strategy_config["name"]
module_name = strategy_config["module"]
params = strategy_config["params"]
risk_control = strategy_config.get("risk_control", None) # 获取风控配置
logger.info("\n" + "=" * 60)
logger.info(f"开始回测策略: {class_name}")
logger.info("=" * 60)
# 动态加载策略类
module = importlib.import_module(module_name)
strategy_cls = getattr(module, class_name)
# 根据策略类型预计算信号
if class_name == "MaCrossStrategy":
# 均线交叉策略:预计算均线和信号
ma_short = params.get("ma_short", 5)
ma_long = params.get("ma_long", 20)
_precompute_ma(data_dict, ma_short=ma_short, ma_long=ma_long)
buy_signal_index = _precompute_signals(data_dict, ma_short=ma_short, ma_long=ma_long)
elif class_name == "OczStrategy":
# OCZ策略预计算回踩信号
N = params.get("N", 30)
B = params.get("B", 60.0)
V1 = params.get("V1", 1.5)
TOL = params.get("TOL", 1.5)
R = params.get("R", 4.0)
volatility_min = params.get("volatility_min", 2.5)
volatility_max = params.get("volatility_max", 8.0)
buy_signal_index = _precompute_ocz_signals(
data_dict, N=N, B=B, V1=V1, TOL=TOL, R=R,
volatility_min=volatility_min, volatility_max=volatility_max
)
else:
logger.warning(f"未知策略类型 {class_name},不预计算信号")
buy_signal_index = {}
# 创建风险管理模块(使用策略的风控配置)
# 注意:如果策略使用自定义风控逻辑,则不创建通用风控模块
if risk_control and risk_control.get("stop_loss", {}).get("method") == "custom":
# 策略使用自定义风控逻辑,不使用通用模块
stop_loss = None
take_profit = None
logger.info("该策略使用内部自定义止损止盈逻辑")
else:
# 使用通用风控模块
stop_loss, take_profit, position_sizer = _create_risk_modules(risk_config=risk_control)
position_sizer = None # 禁用仓位管理器
# 输出当前策略参数配置
logger.info("当前策略参数配置:")
for param_name, param_value in params.items():
logger.info(f" {param_name}: {param_value}")
# 输出风控配置
if risk_control:
logger.info("当前策略风控配置:")
logger.info(f" 止损方法: {risk_control.get('stop_loss', {}).get('method', 'N/A')}")
logger.info(f" 止损比例: {risk_control.get('stop_loss', {}).get('stop_pct', 'N/A')*100:.1f}%")
logger.info(f" 止盈方法: {risk_control.get('take_profit', {}).get('method', 'N/A')}")
logger.info(f" 止盈比例: {risk_control.get('take_profit', {}).get('stop_pct', 'N/A')*100:.1f}%")
logger.info(f"初始资金: {INITIAL_CASH:,.0f}")
logger.info(f"回测区间: {START_DATE} ~ {END_DATE}")
logger.info("=" * 60)
# 实例化策略
strategy = strategy_cls(
initial_cash=INITIAL_CASH,
stop_loss=stop_loss,
take_profit=take_profit,
position_sizer=position_sizer,
date_index_dict=date_index_dict,
buy_signal_index=buy_signal_index,
**params
)
# 运行回测
equity_df = strategy.run_backtest(data_dict, calendar)
if equity_df.empty:
logger.error(f"{class_name} 资金曲线为空,回测可能失败")
return
# 保存资金曲线
results_dir = Path(RESULTS_DIR)
equity_dir = results_dir / "equity"
equity_dir.mkdir(parents=True, exist_ok=True)
csv_path = equity_dir / f"{class_name}_equity.csv"
equity_df.to_csv(csv_path, index=False, encoding="utf-8")
logger.info(f"资金曲线 CSV 保存到 {csv_path}")
# 保存交易记录
if strategy.trade_history:
trades_dir = results_dir / "trades"
trades_dir.mkdir(parents=True, exist_ok=True)
import datetime
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
# 提取策略参数构造文件名
param_keys = ["ma_short", "ma_long", "N", "B", "hold_days", "position_pct_per_stock"]
param_str = "_".join([
f"{k}{v}" for k, v in params.items()
if k in param_keys
])
trades_filename = f"{timestamp}_{class_name}_{param_str}.csv"
trades_path = trades_dir / trades_filename
trades_df = pd.DataFrame(strategy.trade_history)
trades_df.to_csv(trades_path, index=False, encoding="utf-8")
logger.info(f"交易记录 CSV 保存到 {trades_path}")
logger.info(f"总交易次数: {len(strategy.trade_history)}")
# 计算绩效
calc_performance(
equity_df,
trade_count=strategy.trade_count,
trade_history=strategy.trade_history,
trading_days_per_year=TRADING_DAYS_PER_YEAR,
)
# 加载和计算基准指数
benchmark_df = pd.DataFrame()
if BENCHMARK_FILE.exists():
logger.info(f"加载基准指数: {BENCHMARK_NAME}")
benchmark_data = load_benchmark(BENCHMARK_FILE, START_DATE, END_DATE)
if not benchmark_data.empty:
benchmark_df, benchmark_stats = calc_benchmark_return(benchmark_data, calendar)
logger.info("=" * 60)
logger.info(f"基准指数:{BENCHMARK_NAME}")
logger.info("=" * 60)
logger.info(f"累计收益率: {benchmark_stats['cum_return']*100:+.2f}%")
logger.info(f"年化收益率: {benchmark_stats['ann_return']*100:+.2f}%")
logger.info(f"最大回撤: {benchmark_stats['max_drawdown']*100:.2f}%")
logger.info(f"夏普比率: {benchmark_stats['sharpe']:.4f}")
logger.info("=" * 60)
# 绘制资金曲线
plots_dir = results_dir / "plots"
plots_dir.mkdir(parents=True, exist_ok=True)
plot_path = plots_dir / f"{class_name}_curve.png"
plot_equity_curve(
str(csv_path),
str(plot_path),
benchmark_df=benchmark_df if not benchmark_df.empty else None,
benchmark_name=BENCHMARK_NAME,
)
logger.info(f"{class_name} 回测完成\n")
# 清理策略对象
del strategy
del equity_df
gc.collect()
except Exception as e:
logger.error(f"{strategy_config.get('name', 'Unknown')} 回测异常: {e}", exc_info=True)
def run_backtest_mode(args):
"""单策略或多策略回测模式。"""
try:
# 1. 加载股票池
stock_universe = _load_stock_universe()
if not stock_universe:
logger.error("股票池为空,无法回测")
return
# 2. 加载行情数据
data_dict = _load_all_data(stock_universe)
if not data_dict:
return
# 3. 构建统一交易日历
calendar = _prepare_calendar(data_dict)
if not calendar:
logger.error("无法构建交易日历,退出")
return
# 4. 预计算日期索引映射(性能优化)
date_index_dict = _precompute_date_index(data_dict)
# 5. 获取需要回测的策略列表
from config.settings import STRATEGIES, STRATEGY_SWITCHES
strategies_to_run = []
for strategy_name, switch in STRATEGY_SWITCHES.items():
if switch == 1 and strategy_name in STRATEGIES:
strategy_config = {
"name": strategy_name,
"module": STRATEGIES[strategy_name]["module"],
"params": STRATEGIES[strategy_name]["params"],
"risk_control": STRATEGIES[strategy_name].get("risk_control", None), # 添加风控配置
}
strategies_to_run.append(strategy_config)
if not strategies_to_run:
logger.error("没有任何策略开启,请检查 STRATEGY_SWITCHES 配置")
return
logger.info("=" * 60)
logger.info(f"共有 {len(strategies_to_run)} 个策略需要回测")
for idx, config in enumerate(strategies_to_run, 1):
logger.info(f" {idx}. {config['name']}")
logger.info("=" * 60)
# 6. 依次运行每个策略的回测
for strategy_config in strategies_to_run:
run_single_strategy_backtest(strategy_config, data_dict, calendar, date_index_dict)
# 7. 释放内存
logger.info("所有策略回测完成,清理内存...")
del data_dict
gc.collect()
logger.info("内存清理完成")
except Exception as e:
logger.error(f"回测模式异常: {e}", exc_info=True)
def run_optimize_mode(args):
"""参数优化模式。"""
try:
from optimization.grid_search import grid_search
# 1. 加载股票池
stock_universe = _load_stock_universe()
if not stock_universe:
logger.error("股票池为空,无法优化")
return
# 2. 加载行情数据
data_dict = _load_all_data(stock_universe)
if not data_dict:
return
# 3. 构建统一交易日历
calendar = _prepare_calendar(data_dict)
if not calendar:
logger.error("无法构建交易日历,退出")
return
# 4. 预计算策略所需的指标(根据策略类型)
# 动态加载策略类
module_name = STRATEGY["module"]
class_name = STRATEGY["name"]
module = importlib.import_module(module_name)
strategy_cls = getattr(module, class_name)
# 获取策略的风控配置
from config.settings import STRATEGIES
risk_control = STRATEGIES.get(class_name, {}).get("risk_control", None)
# 根据策略类型预计算指标
if class_name == "MaCrossStrategy":
# 均线交叉策略:需要预计算所有可能的均线周期
logger.info("预计算所有可能的均线周期...")
all_ma_periods = set()
# 从参数空间提取所有均线周期
param_space = get_param_space(args.strategy)
for key in ["ma_short", "short_window"]:
if key in param_space:
values = param_space[key]
if isinstance(values, range):
all_ma_periods.update(list(values))
elif isinstance(values, (list, tuple)):
all_ma_periods.update(values)
for key in ["ma_long", "long_window"]:
if key in param_space:
values = param_space[key]
if isinstance(values, range):
all_ma_periods.update(list(values))
elif isinstance(values, (list, tuple)):
all_ma_periods.update(values)
# 预计算所有均线和成交量变化
for ts_code, df in data_dict.items():
for window in all_ma_periods:
col = f"ma_{window}"
if col not in df.columns:
df[col] = df["close"].rolling(window=window, min_periods=1).mean()
# 预计算成交量变化(避免后续重复计算)
if "vol" in df.columns and "vol_pct_change" not in df.columns:
df["vol_pct_change"] = df["vol"].pct_change()
data_dict[ts_code] = df
logger.info(f"已预计算 {len(all_ma_periods)} 个均线周期和成交量变化")
elif class_name == "OczStrategy":
# OCZ策略不需要预计算均线在grid_search中动态计算
logger.info("OCZ策略不需要预计算均线将在优化过程中动态计算信号")
param_space = get_param_space(args.strategy)
else:
logger.warning(f"未知策略类型 {class_name},跳过指标预计算")
param_space = get_param_space(args.strategy)
# 5. 预计算日期索引映射(性能优化)
logger.info("预计算日期索引字典...")
date_index_dict = _precompute_date_index(data_dict)
# 7. 创建风险管理模块(使用策略的风控配置)
stop_loss, take_profit, position_sizer = _create_risk_modules(risk_config=risk_control)
# 注意:禁用 position_sizer使用策略自带的仓位管理逻辑
position_sizer = None
# 8. 定义参数空间
# 9. 运行网格搜索
n_jobs = args.jobs if args.jobs else OPTIMIZATION_N_JOBS
top_n = args.top if args.top else OPTIMIZATION_TOP_N
metric = args.metric if args.metric else OPTIMIZATION_METRIC
# 获取约束条件
from config.settings import PARAM_CONSTRAINTS
constraint_func = PARAM_CONSTRAINTS.get(args.strategy, None)
if constraint_func is not None:
logger.info(f"使用策略 {args.strategy} 的约束条件")
results = grid_search(
strategy_class=strategy_cls,
data_dict=data_dict,
calendar=calendar,
param_space=param_space,
initial_capital=INITIAL_CASH,
n_jobs=n_jobs,
metric=metric,
top_n=top_n,
constraint_func=constraint_func,
date_index_dict=date_index_dict, # 传入日期索引
stop_loss=stop_loss, # 传入止损管理器
take_profit=take_profit, # 传入止盈管理器
)
# 8. 释放内存
logger.info("参数优化完成,清理内存...")
del data_dict
gc.collect()
logger.info("内存清理完成")
except Exception as e: # noqa: BLE001
logger.error(f"优化模式异常: {e}", exc_info=True)
def get_param_space(strategy_name: str) -> dict:
"""获取策略的参数空间定义。
参数:
strategy_name: 策略名称
返回:
参数空间字典
"""
from config.settings import PARAM_SPACES
if strategy_name in PARAM_SPACES:
param_space = PARAM_SPACES[strategy_name]
logger.info(f"使用配置文件中的参数空间: {strategy_name}")
return param_space
else:
logger.warning(f"未在配置文件中找到策略 {strategy_name} 的参数空间,使用默认值")
# 默认参数空间(兼容旧版本)
if strategy_name == "ma_cross":
return {
"ma_short": [3, 20, 2],
"ma_long": [20, 60, 5],
"hold_days": [3, 10, 1],
}
else:
return {}
def parse_args():
"""解析命令行参数。"""
parser = argparse.ArgumentParser(description="A股回测系统")
parser.add_argument(
"--optimize",
action="store_true",
help="启用参数优化模式"
)
parser.add_argument(
"--strategy",
type=str,
default="ma_cross",
help="策略名称(仅优化模式使用,回测模式从配置文件读取)"
)
parser.add_argument(
"--jobs",
type=int,
default=None,
help="并行进程数(优化模式)"
)
parser.add_argument(
"--top",
type=int,
default=None,
help="保存前N个最优结果优化模式"
)
parser.add_argument(
"--metric",
type=str,
default=None,
choices=["sharpe", "total_return", "max_drawdown", "annual_return"],
help="排序指标(优化模式)"
)
return parser.parse_args()
def main() -> None:
"""回测主流程入口。"""
args = parse_args()
if args.optimize:
logger.info("=" * 60)
logger.info("启动参数优化模式")
logger.info("=" * 60)
run_optimize_mode(args)
else:
logger.info("=" * 60)
logger.info("启动批量回测模式(从配置文件读取 STRATEGY_SWITCHES")
logger.info("=" * 60)
run_backtest_mode(args)
if __name__ == "__main__":
main()