Files
backtrader/backtest/backtest_engine.py
2026-01-17 21:21:30 +08:00

649 lines
26 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from typing import Dict, List, Tuple
import warnings
import time
import sys
from concurrent.futures import ThreadPoolExecutor, as_completed
warnings.filterwarnings('ignore')
from utils.logger import setup_logger
logger = setup_logger()
class BacktestEngine:
"""回测引擎核心类"""
def __init__(self, config: Dict):
self.config = config
self.initial_capital = config.get('initial_capital', 1000000)
self.commission_rate = config.get('commission_rate', 0.0003)
self.stamp_tax_rate = config.get('stamp_tax_rate', 0.001)
self.slippage_rate = config.get('slippage_rate', 0.001)
self.position_ratio = config.get('position_ratio', 0.2)
self.stop_loss = config.get('stop_loss', 0.08)
self.take_profit = config.get('take_profit', 0.20)
self.max_positions = config.get('max_positions', 5) # 最大持股个数
self.max_holding_days = config.get('max_holding_days', 20) # 最大持股天数
self.capital = self.initial_capital
self.positions = {} # 当前持仓
self.trades = [] # 交易记录
self.equity_curve = [] # 权益曲线
# 【新增】信号缓存,避免重复计算
self.signal_cache = {} # {ts_code: signal_df}
def _initialize(self):
"""初始化回测引擎状态"""
self.capital = self.initial_capital
self.positions = {} # 当前持仓
self.trades = [] # 交易记录
self.equity_curve = [] # 权益曲线
self.signal_cache = {} # 清空缓存
def _get_all_dates(self, data_dict):
"""获取所有数据中的唯一日期列表"""
all_dates = set()
for ts_code, data in data_dict.items():
all_dates.update(data.index)
return all_dates
def run_backtest_single_stock(self, ts_code: str, data: pd.DataFrame,
strategy, use_multithread: bool = False,
n_workers: int = 4,
progress_callback=None) -> Dict:
"""
回测单只股票(支持时间分段多线程)
Args:
ts_code: 股票代码
data: 股票数据
strategy: 策略实例
use_multithread: 是否使用多线程时间分段
n_workers: 线程数(时间段数)
progress_callback: 进度回调函数 callback(current, total)
Returns:
回测结果字典
"""
if use_multithread and n_workers > 1:
return self._run_backtest_time_segments(ts_code, data, strategy, n_workers)
# 创建单只股票的数据字典
data_dict = {ts_code: data}
# 初始化
self._initialize()
# 按时间顺序回测
sorted_dates = sorted(data.index)
total_dates = len(sorted_dates)
for i, current_date in enumerate(sorted_dates):
# 【新增】调用进度回调
if progress_callback and i % 10 == 0: # 每10天更新一次
progress_callback(i + 1, total_dates)
# 获取当前日期的数据
hist_data = data[data.index <= current_date]
if len(hist_data) < 60: # 需要足够的数据
continue
# 1. 检查并处理持仓
self._check_positions(current_date, data_dict)
# 2. 更新权益曲线
self._update_equity(current_date, data_dict)
# 【关键修复】检查账户是否亏光
if len(self.equity_curve) > 0:
current_equity = self.equity_curve[-1]['equity']
if current_equity <= 0:
print(f"\n警告: 账户资金已亏光! 日期: {current_date}, 剩余权益: {current_equity:.2f}")
break
# 3. 生成新信号
signals = self._generate_signals(current_date, data_dict, strategy)
# 4. 执行交易
self._execute_trades(current_date, signals, data_dict)
# 【新增】最终进度更新
if progress_callback:
progress_callback(total_dates, total_dates)
# 计算绩效指标
results = self._calculate_performance()
return results
def run_backtest(self, data_dict: Dict[str, pd.DataFrame],
strategy) -> Dict:
"""
运行回测(支持多线程)
Args:
data_dict: 股票数据字典 {ts_code: df}
strategy: 策略实例
Returns:
回测结果字典
"""
# 初始化
start_time = time.time()
self._initialize()
# 【优化】预计算所有股票的信号(一次性计算)
print("预计算所有股票信号...")
precompute_start = time.time()
self._precompute_signals(data_dict, strategy)
precompute_time = time.time() - precompute_start
print(f"信号预计算完成,耗时: {precompute_time:.2f}\n")
# 按时间顺序回测
all_dates = self._get_all_dates(data_dict)
sorted_dates = sorted(all_dates)
total_dates = len(sorted_dates)
step_times = {
'check_positions': 0,
'generate_signals': 0,
'execute_trades': 0,
'update_equity': 0
}
for i, current_date in enumerate(sorted_dates):
# 显示进度
elapsed_time = time.time() - start_time
remaining_time = (elapsed_time / (i + 1)) * (total_dates - (i + 1)) if i > 0 else 0
sys.stdout.write(f"\r回测进度: {i+1}/{total_dates} ({(i+1)/total_dates*100:.1f}%) | 已耗时: {elapsed_time:.2f}s | 预计剩余: {remaining_time:.2f}s")
sys.stdout.flush()
# 1. 检查并处理持仓
step_start = time.time()
self._check_positions(current_date, data_dict)
step_times['check_positions'] += time.time() - step_start
# 2. 更新权益曲线(在执行新交易前检查)
step_start = time.time()
self._update_equity(current_date, data_dict)
step_times['update_equity'] += time.time() - step_start
# 【关键修复】检查账户是否亏光,如果总资产<=0则停止回测
if len(self.equity_curve) > 0:
current_equity = self.equity_curve[-1]['equity']
if current_equity <= 0:
print(f"\n\n警告: 账户资金已亏光! 日期: {current_date}, 剩余权益: {current_equity:.2f}")
print("回测终止,不再执行后续交易。")
break
# 3. 生成新信号(从缓存读取)
step_start = time.time()
new_signals = self._generate_signals_from_cache(current_date)
step_times['generate_signals'] += time.time() - step_start
# 4. 执行交易
step_start = time.time()
self._execute_trades(current_date, new_signals, data_dict)
step_times['execute_trades'] += time.time() - step_start
# 计算绩效指标
step_start = time.time()
results = self._calculate_performance()
calculate_time = time.time() - step_start
# 显示总耗时
total_time = time.time() - start_time
print(f"\n\n" + "=" * 60)
print("回测完成!")
print(f"总耗时: {total_time:.2f}")
print("各步骤耗时:")
print(f" 预计算信号: {precompute_time:.2f}秒 ({precompute_time/total_time*100:.1f}%)")
for step, duration in step_times.items():
print(f" {step}: {duration:.2f}秒 ({duration/total_time*100:.1f}%)")
print(f" 计算绩效: {calculate_time:.2f}秒 ({calculate_time/total_time*100:.1f}%)")
print("=" * 60)
return results
def _precompute_signals(self, data_dict: Dict[str, pd.DataFrame], strategy):
"""【新增】预计算所有股票的信号"""
for ts_code, data in data_dict.items():
if len(data) >= 60:
# 一次性计算整个周期的信号
signal_df = strategy.generate_signals(data)
self.signal_cache[ts_code] = signal_df
def _generate_signals_from_cache(self, current_date):
"""【新增】从缓存中获取当日信号"""
signals = {}
for ts_code, signal_df in self.signal_cache.items():
if not signal_df.empty and current_date in signal_df.index:
signal_row = signal_df.loc[current_date]
if signal_row['signal'] == 1:
signals[ts_code] = {
'strength': signal_row['strength'],
'reason': signal_row['reason']
}
return signals
def _generate_signals(self, current_date, data_dict, strategy):
"""生成交易信号(优化版:批量预计算)"""
signals = {}
# 优化:一次性为所有股票生成信号,避免重复计算
for ts_code, data in data_dict.items():
if current_date not in data.index:
continue
# 获取到当前日期的数据
hist_data = data[data.index <= current_date]
if len(hist_data) < 60: # 需要足够的数据
continue
# 使用策略生成信号(已向量化)
signal_df = strategy.generate_signals(hist_data)
if not signal_df.empty and current_date in signal_df.index:
last_signal = signal_df.loc[current_date]
if last_signal['signal'] == 1:
signals[ts_code] = {
'strength': last_signal['strength'],
'reason': last_signal['reason']
}
return signals
def _calculate_avg_holding_period(self, trades_df):
"""计算平均持仓时间"""
if trades_df is None or len(trades_df) < 2:
return 0
# 按股票代码分组
grouped = trades_df.groupby('ts_code')
holding_periods = []
for ts_code, group in grouped:
if group is None or len(group) < 2:
continue
# 按日期排序
group = group.sort_values('date')
# 匹配买入和卖出交易
buy_idx = 0
sell_idx = 0
while buy_idx < len(group) and sell_idx < len(group):
# 找到下一个买入交易
while buy_idx < len(group) and group.iloc[buy_idx]['action'] != 'BUY':
buy_idx += 1
if buy_idx >= len(group):
break
# 找到对应的卖出交易
sell_idx = buy_idx + 1
while sell_idx < len(group) and group.iloc[sell_idx]['action'] != 'SELL':
sell_idx += 1
if sell_idx >= len(group):
break
# 计算持仓时间(天)
buy_date = group.iloc[buy_idx]['date']
sell_date = group.iloc[sell_idx]['date']
holding_days = (sell_date - buy_date).days
if holding_days > 0:
holding_periods.append(holding_days)
# 移动到下一对交易
buy_idx = sell_idx + 1
sell_idx = buy_idx
# 计算平均持仓时间
if holding_periods:
return sum(holding_periods) / len(holding_periods)
else:
return 0
def generate_report(self):
"""生成回测报告"""
performance = self._calculate_performance()
if not performance:
print("没有回测数据可生成报告")
return
# 打印回测结果摘要
print("\n" + "=" * 60)
print(f"{self.strategy_name} 回测结果摘要")
print("=" * 60)
print(f"总收益率: {performance['total_return']:.2f}%")
print(f"年化收益率: {performance['annual_return']:.2f}%")
print(f"最大回撤: {performance['max_drawdown']:.2f}%")
print(f"夏普比率: {performance['sharpe_ratio']:.2f}")
print(f"胜率: {performance['win_rate']:.2f}%")
print(f"平均盈利: {performance['avg_win_pct']:.2f}%")
print(f"平均亏损: {performance['avg_loss_pct']:.2f}%")
print(f"盈亏比: {performance['profit_factor']:.2f}")
print(f"平均持仓时间: {performance['avg_holding_period']:.2f}")
print(f"总交易次数: {performance['total_trades']}")
def _execute_trades(self, current_date, new_signals, data_dict):
"""执行交易【修复使用max_positions限制】"""
# 按信号强度排序
sorted_signals = sorted(new_signals.items(),
key=lambda x: x[1]['strength'],
reverse=True)
# 【关键修复】计算可用资金,防止负资金买入
available_capital = self.capital
for pos in self.positions.values():
available_capital -= pos['cost_value']
# 如果可用资金<=0直接返回不执行买入
if available_capital <= 0:
return
# 【新增】检查持仓个数限制
current_position_count = len(self.positions)
max_new_positions = self.max_positions - current_position_count
if max_new_positions <= 0:
return # 已达持仓上限,不再买入
# 执行买入(按最大持股个数限制)
for ts_code, signal_info in sorted_signals[:max_new_positions]:
if ts_code in self.positions:
continue # 已持有
if available_capital <= 0:
break
if current_date in data_dict[ts_code].index:
price = data_dict[ts_code].loc[current_date, 'close']
buy_price = price * (1 + self.slippage_rate)
# 计算买入数量
position_value = available_capital * self.position_ratio
shares = int(position_value / (buy_price * 100)) * 100
if shares > 0:
commission = buy_price * shares * self.commission_rate
cost = buy_price * shares + commission
if cost <= available_capital:
self.positions[ts_code] = {
'entry_date': current_date,
'entry_price': buy_price,
'shares': shares,
'cost_value': cost,
'stop_loss_price': buy_price * (1 - self.stop_loss),
'take_profit_price': buy_price * (1 + self.take_profit),
'last_price': buy_price
}
# 更新资金
self.capital -= cost
self.trades.append({
'date': current_date,
'ts_code': ts_code,
'action': 'BUY',
'price': buy_price,
'shares': shares,
'commission': commission,
'reason': signal_info['reason'],
'balance': self.capital # 交易后资金余额
})
available_capital -= cost
def _check_positions(self, current_date, data_dict):
"""检查持仓,处理止损止盈【新增最大持股天数】"""
positions_to_remove = []
for ts_code, position in self.positions.items():
if ts_code in data_dict and current_date in data_dict[ts_code].index:
current_price = data_dict[ts_code].loc[current_date, 'close']
position['last_price'] = current_price
# 【新增】检查持股天数
holding_days = (current_date - position['entry_date']).days
if holding_days >= self.max_holding_days:
positions_to_remove.append((ts_code, f'MAX_HOLDING_DAYS_{holding_days}', current_price))
continue # 跳过其他检查
# 检查止损
if current_price <= position['stop_loss_price']:
positions_to_remove.append((ts_code, 'STOP_LOSS', current_price))
# 检查止盈
elif current_price >= position['take_profit_price']:
positions_to_remove.append((ts_code, 'TAKE_PROFIT', current_price))
# 动态移动止损 (追踪止盈)
elif current_price > position['entry_price'] * 1.1: # 盈利超过10%
new_stop = current_price * (1 - self.stop_loss * 0.8) # 收紧止损
position['stop_loss_price'] = max(position['stop_loss_price'], new_stop)
# 执行卖出
for ts_code, reason, price in positions_to_remove:
self._sell_position(current_date, ts_code, price, reason)
def _update_equity(self, current_date, data_dict):
"""更新权益曲线"""
# 计算当前持仓市值
positions_value = 0
for ts_code, position in self.positions.items():
if ts_code in data_dict and current_date in data_dict[ts_code].index:
current_price = data_dict[ts_code].loc[current_date, 'close']
positions_value += current_price * position['shares']
else:
# 如果当天没有数据,使用持仓的最后价格
positions_value += position['last_price'] * position['shares']
# 计算总资产
total_equity = self.capital + positions_value
# 添加到权益曲线
self.equity_curve.append({
'date': current_date,
'cash': self.capital,
'positions_value': positions_value,
'equity': total_equity
})
def _sell_position(self, current_date, ts_code, price, reason):
"""卖出持仓"""
if ts_code in self.positions:
position = self.positions[ts_code]
sell_price = price * (1 - self.slippage_rate)
sell_value = sell_price * position['shares']
# 计算费用
commission = sell_value * self.commission_rate
stamp_tax = sell_value * self.stamp_tax_rate
total_fee = commission + stamp_tax
net_proceeds = sell_value - total_fee
# 计算盈亏
pnl = net_proceeds - position['cost_value']
pnl_pct = pnl / position['cost_value']
# 更新资金
self.capital += net_proceeds
# 记录交易
self.trades.append({
'date': current_date,
'ts_code': ts_code,
'action': 'SELL',
'price': sell_price,
'shares': position['shares'],
'commission': total_fee,
'pnl': pnl,
'pnl_pct': pnl_pct,
'reason': reason,
'balance': self.capital # 交易后资金余额
})
# 移除持仓
del self.positions[ts_code]
def _calculate_performance(self) -> Dict:
"""计算绩效指标"""
if not self.equity_curve:
return {}
equity_df = pd.DataFrame(self.equity_curve)
equity_df.set_index('date', inplace=True)
equity_df['returns'] = equity_df['equity'].pct_change()
# 计算基本指标
total_return = (equity_df['equity'].iloc[-1] / self.initial_capital - 1) * 100
# 使用复利计算年化收益率
annual_return = ((1 + total_return / 100) ** (252 / len(equity_df)) - 1) * 100
# 最大回撤
equity_df['cummax'] = equity_df['equity'].cummax()
equity_df['drawdown'] = (equity_df['equity'] - equity_df['cummax']) / equity_df['cummax'] * 100
max_drawdown = equity_df['drawdown'].min()
# 夏普比率 (简化)
if equity_df['returns'].std() > 0:
sharpe_ratio = equity_df['returns'].mean() / equity_df['returns'].std() * np.sqrt(252)
else:
sharpe_ratio = 0
# 交易统计
trades_df = pd.DataFrame(self.trades)
if not trades_df.empty:
win_trades = trades_df[trades_df['pnl'] > 0] if 'pnl' in trades_df.columns else pd.DataFrame()
loss_trades = trades_df[trades_df['pnl'] <= 0] if 'pnl' in trades_df.columns else pd.DataFrame()
win_rate = len(win_trades) / len(trades_df) * 100 if len(trades_df) > 0 else 0
avg_win = win_trades['pnl_pct'].mean() * 100 if not win_trades.empty else 0
avg_loss = loss_trades['pnl_pct'].mean() * 100 if not loss_trades.empty else 0
profit_factor = abs(win_trades['pnl'].sum() / loss_trades['pnl'].sum()) if not loss_trades.empty and loss_trades['pnl'].sum() != 0 else 0
# 计算平均持仓时间
avg_holding_period = self._calculate_avg_holding_period(trades_df)
else:
win_rate = avg_win = avg_loss = profit_factor = avg_holding_period = 0
return {
'total_return': total_return,
'annual_return': annual_return,
'max_drawdown': max_drawdown,
'sharpe_ratio': sharpe_ratio,
'win_rate': win_rate,
'avg_win_pct': avg_win,
'avg_loss_pct': avg_loss,
'profit_factor': profit_factor,
'avg_holding_period': avg_holding_period,
'total_trades': len(trades_df),
'equity_curve': equity_df,
'trades': trades_df
}
def _run_backtest_time_segments(self, ts_code: str, data: pd.DataFrame,
strategy, n_workers: int) -> Dict:
"""
【新增】时间分段多线程回测(单股加速)
原理将时间轴分N段每段并行回测最后按时间顺序合并
注意:每段独立计算,不会共享持仓状态
"""
print(f"\n单股多线程回测: {ts_code}, 时间分段数: {n_workers}")
sorted_dates = sorted(data.index)
total_days = len(sorted_dates)
segment_size = total_days // n_workers
# 分割时间段
segments = []
for i in range(n_workers):
start_idx = i * segment_size
end_idx = (i + 1) * segment_size if i < n_workers - 1 else total_days
segment_dates = sorted_dates[start_idx:end_idx]
segments.append((i, segment_dates, data.loc[segment_dates]))
# 并行执行各段回测
segment_results = []
start_time = time.time()
with ThreadPoolExecutor(max_workers=n_workers) as executor:
future_to_segment = {}
for seg_id, seg_dates, seg_data in segments:
future = executor.submit(self._run_segment_backtest,
ts_code, seg_data, strategy, seg_id)
future_to_segment[future] = seg_id
for future in as_completed(future_to_segment):
seg_id = future_to_segment[future]
try:
result = future.result()
segment_results.append((seg_id, result))
print(f"\r{seg_id+1}/{n_workers} 完成", end="")
except Exception as e:
print(f"\n{seg_id} 失败: {e}")
# 按seg_id排序
segment_results.sort(key=lambda x: x[0])
# 合并结果
merged_result = self._merge_segment_results([r[1] for r in segment_results])
elapsed = time.time() - start_time
print(f"\n时间分段回测完成! 总耗时: {elapsed:.2f}s")
return merged_result
def _run_segment_backtest(self, ts_code: str, seg_data: pd.DataFrame,
strategy, seg_id: int) -> Dict:
"""执行单个时间段的回测"""
# 创建独立引擎实例(避免状态冲突)
engine = BacktestEngine(self.config)
data_dict = {ts_code: seg_data}
for current_date in sorted(seg_data.index):
hist_data = seg_data[seg_data.index <= current_date]
if len(hist_data) < 60:
continue
engine._check_positions(current_date, data_dict)
engine._update_equity(current_date, data_dict)
# 检查亏光
if len(engine.equity_curve) > 0 and engine.equity_curve[-1]['equity'] <= 0:
break
signals = engine._generate_signals(current_date, data_dict, strategy)
engine._execute_trades(current_date, signals, data_dict)
# 返回该段的交易和权益
return {
'trades': engine.trades,
'equity_curve': engine.equity_curve,
'final_capital': engine.capital
}
def _merge_segment_results(self, segment_results: List[Dict]) -> Dict:
"""合并多个时间段的回测结果"""
# 合并交易记录
all_trades = []
for result in segment_results:
all_trades.extend(result['trades'])
# 合并权益曲线(按时间顺序)
all_equity = []
for result in segment_results:
all_equity.extend(result['equity_curve'])
# 重新赋值到self然后调用_calculate_performance
self.trades = all_trades
self.equity_curve = all_equity
return self._calculate_performance()