Files
backtrader/parameter_optimizer.py
2026-01-17 21:21:30 +08:00

450 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""
参数优化模块
用于优化交易策略的参数,支持多种优化算法
"""
import pandas as pd
import numpy as np
import itertools
import random
import logging
import time
import sys
import os
from typing import Dict, List, Tuple, Callable, Union
from concurrent.futures import ThreadPoolExecutor, as_completed
from backtest.backtest_engine import BacktestEngine
from strategy.yiyang_strategy import YiYangStrategy
from utils.logger import setup_logger
# 设置日志
logger = setup_logger('parameter_optimizer', 'logs/parameter_optimizer.log')
# 中文指标映射
METRIC_NAME_MAP = {
'total_return': '总收益率',
'annual_return': '年化收益率',
'sharpe_ratio': '夏普比率',
'max_drawdown': '最大回撤',
'win_rate': '胜率',
'avg_win_pct': '平均盈利%',
'avg_loss_pct': '平均亏损%',
'profit_factor': '盈亏比',
'total_trades': '交易次数',
'avg_holding_period': '平均持有周期',
'min_pct_change': '最小涨幅%',
'min_volume_ratio': '最小量比',
'max_ma_gap_pct': '均线最大间距%',
'min_entity_ratio': '实体最小比例',
'confirm_days': '确认天数'
}
class ParameterOptimizer:
"""
参数优化器类
用于优化交易策略的参数,支持网格搜索和随机搜索算法
"""
def __init__(self, data_fetcher, strategy_class=YiYangStrategy, max_workers=4):
"""
初始化参数优化器
Args:
data_fetcher: 数据获取器实例
strategy_class: 策略类默认为YiYangStrategy
max_workers: 多线程池大小
"""
self.data_fetcher = data_fetcher
self.strategy_class = strategy_class
self.max_workers = max_workers
self.results = [] # 保存优化结果
def define_parameter_space(self, parameter_ranges: Dict[str, Union[List, range]]) -> Dict:
"""
定义参数搜索空间
Args:
parameter_ranges: 参数范围字典键为参数名值为参数范围列表或range对象
Returns:
Dict: 参数空间定义
"""
return parameter_ranges
def grid_search(self, parameter_ranges: Dict, symbol: str, start_date: str, end_date: str,
initial_capital: float = 100000, metric: str = 'total_return') -> Tuple[Dict, float]:
"""
网格搜索优化(支持多线程)
Args:
parameter_ranges: 参数范围字典
symbol: 股票代码
start_date: 开始日期格式YYYYMMDD
end_date: 结束日期格式YYYYMMDD
initial_capital: 初始资金
metric: 优化指标
Returns:
Tuple[Dict, float]: 最佳参数组合和最佳指标值
"""
start_time = time.time()
metric_cn = METRIC_NAME_MAP.get(metric, metric)
logger.info(f"开始网格搜索优化,参数范围: {parameter_ranges}")
# 获取所有参数组合
param_names = list(parameter_ranges.keys())
param_values = [list(parameter_ranges[name]) for name in param_names]
param_combinations = list(itertools.product(*param_values))
total_combinations = len(param_combinations)
logger.info(f"总共有 {total_combinations} 种参数组合,使用 {self.max_workers} 个线程")
# 【关键修复】使用多线程执行
best_params = None
best_metric = -float('inf') if metric != 'max_drawdown' else float('inf')
completed = 0
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
# 提交所有任务
future_to_params = {}
for params in param_combinations:
param_dict = dict(zip(param_names, params))
future = executor.submit(self._run_backtest, param_dict, symbol,
start_date, end_date, initial_capital)
future_to_params[future] = param_dict
# 处理完成的任务
for future in as_completed(future_to_params):
completed += 1
param_dict = future_to_params[future]
# 显示进度
elapsed_time = time.time() - start_time
remaining_time = (elapsed_time / completed) * (total_combinations - completed)
sys.stdout.write(f"\r网格搜索进度: {completed}/{total_combinations} "
f"({completed/total_combinations*100:.1f}%) | "
f"已耗时: {elapsed_time:.2f}s | 预计剩余: {remaining_time:.2f}s")
sys.stdout.flush()
try:
result = future.result()
if result:
# 保存结果
numeric_result = {k: v for k, v in result.items() if not isinstance(v, pd.DataFrame)}
self.results.append({
'params': param_dict,
'result': numeric_result
})
# 更新最佳参数
current_metric = result[metric]
if (metric != 'max_drawdown' and current_metric > best_metric) or \
(metric == 'max_drawdown' and current_metric < best_metric):
best_metric = current_metric
best_params = param_dict
logger.info(f"找到更优参数: {best_params}, {metric_cn}: {best_metric:.4f}")
except Exception as e:
logger.error(f"参数组合 {param_dict} 失败: {str(e)}")
# 显示总耗时统计
total_time = time.time() - start_time
avg_time = total_time / total_combinations
print(f"\n\n" + "=" * 60)
print("网格搜索优化完成!")
print(f"总耗时: {total_time:.2f}")
print(f"参数组合总数: {total_combinations}")
print(f"平均每组耗时: {avg_time:.2f}")
print(f"多线程数: {self.max_workers}")
print(f"加速比: {(avg_time * total_combinations / total_time):.2f}x")
print("=" * 60)
logger.info(f"网格搜索完成,最佳参数: {best_params}, 最佳{metric_cn}: {best_metric}")
return best_params, best_metric
def random_search(self, parameter_ranges: Dict, symbol: str, start_date: str, end_date: str,
initial_capital: float = 100000, n_iter: int = 50,
metric: str = 'total_return') -> Tuple[Dict, float]:
"""
随机搜索优化
Args:
parameter_ranges: 参数范围字典
symbol: 股票代码
start_date: 开始日期格式YYYYMMDD
end_date: 结束日期格式YYYYMMDD
initial_capital: 初始资金
n_iter: 迭代次数
metric: 优化指标,可选'total_return', 'annualized_return', 'sharpe_ratio', 'max_drawdown'
Returns:
Tuple[Dict, float]: 最佳参数组合和最佳指标值
"""
start_time = time.time()
logger.info(f"开始随机搜索优化,参数范围: {parameter_ranges}, 迭代次数: {n_iter}")
param_names = list(parameter_ranges.keys())
# 运行随机参数组合
best_params = None
best_metric = -float('inf') if metric != 'max_drawdown' else float('inf')
iteration_times = []
for i in range(n_iter):
# 显示进度和实时耗时
elapsed_time = time.time() - start_time
iteration_start = time.time()
remaining_time = (elapsed_time / (i + 1)) * (n_iter - (i + 1)) if i > 0 else 0
sys.stdout.write(f"\r随机搜索进度: {i+1}/{n_iter} ({(i+1)/n_iter*100:.1f}%) | "
f"已耗时: {elapsed_time:.2f}s | 预计剩余: {remaining_time:.2f}s")
sys.stdout.flush()
# 随机生成参数组合
param_dict = {}
for name in param_names:
values = list(parameter_ranges[name])
param_dict[name] = random.choice(values)
logger.info(f"正在测试第 {i+1}/{n_iter} 组参数: {param_dict}")
# 运行回测
result = self._run_backtest(param_dict, symbol, start_date, end_date, initial_capital)
iteration_time = time.time() - iteration_start
iteration_times.append(iteration_time)
if result:
# 只保留数值指标,排除复杂数据结构
numeric_result = {k: v for k, v in result.items() if not isinstance(v, pd.DataFrame)}
# 保存结果
self.results.append({
'params': param_dict,
'result': numeric_result,
'run_time': iteration_time
})
# 更新最佳参数
current_metric = result[metric]
if (metric != 'max_drawdown' and current_metric > best_metric) or \
(metric == 'max_drawdown' and current_metric < best_metric):
best_metric = current_metric
best_params = param_dict
logger.info(f"找到更优参数: {best_params}, {metric}: {best_metric}, 耗时: {iteration_time:.2f}s")
# 显示总耗时统计
total_time = time.time() - start_time
avg_iteration_time = np.mean(iteration_times) if iteration_times else 0
max_iteration_time = np.max(iteration_times) if iteration_times else 0
min_iteration_time = np.min(iteration_times) if iteration_times else 0
print(f"\n\n" + "=" * 60)
print("随机搜索优化完成!")
print(f"总耗时: {total_time:.2f}")
print(f"迭代次数: {n_iter}")
print(f"平均每次迭代耗时: {avg_iteration_time:.2f}")
print(f"最快/最慢迭代耗时: {min_iteration_time:.2f}s / {max_iteration_time:.2f}s")
print("=" * 60)
logger.info(f"随机搜索完成,最佳参数: {best_params}, 最佳{metric}: {best_metric}")
return best_params, best_metric
def _run_backtest(self, params: Dict, symbol: str, start_date: str, end_date: str,
initial_capital: float) -> Dict:
"""
运行单个参数组合的回测
Args:
params: 参数组合
symbol: 股票代码
start_date: 开始日期
end_date: 结束日期
initial_capital: 初始资金
Returns:
Dict: 回测结果
"""
try:
# 获取数据
data = self.data_fetcher.get_daily_data(ts_code=symbol, start_date=start_date, end_date=end_date)
if data.empty:
logger.warning(f"无法获取 {symbol} 的数据")
return None
# 创建策略实例
strategy_config = params.copy() # 将参数作为策略配置
strategy = self.strategy_class(strategy_config)
# 创建回测引擎配置
backtest_config = {
'initial_capital': initial_capital,
'commission_rate': 0.0003,
'stamp_tax_rate': 0.001,
'slippage_rate': 0.001,
'position_ratio': 0.2,
'stop_loss': 0.08,
'take_profit': 0.20
}
# 运行回测
engine = BacktestEngine(backtest_config)
data_dict = {symbol: data} # 回测引擎需要字典格式的数据
results = engine.run_backtest(data_dict, strategy)
return results
except Exception as e:
logger.error(f"回测失败,参数: {params},错误: {str(e)}")
return None
def get_results(self) -> List[Dict]:
"""
获取所有优化结果
Returns:
List[Dict]: 优化结果列表
"""
return self.results
def get_best_params(self, metric: str = 'total_return') -> Tuple[Dict, float]:
"""
获取最佳参数组合
Args:
metric: 优化指标
Returns:
Tuple[Dict, float]: 最佳参数组合和最佳指标值
"""
if not self.results:
return None, None
best_result = None
best_metric_value = -float('inf') if metric != 'max_drawdown' else float('inf')
for result in self.results:
current_metric = result['result'][metric]
if (metric != 'max_drawdown' and current_metric > best_metric_value) or \
(metric == 'max_drawdown' and current_metric < best_metric_value):
best_metric_value = current_metric
best_result = result['params']
return best_result, best_metric_value
def export_results(self, filename: str = 'optimization_results.csv', strategy_name: str = None) -> None:
"""
导出优化结果到CSV文件使用中文标题
Args:
filename: 文件名
strategy_name: 策略名称,用于创建子目录
"""
if not self.results:
logger.warning("没有结果可导出")
return
# 创建目录结构
if strategy_name:
optimization_dir = os.path.join('optimization_results', strategy_name)
if not os.path.exists(optimization_dir):
os.makedirs(optimization_dir)
filepath = os.path.join(optimization_dir, filename)
else:
filepath = filename
# 转换结果为DataFrame
data = []
for result in self.results:
row = result['params'].copy()
row.update(result['result'])
data.append(row)
df = pd.DataFrame(data)
# 【关键修复】将列名替换为中文
df.rename(columns=METRIC_NAME_MAP, inplace=True)
df.to_csv(filepath, index=False, encoding='utf-8-sig')
logger.info(f"优化结果已导出到 {filepath}")
def analyze_results(self) -> None:
"""
分析优化结果
"""
if not self.results:
logger.warning("没有结果可分析")
return
# 转换结果为DataFrame
data = []
for result in self.results:
row = result['params'].copy()
row.update(result['result'])
data.append(row)
df = pd.DataFrame(data)
# 打印基本统计信息
logger.info("优化结果统计:")
logger.info(f"总测试参数组合数: {len(df)}")
logger.info("\n主要指标统计:")
logger.info(df[['total_return', 'annual_return', 'sharpe_ratio', 'max_drawdown']].describe())
# 打印相关性分析
logger.info("\n参数与指标相关性:")
param_cols = [col for col in df.columns if col not in ['total_return', 'annual_return', 'sharpe_ratio',
'max_drawdown', 'win_rate', 'avg_win_pct',
'avg_loss_pct', 'profit_factor', 'total_trades',
'avg_holding_period']]
corr_matrix = df[param_cols + ['total_return', 'annual_return', 'sharpe_ratio', 'max_drawdown']].corr()
logger.info(corr_matrix.to_string())
# 示例用法
if __name__ == "__main__":
from config import Config
from data.tushare_data import TushareDataFetcher
# 加载配置
config = Config()
# 初始化数据获取器
data_fetcher = TushareDataFetcher(
token=config.TUSHARE_TOKEN,
call_interval=config.TUSHARE_CALL_INTERVAL
)
# 创建参数优化器
optimizer = ParameterOptimizer(data_fetcher)
# 定义参数搜索空间
param_space = {
'min_pct_change': [2.0, 2.5, 3.0],
'min_volume_ratio': [1.0, 1.5],
'max_ma_gap_pct': [4.0, 5.0],
'min_entity_ratio': [0.6, 0.7],
'confirm_days': [2, 3]
}
# 运行网格搜索
best_params, best_return = optimizer.grid_search(
parameter_ranges=param_space,
symbol='600252.SH',
start_date='20200101',
end_date='20231231',
initial_capital=100000,
metric='sharpe_ratio'
)
# 打印结果
logger.info(f"最佳参数组合: {best_params}")
logger.info(f"最佳夏普比率: {best_return}")
# 导出结果
optimizer.export_results('optimization_results.csv')
# 分析结果
optimizer.analyze_results()