Files
strategy_backtest/strategies/base_strategy.py

293 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""策略基类与回测主循环实现。
BaseStrategy 定义:
- 账户与持仓管理;
- 买卖接口;
- 回测主循环 run_backtest
- 集成止盈止损和仓位管理钩子;
子类只需实现 on_bar在每个交易日根据行情数据决策买卖。
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Optional
import numpy as np
import pandas as pd
from utils.logger import setup_logger
logger = setup_logger(__name__)
@dataclass
class Position:
"""单只股票的持仓信息。"""
ts_code: str
quantity: int
cost: float # 成本价
days_held: int = 0
class BaseStrategy(ABC):
"""回测策略抽象基类。
- 管理资金与持仓;
- 提供买卖接口;
- 实现回测主循环;
- 支持止盈止损和仓位管理;
- 子类只需实现 on_bar 方法。
"""
def __init__(
self,
initial_cash: float,
max_positions: int = 2,
stop_loss: Optional[object] = None,
take_profit: Optional[object] = None,
position_sizer: Optional[object] = None,
date_index_dict: Optional[Dict[str, Dict[str, int]]] = None,
buy_signal_index: Optional[Dict[str, List[str]]] = None,
):
"""初始化策略。
参数:
initial_cash: 初始资金
max_positions: 最大持仓数
stop_loss: 止损管理器StopLoss实例
take_profit: 止盈管理器StopLoss实例
position_sizer: 仓位管理器PositionSizing实例
date_index_dict: 日期索引字典 {ts_code: {date: idx}}
buy_signal_index: 买入信号索引 {date: [ts_code1, ts_code2, ...]}
"""
self.initial_cash = float(initial_cash)
self.cash: float = float(initial_cash)
self.max_positions: int = int(max_positions)
self.positions: Dict[str, Position] = {}
self.equity_curve: List[Dict] = []
self.trade_count: int = 0 # 交易次数统计(买入+卖出)
# 交易历史记录(用于计算胜率和盈亏比)
self.trade_history: List[Dict] = [] # 记录每笔完整交易(买入->卖出)
# 风险管理模块(可选)
self.stop_loss = stop_loss
self.take_profit = take_profit
self.position_sizer = position_sizer
# 性能优化日期索引字典避免deepcopypandas.DataFrame.attrs
self.date_index_dict = date_index_dict if date_index_dict is not None else {}
self.buy_signal_index = buy_signal_index if buy_signal_index is not None else {}
# ------------ 需要子类实现的接口 ------------
@abstractmethod
def on_bar(self, current_date: str, data_dict: Dict[str, pd.DataFrame]) -> None:
"""每个交易日调用一次,由子类实现交易逻辑。"""
# ------------ 买卖接口 ------------
def buy(self, ts_code: str, price: float, quantity: int) -> None:
"""以指定价格和数量买入股票。
A股交易规则
- 最小交易单位为 1 手100 股);
- 买入数量必须为 100 的整数倍。
"""
# 向下取整到 100 的整数倍
quantity = (quantity // 100) * 100
if quantity <= 0:
return
cost_amount = float(price) * int(quantity)
if cost_amount > self.cash:
return
self.cash -= cost_amount
if ts_code in self.positions:
pos = self.positions[ts_code]
total_cost = pos.cost * pos.quantity + cost_amount
total_qty = pos.quantity + quantity
pos.cost = total_cost / total_qty
pos.quantity = total_qty
# 继续累积 days_held
else:
self.positions[ts_code] = Position(ts_code=ts_code, quantity=quantity, cost=float(price))
self.trade_count += 1 # 记录买入交易
def sell(self, ts_code: str, price: float, quantity: int) -> None:
"""以指定价格和数量卖出股票。"""
if ts_code not in self.positions:
return
pos = self.positions[ts_code]
sell_qty = min(quantity, pos.quantity)
if sell_qty <= 0:
return
# 计算盈亏(用于更新仓位管理器统计)
profit_pct = (price - pos.cost) / pos.cost
profit_amount = (price - pos.cost) * sell_qty
# 记录交易历史
self.trade_history.append({
"ts_code": ts_code,
"buy_price": pos.cost,
"sell_price": price,
"quantity": sell_qty,
"profit_pct": profit_pct,
"profit_amount": profit_amount,
"is_win": profit_pct > 0,
})
self.cash += float(price) * sell_qty
pos.quantity -= sell_qty
if pos.quantity == 0:
del self.positions[ts_code]
# 清理跟踪止盈状态
if self.stop_loss is not None:
self.stop_loss.reset_tracking(ts_code)
if self.take_profit is not None:
self.take_profit.reset_tracking(ts_code)
# 更新仓位管理器的交易统计用于Kelly公式
if self.position_sizer is not None and hasattr(self.position_sizer, 'update_trade_stats'):
self.position_sizer.update_trade_stats(ts_code, profit_pct)
self.trade_count += 1 # 记录卖出交易
# ------------ 回测主循环 ------------
def _calc_market_value(self, current_date: str, data_dict: Dict[str, pd.DataFrame]) -> float:
"""计算当前持仓市值。"""
total = 0.0
for ts_code, pos in self.positions.items():
df = data_dict.get(ts_code)
if df is None or df.empty:
continue
# 性能优化:使用预计算的日期索引
date_index = self.date_index_dict.get(ts_code)
if date_index is not None and current_date in date_index:
idx = date_index[current_date]
price = float(df.iloc[idx]["close"])
else:
# 备用方案:若当日无行情,则以最近一条收盘价估值
price = float(df["close"].iloc[-1])
total += price * pos.quantity
return total
def _update_days_held(self, current_date: str, data_dict: Dict[str, pd.DataFrame]) -> None:
"""默认每天将所有持仓的持有天数 +1。
若子类需要更复杂逻辑,可以覆盖或在 on_bar 中自行管理。
"""
for pos in self.positions.values():
pos.days_held += 1
def _check_stop_loss_take_profit(self, current_date: str, data_dict: Dict[str, pd.DataFrame]) -> None:
"""检查所有持仓的止盈止损条件。
在每个交易日开盘前或收盘后调用,自动触发止损/止盈卖出。
子类可以覆盖此方法以实现自定义逻辑。
"""
for ts_code in list(self.positions.keys()):
df = data_dict.get(ts_code)
if df is None or df.empty:
continue
# 性能优化:使用预计算的日期索引
date_index = self.date_index_dict.get(ts_code)
if date_index is None or current_date not in date_index:
continue
idx = date_index[current_date]
row = df.iloc[idx]
pos = self.positions[ts_code]
close = float(row["close"])
high = float(row["high"]) if "high" in df.columns else None
low = float(row["low"]) if "low" in df.columns else None
# 计算ATR如果需要
atr = None
if self.stop_loss is not None and self.stop_loss.method == "atr":
from risk.stop_loss import calculate_atr
atr_series = calculate_atr(df, self.stop_loss.atr_period)
if not atr_series.empty:
atr = float(atr_series.iloc[-1])
# 检查止损
if self.stop_loss is not None:
should_exit, reason = self.stop_loss.should_exit(
ts_code=ts_code,
current_price=close,
cost_price=pos.cost,
high=high,
low=low,
atr=atr,
)
if should_exit:
quantity = pos.quantity
self.sell(ts_code, close, quantity)
# logger.info(f"{current_date} 止损卖出 {ts_code} 数量 {quantity} 价格 {close:.2f} 原因: {reason}") # 已移除:止损日志改用进度条
continue # 已卖出,不再检查止盈
# 检查止盈
if self.take_profit is not None:
should_exit, reason = self.take_profit.should_exit(
ts_code=ts_code,
current_price=close,
cost_price=pos.cost,
high=high,
low=low,
atr=atr,
)
if should_exit:
quantity = pos.quantity
self.sell(ts_code, close, quantity)
# logger.info(f"{current_date} 止盈卖出 {ts_code} 数量 {quantity} 价格 {close:.2f} 原因: {reason}") # 已移除:止盈日志改用进度条
def run_backtest(self, data_dict: Dict[str, pd.DataFrame], calendar: List[str]) -> pd.DataFrame:
"""主回测入口。
参数:
- data_dict: {ts_code: DataFrame},每个 df 至少含 ['trade_date', 'open', 'high', 'low', 'close']
- calendar: 统一交易日列表(升序,字符串 YYYYMMDD
返回:
- 资金曲线 DataFrame列为 ['trade_date', 'total_asset', 'cash', 'market_value']。
"""
from tqdm import tqdm
logger.info(f"回测开始,共 {len(calendar)} 个交易日")
# 使用进度条显示回测进度
for current_date in tqdm(calendar, desc="回测进度", unit=""):
# 更新持有天数
self._update_days_held(current_date, data_dict)
# 检查止盈止损(在策略决策前)
self._check_stop_loss_take_profit(current_date, data_dict)
# 策略决策
try:
self.on_bar(current_date, data_dict)
except Exception as e: # noqa: BLE001
logger.error(f"on_bar 异常: {e}")
# 计算当日资产
market_value = self._calc_market_value(current_date, data_dict)
total_asset = self.cash + market_value
self.equity_curve.append(
{
"trade_date": current_date,
"total_asset": total_asset,
"cash": self.cash,
"market_value": market_value,
}
)
logger.info("回测结束")
return pd.DataFrame(self.equity_curve)