293 lines
11 KiB
Python
293 lines
11 KiB
Python
"""策略基类与回测主循环实现。
|
||
|
||
BaseStrategy 定义:
|
||
- 账户与持仓管理;
|
||
- 买卖接口;
|
||
- 回测主循环 run_backtest;
|
||
- 集成止盈止损和仓位管理钩子;
|
||
子类只需实现 on_bar,在每个交易日根据行情数据决策买卖。
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
from abc import ABC, abstractmethod
|
||
from dataclasses import dataclass
|
||
from typing import Dict, List, Optional
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
|
||
from utils.logger import setup_logger
|
||
|
||
logger = setup_logger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class Position:
|
||
"""单只股票的持仓信息。"""
|
||
|
||
ts_code: str
|
||
quantity: int
|
||
cost: float # 成本价
|
||
days_held: int = 0
|
||
|
||
|
||
class BaseStrategy(ABC):
|
||
"""回测策略抽象基类。
|
||
|
||
- 管理资金与持仓;
|
||
- 提供买卖接口;
|
||
- 实现回测主循环;
|
||
- 支持止盈止损和仓位管理;
|
||
- 子类只需实现 on_bar 方法。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
initial_cash: float,
|
||
max_positions: int = 2,
|
||
stop_loss: Optional[object] = None,
|
||
take_profit: Optional[object] = None,
|
||
position_sizer: Optional[object] = None,
|
||
date_index_dict: Optional[Dict[str, Dict[str, int]]] = None,
|
||
buy_signal_index: Optional[Dict[str, List[str]]] = None,
|
||
):
|
||
"""初始化策略。
|
||
|
||
参数:
|
||
initial_cash: 初始资金
|
||
max_positions: 最大持仓数
|
||
stop_loss: 止损管理器(StopLoss实例)
|
||
take_profit: 止盈管理器(StopLoss实例)
|
||
position_sizer: 仓位管理器(PositionSizing实例)
|
||
date_index_dict: 日期索引字典 {ts_code: {date: idx}}
|
||
buy_signal_index: 买入信号索引 {date: [ts_code1, ts_code2, ...]}
|
||
"""
|
||
self.initial_cash = float(initial_cash)
|
||
self.cash: float = float(initial_cash)
|
||
self.max_positions: int = int(max_positions)
|
||
self.positions: Dict[str, Position] = {}
|
||
self.equity_curve: List[Dict] = []
|
||
self.trade_count: int = 0 # 交易次数统计(买入+卖出)
|
||
|
||
# 交易历史记录(用于计算胜率和盈亏比)
|
||
self.trade_history: List[Dict] = [] # 记录每笔完整交易(买入->卖出)
|
||
|
||
# 风险管理模块(可选)
|
||
self.stop_loss = stop_loss
|
||
self.take_profit = take_profit
|
||
self.position_sizer = position_sizer
|
||
|
||
# 性能优化:日期索引字典(避免deepcopypandas.DataFrame.attrs)
|
||
self.date_index_dict = date_index_dict if date_index_dict is not None else {}
|
||
self.buy_signal_index = buy_signal_index if buy_signal_index is not None else {}
|
||
|
||
# ------------ 需要子类实现的接口 ------------
|
||
@abstractmethod
|
||
def on_bar(self, current_date: str, data_dict: Dict[str, pd.DataFrame]) -> None:
|
||
"""每个交易日调用一次,由子类实现交易逻辑。"""
|
||
|
||
# ------------ 买卖接口 ------------
|
||
def buy(self, ts_code: str, price: float, quantity: int) -> None:
|
||
"""以指定价格和数量买入股票。
|
||
|
||
A股交易规则:
|
||
- 最小交易单位为 1 手(100 股);
|
||
- 买入数量必须为 100 的整数倍。
|
||
"""
|
||
# 向下取整到 100 的整数倍
|
||
quantity = (quantity // 100) * 100
|
||
if quantity <= 0:
|
||
return
|
||
cost_amount = float(price) * int(quantity)
|
||
if cost_amount > self.cash:
|
||
return
|
||
|
||
self.cash -= cost_amount
|
||
if ts_code in self.positions:
|
||
pos = self.positions[ts_code]
|
||
total_cost = pos.cost * pos.quantity + cost_amount
|
||
total_qty = pos.quantity + quantity
|
||
pos.cost = total_cost / total_qty
|
||
pos.quantity = total_qty
|
||
# 继续累积 days_held
|
||
else:
|
||
self.positions[ts_code] = Position(ts_code=ts_code, quantity=quantity, cost=float(price))
|
||
|
||
self.trade_count += 1 # 记录买入交易
|
||
|
||
def sell(self, ts_code: str, price: float, quantity: int) -> None:
|
||
"""以指定价格和数量卖出股票。"""
|
||
if ts_code not in self.positions:
|
||
return
|
||
pos = self.positions[ts_code]
|
||
sell_qty = min(quantity, pos.quantity)
|
||
if sell_qty <= 0:
|
||
return
|
||
|
||
# 计算盈亏(用于更新仓位管理器统计)
|
||
profit_pct = (price - pos.cost) / pos.cost
|
||
profit_amount = (price - pos.cost) * sell_qty
|
||
|
||
# 记录交易历史
|
||
self.trade_history.append({
|
||
"ts_code": ts_code,
|
||
"buy_price": pos.cost,
|
||
"sell_price": price,
|
||
"quantity": sell_qty,
|
||
"profit_pct": profit_pct,
|
||
"profit_amount": profit_amount,
|
||
"is_win": profit_pct > 0,
|
||
})
|
||
|
||
self.cash += float(price) * sell_qty
|
||
pos.quantity -= sell_qty
|
||
if pos.quantity == 0:
|
||
del self.positions[ts_code]
|
||
# 清理跟踪止盈状态
|
||
if self.stop_loss is not None:
|
||
self.stop_loss.reset_tracking(ts_code)
|
||
if self.take_profit is not None:
|
||
self.take_profit.reset_tracking(ts_code)
|
||
|
||
# 更新仓位管理器的交易统计(用于Kelly公式)
|
||
if self.position_sizer is not None and hasattr(self.position_sizer, 'update_trade_stats'):
|
||
self.position_sizer.update_trade_stats(ts_code, profit_pct)
|
||
|
||
self.trade_count += 1 # 记录卖出交易
|
||
|
||
# ------------ 回测主循环 ------------
|
||
def _calc_market_value(self, current_date: str, data_dict: Dict[str, pd.DataFrame]) -> float:
|
||
"""计算当前持仓市值。"""
|
||
total = 0.0
|
||
for ts_code, pos in self.positions.items():
|
||
df = data_dict.get(ts_code)
|
||
if df is None or df.empty:
|
||
continue
|
||
|
||
# 性能优化:使用预计算的日期索引
|
||
date_index = self.date_index_dict.get(ts_code)
|
||
if date_index is not None and current_date in date_index:
|
||
idx = date_index[current_date]
|
||
price = float(df.iloc[idx]["close"])
|
||
else:
|
||
# 备用方案:若当日无行情,则以最近一条收盘价估值
|
||
price = float(df["close"].iloc[-1])
|
||
|
||
total += price * pos.quantity
|
||
return total
|
||
|
||
def _update_days_held(self, current_date: str, data_dict: Dict[str, pd.DataFrame]) -> None:
|
||
"""默认每天将所有持仓的持有天数 +1。
|
||
|
||
若子类需要更复杂逻辑,可以覆盖或在 on_bar 中自行管理。
|
||
"""
|
||
for pos in self.positions.values():
|
||
pos.days_held += 1
|
||
|
||
def _check_stop_loss_take_profit(self, current_date: str, data_dict: Dict[str, pd.DataFrame]) -> None:
|
||
"""检查所有持仓的止盈止损条件。
|
||
|
||
在每个交易日开盘前或收盘后调用,自动触发止损/止盈卖出。
|
||
子类可以覆盖此方法以实现自定义逻辑。
|
||
"""
|
||
for ts_code in list(self.positions.keys()):
|
||
df = data_dict.get(ts_code)
|
||
if df is None or df.empty:
|
||
continue
|
||
|
||
# 性能优化:使用预计算的日期索引
|
||
date_index = self.date_index_dict.get(ts_code)
|
||
if date_index is None or current_date not in date_index:
|
||
continue
|
||
|
||
idx = date_index[current_date]
|
||
row = df.iloc[idx]
|
||
|
||
pos = self.positions[ts_code]
|
||
close = float(row["close"])
|
||
high = float(row["high"]) if "high" in df.columns else None
|
||
low = float(row["low"]) if "low" in df.columns else None
|
||
|
||
# 计算ATR(如果需要)
|
||
atr = None
|
||
if self.stop_loss is not None and self.stop_loss.method == "atr":
|
||
from risk.stop_loss import calculate_atr
|
||
atr_series = calculate_atr(df, self.stop_loss.atr_period)
|
||
if not atr_series.empty:
|
||
atr = float(atr_series.iloc[-1])
|
||
|
||
# 检查止损
|
||
if self.stop_loss is not None:
|
||
should_exit, reason = self.stop_loss.should_exit(
|
||
ts_code=ts_code,
|
||
current_price=close,
|
||
cost_price=pos.cost,
|
||
high=high,
|
||
low=low,
|
||
atr=atr,
|
||
)
|
||
if should_exit:
|
||
quantity = pos.quantity
|
||
self.sell(ts_code, close, quantity)
|
||
# logger.info(f"{current_date} 止损卖出 {ts_code} 数量 {quantity} 价格 {close:.2f} 原因: {reason}") # 已移除:止损日志改用进度条
|
||
continue # 已卖出,不再检查止盈
|
||
|
||
# 检查止盈
|
||
if self.take_profit is not None:
|
||
should_exit, reason = self.take_profit.should_exit(
|
||
ts_code=ts_code,
|
||
current_price=close,
|
||
cost_price=pos.cost,
|
||
high=high,
|
||
low=low,
|
||
atr=atr,
|
||
)
|
||
if should_exit:
|
||
quantity = pos.quantity
|
||
self.sell(ts_code, close, quantity)
|
||
# logger.info(f"{current_date} 止盈卖出 {ts_code} 数量 {quantity} 价格 {close:.2f} 原因: {reason}") # 已移除:止盈日志改用进度条
|
||
|
||
def run_backtest(self, data_dict: Dict[str, pd.DataFrame], calendar: List[str]) -> pd.DataFrame:
|
||
"""主回测入口。
|
||
|
||
参数:
|
||
- data_dict: {ts_code: DataFrame},每个 df 至少含 ['trade_date', 'open', 'high', 'low', 'close'];
|
||
- calendar: 统一交易日列表(升序,字符串 YYYYMMDD)。
|
||
|
||
返回:
|
||
- 资金曲线 DataFrame,列为 ['trade_date', 'total_asset', 'cash', 'market_value']。
|
||
"""
|
||
from tqdm import tqdm
|
||
|
||
logger.info(f"回测开始,共 {len(calendar)} 个交易日")
|
||
|
||
# 使用进度条显示回测进度
|
||
for current_date in tqdm(calendar, desc="回测进度", unit="日"):
|
||
# 更新持有天数
|
||
self._update_days_held(current_date, data_dict)
|
||
|
||
# 检查止盈止损(在策略决策前)
|
||
self._check_stop_loss_take_profit(current_date, data_dict)
|
||
|
||
# 策略决策
|
||
try:
|
||
self.on_bar(current_date, data_dict)
|
||
except Exception as e: # noqa: BLE001
|
||
logger.error(f"on_bar 异常: {e}")
|
||
|
||
# 计算当日资产
|
||
market_value = self._calc_market_value(current_date, data_dict)
|
||
total_asset = self.cash + market_value
|
||
|
||
self.equity_curve.append(
|
||
{
|
||
"trade_date": current_date,
|
||
"total_asset": total_asset,
|
||
"cash": self.cash,
|
||
"market_value": market_value,
|
||
}
|
||
)
|
||
|
||
logger.info("回测结束")
|
||
return pd.DataFrame(self.equity_curve)
|