460 lines
19 KiB
Python
460 lines
19 KiB
Python
from strategy.yiyang_strategy import YiYangStrategy
|
||
from strategy.dual_moving_average_strategy import DualMovingAverageStrategy
|
||
import sys
|
||
import os
|
||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||
|
||
import pandas as pd
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
from datetime import datetime
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
|
||
from config import Config
|
||
from data.tushare_data import TushareDataFetcher
|
||
from data.local_data import LocalDataFetcher
|
||
from data.stock_filter import StockFilter
|
||
from backtest.backtest_engine import BacktestEngine
|
||
from utils.logger import setup_logger
|
||
from analysis.chart_generator import ChartGenerator
|
||
from parameter_optimizer import ParameterOptimizer
|
||
|
||
logger = setup_logger()
|
||
|
||
def run_parameter_optimization(config, data_fetcher, strategy_name="YiYangStrategy"):
|
||
"""
|
||
运行参数优化
|
||
|
||
Args:
|
||
config: 配置对象
|
||
data_fetcher: 数据获取器
|
||
strategy_name: 策略名称
|
||
"""
|
||
print("\n" + "=" * 60)
|
||
print("参数优化模块")
|
||
print("=" * 60)
|
||
|
||
# 创建参数优化器
|
||
optimizer = ParameterOptimizer(data_fetcher)
|
||
|
||
# 定义参数搜索空间
|
||
param_space = {
|
||
'min_pct_change': [2.0, 2.5, 3.0, 3.5],
|
||
'min_volume_ratio': [1.0, 1.5, 2.0, 2.5],
|
||
'max_ma_gap_pct': [4.0, 5.0, 6.0, 7.0],
|
||
'min_entity_ratio': [0.6, 0.7, 0.8, 0.9],
|
||
'confirm_days': [2, 3, 4, 5]
|
||
}
|
||
|
||
print("\n[1/3] 开始参数优化...")
|
||
print(f"参数空间: {param_space}")
|
||
|
||
# 选择一只股票进行优化
|
||
if config.STOCK_POOL:
|
||
symbol = config.STOCK_POOL[0]
|
||
else:
|
||
symbol = "600252.SH" # 默认股票
|
||
|
||
# 运行网格搜索
|
||
best_params, best_metric = optimizer.grid_search(
|
||
parameter_ranges=param_space,
|
||
symbol=symbol,
|
||
start_date=config.START_DATE,
|
||
end_date=config.END_DATE,
|
||
initial_capital=config.BACKTEST.get('initial_capital', 100000),
|
||
metric='sharpe_ratio'
|
||
)
|
||
|
||
print("\n[2/3] 参数优化完成!")
|
||
print(f"最佳参数组合: {best_params}")
|
||
print(f"最佳夏普比率: {best_metric:.2f}")
|
||
|
||
# 导出结果
|
||
print("\n[3/3] 导出优化结果...")
|
||
optimizer.export_results('optimization_results.csv', strategy_name=strategy_name)
|
||
optimizer.analyze_results()
|
||
|
||
print("\n参数优化模块执行完成!")
|
||
return best_params
|
||
|
||
def run_backtest_multithreaded(config, stock_data, strategy, thread_pool_size):
|
||
"""
|
||
多线程回测
|
||
Args:
|
||
config: 配置信息
|
||
stock_data: 股票数据
|
||
strategy: 策略实例
|
||
thread_pool_size: 线程池大小
|
||
Returns:
|
||
合并后的回测结果
|
||
"""
|
||
# 为每个线程创建独立的回测引擎
|
||
def backtest_single_stock(ts_code, data, strategy):
|
||
engine = BacktestEngine(config.BACKTEST)
|
||
return engine.run_backtest_single_stock(ts_code, data, strategy)
|
||
|
||
# 收集每个线程的回测结果
|
||
stock_results = {}
|
||
|
||
# 使用tqdm显示进度
|
||
from tqdm import tqdm
|
||
|
||
with ThreadPoolExecutor(max_workers=thread_pool_size) as executor:
|
||
# 提交所有任务
|
||
future_to_stock = {executor.submit(backtest_single_stock, ts_code, data, strategy): ts_code
|
||
for ts_code, data in stock_data.items()}
|
||
|
||
# 使用tqdm创建进度条
|
||
with tqdm(total=len(future_to_stock), desc="多线程回测进度", unit="股票") as pbar:
|
||
# 处理完成的任务
|
||
for future in as_completed(future_to_stock):
|
||
ts_code = future_to_stock[future]
|
||
try:
|
||
result = future.result()
|
||
stock_results[ts_code] = result
|
||
pbar.set_postfix_str(f"当前完成: {ts_code}")
|
||
except Exception as e:
|
||
pbar.set_postfix_str(f"{ts_code} 出错: {e}")
|
||
finally:
|
||
# 更新进度条
|
||
pbar.update(1)
|
||
|
||
# 将所有股票的回测结果合并
|
||
# 合并所有交易记录
|
||
all_trades = []
|
||
# 收集所有有效的权益曲线
|
||
all_equity_curves = []
|
||
|
||
for ts_code, result in stock_results.items():
|
||
if 'trades' in result and not result['trades'].empty:
|
||
all_trades.append(result['trades'])
|
||
if 'equity_curve' in result and not result['equity_curve'].empty:
|
||
# 确保权益曲线包含'equity'列
|
||
if 'equity' in result['equity_curve'].columns:
|
||
all_equity_curves.append(result['equity_curve'])
|
||
else:
|
||
print(f"警告: 股票 {ts_code} 的权益曲线缺少'equity'列")
|
||
|
||
if all_trades:
|
||
merged_trades = pd.concat(all_trades, ignore_index=True)
|
||
else:
|
||
merged_trades = pd.DataFrame()
|
||
|
||
# 合并所有权益曲线 - 修复:应该使用平均收益率而不是直接累加权益
|
||
if all_equity_curves:
|
||
# 获取所有日期的集合
|
||
all_dates = pd.DatetimeIndex([])
|
||
for equity_df in all_equity_curves:
|
||
all_dates = all_dates.union(pd.DatetimeIndex(equity_df.index))
|
||
all_dates = all_dates.sort_values()
|
||
|
||
# 修复:计算每日平均收益率,而不是累加权益
|
||
# 使用一个初始资金作为基准
|
||
initial_capital = config.BACKTEST.get('initial_capital', 100000)
|
||
|
||
# 计算每只股票的收益率曲线
|
||
all_returns = []
|
||
for equity_df in all_equity_curves:
|
||
try:
|
||
# 计算该股票的收益率
|
||
stock_returns = equity_df['returns'].reindex(all_dates).fillna(0)
|
||
all_returns.append(stock_returns)
|
||
except KeyError:
|
||
print(f"错误: 权益曲线缺少'returns'列")
|
||
continue
|
||
|
||
if all_returns:
|
||
# 计算平均收益率(等权重组合)
|
||
avg_returns = pd.concat(all_returns, axis=1).mean(axis=1)
|
||
|
||
# 根据平均收益率计算组合权益
|
||
total_equity = pd.Series(initial_capital, index=all_dates)
|
||
for date in all_dates[1:]:
|
||
prev_equity = total_equity.iloc[total_equity.index.get_loc(date) - 1]
|
||
total_equity[date] = prev_equity * (1 + avg_returns[date])
|
||
|
||
# 创建整体权益曲线
|
||
equity_df = pd.DataFrame({
|
||
'equity': total_equity,
|
||
'returns': avg_returns
|
||
})
|
||
else:
|
||
equity_df = pd.DataFrame()
|
||
total_return = annual_return = max_drawdown = sharpe_ratio = 0
|
||
win_rate = avg_win = avg_loss = profit_factor = avg_holding_period = 0
|
||
|
||
# 计算整体绩效指标
|
||
if not equity_df.empty:
|
||
final_capital = equity_df['equity'].iloc[-1]
|
||
total_return = (final_capital / initial_capital - 1) * 100
|
||
# 使用复利计算年化收益率
|
||
annual_return = ((1 + total_return / 100) ** (252 / len(equity_df)) - 1) * 100
|
||
|
||
# 最大回撤
|
||
equity_df['cummax'] = equity_df['equity'].cummax()
|
||
equity_df['drawdown'] = (equity_df['equity'] - equity_df['cummax']) / equity_df['cummax'] * 100
|
||
max_drawdown = equity_df['drawdown'].min()
|
||
|
||
# 夏普比率 (简化)
|
||
if equity_df['returns'].std() > 0:
|
||
sharpe_ratio = equity_df['returns'].mean() / equity_df['returns'].std() * np.sqrt(252)
|
||
else:
|
||
sharpe_ratio = 0
|
||
|
||
# 交易统计
|
||
if not merged_trades.empty:
|
||
win_trades = merged_trades[merged_trades['pnl'] > 0] if 'pnl' in merged_trades.columns else pd.DataFrame()
|
||
loss_trades = merged_trades[merged_trades['pnl'] <= 0] if 'pnl' in merged_trades.columns else pd.DataFrame()
|
||
|
||
win_rate = len(win_trades) / len(merged_trades) * 100 if len(merged_trades) > 0 else 0
|
||
avg_win = win_trades['pnl_pct'].mean() * 100 if not win_trades.empty else 0
|
||
avg_loss = loss_trades['pnl_pct'].mean() * 100 if not loss_trades.empty else 0
|
||
profit_factor = abs(win_trades['pnl'].sum() / loss_trades['pnl'].sum()) if not loss_trades.empty and loss_trades['pnl'].sum() != 0 else 0
|
||
avg_holding_period = 0 # 这个指标在多线程情况下暂时无法计算
|
||
else:
|
||
win_rate = avg_win = avg_loss = profit_factor = 0
|
||
avg_holding_period = 0
|
||
else:
|
||
total_return = annual_return = max_drawdown = sharpe_ratio = 0
|
||
win_rate = avg_win = avg_loss = profit_factor = avg_holding_period = 0
|
||
else:
|
||
equity_df = pd.DataFrame()
|
||
total_return = annual_return = max_drawdown = sharpe_ratio = 0
|
||
win_rate = avg_win = avg_loss = profit_factor = avg_holding_period = 0
|
||
|
||
# 返回完整的回测结果
|
||
final_results = {
|
||
'total_return': total_return,
|
||
'annual_return': annual_return,
|
||
'max_drawdown': max_drawdown,
|
||
'sharpe_ratio': sharpe_ratio,
|
||
'win_rate': win_rate,
|
||
'avg_win_pct': avg_win,
|
||
'avg_loss_pct': avg_loss,
|
||
'profit_factor': profit_factor,
|
||
'avg_holding_period': avg_holding_period,
|
||
'total_trades': len(merged_trades),
|
||
'equity_curve': equity_df,
|
||
'trades': merged_trades,
|
||
'stock_results': stock_results # 保留各股票的原始结果
|
||
}
|
||
|
||
return final_results
|
||
|
||
def main():
|
||
"""主函数"""
|
||
print("=" * 60)
|
||
print("多策略回测系统")
|
||
print("=" * 60)
|
||
|
||
# 1. 加载配置
|
||
config = Config()
|
||
logger.info(f"回测期间: {config.START_DATE} 至 {config.END_DATE}")
|
||
|
||
# 2. 初始化数据获取器
|
||
print("\n[1/4] 初始化数据连接...")
|
||
if config.LOCAL_HQ:
|
||
data_fetcher = LocalDataFetcher(config.DATA_DIR)
|
||
logger.info(f"使用本地行情数据,数据目录: {config.DATA_DIR}")
|
||
else:
|
||
data_fetcher = TushareDataFetcher(
|
||
token=config.TUSHARE_TOKEN,
|
||
call_interval=config.TUSHARE_CALL_INTERVAL
|
||
)
|
||
logger.info("使用在线行情数据")
|
||
|
||
# 3. 运行参数优化(可选)
|
||
if getattr(config, 'RUN_PARAMETER_OPTIMIZATION', False):
|
||
optimized_params = run_parameter_optimization(config, data_fetcher)
|
||
logger.info(f"优化后的参数: {optimized_params}")
|
||
# 可以将优化后的参数保存到配置中
|
||
config.YIYANG_CONDITIONS.update(optimized_params)
|
||
|
||
# 4. 使用股票筛选器获取股票池
|
||
print("\n[2/4] 筛选股票池...")
|
||
stock_filter = StockFilter(data_fetcher)
|
||
|
||
# 获取筛选配置
|
||
filter_config = getattr(config, 'STOCK_FILTER_CONFIG', {
|
||
'include_st': False,
|
||
'include_bse': False,
|
||
'include_sse_star': False,
|
||
'include_gem': True,
|
||
'include_main': True,
|
||
'custom_stocks': config.STOCK_POOL
|
||
})
|
||
|
||
# 获取筛选后的股票池
|
||
selected_stocks = stock_filter.get_stock_pool(filter_config)
|
||
|
||
if not selected_stocks:
|
||
logger.error("股票筛选后无可用股票,程序退出")
|
||
return
|
||
|
||
# 5. 获取数据
|
||
print("\n[3/4] 获取股票数据...")
|
||
stock_data = {}
|
||
|
||
for ts_code in selected_stocks[:10]: # 限制前10只测试
|
||
logger.info(f"获取 {ts_code} 数据...")
|
||
df = data_fetcher.get_daily_data(
|
||
ts_code=ts_code,
|
||
start_date=config.START_DATE,
|
||
end_date=config.END_DATE
|
||
)
|
||
|
||
if not df.empty:
|
||
stock_data[ts_code] = df
|
||
logger.info(f" {ts_code}: {len(df)} 条记录")
|
||
else:
|
||
logger.warning(f" {ts_code}: 无数据")
|
||
|
||
if not stock_data:
|
||
logger.error("未获取到有效数据,程序退出")
|
||
return
|
||
|
||
# 获取上证指数(000001.SH)数据作为基准
|
||
print("\n[4/4] 获取上证指数(000001.SH)数据作为基准...")
|
||
benchmark_data = data_fetcher.get_index_data(
|
||
index_code="000001.SH",
|
||
start_date=config.START_DATE,
|
||
end_date=config.END_DATE
|
||
)
|
||
|
||
if benchmark_data.empty:
|
||
logger.warning("未获取到上证指数数据,将不显示基准对比")
|
||
else:
|
||
logger.info(f" 上证指数: {len(benchmark_data)} 条记录")
|
||
|
||
# 5. 运行回测
|
||
print("\n[5/5] 初始化回测引擎...")
|
||
backtest_config = config.BACKTEST
|
||
engine = BacktestEngine(backtest_config)
|
||
|
||
# 6. 根据配置的策略列表进行回测
|
||
print("\n[6/6] 运行回测...")
|
||
|
||
# 策略类映射,用于动态加载策略
|
||
strategy_class_map = {
|
||
'YiYangStrategy': YiYangStrategy,
|
||
'DualMovingAverageStrategy': DualMovingAverageStrategy,
|
||
# 可以在这里添加更多策略类
|
||
}
|
||
|
||
for strategy_name in config.STRATEGY_LIST:
|
||
if strategy_name not in strategy_class_map:
|
||
logger.warning(f"策略 {strategy_name} 未在系统中实现,跳过回测")
|
||
continue
|
||
|
||
print(f"\n--- 开始回测策略: {strategy_name} ---")
|
||
|
||
try:
|
||
# 初始化策略,根据策略名称加载不同的配置
|
||
if strategy_name == 'YiYangStrategy':
|
||
strategy_config = {
|
||
**config.YIYANG_CONDITIONS,
|
||
'ma_params': config.MA_PARAMS
|
||
}
|
||
elif strategy_name == 'DualMovingAverageStrategy':
|
||
strategy_config = {
|
||
**config.DUAL_MA_CONDITIONS,
|
||
'ma_params': config.MA_PARAMS
|
||
}
|
||
else:
|
||
# 默认配置
|
||
strategy_config = {
|
||
'ma_params': config.MA_PARAMS
|
||
}
|
||
|
||
strategy = strategy_class_map[strategy_name](strategy_config)
|
||
|
||
# 运行回测
|
||
if hasattr(config, 'THREAD_POOL_SIZE') and config.THREAD_POOL_SIZE > 1:
|
||
# 使用多线程回测
|
||
print(f"使用多线程回测,线程数: {config.THREAD_POOL_SIZE}")
|
||
results = run_backtest_multithreaded(config, stock_data, strategy,
|
||
config.THREAD_POOL_SIZE)
|
||
else:
|
||
# 使用单线程回测
|
||
results = engine.run_backtest(stock_data, strategy)
|
||
|
||
# 显示结果
|
||
print("\n" + "=" * 60)
|
||
print(f"{strategy_name} 回测结果摘要")
|
||
print("=" * 60)
|
||
|
||
if results:
|
||
if 'stock_results' in results: # 多线程回测结果
|
||
print(f"\n注意:多股票组合的收益率为等权重平均值,不是简单累加")
|
||
print(f"总收益率: {results['total_return']:.2f}% (基于平均日收益率)")
|
||
print(f"年化收益率: {results['annual_return']:.2f}%")
|
||
print(f"最大回撤: {results['max_drawdown']:.2f}%")
|
||
print(f"夏普比率: {results['sharpe_ratio']:.2f}")
|
||
print(f"总交易次数: {results['total_trades']}")
|
||
print(f"测试股票数量: {len(results['stock_results'])}")
|
||
|
||
# 简单统计各股票的回测结果
|
||
print(f"\n各股票独立回测结果:")
|
||
total_returns = []
|
||
for ts_code, stock_result in results['stock_results'].items():
|
||
if 'total_return' in stock_result:
|
||
total_returns.append(stock_result['total_return'])
|
||
print(f" {ts_code}: {stock_result['total_return']:.2f}%")
|
||
|
||
if total_returns:
|
||
print(f"\n简单算术平均: {sum(total_returns)/len(total_returns):.2f}% (仅供参考,实际为日收益率平均)")
|
||
print(f"最高收益股票: {max(total_returns):.2f}%")
|
||
print(f"最低收益股票: {min(total_returns):.2f}%")
|
||
else: # 单线程回测结果
|
||
print(f"总收益率: {results['total_return']:.2f}%")
|
||
print(f"年化收益率: {results['annual_return']:.2f}%")
|
||
print(f"最大回撤: {results['max_drawdown']:.2f}%")
|
||
print(f"夏普比率: {results['sharpe_ratio']:.2f}")
|
||
print(f"胜率: {results['win_rate']:.2f}%")
|
||
print(f"平均盈利: {results['avg_win_pct']:.2f}%")
|
||
print(f"平均亏损: {results['avg_loss_pct']:.2f}%")
|
||
print(f"盈亏比: {results['profit_factor']:.2f}")
|
||
print(f"平均持仓时间: {results['avg_holding_period']:.2f} 天")
|
||
print(f"总交易次数: {results['total_trades']}")
|
||
|
||
# 生成所有图表
|
||
if results:
|
||
print("\n生成图表...")
|
||
chart_generator = ChartGenerator(strategy_name=strategy_name)
|
||
|
||
# 生成基础图表
|
||
charts = chart_generator.generate_all_charts(results, strategy_name, benchmark_data)
|
||
|
||
# 打印生成的图表
|
||
if charts:
|
||
print("\n生成的图表:")
|
||
for chart_name, chart_path in charts:
|
||
if chart_path:
|
||
print(f" - {chart_name}: {chart_path}")
|
||
|
||
# 保存交易记录为CSV文件
|
||
if 'trades' in results and not results['trades'].empty:
|
||
print("\n保存交易记录...")
|
||
# 创建交易记录目录
|
||
trades_dir = 'trades_records'
|
||
os.makedirs(trades_dir, exist_ok=True)
|
||
|
||
# 保存完整交易记录
|
||
trades_file = os.path.join(trades_dir, f'{strategy_name}_trades.csv')
|
||
results['trades'].to_csv(trades_file, index=False, encoding='utf-8-sig')
|
||
print(f" 完整交易记录已保存至: {trades_file}")
|
||
|
||
# 显示最近交易
|
||
if 'trades' in results and not results['trades'].empty:
|
||
print("\n最近10笔交易:")
|
||
trades_df = results['trades'].tail(10)
|
||
print(trades_df[['date', 'ts_code', 'action', 'price', 'pnl_pct', 'reason']].to_string())
|
||
|
||
except Exception as e:
|
||
logger.error(f"策略 {strategy_name} 回测出错: {e}")
|
||
print(f"策略 {strategy_name} 回测出错: {e}")
|
||
|
||
print("\n所有策略回测完成!")
|
||
|
||
if __name__ == "__main__":
|
||
main() |