417 lines
16 KiB
Python
417 lines
16 KiB
Python
"""多进程网格搜索参数优化。
|
||
|
||
使用 multiprocessing.Pool 并行回测不同参数组合,
|
||
提高参数优化效率。
|
||
|
||
示例:
|
||
from optimization.grid_search import grid_search
|
||
from strategies.ma_cross import MACrossStrategy
|
||
|
||
param_space = {
|
||
"ma_short": range(3, 21, 2),
|
||
"ma_long": range(20, 61, 5),
|
||
"hold_days": range(3, 11),
|
||
"stop_loss_pct": [0.03, 0.05, 0.08],
|
||
"take_profit_pct": [0.10, 0.15, 0.20],
|
||
}
|
||
|
||
results = grid_search(
|
||
strategy_class=MACrossStrategy,
|
||
data_dict=data_dict,
|
||
calendar=calendar,
|
||
param_space=param_space,
|
||
initial_capital=1_000_000,
|
||
n_jobs=4,
|
||
metric="sharpe"
|
||
)
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import csv
|
||
import gc
|
||
import importlib
|
||
import multiprocessing as mp
|
||
import time
|
||
from pathlib import Path
|
||
from typing import Any, Dict, List, Type
|
||
|
||
import pandas as pd
|
||
from tqdm import tqdm
|
||
|
||
from optimization.param_space import estimate_combinations_count, generate_param_combinations, validate_param_space
|
||
from strategies.base_strategy import BaseStrategy
|
||
from utils.logger import setup_logger
|
||
from utils.performance import calc_performance
|
||
|
||
logger = setup_logger(__name__)
|
||
|
||
|
||
def _run_single_backtest(args: tuple) -> Dict[str, Any]:
|
||
"""单次回测任务(供进程池调用)。
|
||
|
||
参数:
|
||
args: (strategy_class, data_dict, calendar, params, initial_capital, idx, date_index_dict, stop_loss, take_profit)
|
||
|
||
返回:
|
||
Dict: 包含参数和绩效的字典
|
||
"""
|
||
strategy_class, data_dict, calendar, params, initial_capital, idx, date_index_dict, stop_loss, take_profit = args
|
||
|
||
try:
|
||
# 根据策略类型预计算买入信号索引
|
||
buy_signal_index = {}
|
||
strategy_name = strategy_class.__name__
|
||
|
||
if strategy_name == "MaCrossStrategy":
|
||
# 均线交叉策略:预计算金叉信号
|
||
ma_short = params.get("ma_short")
|
||
ma_long = params.get("ma_long")
|
||
|
||
if ma_short is not None and ma_long is not None:
|
||
for ts_code, df in data_dict.items():
|
||
if df.empty:
|
||
continue
|
||
|
||
# 使用预计算的均线列
|
||
ma_short_col = f"ma_{ma_short}"
|
||
ma_long_col = f"ma_{ma_long}"
|
||
|
||
if ma_short_col not in df.columns or ma_long_col not in df.columns:
|
||
continue
|
||
|
||
# 计算金叉和放量
|
||
df["ma_short_prev"] = df[ma_short_col].shift(1)
|
||
df["ma_long_prev"] = df[ma_long_col].shift(1)
|
||
df["golden_cross"] = (
|
||
(df[ma_short_col] > df[ma_long_col]) &
|
||
(df["ma_short_prev"] <= df["ma_long_prev"])
|
||
)
|
||
|
||
# 计算死叉(卖出信号)
|
||
df["death_cross"] = (
|
||
(df[ma_short_col] < df[ma_long_col]) &
|
||
(df["ma_short_prev"] >= df["ma_long_prev"])
|
||
)
|
||
|
||
# 放量信号:成交量增幅>20%
|
||
# 注意:vol_pct_change 应该已经在均线预计算时计算过了
|
||
if "vol_pct_change" not in df.columns:
|
||
df["vol_pct_change"] = df["vol"].pct_change()
|
||
df["volume_surge"] = df["vol_pct_change"] > 0.2
|
||
df["buy_signal"] = df["golden_cross"] & df["volume_surge"]
|
||
|
||
# 建立买入信号索引
|
||
buy_dates = df[df["buy_signal"] == True]["trade_date"].tolist()
|
||
for date in buy_dates:
|
||
if date not in buy_signal_index:
|
||
buy_signal_index[date] = []
|
||
buy_signal_index[date].append(ts_code)
|
||
|
||
elif strategy_name == "OczStrategy":
|
||
# OCZ策略:预计算回踩信号
|
||
# 注意:只计算依赖当前参数的指标,其他基础指标可复用
|
||
N = params.get("N", 30)
|
||
B = params.get("B", 60.0)
|
||
V1 = params.get("V1", 1.5)
|
||
TOL = params.get("TOL", 1.5)
|
||
R = params.get("R", 4.0)
|
||
volatility_min = params.get("volatility_min", 2.5)
|
||
volatility_max = params.get("volatility_max", 8.0)
|
||
|
||
for ts_code, df in data_dict.items():
|
||
if df.empty or len(df) < N + 2:
|
||
continue
|
||
|
||
# 基础指标:只在第一次计算,后续复用
|
||
# body, range, body_pct, return_pct, volatility 不依赖于参数,可以复用
|
||
if "body" not in df.columns:
|
||
df["body"] = (df["close"] - df["open"]).abs()
|
||
if "range" not in df.columns:
|
||
df["range"] = df["high"] - df["low"]
|
||
if "body_pct" not in df.columns:
|
||
df["body_pct"] = df["body"] / df["range"] * 100
|
||
if "return_pct" not in df.columns:
|
||
df["return_pct"] = (df["close"] / df["close"].shift(1) - 1) * 100
|
||
if "volatility" not in df.columns:
|
||
df["volatility"] = (df["high"] - df["low"]).rolling(window=30, min_periods=30).mean() / df["close"] * 100
|
||
|
||
# 依赖参数N的指标:每次重新计算
|
||
df["resistance"] = df["high"].rolling(window=N, min_periods=N).max()
|
||
df["vol_ma_n"] = df["vol"].rolling(window=N, min_periods=N).mean()
|
||
|
||
# 识别突破信号(依赖B、R、V1参数)
|
||
breakthrough = (
|
||
(df["close"] > df["resistance"].shift(1)) &
|
||
(df["body_pct"] > B) &
|
||
(df["return_pct"] >= R) &
|
||
(df["vol"] > df["vol_ma_n"] * V1) &
|
||
(df["volatility"] >= volatility_min) &
|
||
(df["volatility"] <= volatility_max)
|
||
)
|
||
df["breakthrough"] = breakthrough
|
||
|
||
# 计算距离突破的天数(向量化优化)
|
||
breakthrough_indices = df.index[df["breakthrough"]].tolist()
|
||
if not breakthrough_indices:
|
||
df["bars_since_breakthrough"] = 999999
|
||
else:
|
||
breakthrough_indices_series = pd.Series(breakthrough_indices)
|
||
df["bars_since_breakthrough"] = df.index.to_series().apply(
|
||
lambda idx: idx - breakthrough_indices_series[breakthrough_indices_series <= idx].max()
|
||
if breakthrough_indices_series[breakthrough_indices_series <= idx].any()
|
||
else 999999
|
||
).values
|
||
|
||
# 识别回踩信号(依赖TOL参数)
|
||
df["resistance_2days_ago"] = df["resistance"].shift(2)
|
||
df["vol_breakthrough"] = df["vol"].shift(1)
|
||
|
||
pullback_signal = (
|
||
(df["bars_since_breakthrough"] == 1) &
|
||
(df["low"] >= df["resistance_2days_ago"] * (1 - TOL / 100)) &
|
||
(df["low"] <= df["resistance_2days_ago"] * (1 + TOL / 100)) &
|
||
(df["close"] > df["resistance_2days_ago"]) &
|
||
(df["vol"] < df["vol_breakthrough"])
|
||
)
|
||
df["pullback_signal"] = pullback_signal
|
||
|
||
# 建立买入信号索引
|
||
signal_dates = df[df["pullback_signal"] == True]["trade_date"].tolist()
|
||
for date in signal_dates:
|
||
if date not in buy_signal_index:
|
||
buy_signal_index[date] = []
|
||
buy_signal_index[date].append(ts_code)
|
||
|
||
# 实例化策略(传入日期索引、买入信号索引、风险管理模块)
|
||
strategy = strategy_class(
|
||
initial_cash=initial_capital,
|
||
stop_loss=stop_loss,
|
||
take_profit=take_profit,
|
||
position_sizer=None, # 禁用仓位管理器,使用策略自带逻辑
|
||
date_index_dict=date_index_dict,
|
||
buy_signal_index=buy_signal_index,
|
||
**params
|
||
)
|
||
|
||
# 运行回测
|
||
equity_df = strategy.run_backtest(data_dict, calendar)
|
||
|
||
# 计算绩效
|
||
perf = calc_performance(
|
||
equity_df=equity_df,
|
||
trade_count=strategy.trade_count,
|
||
trade_history=strategy.trade_history,
|
||
)
|
||
|
||
# 清理内存
|
||
del strategy
|
||
del equity_df
|
||
gc.collect()
|
||
|
||
# 返回参数 + 绩效
|
||
result = {
|
||
"idx": idx,
|
||
"params": params,
|
||
"total_return": perf["cum_return"], # 注意:calc_performance 返回 cum_return
|
||
"annual_return": perf["ann_return"], # 注意:calc_performance 返回 ann_return
|
||
"sharpe": perf["sharpe"],
|
||
"max_drawdown": perf["max_drawdown"],
|
||
"avg_capital_utilization": perf["avg_capital_utilization"],
|
||
"total_trades": perf["total_trades"],
|
||
"avg_trades_per_year": perf["avg_trades_per_year"],
|
||
"win_rate": perf.get("win_rate", 0.0),
|
||
"profit_loss_ratio": perf.get("profit_loss_ratio", 0.0),
|
||
}
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"参数 {params} 回测失败: {e}")
|
||
return {
|
||
"idx": idx,
|
||
"params": params,
|
||
"total_return": None,
|
||
"annual_return": None,
|
||
"sharpe": None,
|
||
"max_drawdown": None,
|
||
"avg_capital_utilization": None,
|
||
"total_trades": None,
|
||
"avg_trades_per_year": None,
|
||
"error": str(e),
|
||
}
|
||
|
||
|
||
def grid_search(
|
||
strategy_class: Type[BaseStrategy],
|
||
data_dict: Dict[str, pd.DataFrame],
|
||
calendar: List[str],
|
||
param_space: Dict[str, Any],
|
||
initial_capital: float = 1_000_000,
|
||
n_jobs: int = 4,
|
||
metric: str = "sharpe",
|
||
top_n: int = 20,
|
||
output_dir: Path = None,
|
||
constraint_func: callable = None,
|
||
date_index_dict: Dict[str, Dict[str, int]] = None,
|
||
stop_loss: object = None,
|
||
take_profit: object = None,
|
||
) -> List[Dict[str, Any]]:
|
||
"""网格搜索参数优化(多进程并行)。
|
||
|
||
参数:
|
||
strategy_class: 策略类(未实例化)
|
||
data_dict: 股票数据字典 {ts_code: DataFrame}
|
||
calendar: 交易日列表
|
||
param_space: 参数空间定义
|
||
initial_capital: 初始资金
|
||
n_jobs: 并行进程数
|
||
metric: 排序指标,可选 "sharpe", "total_return", "max_drawdown", "annual_return"
|
||
top_n: 保存前 N 个结果
|
||
output_dir: 结果输出目录
|
||
constraint_func: 参数约束函数,过滤无效的参数组合
|
||
date_index_dict: 日期索引字典 {ts_code: {date: idx}},用于性能优化
|
||
stop_loss: 止损管理器
|
||
take_profit: 止盈管理器
|
||
|
||
返回:
|
||
List[Dict]: 按 metric 排序的参数组合及绩效列表
|
||
"""
|
||
# 验证参数空间
|
||
if not validate_param_space(param_space):
|
||
logger.error("参数空间不合法")
|
||
return []
|
||
|
||
# 生成参数组合
|
||
from optimization.param_space import apply_param_constraints
|
||
combinations = generate_param_combinations(param_space)
|
||
|
||
# 应用约束条件
|
||
if constraint_func is not None:
|
||
combinations = apply_param_constraints(combinations, constraint_func)
|
||
|
||
total_comb = len(combinations)
|
||
|
||
if total_comb == 0:
|
||
logger.warning("参数组合数为 0,无需优化")
|
||
return []
|
||
|
||
logger.info(f"开始参数扫描,共 {total_comb} 组参数,并行 {n_jobs} 核")
|
||
logger.info(f"排序指标: {metric}, 保留前 {top_n} 组结果")
|
||
|
||
# 如果没有传入date_index_dict,预计算一个
|
||
if date_index_dict is None:
|
||
logger.info("预计算日期索引字典...")
|
||
date_index_dict = {}
|
||
for ts_code, df in data_dict.items():
|
||
date_to_idx = {}
|
||
for idx, trade_date in enumerate(df["trade_date"]):
|
||
date_to_idx[trade_date] = idx
|
||
date_index_dict[ts_code] = date_to_idx
|
||
logger.info("日期索引预计算完成")
|
||
|
||
# 准备任务参数
|
||
tasks = [
|
||
(strategy_class, data_dict, calendar, params, initial_capital, idx, date_index_dict, stop_loss, take_profit)
|
||
for idx, params in enumerate(combinations)
|
||
]
|
||
|
||
# 多进程并行执行
|
||
start_time = time.time()
|
||
results = []
|
||
|
||
with mp.Pool(processes=n_jobs) as pool:
|
||
# 使用 imap_unordered 配合 tqdm 显示进度
|
||
for result in tqdm(
|
||
pool.imap_unordered(_run_single_backtest, tasks),
|
||
total=total_comb,
|
||
desc="参数优化进度",
|
||
unit="组"
|
||
):
|
||
results.append(result)
|
||
|
||
elapsed = time.time() - start_time
|
||
logger.info(f"参数扫描完成,耗时 {elapsed:.2f} 秒")
|
||
|
||
# 过滤失败的结果
|
||
valid_results = [r for r in results if r.get(metric) is not None]
|
||
failed_count = len(results) - len(valid_results)
|
||
|
||
if failed_count > 0:
|
||
logger.warning(f"有 {failed_count} 组参数回测失败")
|
||
|
||
if not valid_results:
|
||
logger.error("所有参数组合都失败,无法生成结果")
|
||
return []
|
||
|
||
# 按指定指标排序
|
||
reverse = True # 默认降序(越大越好)
|
||
if metric == "max_drawdown":
|
||
reverse = False # 最大回撤越小越好
|
||
|
||
sorted_results = sorted(valid_results, key=lambda x: x[metric], reverse=reverse)
|
||
|
||
# 取前 top_n
|
||
top_results = sorted_results[:top_n]
|
||
|
||
# 输出结果到 CSV
|
||
if output_dir is None:
|
||
from config.settings import RESULTS_DIR
|
||
output_dir = RESULTS_DIR / "optimization"
|
||
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||
csv_path = output_dir / f"grid_search_{timestamp}.csv"
|
||
|
||
_save_results_to_csv(top_results, csv_path)
|
||
logger.info(f"参数优化结果已保存到: {csv_path}")
|
||
|
||
# 打印前 5 组结果
|
||
logger.info("=" * 80)
|
||
logger.info(f"参数优化 Top {min(5, len(top_results))} 结果(按 {metric} 排序)")
|
||
logger.info("=" * 80)
|
||
|
||
for i, result in enumerate(top_results[:5], 1):
|
||
logger.info(f"第 {i} 名:")
|
||
logger.info(f" 参数: {result['params']}")
|
||
logger.info(f" 累计收益: {result['total_return']*100:+.2f}%")
|
||
logger.info(f" 年化收益: {result['annual_return']*100:+.2f}%")
|
||
logger.info(f" 夏普比率: {result['sharpe']:.4f}")
|
||
logger.info(f" 最大回撤: {result['max_drawdown']*100:.2f}%")
|
||
logger.info(f" 总交易次数: {result['total_trades']}")
|
||
logger.info("-" * 80)
|
||
|
||
return top_results
|
||
|
||
|
||
def _save_results_to_csv(results: List[Dict[str, Any]], csv_path: Path) -> None:
|
||
"""将优化结果保存为 CSV 文件。
|
||
|
||
参数:
|
||
results: 结果列表
|
||
csv_path: CSV 文件路径
|
||
"""
|
||
if not results:
|
||
logger.warning("结果为空,不保存 CSV")
|
||
return
|
||
|
||
# 展开 params 字典为独立列
|
||
rows = []
|
||
for result in results:
|
||
row = {}
|
||
# 添加参数列
|
||
if "params" in result:
|
||
for param_name, param_value in result["params"].items():
|
||
row[f"param_{param_name}"] = param_value
|
||
|
||
# 添加绩效列
|
||
for key in ["total_return", "annual_return", "sharpe", "max_drawdown",
|
||
"avg_capital_utilization", "total_trades", "avg_trades_per_year"]:
|
||
row[key] = result.get(key)
|
||
|
||
rows.append(row)
|
||
|
||
# 写入 CSV
|
||
df = pd.DataFrame(rows)
|
||
df.to_csv(csv_path, index=False, encoding="utf-8-sig")
|
||
logger.info(f"结果已保存,共 {len(rows)} 行")
|