import pandas as pd import numpy as np from datetime import datetime, timedelta from typing import Dict, List, Tuple import warnings import time import sys from concurrent.futures import ThreadPoolExecutor, as_completed warnings.filterwarnings('ignore') from utils.logger import setup_logger logger = setup_logger() class BacktestEngine: """回测引擎核心类""" def __init__(self, config: Dict): self.config = config self.initial_capital = config.get('initial_capital', 1000000) self.commission_rate = config.get('commission_rate', 0.0003) self.stamp_tax_rate = config.get('stamp_tax_rate', 0.001) self.slippage_rate = config.get('slippage_rate', 0.001) self.position_ratio = config.get('position_ratio', 0.2) self.stop_loss = config.get('stop_loss', 0.08) self.take_profit = config.get('take_profit', 0.20) self.max_positions = config.get('max_positions', 5) # 最大持股个数 self.max_holding_days = config.get('max_holding_days', 20) # 最大持股天数 self.capital = self.initial_capital self.positions = {} # 当前持仓 self.trades = [] # 交易记录 self.equity_curve = [] # 权益曲线 # 【新增】信号缓存,避免重复计算 self.signal_cache = {} # {ts_code: signal_df} def _initialize(self): """初始化回测引擎状态""" self.capital = self.initial_capital self.positions = {} # 当前持仓 self.trades = [] # 交易记录 self.equity_curve = [] # 权益曲线 self.signal_cache = {} # 清空缓存 def _get_all_dates(self, data_dict): """获取所有数据中的唯一日期列表""" all_dates = set() for ts_code, data in data_dict.items(): all_dates.update(data.index) return all_dates def run_backtest_single_stock(self, ts_code: str, data: pd.DataFrame, strategy, use_multithread: bool = False, n_workers: int = 4, progress_callback=None) -> Dict: """ 回测单只股票(支持时间分段多线程) Args: ts_code: 股票代码 data: 股票数据 strategy: 策略实例 use_multithread: 是否使用多线程时间分段 n_workers: 线程数(时间段数) progress_callback: 进度回调函数 callback(current, total) Returns: 回测结果字典 """ if use_multithread and n_workers > 1: return self._run_backtest_time_segments(ts_code, data, strategy, n_workers) # 创建单只股票的数据字典 data_dict = {ts_code: data} # 初始化 self._initialize() # 按时间顺序回测 sorted_dates = sorted(data.index) total_dates = len(sorted_dates) for i, current_date in enumerate(sorted_dates): # 【新增】调用进度回调 if progress_callback and i % 10 == 0: # 每10天更新一次 progress_callback(i + 1, total_dates) # 获取当前日期的数据 hist_data = data[data.index <= current_date] if len(hist_data) < 60: # 需要足够的数据 continue # 1. 检查并处理持仓 self._check_positions(current_date, data_dict) # 2. 更新权益曲线 self._update_equity(current_date, data_dict) # 【关键修复】检查账户是否亏光 if len(self.equity_curve) > 0: current_equity = self.equity_curve[-1]['equity'] if current_equity <= 0: print(f"\n警告: 账户资金已亏光! 日期: {current_date}, 剩余权益: {current_equity:.2f}") break # 3. 生成新信号 signals = self._generate_signals(current_date, data_dict, strategy) # 4. 执行交易 self._execute_trades(current_date, signals, data_dict) # 【新增】最终进度更新 if progress_callback: progress_callback(total_dates, total_dates) # 计算绩效指标 results = self._calculate_performance() return results def run_backtest(self, data_dict: Dict[str, pd.DataFrame], strategy) -> Dict: """ 运行回测(支持多线程) Args: data_dict: 股票数据字典 {ts_code: df} strategy: 策略实例 Returns: 回测结果字典 """ # 初始化 start_time = time.time() self._initialize() # 【优化】预计算所有股票的信号(一次性计算) print("预计算所有股票信号...") precompute_start = time.time() self._precompute_signals(data_dict, strategy) precompute_time = time.time() - precompute_start print(f"信号预计算完成,耗时: {precompute_time:.2f}秒\n") # 按时间顺序回测 all_dates = self._get_all_dates(data_dict) sorted_dates = sorted(all_dates) total_dates = len(sorted_dates) step_times = { 'check_positions': 0, 'generate_signals': 0, 'execute_trades': 0, 'update_equity': 0 } for i, current_date in enumerate(sorted_dates): # 显示进度 elapsed_time = time.time() - start_time remaining_time = (elapsed_time / (i + 1)) * (total_dates - (i + 1)) if i > 0 else 0 sys.stdout.write(f"\r回测进度: {i+1}/{total_dates} ({(i+1)/total_dates*100:.1f}%) | 已耗时: {elapsed_time:.2f}s | 预计剩余: {remaining_time:.2f}s") sys.stdout.flush() # 1. 检查并处理持仓 step_start = time.time() self._check_positions(current_date, data_dict) step_times['check_positions'] += time.time() - step_start # 2. 更新权益曲线(在执行新交易前检查) step_start = time.time() self._update_equity(current_date, data_dict) step_times['update_equity'] += time.time() - step_start # 【关键修复】检查账户是否亏光,如果总资产<=0则停止回测 if len(self.equity_curve) > 0: current_equity = self.equity_curve[-1]['equity'] if current_equity <= 0: print(f"\n\n警告: 账户资金已亏光! 日期: {current_date}, 剩余权益: {current_equity:.2f}") print("回测终止,不再执行后续交易。") break # 3. 生成新信号(从缓存读取) step_start = time.time() new_signals = self._generate_signals_from_cache(current_date) step_times['generate_signals'] += time.time() - step_start # 4. 执行交易 step_start = time.time() self._execute_trades(current_date, new_signals, data_dict) step_times['execute_trades'] += time.time() - step_start # 计算绩效指标 step_start = time.time() results = self._calculate_performance() calculate_time = time.time() - step_start # 显示总耗时 total_time = time.time() - start_time print(f"\n\n" + "=" * 60) print("回测完成!") print(f"总耗时: {total_time:.2f}秒") print("各步骤耗时:") print(f" 预计算信号: {precompute_time:.2f}秒 ({precompute_time/total_time*100:.1f}%)") for step, duration in step_times.items(): print(f" {step}: {duration:.2f}秒 ({duration/total_time*100:.1f}%)") print(f" 计算绩效: {calculate_time:.2f}秒 ({calculate_time/total_time*100:.1f}%)") print("=" * 60) return results def _precompute_signals(self, data_dict: Dict[str, pd.DataFrame], strategy): """【新增】预计算所有股票的信号""" for ts_code, data in data_dict.items(): if len(data) >= 60: # 一次性计算整个周期的信号 signal_df = strategy.generate_signals(data) self.signal_cache[ts_code] = signal_df def _generate_signals_from_cache(self, current_date): """【新增】从缓存中获取当日信号""" signals = {} for ts_code, signal_df in self.signal_cache.items(): if not signal_df.empty and current_date in signal_df.index: signal_row = signal_df.loc[current_date] if signal_row['signal'] == 1: signals[ts_code] = { 'strength': signal_row['strength'], 'reason': signal_row['reason'] } return signals def _generate_signals(self, current_date, data_dict, strategy): """生成交易信号(优化版:批量预计算)""" signals = {} # 优化:一次性为所有股票生成信号,避免重复计算 for ts_code, data in data_dict.items(): if current_date not in data.index: continue # 获取到当前日期的数据 hist_data = data[data.index <= current_date] if len(hist_data) < 60: # 需要足够的数据 continue # 使用策略生成信号(已向量化) signal_df = strategy.generate_signals(hist_data) if not signal_df.empty and current_date in signal_df.index: last_signal = signal_df.loc[current_date] if last_signal['signal'] == 1: signals[ts_code] = { 'strength': last_signal['strength'], 'reason': last_signal['reason'] } return signals def _calculate_avg_holding_period(self, trades_df): """计算平均持仓时间""" if trades_df is None or len(trades_df) < 2: return 0 # 按股票代码分组 grouped = trades_df.groupby('ts_code') holding_periods = [] for ts_code, group in grouped: if group is None or len(group) < 2: continue # 按日期排序 group = group.sort_values('date') # 匹配买入和卖出交易 buy_idx = 0 sell_idx = 0 while buy_idx < len(group) and sell_idx < len(group): # 找到下一个买入交易 while buy_idx < len(group) and group.iloc[buy_idx]['action'] != 'BUY': buy_idx += 1 if buy_idx >= len(group): break # 找到对应的卖出交易 sell_idx = buy_idx + 1 while sell_idx < len(group) and group.iloc[sell_idx]['action'] != 'SELL': sell_idx += 1 if sell_idx >= len(group): break # 计算持仓时间(天) buy_date = group.iloc[buy_idx]['date'] sell_date = group.iloc[sell_idx]['date'] holding_days = (sell_date - buy_date).days if holding_days > 0: holding_periods.append(holding_days) # 移动到下一对交易 buy_idx = sell_idx + 1 sell_idx = buy_idx # 计算平均持仓时间 if holding_periods: return sum(holding_periods) / len(holding_periods) else: return 0 def generate_report(self): """生成回测报告""" performance = self._calculate_performance() if not performance: print("没有回测数据可生成报告") return # 打印回测结果摘要 print("\n" + "=" * 60) print(f"{self.strategy_name} 回测结果摘要") print("=" * 60) print(f"总收益率: {performance['total_return']:.2f}%") print(f"年化收益率: {performance['annual_return']:.2f}%") print(f"最大回撤: {performance['max_drawdown']:.2f}%") print(f"夏普比率: {performance['sharpe_ratio']:.2f}") print(f"胜率: {performance['win_rate']:.2f}%") print(f"平均盈利: {performance['avg_win_pct']:.2f}%") print(f"平均亏损: {performance['avg_loss_pct']:.2f}%") print(f"盈亏比: {performance['profit_factor']:.2f}") print(f"平均持仓时间: {performance['avg_holding_period']:.2f} 天") print(f"总交易次数: {performance['total_trades']}") def _execute_trades(self, current_date, new_signals, data_dict): """执行交易【修复:使用max_positions限制】""" # 按信号强度排序 sorted_signals = sorted(new_signals.items(), key=lambda x: x[1]['strength'], reverse=True) # 【关键修复】计算可用资金,防止负资金买入 available_capital = self.capital for pos in self.positions.values(): available_capital -= pos['cost_value'] # 如果可用资金<=0,直接返回不执行买入 if available_capital <= 0: return # 【新增】检查持仓个数限制 current_position_count = len(self.positions) max_new_positions = self.max_positions - current_position_count if max_new_positions <= 0: return # 已达持仓上限,不再买入 # 执行买入(按最大持股个数限制) for ts_code, signal_info in sorted_signals[:max_new_positions]: if ts_code in self.positions: continue # 已持有 if available_capital <= 0: break if current_date in data_dict[ts_code].index: price = data_dict[ts_code].loc[current_date, 'close'] buy_price = price * (1 + self.slippage_rate) # 计算买入数量 position_value = available_capital * self.position_ratio shares = int(position_value / (buy_price * 100)) * 100 if shares > 0: commission = buy_price * shares * self.commission_rate cost = buy_price * shares + commission if cost <= available_capital: self.positions[ts_code] = { 'entry_date': current_date, 'entry_price': buy_price, 'shares': shares, 'cost_value': cost, 'stop_loss_price': buy_price * (1 - self.stop_loss), 'take_profit_price': buy_price * (1 + self.take_profit), 'last_price': buy_price } # 更新资金 self.capital -= cost self.trades.append({ 'date': current_date, 'ts_code': ts_code, 'action': 'BUY', 'price': buy_price, 'shares': shares, 'commission': commission, 'reason': signal_info['reason'], 'balance': self.capital # 交易后资金余额 }) available_capital -= cost def _check_positions(self, current_date, data_dict): """检查持仓,处理止损止盈【新增最大持股天数】""" positions_to_remove = [] for ts_code, position in self.positions.items(): if ts_code in data_dict and current_date in data_dict[ts_code].index: current_price = data_dict[ts_code].loc[current_date, 'close'] position['last_price'] = current_price # 【新增】检查持股天数 holding_days = (current_date - position['entry_date']).days if holding_days >= self.max_holding_days: positions_to_remove.append((ts_code, f'MAX_HOLDING_DAYS_{holding_days}', current_price)) continue # 跳过其他检查 # 检查止损 if current_price <= position['stop_loss_price']: positions_to_remove.append((ts_code, 'STOP_LOSS', current_price)) # 检查止盈 elif current_price >= position['take_profit_price']: positions_to_remove.append((ts_code, 'TAKE_PROFIT', current_price)) # 动态移动止损 (追踪止盈) elif current_price > position['entry_price'] * 1.1: # 盈利超过10% new_stop = current_price * (1 - self.stop_loss * 0.8) # 收紧止损 position['stop_loss_price'] = max(position['stop_loss_price'], new_stop) # 执行卖出 for ts_code, reason, price in positions_to_remove: self._sell_position(current_date, ts_code, price, reason) def _update_equity(self, current_date, data_dict): """更新权益曲线""" # 计算当前持仓市值 positions_value = 0 for ts_code, position in self.positions.items(): if ts_code in data_dict and current_date in data_dict[ts_code].index: current_price = data_dict[ts_code].loc[current_date, 'close'] positions_value += current_price * position['shares'] else: # 如果当天没有数据,使用持仓的最后价格 positions_value += position['last_price'] * position['shares'] # 计算总资产 total_equity = self.capital + positions_value # 添加到权益曲线 self.equity_curve.append({ 'date': current_date, 'cash': self.capital, 'positions_value': positions_value, 'equity': total_equity }) def _sell_position(self, current_date, ts_code, price, reason): """卖出持仓""" if ts_code in self.positions: position = self.positions[ts_code] sell_price = price * (1 - self.slippage_rate) sell_value = sell_price * position['shares'] # 计算费用 commission = sell_value * self.commission_rate stamp_tax = sell_value * self.stamp_tax_rate total_fee = commission + stamp_tax net_proceeds = sell_value - total_fee # 计算盈亏 pnl = net_proceeds - position['cost_value'] pnl_pct = pnl / position['cost_value'] # 更新资金 self.capital += net_proceeds # 记录交易 self.trades.append({ 'date': current_date, 'ts_code': ts_code, 'action': 'SELL', 'price': sell_price, 'shares': position['shares'], 'commission': total_fee, 'pnl': pnl, 'pnl_pct': pnl_pct, 'reason': reason, 'balance': self.capital # 交易后资金余额 }) # 移除持仓 del self.positions[ts_code] def _calculate_performance(self) -> Dict: """计算绩效指标""" if not self.equity_curve: return {} equity_df = pd.DataFrame(self.equity_curve) equity_df.set_index('date', inplace=True) equity_df['returns'] = equity_df['equity'].pct_change() # 计算基本指标 total_return = (equity_df['equity'].iloc[-1] / self.initial_capital - 1) * 100 # 使用复利计算年化收益率 annual_return = ((1 + total_return / 100) ** (252 / len(equity_df)) - 1) * 100 # 最大回撤 equity_df['cummax'] = equity_df['equity'].cummax() equity_df['drawdown'] = (equity_df['equity'] - equity_df['cummax']) / equity_df['cummax'] * 100 max_drawdown = equity_df['drawdown'].min() # 夏普比率 (简化) if equity_df['returns'].std() > 0: sharpe_ratio = equity_df['returns'].mean() / equity_df['returns'].std() * np.sqrt(252) else: sharpe_ratio = 0 # 交易统计 trades_df = pd.DataFrame(self.trades) if not trades_df.empty: win_trades = trades_df[trades_df['pnl'] > 0] if 'pnl' in trades_df.columns else pd.DataFrame() loss_trades = trades_df[trades_df['pnl'] <= 0] if 'pnl' in trades_df.columns else pd.DataFrame() win_rate = len(win_trades) / len(trades_df) * 100 if len(trades_df) > 0 else 0 avg_win = win_trades['pnl_pct'].mean() * 100 if not win_trades.empty else 0 avg_loss = loss_trades['pnl_pct'].mean() * 100 if not loss_trades.empty else 0 profit_factor = abs(win_trades['pnl'].sum() / loss_trades['pnl'].sum()) if not loss_trades.empty and loss_trades['pnl'].sum() != 0 else 0 # 计算平均持仓时间 avg_holding_period = self._calculate_avg_holding_period(trades_df) else: win_rate = avg_win = avg_loss = profit_factor = avg_holding_period = 0 return { 'total_return': total_return, 'annual_return': annual_return, 'max_drawdown': max_drawdown, 'sharpe_ratio': sharpe_ratio, 'win_rate': win_rate, 'avg_win_pct': avg_win, 'avg_loss_pct': avg_loss, 'profit_factor': profit_factor, 'avg_holding_period': avg_holding_period, 'total_trades': len(trades_df), 'equity_curve': equity_df, 'trades': trades_df } def _run_backtest_time_segments(self, ts_code: str, data: pd.DataFrame, strategy, n_workers: int) -> Dict: """ 【新增】时间分段多线程回测(单股加速) 原理:将时间轴分N段,每段并行回测,最后按时间顺序合并 注意:每段独立计算,不会共享持仓状态 """ print(f"\n单股多线程回测: {ts_code}, 时间分段数: {n_workers}") sorted_dates = sorted(data.index) total_days = len(sorted_dates) segment_size = total_days // n_workers # 分割时间段 segments = [] for i in range(n_workers): start_idx = i * segment_size end_idx = (i + 1) * segment_size if i < n_workers - 1 else total_days segment_dates = sorted_dates[start_idx:end_idx] segments.append((i, segment_dates, data.loc[segment_dates])) # 并行执行各段回测 segment_results = [] start_time = time.time() with ThreadPoolExecutor(max_workers=n_workers) as executor: future_to_segment = {} for seg_id, seg_dates, seg_data in segments: future = executor.submit(self._run_segment_backtest, ts_code, seg_data, strategy, seg_id) future_to_segment[future] = seg_id for future in as_completed(future_to_segment): seg_id = future_to_segment[future] try: result = future.result() segment_results.append((seg_id, result)) print(f"\r段 {seg_id+1}/{n_workers} 完成", end="") except Exception as e: print(f"\n段 {seg_id} 失败: {e}") # 按seg_id排序 segment_results.sort(key=lambda x: x[0]) # 合并结果 merged_result = self._merge_segment_results([r[1] for r in segment_results]) elapsed = time.time() - start_time print(f"\n时间分段回测完成! 总耗时: {elapsed:.2f}s") return merged_result def _run_segment_backtest(self, ts_code: str, seg_data: pd.DataFrame, strategy, seg_id: int) -> Dict: """执行单个时间段的回测""" # 创建独立引擎实例(避免状态冲突) engine = BacktestEngine(self.config) data_dict = {ts_code: seg_data} for current_date in sorted(seg_data.index): hist_data = seg_data[seg_data.index <= current_date] if len(hist_data) < 60: continue engine._check_positions(current_date, data_dict) engine._update_equity(current_date, data_dict) # 检查亏光 if len(engine.equity_curve) > 0 and engine.equity_curve[-1]['equity'] <= 0: break signals = engine._generate_signals(current_date, data_dict, strategy) engine._execute_trades(current_date, signals, data_dict) # 返回该段的交易和权益 return { 'trades': engine.trades, 'equity_curve': engine.equity_curve, 'final_capital': engine.capital } def _merge_segment_results(self, segment_results: List[Dict]) -> Dict: """合并多个时间段的回测结果""" # 合并交易记录 all_trades = [] for result in segment_results: all_trades.extend(result['trades']) # 合并权益曲线(按时间顺序) all_equity = [] for result in segment_results: all_equity.extend(result['equity_curve']) # 重新赋值到self,然后调用_calculate_performance self.trades = all_trades self.equity_curve = all_equity return self._calculate_performance()