450 lines
17 KiB
Python
450 lines
17 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
参数优化模块
|
||
用于优化交易策略的参数,支持多种优化算法
|
||
"""
|
||
|
||
import pandas as pd
|
||
import numpy as np
|
||
import itertools
|
||
import random
|
||
import logging
|
||
import time
|
||
import sys
|
||
import os
|
||
from typing import Dict, List, Tuple, Callable, Union
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
from backtest.backtest_engine import BacktestEngine
|
||
from strategy.yiyang_strategy import YiYangStrategy
|
||
from utils.logger import setup_logger
|
||
|
||
# 设置日志
|
||
logger = setup_logger('parameter_optimizer', 'logs/parameter_optimizer.log')
|
||
|
||
# 中文指标映射
|
||
METRIC_NAME_MAP = {
|
||
'total_return': '总收益率',
|
||
'annual_return': '年化收益率',
|
||
'sharpe_ratio': '夏普比率',
|
||
'max_drawdown': '最大回撤',
|
||
'win_rate': '胜率',
|
||
'avg_win_pct': '平均盈利%',
|
||
'avg_loss_pct': '平均亏损%',
|
||
'profit_factor': '盈亏比',
|
||
'total_trades': '交易次数',
|
||
'avg_holding_period': '平均持有周期',
|
||
'min_pct_change': '最小涨幅%',
|
||
'min_volume_ratio': '最小量比',
|
||
'max_ma_gap_pct': '均线最大间距%',
|
||
'min_entity_ratio': '实体最小比例',
|
||
'confirm_days': '确认天数'
|
||
}
|
||
|
||
class ParameterOptimizer:
|
||
"""
|
||
参数优化器类
|
||
用于优化交易策略的参数,支持网格搜索和随机搜索算法
|
||
"""
|
||
|
||
def __init__(self, data_fetcher, strategy_class=YiYangStrategy, max_workers=4):
|
||
"""
|
||
初始化参数优化器
|
||
|
||
Args:
|
||
data_fetcher: 数据获取器实例
|
||
strategy_class: 策略类,默认为YiYangStrategy
|
||
max_workers: 多线程池大小
|
||
"""
|
||
self.data_fetcher = data_fetcher
|
||
self.strategy_class = strategy_class
|
||
self.max_workers = max_workers
|
||
self.results = [] # 保存优化结果
|
||
|
||
def define_parameter_space(self, parameter_ranges: Dict[str, Union[List, range]]) -> Dict:
|
||
"""
|
||
定义参数搜索空间
|
||
|
||
Args:
|
||
parameter_ranges: 参数范围字典,键为参数名,值为参数范围(列表或range对象)
|
||
|
||
Returns:
|
||
Dict: 参数空间定义
|
||
"""
|
||
return parameter_ranges
|
||
|
||
def grid_search(self, parameter_ranges: Dict, symbol: str, start_date: str, end_date: str,
|
||
initial_capital: float = 100000, metric: str = 'total_return') -> Tuple[Dict, float]:
|
||
"""
|
||
网格搜索优化(支持多线程)
|
||
|
||
Args:
|
||
parameter_ranges: 参数范围字典
|
||
symbol: 股票代码
|
||
start_date: 开始日期,格式YYYYMMDD
|
||
end_date: 结束日期,格式YYYYMMDD
|
||
initial_capital: 初始资金
|
||
metric: 优化指标
|
||
|
||
Returns:
|
||
Tuple[Dict, float]: 最佳参数组合和最佳指标值
|
||
"""
|
||
start_time = time.time()
|
||
metric_cn = METRIC_NAME_MAP.get(metric, metric)
|
||
logger.info(f"开始网格搜索优化,参数范围: {parameter_ranges}")
|
||
|
||
# 获取所有参数组合
|
||
param_names = list(parameter_ranges.keys())
|
||
param_values = [list(parameter_ranges[name]) for name in param_names]
|
||
param_combinations = list(itertools.product(*param_values))
|
||
|
||
total_combinations = len(param_combinations)
|
||
logger.info(f"总共有 {total_combinations} 种参数组合,使用 {self.max_workers} 个线程")
|
||
|
||
# 【关键修复】使用多线程执行
|
||
best_params = None
|
||
best_metric = -float('inf') if metric != 'max_drawdown' else float('inf')
|
||
completed = 0
|
||
|
||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||
# 提交所有任务
|
||
future_to_params = {}
|
||
for params in param_combinations:
|
||
param_dict = dict(zip(param_names, params))
|
||
future = executor.submit(self._run_backtest, param_dict, symbol,
|
||
start_date, end_date, initial_capital)
|
||
future_to_params[future] = param_dict
|
||
|
||
# 处理完成的任务
|
||
for future in as_completed(future_to_params):
|
||
completed += 1
|
||
param_dict = future_to_params[future]
|
||
|
||
# 显示进度
|
||
elapsed_time = time.time() - start_time
|
||
remaining_time = (elapsed_time / completed) * (total_combinations - completed)
|
||
sys.stdout.write(f"\r网格搜索进度: {completed}/{total_combinations} "
|
||
f"({completed/total_combinations*100:.1f}%) | "
|
||
f"已耗时: {elapsed_time:.2f}s | 预计剩余: {remaining_time:.2f}s")
|
||
sys.stdout.flush()
|
||
|
||
try:
|
||
result = future.result()
|
||
|
||
if result:
|
||
# 保存结果
|
||
numeric_result = {k: v for k, v in result.items() if not isinstance(v, pd.DataFrame)}
|
||
self.results.append({
|
||
'params': param_dict,
|
||
'result': numeric_result
|
||
})
|
||
|
||
# 更新最佳参数
|
||
current_metric = result[metric]
|
||
if (metric != 'max_drawdown' and current_metric > best_metric) or \
|
||
(metric == 'max_drawdown' and current_metric < best_metric):
|
||
best_metric = current_metric
|
||
best_params = param_dict
|
||
logger.info(f"找到更优参数: {best_params}, {metric_cn}: {best_metric:.4f}")
|
||
except Exception as e:
|
||
logger.error(f"参数组合 {param_dict} 失败: {str(e)}")
|
||
|
||
# 显示总耗时统计
|
||
total_time = time.time() - start_time
|
||
avg_time = total_time / total_combinations
|
||
|
||
print(f"\n\n" + "=" * 60)
|
||
print("网格搜索优化完成!")
|
||
print(f"总耗时: {total_time:.2f}秒")
|
||
print(f"参数组合总数: {total_combinations}")
|
||
print(f"平均每组耗时: {avg_time:.2f}秒")
|
||
print(f"多线程数: {self.max_workers}")
|
||
print(f"加速比: {(avg_time * total_combinations / total_time):.2f}x")
|
||
print("=" * 60)
|
||
|
||
logger.info(f"网格搜索完成,最佳参数: {best_params}, 最佳{metric_cn}: {best_metric}")
|
||
return best_params, best_metric
|
||
|
||
def random_search(self, parameter_ranges: Dict, symbol: str, start_date: str, end_date: str,
|
||
initial_capital: float = 100000, n_iter: int = 50,
|
||
metric: str = 'total_return') -> Tuple[Dict, float]:
|
||
"""
|
||
随机搜索优化
|
||
|
||
Args:
|
||
parameter_ranges: 参数范围字典
|
||
symbol: 股票代码
|
||
start_date: 开始日期,格式YYYYMMDD
|
||
end_date: 结束日期,格式YYYYMMDD
|
||
initial_capital: 初始资金
|
||
n_iter: 迭代次数
|
||
metric: 优化指标,可选'total_return', 'annualized_return', 'sharpe_ratio', 'max_drawdown'
|
||
|
||
Returns:
|
||
Tuple[Dict, float]: 最佳参数组合和最佳指标值
|
||
"""
|
||
start_time = time.time()
|
||
logger.info(f"开始随机搜索优化,参数范围: {parameter_ranges}, 迭代次数: {n_iter}")
|
||
|
||
param_names = list(parameter_ranges.keys())
|
||
|
||
# 运行随机参数组合
|
||
best_params = None
|
||
best_metric = -float('inf') if metric != 'max_drawdown' else float('inf')
|
||
|
||
iteration_times = []
|
||
|
||
for i in range(n_iter):
|
||
# 显示进度和实时耗时
|
||
elapsed_time = time.time() - start_time
|
||
iteration_start = time.time()
|
||
|
||
remaining_time = (elapsed_time / (i + 1)) * (n_iter - (i + 1)) if i > 0 else 0
|
||
sys.stdout.write(f"\r随机搜索进度: {i+1}/{n_iter} ({(i+1)/n_iter*100:.1f}%) | "
|
||
f"已耗时: {elapsed_time:.2f}s | 预计剩余: {remaining_time:.2f}s")
|
||
sys.stdout.flush()
|
||
|
||
# 随机生成参数组合
|
||
param_dict = {}
|
||
for name in param_names:
|
||
values = list(parameter_ranges[name])
|
||
param_dict[name] = random.choice(values)
|
||
|
||
logger.info(f"正在测试第 {i+1}/{n_iter} 组参数: {param_dict}")
|
||
|
||
# 运行回测
|
||
result = self._run_backtest(param_dict, symbol, start_date, end_date, initial_capital)
|
||
|
||
iteration_time = time.time() - iteration_start
|
||
iteration_times.append(iteration_time)
|
||
|
||
if result:
|
||
# 只保留数值指标,排除复杂数据结构
|
||
numeric_result = {k: v for k, v in result.items() if not isinstance(v, pd.DataFrame)}
|
||
# 保存结果
|
||
self.results.append({
|
||
'params': param_dict,
|
||
'result': numeric_result,
|
||
'run_time': iteration_time
|
||
})
|
||
|
||
# 更新最佳参数
|
||
current_metric = result[metric]
|
||
if (metric != 'max_drawdown' and current_metric > best_metric) or \
|
||
(metric == 'max_drawdown' and current_metric < best_metric):
|
||
best_metric = current_metric
|
||
best_params = param_dict
|
||
logger.info(f"找到更优参数: {best_params}, {metric}: {best_metric}, 耗时: {iteration_time:.2f}s")
|
||
|
||
# 显示总耗时统计
|
||
total_time = time.time() - start_time
|
||
avg_iteration_time = np.mean(iteration_times) if iteration_times else 0
|
||
max_iteration_time = np.max(iteration_times) if iteration_times else 0
|
||
min_iteration_time = np.min(iteration_times) if iteration_times else 0
|
||
|
||
print(f"\n\n" + "=" * 60)
|
||
print("随机搜索优化完成!")
|
||
print(f"总耗时: {total_time:.2f}秒")
|
||
print(f"迭代次数: {n_iter}")
|
||
print(f"平均每次迭代耗时: {avg_iteration_time:.2f}秒")
|
||
print(f"最快/最慢迭代耗时: {min_iteration_time:.2f}s / {max_iteration_time:.2f}s")
|
||
print("=" * 60)
|
||
|
||
logger.info(f"随机搜索完成,最佳参数: {best_params}, 最佳{metric}: {best_metric}")
|
||
return best_params, best_metric
|
||
|
||
def _run_backtest(self, params: Dict, symbol: str, start_date: str, end_date: str,
|
||
initial_capital: float) -> Dict:
|
||
"""
|
||
运行单个参数组合的回测
|
||
|
||
Args:
|
||
params: 参数组合
|
||
symbol: 股票代码
|
||
start_date: 开始日期
|
||
end_date: 结束日期
|
||
initial_capital: 初始资金
|
||
|
||
Returns:
|
||
Dict: 回测结果
|
||
"""
|
||
try:
|
||
# 获取数据
|
||
data = self.data_fetcher.get_daily_data(ts_code=symbol, start_date=start_date, end_date=end_date)
|
||
|
||
if data.empty:
|
||
logger.warning(f"无法获取 {symbol} 的数据")
|
||
return None
|
||
|
||
# 创建策略实例
|
||
strategy_config = params.copy() # 将参数作为策略配置
|
||
strategy = self.strategy_class(strategy_config)
|
||
|
||
# 创建回测引擎配置
|
||
backtest_config = {
|
||
'initial_capital': initial_capital,
|
||
'commission_rate': 0.0003,
|
||
'stamp_tax_rate': 0.001,
|
||
'slippage_rate': 0.001,
|
||
'position_ratio': 0.2,
|
||
'stop_loss': 0.08,
|
||
'take_profit': 0.20
|
||
}
|
||
|
||
# 运行回测
|
||
engine = BacktestEngine(backtest_config)
|
||
data_dict = {symbol: data} # 回测引擎需要字典格式的数据
|
||
results = engine.run_backtest(data_dict, strategy)
|
||
|
||
return results
|
||
|
||
except Exception as e:
|
||
logger.error(f"回测失败,参数: {params},错误: {str(e)}")
|
||
return None
|
||
|
||
def get_results(self) -> List[Dict]:
|
||
"""
|
||
获取所有优化结果
|
||
|
||
Returns:
|
||
List[Dict]: 优化结果列表
|
||
"""
|
||
return self.results
|
||
|
||
def get_best_params(self, metric: str = 'total_return') -> Tuple[Dict, float]:
|
||
"""
|
||
获取最佳参数组合
|
||
|
||
Args:
|
||
metric: 优化指标
|
||
|
||
Returns:
|
||
Tuple[Dict, float]: 最佳参数组合和最佳指标值
|
||
"""
|
||
if not self.results:
|
||
return None, None
|
||
|
||
best_result = None
|
||
best_metric_value = -float('inf') if metric != 'max_drawdown' else float('inf')
|
||
|
||
for result in self.results:
|
||
current_metric = result['result'][metric]
|
||
if (metric != 'max_drawdown' and current_metric > best_metric_value) or \
|
||
(metric == 'max_drawdown' and current_metric < best_metric_value):
|
||
best_metric_value = current_metric
|
||
best_result = result['params']
|
||
|
||
return best_result, best_metric_value
|
||
|
||
def export_results(self, filename: str = 'optimization_results.csv', strategy_name: str = None) -> None:
|
||
"""
|
||
导出优化结果到CSV文件(使用中文标题)
|
||
|
||
Args:
|
||
filename: 文件名
|
||
strategy_name: 策略名称,用于创建子目录
|
||
"""
|
||
if not self.results:
|
||
logger.warning("没有结果可导出")
|
||
return
|
||
|
||
# 创建目录结构
|
||
if strategy_name:
|
||
optimization_dir = os.path.join('optimization_results', strategy_name)
|
||
if not os.path.exists(optimization_dir):
|
||
os.makedirs(optimization_dir)
|
||
filepath = os.path.join(optimization_dir, filename)
|
||
else:
|
||
filepath = filename
|
||
|
||
# 转换结果为DataFrame
|
||
data = []
|
||
for result in self.results:
|
||
row = result['params'].copy()
|
||
row.update(result['result'])
|
||
data.append(row)
|
||
|
||
df = pd.DataFrame(data)
|
||
|
||
# 【关键修复】将列名替换为中文
|
||
df.rename(columns=METRIC_NAME_MAP, inplace=True)
|
||
|
||
df.to_csv(filepath, index=False, encoding='utf-8-sig')
|
||
logger.info(f"优化结果已导出到 {filepath}")
|
||
|
||
def analyze_results(self) -> None:
|
||
"""
|
||
分析优化结果
|
||
"""
|
||
if not self.results:
|
||
logger.warning("没有结果可分析")
|
||
return
|
||
|
||
# 转换结果为DataFrame
|
||
data = []
|
||
for result in self.results:
|
||
row = result['params'].copy()
|
||
row.update(result['result'])
|
||
data.append(row)
|
||
|
||
df = pd.DataFrame(data)
|
||
|
||
# 打印基本统计信息
|
||
logger.info("优化结果统计:")
|
||
logger.info(f"总测试参数组合数: {len(df)}")
|
||
logger.info("\n主要指标统计:")
|
||
logger.info(df[['total_return', 'annual_return', 'sharpe_ratio', 'max_drawdown']].describe())
|
||
|
||
# 打印相关性分析
|
||
logger.info("\n参数与指标相关性:")
|
||
param_cols = [col for col in df.columns if col not in ['total_return', 'annual_return', 'sharpe_ratio',
|
||
'max_drawdown', 'win_rate', 'avg_win_pct',
|
||
'avg_loss_pct', 'profit_factor', 'total_trades',
|
||
'avg_holding_period']]
|
||
corr_matrix = df[param_cols + ['total_return', 'annual_return', 'sharpe_ratio', 'max_drawdown']].corr()
|
||
logger.info(corr_matrix.to_string())
|
||
|
||
# 示例用法
|
||
if __name__ == "__main__":
|
||
from config import Config
|
||
from data.tushare_data import TushareDataFetcher
|
||
|
||
# 加载配置
|
||
config = Config()
|
||
|
||
# 初始化数据获取器
|
||
data_fetcher = TushareDataFetcher(
|
||
token=config.TUSHARE_TOKEN,
|
||
call_interval=config.TUSHARE_CALL_INTERVAL
|
||
)
|
||
|
||
# 创建参数优化器
|
||
optimizer = ParameterOptimizer(data_fetcher)
|
||
|
||
# 定义参数搜索空间
|
||
param_space = {
|
||
'min_pct_change': [2.0, 2.5, 3.0],
|
||
'min_volume_ratio': [1.0, 1.5],
|
||
'max_ma_gap_pct': [4.0, 5.0],
|
||
'min_entity_ratio': [0.6, 0.7],
|
||
'confirm_days': [2, 3]
|
||
}
|
||
|
||
# 运行网格搜索
|
||
best_params, best_return = optimizer.grid_search(
|
||
parameter_ranges=param_space,
|
||
symbol='600252.SH',
|
||
start_date='20200101',
|
||
end_date='20231231',
|
||
initial_capital=100000,
|
||
metric='sharpe_ratio'
|
||
)
|
||
|
||
# 打印结果
|
||
logger.info(f"最佳参数组合: {best_params}")
|
||
logger.info(f"最佳夏普比率: {best_return}")
|
||
|
||
# 导出结果
|
||
optimizer.export_results('optimization_results.csv')
|
||
|
||
# 分析结果
|
||
optimizer.analyze_results() |