649 lines
26 KiB
Python
649 lines
26 KiB
Python
import pandas as pd
|
||
import numpy as np
|
||
from datetime import datetime, timedelta
|
||
from typing import Dict, List, Tuple
|
||
import warnings
|
||
import time
|
||
import sys
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
warnings.filterwarnings('ignore')
|
||
|
||
from utils.logger import setup_logger
|
||
logger = setup_logger()
|
||
|
||
class BacktestEngine:
|
||
"""回测引擎核心类"""
|
||
|
||
def __init__(self, config: Dict):
|
||
self.config = config
|
||
self.initial_capital = config.get('initial_capital', 1000000)
|
||
self.commission_rate = config.get('commission_rate', 0.0003)
|
||
self.stamp_tax_rate = config.get('stamp_tax_rate', 0.001)
|
||
self.slippage_rate = config.get('slippage_rate', 0.001)
|
||
self.position_ratio = config.get('position_ratio', 0.2)
|
||
self.stop_loss = config.get('stop_loss', 0.08)
|
||
self.take_profit = config.get('take_profit', 0.20)
|
||
self.max_positions = config.get('max_positions', 5) # 最大持股个数
|
||
self.max_holding_days = config.get('max_holding_days', 20) # 最大持股天数
|
||
|
||
self.capital = self.initial_capital
|
||
self.positions = {} # 当前持仓
|
||
self.trades = [] # 交易记录
|
||
self.equity_curve = [] # 权益曲线
|
||
|
||
# 【新增】信号缓存,避免重复计算
|
||
self.signal_cache = {} # {ts_code: signal_df}
|
||
|
||
def _initialize(self):
|
||
"""初始化回测引擎状态"""
|
||
self.capital = self.initial_capital
|
||
self.positions = {} # 当前持仓
|
||
self.trades = [] # 交易记录
|
||
self.equity_curve = [] # 权益曲线
|
||
self.signal_cache = {} # 清空缓存
|
||
|
||
def _get_all_dates(self, data_dict):
|
||
"""获取所有数据中的唯一日期列表"""
|
||
all_dates = set()
|
||
for ts_code, data in data_dict.items():
|
||
all_dates.update(data.index)
|
||
return all_dates
|
||
|
||
def run_backtest_single_stock(self, ts_code: str, data: pd.DataFrame,
|
||
strategy, use_multithread: bool = False,
|
||
n_workers: int = 4,
|
||
progress_callback=None) -> Dict:
|
||
"""
|
||
回测单只股票(支持时间分段多线程)
|
||
Args:
|
||
ts_code: 股票代码
|
||
data: 股票数据
|
||
strategy: 策略实例
|
||
use_multithread: 是否使用多线程时间分段
|
||
n_workers: 线程数(时间段数)
|
||
progress_callback: 进度回调函数 callback(current, total)
|
||
Returns:
|
||
回测结果字典
|
||
"""
|
||
if use_multithread and n_workers > 1:
|
||
return self._run_backtest_time_segments(ts_code, data, strategy, n_workers)
|
||
|
||
# 创建单只股票的数据字典
|
||
data_dict = {ts_code: data}
|
||
|
||
# 初始化
|
||
self._initialize()
|
||
|
||
# 按时间顺序回测
|
||
sorted_dates = sorted(data.index)
|
||
total_dates = len(sorted_dates)
|
||
|
||
for i, current_date in enumerate(sorted_dates):
|
||
# 【新增】调用进度回调
|
||
if progress_callback and i % 10 == 0: # 每10天更新一次
|
||
progress_callback(i + 1, total_dates)
|
||
|
||
# 获取当前日期的数据
|
||
hist_data = data[data.index <= current_date]
|
||
if len(hist_data) < 60: # 需要足够的数据
|
||
continue
|
||
|
||
# 1. 检查并处理持仓
|
||
self._check_positions(current_date, data_dict)
|
||
|
||
# 2. 更新权益曲线
|
||
self._update_equity(current_date, data_dict)
|
||
|
||
# 【关键修复】检查账户是否亏光
|
||
if len(self.equity_curve) > 0:
|
||
current_equity = self.equity_curve[-1]['equity']
|
||
if current_equity <= 0:
|
||
print(f"\n警告: 账户资金已亏光! 日期: {current_date}, 剩余权益: {current_equity:.2f}")
|
||
break
|
||
|
||
# 3. 生成新信号
|
||
signals = self._generate_signals(current_date, data_dict, strategy)
|
||
|
||
# 4. 执行交易
|
||
self._execute_trades(current_date, signals, data_dict)
|
||
|
||
# 【新增】最终进度更新
|
||
if progress_callback:
|
||
progress_callback(total_dates, total_dates)
|
||
|
||
# 计算绩效指标
|
||
results = self._calculate_performance()
|
||
|
||
return results
|
||
|
||
def run_backtest(self, data_dict: Dict[str, pd.DataFrame],
|
||
strategy) -> Dict:
|
||
"""
|
||
运行回测(支持多线程)
|
||
Args:
|
||
data_dict: 股票数据字典 {ts_code: df}
|
||
strategy: 策略实例
|
||
Returns:
|
||
回测结果字典
|
||
"""
|
||
# 初始化
|
||
start_time = time.time()
|
||
self._initialize()
|
||
|
||
# 【优化】预计算所有股票的信号(一次性计算)
|
||
print("预计算所有股票信号...")
|
||
precompute_start = time.time()
|
||
self._precompute_signals(data_dict, strategy)
|
||
precompute_time = time.time() - precompute_start
|
||
print(f"信号预计算完成,耗时: {precompute_time:.2f}秒\n")
|
||
|
||
# 按时间顺序回测
|
||
all_dates = self._get_all_dates(data_dict)
|
||
sorted_dates = sorted(all_dates)
|
||
total_dates = len(sorted_dates)
|
||
|
||
step_times = {
|
||
'check_positions': 0,
|
||
'generate_signals': 0,
|
||
'execute_trades': 0,
|
||
'update_equity': 0
|
||
}
|
||
|
||
for i, current_date in enumerate(sorted_dates):
|
||
# 显示进度
|
||
elapsed_time = time.time() - start_time
|
||
remaining_time = (elapsed_time / (i + 1)) * (total_dates - (i + 1)) if i > 0 else 0
|
||
sys.stdout.write(f"\r回测进度: {i+1}/{total_dates} ({(i+1)/total_dates*100:.1f}%) | 已耗时: {elapsed_time:.2f}s | 预计剩余: {remaining_time:.2f}s")
|
||
sys.stdout.flush()
|
||
|
||
# 1. 检查并处理持仓
|
||
step_start = time.time()
|
||
self._check_positions(current_date, data_dict)
|
||
step_times['check_positions'] += time.time() - step_start
|
||
|
||
# 2. 更新权益曲线(在执行新交易前检查)
|
||
step_start = time.time()
|
||
self._update_equity(current_date, data_dict)
|
||
step_times['update_equity'] += time.time() - step_start
|
||
|
||
# 【关键修复】检查账户是否亏光,如果总资产<=0则停止回测
|
||
if len(self.equity_curve) > 0:
|
||
current_equity = self.equity_curve[-1]['equity']
|
||
if current_equity <= 0:
|
||
print(f"\n\n警告: 账户资金已亏光! 日期: {current_date}, 剩余权益: {current_equity:.2f}")
|
||
print("回测终止,不再执行后续交易。")
|
||
break
|
||
|
||
# 3. 生成新信号(从缓存读取)
|
||
step_start = time.time()
|
||
new_signals = self._generate_signals_from_cache(current_date)
|
||
step_times['generate_signals'] += time.time() - step_start
|
||
|
||
# 4. 执行交易
|
||
step_start = time.time()
|
||
self._execute_trades(current_date, new_signals, data_dict)
|
||
step_times['execute_trades'] += time.time() - step_start
|
||
|
||
# 计算绩效指标
|
||
step_start = time.time()
|
||
results = self._calculate_performance()
|
||
calculate_time = time.time() - step_start
|
||
|
||
# 显示总耗时
|
||
total_time = time.time() - start_time
|
||
print(f"\n\n" + "=" * 60)
|
||
print("回测完成!")
|
||
print(f"总耗时: {total_time:.2f}秒")
|
||
print("各步骤耗时:")
|
||
print(f" 预计算信号: {precompute_time:.2f}秒 ({precompute_time/total_time*100:.1f}%)")
|
||
for step, duration in step_times.items():
|
||
print(f" {step}: {duration:.2f}秒 ({duration/total_time*100:.1f}%)")
|
||
print(f" 计算绩效: {calculate_time:.2f}秒 ({calculate_time/total_time*100:.1f}%)")
|
||
print("=" * 60)
|
||
|
||
return results
|
||
|
||
def _precompute_signals(self, data_dict: Dict[str, pd.DataFrame], strategy):
|
||
"""【新增】预计算所有股票的信号"""
|
||
for ts_code, data in data_dict.items():
|
||
if len(data) >= 60:
|
||
# 一次性计算整个周期的信号
|
||
signal_df = strategy.generate_signals(data)
|
||
self.signal_cache[ts_code] = signal_df
|
||
|
||
def _generate_signals_from_cache(self, current_date):
|
||
"""【新增】从缓存中获取当日信号"""
|
||
signals = {}
|
||
|
||
for ts_code, signal_df in self.signal_cache.items():
|
||
if not signal_df.empty and current_date in signal_df.index:
|
||
signal_row = signal_df.loc[current_date]
|
||
if signal_row['signal'] == 1:
|
||
signals[ts_code] = {
|
||
'strength': signal_row['strength'],
|
||
'reason': signal_row['reason']
|
||
}
|
||
|
||
return signals
|
||
|
||
def _generate_signals(self, current_date, data_dict, strategy):
|
||
"""生成交易信号(优化版:批量预计算)"""
|
||
signals = {}
|
||
|
||
# 优化:一次性为所有股票生成信号,避免重复计算
|
||
for ts_code, data in data_dict.items():
|
||
if current_date not in data.index:
|
||
continue
|
||
|
||
# 获取到当前日期的数据
|
||
hist_data = data[data.index <= current_date]
|
||
if len(hist_data) < 60: # 需要足够的数据
|
||
continue
|
||
|
||
# 使用策略生成信号(已向量化)
|
||
signal_df = strategy.generate_signals(hist_data)
|
||
|
||
if not signal_df.empty and current_date in signal_df.index:
|
||
last_signal = signal_df.loc[current_date]
|
||
if last_signal['signal'] == 1:
|
||
signals[ts_code] = {
|
||
'strength': last_signal['strength'],
|
||
'reason': last_signal['reason']
|
||
}
|
||
return signals
|
||
|
||
def _calculate_avg_holding_period(self, trades_df):
|
||
"""计算平均持仓时间"""
|
||
if trades_df is None or len(trades_df) < 2:
|
||
return 0
|
||
|
||
# 按股票代码分组
|
||
grouped = trades_df.groupby('ts_code')
|
||
holding_periods = []
|
||
|
||
for ts_code, group in grouped:
|
||
if group is None or len(group) < 2:
|
||
continue
|
||
|
||
# 按日期排序
|
||
group = group.sort_values('date')
|
||
|
||
# 匹配买入和卖出交易
|
||
buy_idx = 0
|
||
sell_idx = 0
|
||
|
||
while buy_idx < len(group) and sell_idx < len(group):
|
||
# 找到下一个买入交易
|
||
while buy_idx < len(group) and group.iloc[buy_idx]['action'] != 'BUY':
|
||
buy_idx += 1
|
||
|
||
if buy_idx >= len(group):
|
||
break
|
||
|
||
# 找到对应的卖出交易
|
||
sell_idx = buy_idx + 1
|
||
while sell_idx < len(group) and group.iloc[sell_idx]['action'] != 'SELL':
|
||
sell_idx += 1
|
||
|
||
if sell_idx >= len(group):
|
||
break
|
||
|
||
# 计算持仓时间(天)
|
||
buy_date = group.iloc[buy_idx]['date']
|
||
sell_date = group.iloc[sell_idx]['date']
|
||
holding_days = (sell_date - buy_date).days
|
||
|
||
if holding_days > 0:
|
||
holding_periods.append(holding_days)
|
||
|
||
# 移动到下一对交易
|
||
buy_idx = sell_idx + 1
|
||
sell_idx = buy_idx
|
||
|
||
# 计算平均持仓时间
|
||
if holding_periods:
|
||
return sum(holding_periods) / len(holding_periods)
|
||
else:
|
||
return 0
|
||
|
||
def generate_report(self):
|
||
"""生成回测报告"""
|
||
performance = self._calculate_performance()
|
||
|
||
if not performance:
|
||
print("没有回测数据可生成报告")
|
||
return
|
||
|
||
# 打印回测结果摘要
|
||
print("\n" + "=" * 60)
|
||
print(f"{self.strategy_name} 回测结果摘要")
|
||
print("=" * 60)
|
||
print(f"总收益率: {performance['total_return']:.2f}%")
|
||
print(f"年化收益率: {performance['annual_return']:.2f}%")
|
||
print(f"最大回撤: {performance['max_drawdown']:.2f}%")
|
||
print(f"夏普比率: {performance['sharpe_ratio']:.2f}")
|
||
print(f"胜率: {performance['win_rate']:.2f}%")
|
||
print(f"平均盈利: {performance['avg_win_pct']:.2f}%")
|
||
print(f"平均亏损: {performance['avg_loss_pct']:.2f}%")
|
||
print(f"盈亏比: {performance['profit_factor']:.2f}")
|
||
print(f"平均持仓时间: {performance['avg_holding_period']:.2f} 天")
|
||
print(f"总交易次数: {performance['total_trades']}")
|
||
|
||
def _execute_trades(self, current_date, new_signals, data_dict):
|
||
"""执行交易【修复:使用max_positions限制】"""
|
||
# 按信号强度排序
|
||
sorted_signals = sorted(new_signals.items(),
|
||
key=lambda x: x[1]['strength'],
|
||
reverse=True)
|
||
|
||
# 【关键修复】计算可用资金,防止负资金买入
|
||
available_capital = self.capital
|
||
for pos in self.positions.values():
|
||
available_capital -= pos['cost_value']
|
||
|
||
# 如果可用资金<=0,直接返回不执行买入
|
||
if available_capital <= 0:
|
||
return
|
||
|
||
# 【新增】检查持仓个数限制
|
||
current_position_count = len(self.positions)
|
||
max_new_positions = self.max_positions - current_position_count
|
||
|
||
if max_new_positions <= 0:
|
||
return # 已达持仓上限,不再买入
|
||
|
||
# 执行买入(按最大持股个数限制)
|
||
for ts_code, signal_info in sorted_signals[:max_new_positions]:
|
||
if ts_code in self.positions:
|
||
continue # 已持有
|
||
|
||
if available_capital <= 0:
|
||
break
|
||
|
||
if current_date in data_dict[ts_code].index:
|
||
price = data_dict[ts_code].loc[current_date, 'close']
|
||
buy_price = price * (1 + self.slippage_rate)
|
||
|
||
# 计算买入数量
|
||
position_value = available_capital * self.position_ratio
|
||
shares = int(position_value / (buy_price * 100)) * 100
|
||
|
||
if shares > 0:
|
||
commission = buy_price * shares * self.commission_rate
|
||
cost = buy_price * shares + commission
|
||
|
||
if cost <= available_capital:
|
||
self.positions[ts_code] = {
|
||
'entry_date': current_date,
|
||
'entry_price': buy_price,
|
||
'shares': shares,
|
||
'cost_value': cost,
|
||
'stop_loss_price': buy_price * (1 - self.stop_loss),
|
||
'take_profit_price': buy_price * (1 + self.take_profit),
|
||
'last_price': buy_price
|
||
}
|
||
|
||
# 更新资金
|
||
self.capital -= cost
|
||
|
||
self.trades.append({
|
||
'date': current_date,
|
||
'ts_code': ts_code,
|
||
'action': 'BUY',
|
||
'price': buy_price,
|
||
'shares': shares,
|
||
'commission': commission,
|
||
'reason': signal_info['reason'],
|
||
'balance': self.capital # 交易后资金余额
|
||
})
|
||
|
||
available_capital -= cost
|
||
|
||
def _check_positions(self, current_date, data_dict):
|
||
"""检查持仓,处理止损止盈【新增最大持股天数】"""
|
||
positions_to_remove = []
|
||
|
||
for ts_code, position in self.positions.items():
|
||
if ts_code in data_dict and current_date in data_dict[ts_code].index:
|
||
current_price = data_dict[ts_code].loc[current_date, 'close']
|
||
position['last_price'] = current_price
|
||
|
||
# 【新增】检查持股天数
|
||
holding_days = (current_date - position['entry_date']).days
|
||
if holding_days >= self.max_holding_days:
|
||
positions_to_remove.append((ts_code, f'MAX_HOLDING_DAYS_{holding_days}', current_price))
|
||
continue # 跳过其他检查
|
||
|
||
# 检查止损
|
||
if current_price <= position['stop_loss_price']:
|
||
positions_to_remove.append((ts_code, 'STOP_LOSS', current_price))
|
||
|
||
# 检查止盈
|
||
elif current_price >= position['take_profit_price']:
|
||
positions_to_remove.append((ts_code, 'TAKE_PROFIT', current_price))
|
||
|
||
# 动态移动止损 (追踪止盈)
|
||
elif current_price > position['entry_price'] * 1.1: # 盈利超过10%
|
||
new_stop = current_price * (1 - self.stop_loss * 0.8) # 收紧止损
|
||
position['stop_loss_price'] = max(position['stop_loss_price'], new_stop)
|
||
|
||
# 执行卖出
|
||
for ts_code, reason, price in positions_to_remove:
|
||
self._sell_position(current_date, ts_code, price, reason)
|
||
|
||
def _update_equity(self, current_date, data_dict):
|
||
"""更新权益曲线"""
|
||
# 计算当前持仓市值
|
||
positions_value = 0
|
||
for ts_code, position in self.positions.items():
|
||
if ts_code in data_dict and current_date in data_dict[ts_code].index:
|
||
current_price = data_dict[ts_code].loc[current_date, 'close']
|
||
positions_value += current_price * position['shares']
|
||
else:
|
||
# 如果当天没有数据,使用持仓的最后价格
|
||
positions_value += position['last_price'] * position['shares']
|
||
|
||
# 计算总资产
|
||
total_equity = self.capital + positions_value
|
||
|
||
# 添加到权益曲线
|
||
self.equity_curve.append({
|
||
'date': current_date,
|
||
'cash': self.capital,
|
||
'positions_value': positions_value,
|
||
'equity': total_equity
|
||
})
|
||
|
||
def _sell_position(self, current_date, ts_code, price, reason):
|
||
"""卖出持仓"""
|
||
if ts_code in self.positions:
|
||
position = self.positions[ts_code]
|
||
|
||
sell_price = price * (1 - self.slippage_rate)
|
||
sell_value = sell_price * position['shares']
|
||
|
||
# 计算费用
|
||
commission = sell_value * self.commission_rate
|
||
stamp_tax = sell_value * self.stamp_tax_rate
|
||
total_fee = commission + stamp_tax
|
||
|
||
net_proceeds = sell_value - total_fee
|
||
|
||
# 计算盈亏
|
||
pnl = net_proceeds - position['cost_value']
|
||
pnl_pct = pnl / position['cost_value']
|
||
|
||
# 更新资金
|
||
self.capital += net_proceeds
|
||
|
||
# 记录交易
|
||
self.trades.append({
|
||
'date': current_date,
|
||
'ts_code': ts_code,
|
||
'action': 'SELL',
|
||
'price': sell_price,
|
||
'shares': position['shares'],
|
||
'commission': total_fee,
|
||
'pnl': pnl,
|
||
'pnl_pct': pnl_pct,
|
||
'reason': reason,
|
||
'balance': self.capital # 交易后资金余额
|
||
})
|
||
|
||
# 移除持仓
|
||
del self.positions[ts_code]
|
||
|
||
def _calculate_performance(self) -> Dict:
|
||
"""计算绩效指标"""
|
||
if not self.equity_curve:
|
||
return {}
|
||
|
||
equity_df = pd.DataFrame(self.equity_curve)
|
||
equity_df.set_index('date', inplace=True)
|
||
equity_df['returns'] = equity_df['equity'].pct_change()
|
||
|
||
# 计算基本指标
|
||
total_return = (equity_df['equity'].iloc[-1] / self.initial_capital - 1) * 100
|
||
# 使用复利计算年化收益率
|
||
annual_return = ((1 + total_return / 100) ** (252 / len(equity_df)) - 1) * 100
|
||
|
||
# 最大回撤
|
||
equity_df['cummax'] = equity_df['equity'].cummax()
|
||
equity_df['drawdown'] = (equity_df['equity'] - equity_df['cummax']) / equity_df['cummax'] * 100
|
||
max_drawdown = equity_df['drawdown'].min()
|
||
|
||
# 夏普比率 (简化)
|
||
if equity_df['returns'].std() > 0:
|
||
sharpe_ratio = equity_df['returns'].mean() / equity_df['returns'].std() * np.sqrt(252)
|
||
else:
|
||
sharpe_ratio = 0
|
||
|
||
# 交易统计
|
||
trades_df = pd.DataFrame(self.trades)
|
||
if not trades_df.empty:
|
||
win_trades = trades_df[trades_df['pnl'] > 0] if 'pnl' in trades_df.columns else pd.DataFrame()
|
||
loss_trades = trades_df[trades_df['pnl'] <= 0] if 'pnl' in trades_df.columns else pd.DataFrame()
|
||
|
||
win_rate = len(win_trades) / len(trades_df) * 100 if len(trades_df) > 0 else 0
|
||
avg_win = win_trades['pnl_pct'].mean() * 100 if not win_trades.empty else 0
|
||
avg_loss = loss_trades['pnl_pct'].mean() * 100 if not loss_trades.empty else 0
|
||
profit_factor = abs(win_trades['pnl'].sum() / loss_trades['pnl'].sum()) if not loss_trades.empty and loss_trades['pnl'].sum() != 0 else 0
|
||
|
||
# 计算平均持仓时间
|
||
avg_holding_period = self._calculate_avg_holding_period(trades_df)
|
||
else:
|
||
win_rate = avg_win = avg_loss = profit_factor = avg_holding_period = 0
|
||
|
||
return {
|
||
'total_return': total_return,
|
||
'annual_return': annual_return,
|
||
'max_drawdown': max_drawdown,
|
||
'sharpe_ratio': sharpe_ratio,
|
||
'win_rate': win_rate,
|
||
'avg_win_pct': avg_win,
|
||
'avg_loss_pct': avg_loss,
|
||
'profit_factor': profit_factor,
|
||
'avg_holding_period': avg_holding_period,
|
||
'total_trades': len(trades_df),
|
||
'equity_curve': equity_df,
|
||
'trades': trades_df
|
||
}
|
||
|
||
def _run_backtest_time_segments(self, ts_code: str, data: pd.DataFrame,
|
||
strategy, n_workers: int) -> Dict:
|
||
"""
|
||
【新增】时间分段多线程回测(单股加速)
|
||
原理:将时间轴分N段,每段并行回测,最后按时间顺序合并
|
||
注意:每段独立计算,不会共享持仓状态
|
||
"""
|
||
print(f"\n单股多线程回测: {ts_code}, 时间分段数: {n_workers}")
|
||
|
||
sorted_dates = sorted(data.index)
|
||
total_days = len(sorted_dates)
|
||
segment_size = total_days // n_workers
|
||
|
||
# 分割时间段
|
||
segments = []
|
||
for i in range(n_workers):
|
||
start_idx = i * segment_size
|
||
end_idx = (i + 1) * segment_size if i < n_workers - 1 else total_days
|
||
segment_dates = sorted_dates[start_idx:end_idx]
|
||
segments.append((i, segment_dates, data.loc[segment_dates]))
|
||
|
||
# 并行执行各段回测
|
||
segment_results = []
|
||
start_time = time.time()
|
||
|
||
with ThreadPoolExecutor(max_workers=n_workers) as executor:
|
||
future_to_segment = {}
|
||
for seg_id, seg_dates, seg_data in segments:
|
||
future = executor.submit(self._run_segment_backtest,
|
||
ts_code, seg_data, strategy, seg_id)
|
||
future_to_segment[future] = seg_id
|
||
|
||
for future in as_completed(future_to_segment):
|
||
seg_id = future_to_segment[future]
|
||
try:
|
||
result = future.result()
|
||
segment_results.append((seg_id, result))
|
||
print(f"\r段 {seg_id+1}/{n_workers} 完成", end="")
|
||
except Exception as e:
|
||
print(f"\n段 {seg_id} 失败: {e}")
|
||
|
||
# 按seg_id排序
|
||
segment_results.sort(key=lambda x: x[0])
|
||
|
||
# 合并结果
|
||
merged_result = self._merge_segment_results([r[1] for r in segment_results])
|
||
|
||
elapsed = time.time() - start_time
|
||
print(f"\n时间分段回测完成! 总耗时: {elapsed:.2f}s")
|
||
|
||
return merged_result
|
||
|
||
def _run_segment_backtest(self, ts_code: str, seg_data: pd.DataFrame,
|
||
strategy, seg_id: int) -> Dict:
|
||
"""执行单个时间段的回测"""
|
||
# 创建独立引擎实例(避免状态冲突)
|
||
engine = BacktestEngine(self.config)
|
||
data_dict = {ts_code: seg_data}
|
||
|
||
for current_date in sorted(seg_data.index):
|
||
hist_data = seg_data[seg_data.index <= current_date]
|
||
if len(hist_data) < 60:
|
||
continue
|
||
|
||
engine._check_positions(current_date, data_dict)
|
||
engine._update_equity(current_date, data_dict)
|
||
|
||
# 检查亏光
|
||
if len(engine.equity_curve) > 0 and engine.equity_curve[-1]['equity'] <= 0:
|
||
break
|
||
|
||
signals = engine._generate_signals(current_date, data_dict, strategy)
|
||
engine._execute_trades(current_date, signals, data_dict)
|
||
|
||
# 返回该段的交易和权益
|
||
return {
|
||
'trades': engine.trades,
|
||
'equity_curve': engine.equity_curve,
|
||
'final_capital': engine.capital
|
||
}
|
||
|
||
def _merge_segment_results(self, segment_results: List[Dict]) -> Dict:
|
||
"""合并多个时间段的回测结果"""
|
||
# 合并交易记录
|
||
all_trades = []
|
||
for result in segment_results:
|
||
all_trades.extend(result['trades'])
|
||
|
||
# 合并权益曲线(按时间顺序)
|
||
all_equity = []
|
||
for result in segment_results:
|
||
all_equity.extend(result['equity_curve'])
|
||
|
||
# 重新赋值到self,然后调用_calculate_performance
|
||
self.trades = all_trades
|
||
self.equity_curve = all_equity
|
||
|
||
return self._calculate_performance() |