Files
strategy_backtest/optimization/grid_search.py

417 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""多进程网格搜索参数优化。
使用 multiprocessing.Pool 并行回测不同参数组合,
提高参数优化效率。
示例:
from optimization.grid_search import grid_search
from strategies.ma_cross import MACrossStrategy
param_space = {
"ma_short": range(3, 21, 2),
"ma_long": range(20, 61, 5),
"hold_days": range(3, 11),
"stop_loss_pct": [0.03, 0.05, 0.08],
"take_profit_pct": [0.10, 0.15, 0.20],
}
results = grid_search(
strategy_class=MACrossStrategy,
data_dict=data_dict,
calendar=calendar,
param_space=param_space,
initial_capital=1_000_000,
n_jobs=4,
metric="sharpe"
)
"""
from __future__ import annotations
import csv
import gc
import importlib
import multiprocessing as mp
import time
from pathlib import Path
from typing import Any, Dict, List, Type
import pandas as pd
from tqdm import tqdm
from optimization.param_space import estimate_combinations_count, generate_param_combinations, validate_param_space
from strategies.base_strategy import BaseStrategy
from utils.logger import setup_logger
from utils.performance import calc_performance
logger = setup_logger(__name__)
def _run_single_backtest(args: tuple) -> Dict[str, Any]:
"""单次回测任务(供进程池调用)。
参数:
args: (strategy_class, data_dict, calendar, params, initial_capital, idx, date_index_dict, stop_loss, take_profit)
返回:
Dict: 包含参数和绩效的字典
"""
strategy_class, data_dict, calendar, params, initial_capital, idx, date_index_dict, stop_loss, take_profit = args
try:
# 根据策略类型预计算买入信号索引
buy_signal_index = {}
strategy_name = strategy_class.__name__
if strategy_name == "MaCrossStrategy":
# 均线交叉策略:预计算金叉信号
ma_short = params.get("ma_short")
ma_long = params.get("ma_long")
if ma_short is not None and ma_long is not None:
for ts_code, df in data_dict.items():
if df.empty:
continue
# 使用预计算的均线列
ma_short_col = f"ma_{ma_short}"
ma_long_col = f"ma_{ma_long}"
if ma_short_col not in df.columns or ma_long_col not in df.columns:
continue
# 计算金叉和放量
df["ma_short_prev"] = df[ma_short_col].shift(1)
df["ma_long_prev"] = df[ma_long_col].shift(1)
df["golden_cross"] = (
(df[ma_short_col] > df[ma_long_col]) &
(df["ma_short_prev"] <= df["ma_long_prev"])
)
# 计算死叉(卖出信号)
df["death_cross"] = (
(df[ma_short_col] < df[ma_long_col]) &
(df["ma_short_prev"] >= df["ma_long_prev"])
)
# 放量信号:成交量增幅>20%
# 注意vol_pct_change 应该已经在均线预计算时计算过了
if "vol_pct_change" not in df.columns:
df["vol_pct_change"] = df["vol"].pct_change()
df["volume_surge"] = df["vol_pct_change"] > 0.2
df["buy_signal"] = df["golden_cross"] & df["volume_surge"]
# 建立买入信号索引
buy_dates = df[df["buy_signal"] == True]["trade_date"].tolist()
for date in buy_dates:
if date not in buy_signal_index:
buy_signal_index[date] = []
buy_signal_index[date].append(ts_code)
elif strategy_name == "OczStrategy":
# OCZ策略预计算回踩信号
# 注意:只计算依赖当前参数的指标,其他基础指标可复用
N = params.get("N", 30)
B = params.get("B", 60.0)
V1 = params.get("V1", 1.5)
TOL = params.get("TOL", 1.5)
R = params.get("R", 4.0)
volatility_min = params.get("volatility_min", 2.5)
volatility_max = params.get("volatility_max", 8.0)
for ts_code, df in data_dict.items():
if df.empty or len(df) < N + 2:
continue
# 基础指标:只在第一次计算,后续复用
# body, range, body_pct, return_pct, volatility 不依赖于参数,可以复用
if "body" not in df.columns:
df["body"] = (df["close"] - df["open"]).abs()
if "range" not in df.columns:
df["range"] = df["high"] - df["low"]
if "body_pct" not in df.columns:
df["body_pct"] = df["body"] / df["range"] * 100
if "return_pct" not in df.columns:
df["return_pct"] = (df["close"] / df["close"].shift(1) - 1) * 100
if "volatility" not in df.columns:
df["volatility"] = (df["high"] - df["low"]).rolling(window=30, min_periods=30).mean() / df["close"] * 100
# 依赖参数N的指标每次重新计算
df["resistance"] = df["high"].rolling(window=N, min_periods=N).max()
df["vol_ma_n"] = df["vol"].rolling(window=N, min_periods=N).mean()
# 识别突破信号依赖B、R、V1参数
breakthrough = (
(df["close"] > df["resistance"].shift(1)) &
(df["body_pct"] > B) &
(df["return_pct"] >= R) &
(df["vol"] > df["vol_ma_n"] * V1) &
(df["volatility"] >= volatility_min) &
(df["volatility"] <= volatility_max)
)
df["breakthrough"] = breakthrough
# 计算距离突破的天数(向量化优化)
breakthrough_indices = df.index[df["breakthrough"]].tolist()
if not breakthrough_indices:
df["bars_since_breakthrough"] = 999999
else:
breakthrough_indices_series = pd.Series(breakthrough_indices)
df["bars_since_breakthrough"] = df.index.to_series().apply(
lambda idx: idx - breakthrough_indices_series[breakthrough_indices_series <= idx].max()
if breakthrough_indices_series[breakthrough_indices_series <= idx].any()
else 999999
).values
# 识别回踩信号依赖TOL参数
df["resistance_2days_ago"] = df["resistance"].shift(2)
df["vol_breakthrough"] = df["vol"].shift(1)
pullback_signal = (
(df["bars_since_breakthrough"] == 1) &
(df["low"] >= df["resistance_2days_ago"] * (1 - TOL / 100)) &
(df["low"] <= df["resistance_2days_ago"] * (1 + TOL / 100)) &
(df["close"] > df["resistance_2days_ago"]) &
(df["vol"] < df["vol_breakthrough"])
)
df["pullback_signal"] = pullback_signal
# 建立买入信号索引
signal_dates = df[df["pullback_signal"] == True]["trade_date"].tolist()
for date in signal_dates:
if date not in buy_signal_index:
buy_signal_index[date] = []
buy_signal_index[date].append(ts_code)
# 实例化策略(传入日期索引、买入信号索引、风险管理模块)
strategy = strategy_class(
initial_cash=initial_capital,
stop_loss=stop_loss,
take_profit=take_profit,
position_sizer=None, # 禁用仓位管理器,使用策略自带逻辑
date_index_dict=date_index_dict,
buy_signal_index=buy_signal_index,
**params
)
# 运行回测
equity_df = strategy.run_backtest(data_dict, calendar)
# 计算绩效
perf = calc_performance(
equity_df=equity_df,
trade_count=strategy.trade_count,
trade_history=strategy.trade_history,
)
# 清理内存
del strategy
del equity_df
gc.collect()
# 返回参数 + 绩效
result = {
"idx": idx,
"params": params,
"total_return": perf["cum_return"], # 注意calc_performance 返回 cum_return
"annual_return": perf["ann_return"], # 注意calc_performance 返回 ann_return
"sharpe": perf["sharpe"],
"max_drawdown": perf["max_drawdown"],
"avg_capital_utilization": perf["avg_capital_utilization"],
"total_trades": perf["total_trades"],
"avg_trades_per_year": perf["avg_trades_per_year"],
"win_rate": perf.get("win_rate", 0.0),
"profit_loss_ratio": perf.get("profit_loss_ratio", 0.0),
}
return result
except Exception as e:
logger.error(f"参数 {params} 回测失败: {e}")
return {
"idx": idx,
"params": params,
"total_return": None,
"annual_return": None,
"sharpe": None,
"max_drawdown": None,
"avg_capital_utilization": None,
"total_trades": None,
"avg_trades_per_year": None,
"error": str(e),
}
def grid_search(
strategy_class: Type[BaseStrategy],
data_dict: Dict[str, pd.DataFrame],
calendar: List[str],
param_space: Dict[str, Any],
initial_capital: float = 1_000_000,
n_jobs: int = 4,
metric: str = "sharpe",
top_n: int = 20,
output_dir: Path = None,
constraint_func: callable = None,
date_index_dict: Dict[str, Dict[str, int]] = None,
stop_loss: object = None,
take_profit: object = None,
) -> List[Dict[str, Any]]:
"""网格搜索参数优化(多进程并行)。
参数:
strategy_class: 策略类(未实例化)
data_dict: 股票数据字典 {ts_code: DataFrame}
calendar: 交易日列表
param_space: 参数空间定义
initial_capital: 初始资金
n_jobs: 并行进程数
metric: 排序指标,可选 "sharpe", "total_return", "max_drawdown", "annual_return"
top_n: 保存前 N 个结果
output_dir: 结果输出目录
constraint_func: 参数约束函数,过滤无效的参数组合
date_index_dict: 日期索引字典 {ts_code: {date: idx}},用于性能优化
stop_loss: 止损管理器
take_profit: 止盈管理器
返回:
List[Dict]: 按 metric 排序的参数组合及绩效列表
"""
# 验证参数空间
if not validate_param_space(param_space):
logger.error("参数空间不合法")
return []
# 生成参数组合
from optimization.param_space import apply_param_constraints
combinations = generate_param_combinations(param_space)
# 应用约束条件
if constraint_func is not None:
combinations = apply_param_constraints(combinations, constraint_func)
total_comb = len(combinations)
if total_comb == 0:
logger.warning("参数组合数为 0无需优化")
return []
logger.info(f"开始参数扫描,共 {total_comb} 组参数,并行 {n_jobs}")
logger.info(f"排序指标: {metric}, 保留前 {top_n} 组结果")
# 如果没有传入date_index_dict预计算一个
if date_index_dict is None:
logger.info("预计算日期索引字典...")
date_index_dict = {}
for ts_code, df in data_dict.items():
date_to_idx = {}
for idx, trade_date in enumerate(df["trade_date"]):
date_to_idx[trade_date] = idx
date_index_dict[ts_code] = date_to_idx
logger.info("日期索引预计算完成")
# 准备任务参数
tasks = [
(strategy_class, data_dict, calendar, params, initial_capital, idx, date_index_dict, stop_loss, take_profit)
for idx, params in enumerate(combinations)
]
# 多进程并行执行
start_time = time.time()
results = []
with mp.Pool(processes=n_jobs) as pool:
# 使用 imap_unordered 配合 tqdm 显示进度
for result in tqdm(
pool.imap_unordered(_run_single_backtest, tasks),
total=total_comb,
desc="参数优化进度",
unit=""
):
results.append(result)
elapsed = time.time() - start_time
logger.info(f"参数扫描完成,耗时 {elapsed:.2f}")
# 过滤失败的结果
valid_results = [r for r in results if r.get(metric) is not None]
failed_count = len(results) - len(valid_results)
if failed_count > 0:
logger.warning(f"{failed_count} 组参数回测失败")
if not valid_results:
logger.error("所有参数组合都失败,无法生成结果")
return []
# 按指定指标排序
reverse = True # 默认降序(越大越好)
if metric == "max_drawdown":
reverse = False # 最大回撤越小越好
sorted_results = sorted(valid_results, key=lambda x: x[metric], reverse=reverse)
# 取前 top_n
top_results = sorted_results[:top_n]
# 输出结果到 CSV
if output_dir is None:
from config.settings import RESULTS_DIR
output_dir = RESULTS_DIR / "optimization"
output_dir.mkdir(parents=True, exist_ok=True)
timestamp = time.strftime("%Y%m%d_%H%M%S")
csv_path = output_dir / f"grid_search_{timestamp}.csv"
_save_results_to_csv(top_results, csv_path)
logger.info(f"参数优化结果已保存到: {csv_path}")
# 打印前 5 组结果
logger.info("=" * 80)
logger.info(f"参数优化 Top {min(5, len(top_results))} 结果(按 {metric} 排序)")
logger.info("=" * 80)
for i, result in enumerate(top_results[:5], 1):
logger.info(f"{i} 名:")
logger.info(f" 参数: {result['params']}")
logger.info(f" 累计收益: {result['total_return']*100:+.2f}%")
logger.info(f" 年化收益: {result['annual_return']*100:+.2f}%")
logger.info(f" 夏普比率: {result['sharpe']:.4f}")
logger.info(f" 最大回撤: {result['max_drawdown']*100:.2f}%")
logger.info(f" 总交易次数: {result['total_trades']}")
logger.info("-" * 80)
return top_results
def _save_results_to_csv(results: List[Dict[str, Any]], csv_path: Path) -> None:
"""将优化结果保存为 CSV 文件。
参数:
results: 结果列表
csv_path: CSV 文件路径
"""
if not results:
logger.warning("结果为空,不保存 CSV")
return
# 展开 params 字典为独立列
rows = []
for result in results:
row = {}
# 添加参数列
if "params" in result:
for param_name, param_value in result["params"].items():
row[f"param_{param_name}"] = param_value
# 添加绩效列
for key in ["total_return", "annual_return", "sharpe", "max_drawdown",
"avg_capital_utilization", "total_trades", "avg_trades_per_year"]:
row[key] = result.get(key)
rows.append(row)
# 写入 CSV
df = pd.DataFrame(rows)
df.to_csv(csv_path, index=False, encoding="utf-8-sig")
logger.info(f"结果已保存,共 {len(rows)}")