"""策略基类与回测主循环实现。 BaseStrategy 定义: - 账户与持仓管理; - 买卖接口; - 回测主循环 run_backtest; - 集成止盈止损和仓位管理钩子; 子类只需实现 on_bar,在每个交易日根据行情数据决策买卖。 """ from __future__ import annotations from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Dict, List, Optional import numpy as np import pandas as pd from utils.logger import setup_logger logger = setup_logger(__name__) @dataclass class Position: """单只股票的持仓信息。""" ts_code: str quantity: int cost: float # 成本价 days_held: int = 0 class BaseStrategy(ABC): """回测策略抽象基类。 - 管理资金与持仓; - 提供买卖接口; - 实现回测主循环; - 支持止盈止损和仓位管理; - 子类只需实现 on_bar 方法。 """ def __init__( self, initial_cash: float, max_positions: int = 2, stop_loss: Optional[object] = None, take_profit: Optional[object] = None, position_sizer: Optional[object] = None, date_index_dict: Optional[Dict[str, Dict[str, int]]] = None, buy_signal_index: Optional[Dict[str, List[str]]] = None, ): """初始化策略。 参数: initial_cash: 初始资金 max_positions: 最大持仓数 stop_loss: 止损管理器(StopLoss实例) take_profit: 止盈管理器(StopLoss实例) position_sizer: 仓位管理器(PositionSizing实例) date_index_dict: 日期索引字典 {ts_code: {date: idx}} buy_signal_index: 买入信号索引 {date: [ts_code1, ts_code2, ...]} """ self.initial_cash = float(initial_cash) self.cash: float = float(initial_cash) self.max_positions: int = int(max_positions) self.positions: Dict[str, Position] = {} self.equity_curve: List[Dict] = [] self.trade_count: int = 0 # 交易次数统计(买入+卖出) # 交易历史记录(用于计算胜率和盈亏比) self.trade_history: List[Dict] = [] # 记录每笔完整交易(买入->卖出) # 风险管理模块(可选) self.stop_loss = stop_loss self.take_profit = take_profit self.position_sizer = position_sizer # 性能优化:日期索引字典(避免deepcopypandas.DataFrame.attrs) self.date_index_dict = date_index_dict if date_index_dict is not None else {} self.buy_signal_index = buy_signal_index if buy_signal_index is not None else {} # ------------ 需要子类实现的接口 ------------ @abstractmethod def on_bar(self, current_date: str, data_dict: Dict[str, pd.DataFrame]) -> None: """每个交易日调用一次,由子类实现交易逻辑。""" # ------------ 买卖接口 ------------ def buy(self, ts_code: str, price: float, quantity: int) -> None: """以指定价格和数量买入股票。 A股交易规则: - 最小交易单位为 1 手(100 股); - 买入数量必须为 100 的整数倍。 """ # 向下取整到 100 的整数倍 quantity = (quantity // 100) * 100 if quantity <= 0: return cost_amount = float(price) * int(quantity) if cost_amount > self.cash: return self.cash -= cost_amount if ts_code in self.positions: pos = self.positions[ts_code] total_cost = pos.cost * pos.quantity + cost_amount total_qty = pos.quantity + quantity pos.cost = total_cost / total_qty pos.quantity = total_qty # 继续累积 days_held else: self.positions[ts_code] = Position(ts_code=ts_code, quantity=quantity, cost=float(price)) self.trade_count += 1 # 记录买入交易 def sell(self, ts_code: str, price: float, quantity: int) -> None: """以指定价格和数量卖出股票。""" if ts_code not in self.positions: return pos = self.positions[ts_code] sell_qty = min(quantity, pos.quantity) if sell_qty <= 0: return # 计算盈亏(用于更新仓位管理器统计) profit_pct = (price - pos.cost) / pos.cost profit_amount = (price - pos.cost) * sell_qty # 记录交易历史 self.trade_history.append({ "ts_code": ts_code, "buy_price": pos.cost, "sell_price": price, "quantity": sell_qty, "profit_pct": profit_pct, "profit_amount": profit_amount, "is_win": profit_pct > 0, }) self.cash += float(price) * sell_qty pos.quantity -= sell_qty if pos.quantity == 0: del self.positions[ts_code] # 清理跟踪止盈状态 if self.stop_loss is not None: self.stop_loss.reset_tracking(ts_code) if self.take_profit is not None: self.take_profit.reset_tracking(ts_code) # 更新仓位管理器的交易统计(用于Kelly公式) if self.position_sizer is not None and hasattr(self.position_sizer, 'update_trade_stats'): self.position_sizer.update_trade_stats(ts_code, profit_pct) self.trade_count += 1 # 记录卖出交易 # ------------ 回测主循环 ------------ def _calc_market_value(self, current_date: str, data_dict: Dict[str, pd.DataFrame]) -> float: """计算当前持仓市值。""" total = 0.0 for ts_code, pos in self.positions.items(): df = data_dict.get(ts_code) if df is None or df.empty: continue # 性能优化:使用预计算的日期索引 date_index = self.date_index_dict.get(ts_code) if date_index is not None and current_date in date_index: idx = date_index[current_date] price = float(df.iloc[idx]["close"]) else: # 备用方案:若当日无行情,则以最近一条收盘价估值 price = float(df["close"].iloc[-1]) total += price * pos.quantity return total def _update_days_held(self, current_date: str, data_dict: Dict[str, pd.DataFrame]) -> None: """默认每天将所有持仓的持有天数 +1。 若子类需要更复杂逻辑,可以覆盖或在 on_bar 中自行管理。 """ for pos in self.positions.values(): pos.days_held += 1 def _check_stop_loss_take_profit(self, current_date: str, data_dict: Dict[str, pd.DataFrame]) -> None: """检查所有持仓的止盈止损条件。 在每个交易日开盘前或收盘后调用,自动触发止损/止盈卖出。 子类可以覆盖此方法以实现自定义逻辑。 """ for ts_code in list(self.positions.keys()): df = data_dict.get(ts_code) if df is None or df.empty: continue # 性能优化:使用预计算的日期索引 date_index = self.date_index_dict.get(ts_code) if date_index is None or current_date not in date_index: continue idx = date_index[current_date] row = df.iloc[idx] pos = self.positions[ts_code] close = float(row["close"]) high = float(row["high"]) if "high" in df.columns else None low = float(row["low"]) if "low" in df.columns else None # 计算ATR(如果需要) atr = None if self.stop_loss is not None and self.stop_loss.method == "atr": from risk.stop_loss import calculate_atr atr_series = calculate_atr(df, self.stop_loss.atr_period) if not atr_series.empty: atr = float(atr_series.iloc[-1]) # 检查止损 if self.stop_loss is not None: should_exit, reason = self.stop_loss.should_exit( ts_code=ts_code, current_price=close, cost_price=pos.cost, high=high, low=low, atr=atr, ) if should_exit: quantity = pos.quantity self.sell(ts_code, close, quantity) # logger.info(f"{current_date} 止损卖出 {ts_code} 数量 {quantity} 价格 {close:.2f} 原因: {reason}") # 已移除:止损日志改用进度条 continue # 已卖出,不再检查止盈 # 检查止盈 if self.take_profit is not None: should_exit, reason = self.take_profit.should_exit( ts_code=ts_code, current_price=close, cost_price=pos.cost, high=high, low=low, atr=atr, ) if should_exit: quantity = pos.quantity self.sell(ts_code, close, quantity) # logger.info(f"{current_date} 止盈卖出 {ts_code} 数量 {quantity} 价格 {close:.2f} 原因: {reason}") # 已移除:止盈日志改用进度条 def run_backtest(self, data_dict: Dict[str, pd.DataFrame], calendar: List[str]) -> pd.DataFrame: """主回测入口。 参数: - data_dict: {ts_code: DataFrame},每个 df 至少含 ['trade_date', 'open', 'high', 'low', 'close']; - calendar: 统一交易日列表(升序,字符串 YYYYMMDD)。 返回: - 资金曲线 DataFrame,列为 ['trade_date', 'total_asset', 'cash', 'market_value']。 """ from tqdm import tqdm logger.info(f"回测开始,共 {len(calendar)} 个交易日") # 使用进度条显示回测进度 for current_date in tqdm(calendar, desc="回测进度", unit="日"): # 更新持有天数 self._update_days_held(current_date, data_dict) # 检查止盈止损(在策略决策前) self._check_stop_loss_take_profit(current_date, data_dict) # 策略决策 try: self.on_bar(current_date, data_dict) except Exception as e: # noqa: BLE001 logger.error(f"on_bar 异常: {e}") # 计算当日资产 market_value = self._calc_market_value(current_date, data_dict) total_asset = self.cash + market_value self.equity_curve.append( { "trade_date": current_date, "total_asset": total_asset, "cash": self.cash, "market_value": market_value, } ) logger.info("回测结束") return pd.DataFrame(self.equity_curve)