Files
backtrader/main.py
2026-01-17 21:21:30 +08:00

460 lines
19 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from strategy.yiyang_strategy import YiYangStrategy
from strategy.dual_moving_average_strategy import DualMovingAverageStrategy
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed
from config import Config
from data.tushare_data import TushareDataFetcher
from data.local_data import LocalDataFetcher
from data.stock_filter import StockFilter
from backtest.backtest_engine import BacktestEngine
from utils.logger import setup_logger
from analysis.chart_generator import ChartGenerator
from parameter_optimizer import ParameterOptimizer
logger = setup_logger()
def run_parameter_optimization(config, data_fetcher, strategy_name="YiYangStrategy"):
"""
运行参数优化
Args:
config: 配置对象
data_fetcher: 数据获取器
strategy_name: 策略名称
"""
print("\n" + "=" * 60)
print("参数优化模块")
print("=" * 60)
# 创建参数优化器
optimizer = ParameterOptimizer(data_fetcher)
# 定义参数搜索空间
param_space = {
'min_pct_change': [2.0, 2.5, 3.0, 3.5],
'min_volume_ratio': [1.0, 1.5, 2.0, 2.5],
'max_ma_gap_pct': [4.0, 5.0, 6.0, 7.0],
'min_entity_ratio': [0.6, 0.7, 0.8, 0.9],
'confirm_days': [2, 3, 4, 5]
}
print("\n[1/3] 开始参数优化...")
print(f"参数空间: {param_space}")
# 选择一只股票进行优化
if config.STOCK_POOL:
symbol = config.STOCK_POOL[0]
else:
symbol = "600252.SH" # 默认股票
# 运行网格搜索
best_params, best_metric = optimizer.grid_search(
parameter_ranges=param_space,
symbol=symbol,
start_date=config.START_DATE,
end_date=config.END_DATE,
initial_capital=config.BACKTEST.get('initial_capital', 100000),
metric='sharpe_ratio'
)
print("\n[2/3] 参数优化完成!")
print(f"最佳参数组合: {best_params}")
print(f"最佳夏普比率: {best_metric:.2f}")
# 导出结果
print("\n[3/3] 导出优化结果...")
optimizer.export_results('optimization_results.csv', strategy_name=strategy_name)
optimizer.analyze_results()
print("\n参数优化模块执行完成!")
return best_params
def run_backtest_multithreaded(config, stock_data, strategy, thread_pool_size):
"""
多线程回测
Args:
config: 配置信息
stock_data: 股票数据
strategy: 策略实例
thread_pool_size: 线程池大小
Returns:
合并后的回测结果
"""
# 为每个线程创建独立的回测引擎
def backtest_single_stock(ts_code, data, strategy):
engine = BacktestEngine(config.BACKTEST)
return engine.run_backtest_single_stock(ts_code, data, strategy)
# 收集每个线程的回测结果
stock_results = {}
# 使用tqdm显示进度
from tqdm import tqdm
with ThreadPoolExecutor(max_workers=thread_pool_size) as executor:
# 提交所有任务
future_to_stock = {executor.submit(backtest_single_stock, ts_code, data, strategy): ts_code
for ts_code, data in stock_data.items()}
# 使用tqdm创建进度条
with tqdm(total=len(future_to_stock), desc="多线程回测进度", unit="股票") as pbar:
# 处理完成的任务
for future in as_completed(future_to_stock):
ts_code = future_to_stock[future]
try:
result = future.result()
stock_results[ts_code] = result
pbar.set_postfix_str(f"当前完成: {ts_code}")
except Exception as e:
pbar.set_postfix_str(f"{ts_code} 出错: {e}")
finally:
# 更新进度条
pbar.update(1)
# 将所有股票的回测结果合并
# 合并所有交易记录
all_trades = []
# 收集所有有效的权益曲线
all_equity_curves = []
for ts_code, result in stock_results.items():
if 'trades' in result and not result['trades'].empty:
all_trades.append(result['trades'])
if 'equity_curve' in result and not result['equity_curve'].empty:
# 确保权益曲线包含'equity'列
if 'equity' in result['equity_curve'].columns:
all_equity_curves.append(result['equity_curve'])
else:
print(f"警告: 股票 {ts_code} 的权益曲线缺少'equity'")
if all_trades:
merged_trades = pd.concat(all_trades, ignore_index=True)
else:
merged_trades = pd.DataFrame()
# 合并所有权益曲线 - 修复:应该使用平均收益率而不是直接累加权益
if all_equity_curves:
# 获取所有日期的集合
all_dates = pd.DatetimeIndex([])
for equity_df in all_equity_curves:
all_dates = all_dates.union(pd.DatetimeIndex(equity_df.index))
all_dates = all_dates.sort_values()
# 修复:计算每日平均收益率,而不是累加权益
# 使用一个初始资金作为基准
initial_capital = config.BACKTEST.get('initial_capital', 100000)
# 计算每只股票的收益率曲线
all_returns = []
for equity_df in all_equity_curves:
try:
# 计算该股票的收益率
stock_returns = equity_df['returns'].reindex(all_dates).fillna(0)
all_returns.append(stock_returns)
except KeyError:
print(f"错误: 权益曲线缺少'returns'")
continue
if all_returns:
# 计算平均收益率(等权重组合)
avg_returns = pd.concat(all_returns, axis=1).mean(axis=1)
# 根据平均收益率计算组合权益
total_equity = pd.Series(initial_capital, index=all_dates)
for date in all_dates[1:]:
prev_equity = total_equity.iloc[total_equity.index.get_loc(date) - 1]
total_equity[date] = prev_equity * (1 + avg_returns[date])
# 创建整体权益曲线
equity_df = pd.DataFrame({
'equity': total_equity,
'returns': avg_returns
})
else:
equity_df = pd.DataFrame()
total_return = annual_return = max_drawdown = sharpe_ratio = 0
win_rate = avg_win = avg_loss = profit_factor = avg_holding_period = 0
# 计算整体绩效指标
if not equity_df.empty:
final_capital = equity_df['equity'].iloc[-1]
total_return = (final_capital / initial_capital - 1) * 100
# 使用复利计算年化收益率
annual_return = ((1 + total_return / 100) ** (252 / len(equity_df)) - 1) * 100
# 最大回撤
equity_df['cummax'] = equity_df['equity'].cummax()
equity_df['drawdown'] = (equity_df['equity'] - equity_df['cummax']) / equity_df['cummax'] * 100
max_drawdown = equity_df['drawdown'].min()
# 夏普比率 (简化)
if equity_df['returns'].std() > 0:
sharpe_ratio = equity_df['returns'].mean() / equity_df['returns'].std() * np.sqrt(252)
else:
sharpe_ratio = 0
# 交易统计
if not merged_trades.empty:
win_trades = merged_trades[merged_trades['pnl'] > 0] if 'pnl' in merged_trades.columns else pd.DataFrame()
loss_trades = merged_trades[merged_trades['pnl'] <= 0] if 'pnl' in merged_trades.columns else pd.DataFrame()
win_rate = len(win_trades) / len(merged_trades) * 100 if len(merged_trades) > 0 else 0
avg_win = win_trades['pnl_pct'].mean() * 100 if not win_trades.empty else 0
avg_loss = loss_trades['pnl_pct'].mean() * 100 if not loss_trades.empty else 0
profit_factor = abs(win_trades['pnl'].sum() / loss_trades['pnl'].sum()) if not loss_trades.empty and loss_trades['pnl'].sum() != 0 else 0
avg_holding_period = 0 # 这个指标在多线程情况下暂时无法计算
else:
win_rate = avg_win = avg_loss = profit_factor = 0
avg_holding_period = 0
else:
total_return = annual_return = max_drawdown = sharpe_ratio = 0
win_rate = avg_win = avg_loss = profit_factor = avg_holding_period = 0
else:
equity_df = pd.DataFrame()
total_return = annual_return = max_drawdown = sharpe_ratio = 0
win_rate = avg_win = avg_loss = profit_factor = avg_holding_period = 0
# 返回完整的回测结果
final_results = {
'total_return': total_return,
'annual_return': annual_return,
'max_drawdown': max_drawdown,
'sharpe_ratio': sharpe_ratio,
'win_rate': win_rate,
'avg_win_pct': avg_win,
'avg_loss_pct': avg_loss,
'profit_factor': profit_factor,
'avg_holding_period': avg_holding_period,
'total_trades': len(merged_trades),
'equity_curve': equity_df,
'trades': merged_trades,
'stock_results': stock_results # 保留各股票的原始结果
}
return final_results
def main():
"""主函数"""
print("=" * 60)
print("多策略回测系统")
print("=" * 60)
# 1. 加载配置
config = Config()
logger.info(f"回测期间: {config.START_DATE}{config.END_DATE}")
# 2. 初始化数据获取器
print("\n[1/4] 初始化数据连接...")
if config.LOCAL_HQ:
data_fetcher = LocalDataFetcher(config.DATA_DIR)
logger.info(f"使用本地行情数据,数据目录: {config.DATA_DIR}")
else:
data_fetcher = TushareDataFetcher(
token=config.TUSHARE_TOKEN,
call_interval=config.TUSHARE_CALL_INTERVAL
)
logger.info("使用在线行情数据")
# 3. 运行参数优化(可选)
if getattr(config, 'RUN_PARAMETER_OPTIMIZATION', False):
optimized_params = run_parameter_optimization(config, data_fetcher)
logger.info(f"优化后的参数: {optimized_params}")
# 可以将优化后的参数保存到配置中
config.YIYANG_CONDITIONS.update(optimized_params)
# 4. 使用股票筛选器获取股票池
print("\n[2/4] 筛选股票池...")
stock_filter = StockFilter(data_fetcher)
# 获取筛选配置
filter_config = getattr(config, 'STOCK_FILTER_CONFIG', {
'include_st': False,
'include_bse': False,
'include_sse_star': False,
'include_gem': True,
'include_main': True,
'custom_stocks': config.STOCK_POOL
})
# 获取筛选后的股票池
selected_stocks = stock_filter.get_stock_pool(filter_config)
if not selected_stocks:
logger.error("股票筛选后无可用股票,程序退出")
return
# 5. 获取数据
print("\n[3/4] 获取股票数据...")
stock_data = {}
for ts_code in selected_stocks[:10]: # 限制前10只测试
logger.info(f"获取 {ts_code} 数据...")
df = data_fetcher.get_daily_data(
ts_code=ts_code,
start_date=config.START_DATE,
end_date=config.END_DATE
)
if not df.empty:
stock_data[ts_code] = df
logger.info(f" {ts_code}: {len(df)} 条记录")
else:
logger.warning(f" {ts_code}: 无数据")
if not stock_data:
logger.error("未获取到有效数据,程序退出")
return
# 获取上证指数000001.SH数据作为基准
print("\n[4/4] 获取上证指数000001.SH数据作为基准...")
benchmark_data = data_fetcher.get_index_data(
index_code="000001.SH",
start_date=config.START_DATE,
end_date=config.END_DATE
)
if benchmark_data.empty:
logger.warning("未获取到上证指数数据,将不显示基准对比")
else:
logger.info(f" 上证指数: {len(benchmark_data)} 条记录")
# 5. 运行回测
print("\n[5/5] 初始化回测引擎...")
backtest_config = config.BACKTEST
engine = BacktestEngine(backtest_config)
# 6. 根据配置的策略列表进行回测
print("\n[6/6] 运行回测...")
# 策略类映射,用于动态加载策略
strategy_class_map = {
'YiYangStrategy': YiYangStrategy,
'DualMovingAverageStrategy': DualMovingAverageStrategy,
# 可以在这里添加更多策略类
}
for strategy_name in config.STRATEGY_LIST:
if strategy_name not in strategy_class_map:
logger.warning(f"策略 {strategy_name} 未在系统中实现,跳过回测")
continue
print(f"\n--- 开始回测策略: {strategy_name} ---")
try:
# 初始化策略,根据策略名称加载不同的配置
if strategy_name == 'YiYangStrategy':
strategy_config = {
**config.YIYANG_CONDITIONS,
'ma_params': config.MA_PARAMS
}
elif strategy_name == 'DualMovingAverageStrategy':
strategy_config = {
**config.DUAL_MA_CONDITIONS,
'ma_params': config.MA_PARAMS
}
else:
# 默认配置
strategy_config = {
'ma_params': config.MA_PARAMS
}
strategy = strategy_class_map[strategy_name](strategy_config)
# 运行回测
if hasattr(config, 'THREAD_POOL_SIZE') and config.THREAD_POOL_SIZE > 1:
# 使用多线程回测
print(f"使用多线程回测,线程数: {config.THREAD_POOL_SIZE}")
results = run_backtest_multithreaded(config, stock_data, strategy,
config.THREAD_POOL_SIZE)
else:
# 使用单线程回测
results = engine.run_backtest(stock_data, strategy)
# 显示结果
print("\n" + "=" * 60)
print(f"{strategy_name} 回测结果摘要")
print("=" * 60)
if results:
if 'stock_results' in results: # 多线程回测结果
print(f"\n注意:多股票组合的收益率为等权重平均值,不是简单累加")
print(f"总收益率: {results['total_return']:.2f}% (基于平均日收益率)")
print(f"年化收益率: {results['annual_return']:.2f}%")
print(f"最大回撤: {results['max_drawdown']:.2f}%")
print(f"夏普比率: {results['sharpe_ratio']:.2f}")
print(f"总交易次数: {results['total_trades']}")
print(f"测试股票数量: {len(results['stock_results'])}")
# 简单统计各股票的回测结果
print(f"\n各股票独立回测结果:")
total_returns = []
for ts_code, stock_result in results['stock_results'].items():
if 'total_return' in stock_result:
total_returns.append(stock_result['total_return'])
print(f" {ts_code}: {stock_result['total_return']:.2f}%")
if total_returns:
print(f"\n简单算术平均: {sum(total_returns)/len(total_returns):.2f}% (仅供参考,实际为日收益率平均)")
print(f"最高收益股票: {max(total_returns):.2f}%")
print(f"最低收益股票: {min(total_returns):.2f}%")
else: # 单线程回测结果
print(f"总收益率: {results['total_return']:.2f}%")
print(f"年化收益率: {results['annual_return']:.2f}%")
print(f"最大回撤: {results['max_drawdown']:.2f}%")
print(f"夏普比率: {results['sharpe_ratio']:.2f}")
print(f"胜率: {results['win_rate']:.2f}%")
print(f"平均盈利: {results['avg_win_pct']:.2f}%")
print(f"平均亏损: {results['avg_loss_pct']:.2f}%")
print(f"盈亏比: {results['profit_factor']:.2f}")
print(f"平均持仓时间: {results['avg_holding_period']:.2f}")
print(f"总交易次数: {results['total_trades']}")
# 生成所有图表
if results:
print("\n生成图表...")
chart_generator = ChartGenerator(strategy_name=strategy_name)
# 生成基础图表
charts = chart_generator.generate_all_charts(results, strategy_name, benchmark_data)
# 打印生成的图表
if charts:
print("\n生成的图表:")
for chart_name, chart_path in charts:
if chart_path:
print(f" - {chart_name}: {chart_path}")
# 保存交易记录为CSV文件
if 'trades' in results and not results['trades'].empty:
print("\n保存交易记录...")
# 创建交易记录目录
trades_dir = 'trades_records'
os.makedirs(trades_dir, exist_ok=True)
# 保存完整交易记录
trades_file = os.path.join(trades_dir, f'{strategy_name}_trades.csv')
results['trades'].to_csv(trades_file, index=False, encoding='utf-8-sig')
print(f" 完整交易记录已保存至: {trades_file}")
# 显示最近交易
if 'trades' in results and not results['trades'].empty:
print("\n最近10笔交易:")
trades_df = results['trades'].tail(10)
print(trades_df[['date', 'ts_code', 'action', 'price', 'pnl_pct', 'reason']].to_string())
except Exception as e:
logger.error(f"策略 {strategy_name} 回测出错: {e}")
print(f"策略 {strategy_name} 回测出错: {e}")
print("\n所有策略回测完成!")
if __name__ == "__main__":
main()