新建回测系统,并提交
This commit is contained in:
416
optimization/grid_search.py
Normal file
416
optimization/grid_search.py
Normal file
@@ -0,0 +1,416 @@
|
||||
"""多进程网格搜索参数优化。
|
||||
|
||||
使用 multiprocessing.Pool 并行回测不同参数组合,
|
||||
提高参数优化效率。
|
||||
|
||||
示例:
|
||||
from optimization.grid_search import grid_search
|
||||
from strategies.ma_cross import MACrossStrategy
|
||||
|
||||
param_space = {
|
||||
"ma_short": range(3, 21, 2),
|
||||
"ma_long": range(20, 61, 5),
|
||||
"hold_days": range(3, 11),
|
||||
"stop_loss_pct": [0.03, 0.05, 0.08],
|
||||
"take_profit_pct": [0.10, 0.15, 0.20],
|
||||
}
|
||||
|
||||
results = grid_search(
|
||||
strategy_class=MACrossStrategy,
|
||||
data_dict=data_dict,
|
||||
calendar=calendar,
|
||||
param_space=param_space,
|
||||
initial_capital=1_000_000,
|
||||
n_jobs=4,
|
||||
metric="sharpe"
|
||||
)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import gc
|
||||
import importlib
|
||||
import multiprocessing as mp
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Type
|
||||
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
from optimization.param_space import estimate_combinations_count, generate_param_combinations, validate_param_space
|
||||
from strategies.base_strategy import BaseStrategy
|
||||
from utils.logger import setup_logger
|
||||
from utils.performance import calc_performance
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
def _run_single_backtest(args: tuple) -> Dict[str, Any]:
|
||||
"""单次回测任务(供进程池调用)。
|
||||
|
||||
参数:
|
||||
args: (strategy_class, data_dict, calendar, params, initial_capital, idx, date_index_dict, stop_loss, take_profit)
|
||||
|
||||
返回:
|
||||
Dict: 包含参数和绩效的字典
|
||||
"""
|
||||
strategy_class, data_dict, calendar, params, initial_capital, idx, date_index_dict, stop_loss, take_profit = args
|
||||
|
||||
try:
|
||||
# 根据策略类型预计算买入信号索引
|
||||
buy_signal_index = {}
|
||||
strategy_name = strategy_class.__name__
|
||||
|
||||
if strategy_name == "MaCrossStrategy":
|
||||
# 均线交叉策略:预计算金叉信号
|
||||
ma_short = params.get("ma_short")
|
||||
ma_long = params.get("ma_long")
|
||||
|
||||
if ma_short is not None and ma_long is not None:
|
||||
for ts_code, df in data_dict.items():
|
||||
if df.empty:
|
||||
continue
|
||||
|
||||
# 使用预计算的均线列
|
||||
ma_short_col = f"ma_{ma_short}"
|
||||
ma_long_col = f"ma_{ma_long}"
|
||||
|
||||
if ma_short_col not in df.columns or ma_long_col not in df.columns:
|
||||
continue
|
||||
|
||||
# 计算金叉和放量
|
||||
df["ma_short_prev"] = df[ma_short_col].shift(1)
|
||||
df["ma_long_prev"] = df[ma_long_col].shift(1)
|
||||
df["golden_cross"] = (
|
||||
(df[ma_short_col] > df[ma_long_col]) &
|
||||
(df["ma_short_prev"] <= df["ma_long_prev"])
|
||||
)
|
||||
|
||||
# 计算死叉(卖出信号)
|
||||
df["death_cross"] = (
|
||||
(df[ma_short_col] < df[ma_long_col]) &
|
||||
(df["ma_short_prev"] >= df["ma_long_prev"])
|
||||
)
|
||||
|
||||
# 放量信号:成交量增幅>20%
|
||||
# 注意:vol_pct_change 应该已经在均线预计算时计算过了
|
||||
if "vol_pct_change" not in df.columns:
|
||||
df["vol_pct_change"] = df["vol"].pct_change()
|
||||
df["volume_surge"] = df["vol_pct_change"] > 0.2
|
||||
df["buy_signal"] = df["golden_cross"] & df["volume_surge"]
|
||||
|
||||
# 建立买入信号索引
|
||||
buy_dates = df[df["buy_signal"] == True]["trade_date"].tolist()
|
||||
for date in buy_dates:
|
||||
if date not in buy_signal_index:
|
||||
buy_signal_index[date] = []
|
||||
buy_signal_index[date].append(ts_code)
|
||||
|
||||
elif strategy_name == "OczStrategy":
|
||||
# OCZ策略:预计算回踩信号
|
||||
# 注意:只计算依赖当前参数的指标,其他基础指标可复用
|
||||
N = params.get("N", 30)
|
||||
B = params.get("B", 60.0)
|
||||
V1 = params.get("V1", 1.5)
|
||||
TOL = params.get("TOL", 1.5)
|
||||
R = params.get("R", 4.0)
|
||||
volatility_min = params.get("volatility_min", 2.5)
|
||||
volatility_max = params.get("volatility_max", 8.0)
|
||||
|
||||
for ts_code, df in data_dict.items():
|
||||
if df.empty or len(df) < N + 2:
|
||||
continue
|
||||
|
||||
# 基础指标:只在第一次计算,后续复用
|
||||
# body, range, body_pct, return_pct, volatility 不依赖于参数,可以复用
|
||||
if "body" not in df.columns:
|
||||
df["body"] = (df["close"] - df["open"]).abs()
|
||||
if "range" not in df.columns:
|
||||
df["range"] = df["high"] - df["low"]
|
||||
if "body_pct" not in df.columns:
|
||||
df["body_pct"] = df["body"] / df["range"] * 100
|
||||
if "return_pct" not in df.columns:
|
||||
df["return_pct"] = (df["close"] / df["close"].shift(1) - 1) * 100
|
||||
if "volatility" not in df.columns:
|
||||
df["volatility"] = (df["high"] - df["low"]).rolling(window=30, min_periods=30).mean() / df["close"] * 100
|
||||
|
||||
# 依赖参数N的指标:每次重新计算
|
||||
df["resistance"] = df["high"].rolling(window=N, min_periods=N).max()
|
||||
df["vol_ma_n"] = df["vol"].rolling(window=N, min_periods=N).mean()
|
||||
|
||||
# 识别突破信号(依赖B、R、V1参数)
|
||||
breakthrough = (
|
||||
(df["close"] > df["resistance"].shift(1)) &
|
||||
(df["body_pct"] > B) &
|
||||
(df["return_pct"] >= R) &
|
||||
(df["vol"] > df["vol_ma_n"] * V1) &
|
||||
(df["volatility"] >= volatility_min) &
|
||||
(df["volatility"] <= volatility_max)
|
||||
)
|
||||
df["breakthrough"] = breakthrough
|
||||
|
||||
# 计算距离突破的天数(向量化优化)
|
||||
breakthrough_indices = df.index[df["breakthrough"]].tolist()
|
||||
if not breakthrough_indices:
|
||||
df["bars_since_breakthrough"] = 999999
|
||||
else:
|
||||
breakthrough_indices_series = pd.Series(breakthrough_indices)
|
||||
df["bars_since_breakthrough"] = df.index.to_series().apply(
|
||||
lambda idx: idx - breakthrough_indices_series[breakthrough_indices_series <= idx].max()
|
||||
if breakthrough_indices_series[breakthrough_indices_series <= idx].any()
|
||||
else 999999
|
||||
).values
|
||||
|
||||
# 识别回踩信号(依赖TOL参数)
|
||||
df["resistance_2days_ago"] = df["resistance"].shift(2)
|
||||
df["vol_breakthrough"] = df["vol"].shift(1)
|
||||
|
||||
pullback_signal = (
|
||||
(df["bars_since_breakthrough"] == 1) &
|
||||
(df["low"] >= df["resistance_2days_ago"] * (1 - TOL / 100)) &
|
||||
(df["low"] <= df["resistance_2days_ago"] * (1 + TOL / 100)) &
|
||||
(df["close"] > df["resistance_2days_ago"]) &
|
||||
(df["vol"] < df["vol_breakthrough"])
|
||||
)
|
||||
df["pullback_signal"] = pullback_signal
|
||||
|
||||
# 建立买入信号索引
|
||||
signal_dates = df[df["pullback_signal"] == True]["trade_date"].tolist()
|
||||
for date in signal_dates:
|
||||
if date not in buy_signal_index:
|
||||
buy_signal_index[date] = []
|
||||
buy_signal_index[date].append(ts_code)
|
||||
|
||||
# 实例化策略(传入日期索引、买入信号索引、风险管理模块)
|
||||
strategy = strategy_class(
|
||||
initial_cash=initial_capital,
|
||||
stop_loss=stop_loss,
|
||||
take_profit=take_profit,
|
||||
position_sizer=None, # 禁用仓位管理器,使用策略自带逻辑
|
||||
date_index_dict=date_index_dict,
|
||||
buy_signal_index=buy_signal_index,
|
||||
**params
|
||||
)
|
||||
|
||||
# 运行回测
|
||||
equity_df = strategy.run_backtest(data_dict, calendar)
|
||||
|
||||
# 计算绩效
|
||||
perf = calc_performance(
|
||||
equity_df=equity_df,
|
||||
trade_count=strategy.trade_count,
|
||||
trade_history=strategy.trade_history,
|
||||
)
|
||||
|
||||
# 清理内存
|
||||
del strategy
|
||||
del equity_df
|
||||
gc.collect()
|
||||
|
||||
# 返回参数 + 绩效
|
||||
result = {
|
||||
"idx": idx,
|
||||
"params": params,
|
||||
"total_return": perf["cum_return"], # 注意:calc_performance 返回 cum_return
|
||||
"annual_return": perf["ann_return"], # 注意:calc_performance 返回 ann_return
|
||||
"sharpe": perf["sharpe"],
|
||||
"max_drawdown": perf["max_drawdown"],
|
||||
"avg_capital_utilization": perf["avg_capital_utilization"],
|
||||
"total_trades": perf["total_trades"],
|
||||
"avg_trades_per_year": perf["avg_trades_per_year"],
|
||||
"win_rate": perf.get("win_rate", 0.0),
|
||||
"profit_loss_ratio": perf.get("profit_loss_ratio", 0.0),
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"参数 {params} 回测失败: {e}")
|
||||
return {
|
||||
"idx": idx,
|
||||
"params": params,
|
||||
"total_return": None,
|
||||
"annual_return": None,
|
||||
"sharpe": None,
|
||||
"max_drawdown": None,
|
||||
"avg_capital_utilization": None,
|
||||
"total_trades": None,
|
||||
"avg_trades_per_year": None,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
def grid_search(
|
||||
strategy_class: Type[BaseStrategy],
|
||||
data_dict: Dict[str, pd.DataFrame],
|
||||
calendar: List[str],
|
||||
param_space: Dict[str, Any],
|
||||
initial_capital: float = 1_000_000,
|
||||
n_jobs: int = 4,
|
||||
metric: str = "sharpe",
|
||||
top_n: int = 20,
|
||||
output_dir: Path = None,
|
||||
constraint_func: callable = None,
|
||||
date_index_dict: Dict[str, Dict[str, int]] = None,
|
||||
stop_loss: object = None,
|
||||
take_profit: object = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""网格搜索参数优化(多进程并行)。
|
||||
|
||||
参数:
|
||||
strategy_class: 策略类(未实例化)
|
||||
data_dict: 股票数据字典 {ts_code: DataFrame}
|
||||
calendar: 交易日列表
|
||||
param_space: 参数空间定义
|
||||
initial_capital: 初始资金
|
||||
n_jobs: 并行进程数
|
||||
metric: 排序指标,可选 "sharpe", "total_return", "max_drawdown", "annual_return"
|
||||
top_n: 保存前 N 个结果
|
||||
output_dir: 结果输出目录
|
||||
constraint_func: 参数约束函数,过滤无效的参数组合
|
||||
date_index_dict: 日期索引字典 {ts_code: {date: idx}},用于性能优化
|
||||
stop_loss: 止损管理器
|
||||
take_profit: 止盈管理器
|
||||
|
||||
返回:
|
||||
List[Dict]: 按 metric 排序的参数组合及绩效列表
|
||||
"""
|
||||
# 验证参数空间
|
||||
if not validate_param_space(param_space):
|
||||
logger.error("参数空间不合法")
|
||||
return []
|
||||
|
||||
# 生成参数组合
|
||||
from optimization.param_space import apply_param_constraints
|
||||
combinations = generate_param_combinations(param_space)
|
||||
|
||||
# 应用约束条件
|
||||
if constraint_func is not None:
|
||||
combinations = apply_param_constraints(combinations, constraint_func)
|
||||
|
||||
total_comb = len(combinations)
|
||||
|
||||
if total_comb == 0:
|
||||
logger.warning("参数组合数为 0,无需优化")
|
||||
return []
|
||||
|
||||
logger.info(f"开始参数扫描,共 {total_comb} 组参数,并行 {n_jobs} 核")
|
||||
logger.info(f"排序指标: {metric}, 保留前 {top_n} 组结果")
|
||||
|
||||
# 如果没有传入date_index_dict,预计算一个
|
||||
if date_index_dict is None:
|
||||
logger.info("预计算日期索引字典...")
|
||||
date_index_dict = {}
|
||||
for ts_code, df in data_dict.items():
|
||||
date_to_idx = {}
|
||||
for idx, trade_date in enumerate(df["trade_date"]):
|
||||
date_to_idx[trade_date] = idx
|
||||
date_index_dict[ts_code] = date_to_idx
|
||||
logger.info("日期索引预计算完成")
|
||||
|
||||
# 准备任务参数
|
||||
tasks = [
|
||||
(strategy_class, data_dict, calendar, params, initial_capital, idx, date_index_dict, stop_loss, take_profit)
|
||||
for idx, params in enumerate(combinations)
|
||||
]
|
||||
|
||||
# 多进程并行执行
|
||||
start_time = time.time()
|
||||
results = []
|
||||
|
||||
with mp.Pool(processes=n_jobs) as pool:
|
||||
# 使用 imap_unordered 配合 tqdm 显示进度
|
||||
for result in tqdm(
|
||||
pool.imap_unordered(_run_single_backtest, tasks),
|
||||
total=total_comb,
|
||||
desc="参数优化进度",
|
||||
unit="组"
|
||||
):
|
||||
results.append(result)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(f"参数扫描完成,耗时 {elapsed:.2f} 秒")
|
||||
|
||||
# 过滤失败的结果
|
||||
valid_results = [r for r in results if r.get(metric) is not None]
|
||||
failed_count = len(results) - len(valid_results)
|
||||
|
||||
if failed_count > 0:
|
||||
logger.warning(f"有 {failed_count} 组参数回测失败")
|
||||
|
||||
if not valid_results:
|
||||
logger.error("所有参数组合都失败,无法生成结果")
|
||||
return []
|
||||
|
||||
# 按指定指标排序
|
||||
reverse = True # 默认降序(越大越好)
|
||||
if metric == "max_drawdown":
|
||||
reverse = False # 最大回撤越小越好
|
||||
|
||||
sorted_results = sorted(valid_results, key=lambda x: x[metric], reverse=reverse)
|
||||
|
||||
# 取前 top_n
|
||||
top_results = sorted_results[:top_n]
|
||||
|
||||
# 输出结果到 CSV
|
||||
if output_dir is None:
|
||||
from config.settings import RESULTS_DIR
|
||||
output_dir = RESULTS_DIR / "optimization"
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
csv_path = output_dir / f"grid_search_{timestamp}.csv"
|
||||
|
||||
_save_results_to_csv(top_results, csv_path)
|
||||
logger.info(f"参数优化结果已保存到: {csv_path}")
|
||||
|
||||
# 打印前 5 组结果
|
||||
logger.info("=" * 80)
|
||||
logger.info(f"参数优化 Top {min(5, len(top_results))} 结果(按 {metric} 排序)")
|
||||
logger.info("=" * 80)
|
||||
|
||||
for i, result in enumerate(top_results[:5], 1):
|
||||
logger.info(f"第 {i} 名:")
|
||||
logger.info(f" 参数: {result['params']}")
|
||||
logger.info(f" 累计收益: {result['total_return']*100:+.2f}%")
|
||||
logger.info(f" 年化收益: {result['annual_return']*100:+.2f}%")
|
||||
logger.info(f" 夏普比率: {result['sharpe']:.4f}")
|
||||
logger.info(f" 最大回撤: {result['max_drawdown']*100:.2f}%")
|
||||
logger.info(f" 总交易次数: {result['total_trades']}")
|
||||
logger.info("-" * 80)
|
||||
|
||||
return top_results
|
||||
|
||||
|
||||
def _save_results_to_csv(results: List[Dict[str, Any]], csv_path: Path) -> None:
|
||||
"""将优化结果保存为 CSV 文件。
|
||||
|
||||
参数:
|
||||
results: 结果列表
|
||||
csv_path: CSV 文件路径
|
||||
"""
|
||||
if not results:
|
||||
logger.warning("结果为空,不保存 CSV")
|
||||
return
|
||||
|
||||
# 展开 params 字典为独立列
|
||||
rows = []
|
||||
for result in results:
|
||||
row = {}
|
||||
# 添加参数列
|
||||
if "params" in result:
|
||||
for param_name, param_value in result["params"].items():
|
||||
row[f"param_{param_name}"] = param_value
|
||||
|
||||
# 添加绩效列
|
||||
for key in ["total_return", "annual_return", "sharpe", "max_drawdown",
|
||||
"avg_capital_utilization", "total_trades", "avg_trades_per_year"]:
|
||||
row[key] = result.get(key)
|
||||
|
||||
rows.append(row)
|
||||
|
||||
# 写入 CSV
|
||||
df = pd.DataFrame(rows)
|
||||
df.to_csv(csv_path, index=False, encoding="utf-8-sig")
|
||||
logger.info(f"结果已保存,共 {len(rows)} 行")
|
||||
Reference in New Issue
Block a user