新建回测系统,并提交

This commit is contained in:
2026-01-17 21:37:42 +08:00
commit fe50ea935a
68 changed files with 108208 additions and 0 deletions

292
strategies/base_strategy.py Normal file
View File

@@ -0,0 +1,292 @@
"""策略基类与回测主循环实现。
BaseStrategy 定义:
- 账户与持仓管理;
- 买卖接口;
- 回测主循环 run_backtest
- 集成止盈止损和仓位管理钩子;
子类只需实现 on_bar在每个交易日根据行情数据决策买卖。
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Optional
import numpy as np
import pandas as pd
from utils.logger import setup_logger
logger = setup_logger(__name__)
@dataclass
class Position:
"""单只股票的持仓信息。"""
ts_code: str
quantity: int
cost: float # 成本价
days_held: int = 0
class BaseStrategy(ABC):
"""回测策略抽象基类。
- 管理资金与持仓;
- 提供买卖接口;
- 实现回测主循环;
- 支持止盈止损和仓位管理;
- 子类只需实现 on_bar 方法。
"""
def __init__(
self,
initial_cash: float,
max_positions: int = 2,
stop_loss: Optional[object] = None,
take_profit: Optional[object] = None,
position_sizer: Optional[object] = None,
date_index_dict: Optional[Dict[str, Dict[str, int]]] = None,
buy_signal_index: Optional[Dict[str, List[str]]] = None,
):
"""初始化策略。
参数:
initial_cash: 初始资金
max_positions: 最大持仓数
stop_loss: 止损管理器StopLoss实例
take_profit: 止盈管理器StopLoss实例
position_sizer: 仓位管理器PositionSizing实例
date_index_dict: 日期索引字典 {ts_code: {date: idx}}
buy_signal_index: 买入信号索引 {date: [ts_code1, ts_code2, ...]}
"""
self.initial_cash = float(initial_cash)
self.cash: float = float(initial_cash)
self.max_positions: int = int(max_positions)
self.positions: Dict[str, Position] = {}
self.equity_curve: List[Dict] = []
self.trade_count: int = 0 # 交易次数统计(买入+卖出)
# 交易历史记录(用于计算胜率和盈亏比)
self.trade_history: List[Dict] = [] # 记录每笔完整交易(买入->卖出)
# 风险管理模块(可选)
self.stop_loss = stop_loss
self.take_profit = take_profit
self.position_sizer = position_sizer
# 性能优化日期索引字典避免deepcopypandas.DataFrame.attrs
self.date_index_dict = date_index_dict if date_index_dict is not None else {}
self.buy_signal_index = buy_signal_index if buy_signal_index is not None else {}
# ------------ 需要子类实现的接口 ------------
@abstractmethod
def on_bar(self, current_date: str, data_dict: Dict[str, pd.DataFrame]) -> None:
"""每个交易日调用一次,由子类实现交易逻辑。"""
# ------------ 买卖接口 ------------
def buy(self, ts_code: str, price: float, quantity: int) -> None:
"""以指定价格和数量买入股票。
A股交易规则
- 最小交易单位为 1 手100 股);
- 买入数量必须为 100 的整数倍。
"""
# 向下取整到 100 的整数倍
quantity = (quantity // 100) * 100
if quantity <= 0:
return
cost_amount = float(price) * int(quantity)
if cost_amount > self.cash:
return
self.cash -= cost_amount
if ts_code in self.positions:
pos = self.positions[ts_code]
total_cost = pos.cost * pos.quantity + cost_amount
total_qty = pos.quantity + quantity
pos.cost = total_cost / total_qty
pos.quantity = total_qty
# 继续累积 days_held
else:
self.positions[ts_code] = Position(ts_code=ts_code, quantity=quantity, cost=float(price))
self.trade_count += 1 # 记录买入交易
def sell(self, ts_code: str, price: float, quantity: int) -> None:
"""以指定价格和数量卖出股票。"""
if ts_code not in self.positions:
return
pos = self.positions[ts_code]
sell_qty = min(quantity, pos.quantity)
if sell_qty <= 0:
return
# 计算盈亏(用于更新仓位管理器统计)
profit_pct = (price - pos.cost) / pos.cost
profit_amount = (price - pos.cost) * sell_qty
# 记录交易历史
self.trade_history.append({
"ts_code": ts_code,
"buy_price": pos.cost,
"sell_price": price,
"quantity": sell_qty,
"profit_pct": profit_pct,
"profit_amount": profit_amount,
"is_win": profit_pct > 0,
})
self.cash += float(price) * sell_qty
pos.quantity -= sell_qty
if pos.quantity == 0:
del self.positions[ts_code]
# 清理跟踪止盈状态
if self.stop_loss is not None:
self.stop_loss.reset_tracking(ts_code)
if self.take_profit is not None:
self.take_profit.reset_tracking(ts_code)
# 更新仓位管理器的交易统计用于Kelly公式
if self.position_sizer is not None and hasattr(self.position_sizer, 'update_trade_stats'):
self.position_sizer.update_trade_stats(ts_code, profit_pct)
self.trade_count += 1 # 记录卖出交易
# ------------ 回测主循环 ------------
def _calc_market_value(self, current_date: str, data_dict: Dict[str, pd.DataFrame]) -> float:
"""计算当前持仓市值。"""
total = 0.0
for ts_code, pos in self.positions.items():
df = data_dict.get(ts_code)
if df is None or df.empty:
continue
# 性能优化:使用预计算的日期索引
date_index = self.date_index_dict.get(ts_code)
if date_index is not None and current_date in date_index:
idx = date_index[current_date]
price = float(df.iloc[idx]["close"])
else:
# 备用方案:若当日无行情,则以最近一条收盘价估值
price = float(df["close"].iloc[-1])
total += price * pos.quantity
return total
def _update_days_held(self, current_date: str, data_dict: Dict[str, pd.DataFrame]) -> None:
"""默认每天将所有持仓的持有天数 +1。
若子类需要更复杂逻辑,可以覆盖或在 on_bar 中自行管理。
"""
for pos in self.positions.values():
pos.days_held += 1
def _check_stop_loss_take_profit(self, current_date: str, data_dict: Dict[str, pd.DataFrame]) -> None:
"""检查所有持仓的止盈止损条件。
在每个交易日开盘前或收盘后调用,自动触发止损/止盈卖出。
子类可以覆盖此方法以实现自定义逻辑。
"""
for ts_code in list(self.positions.keys()):
df = data_dict.get(ts_code)
if df is None or df.empty:
continue
# 性能优化:使用预计算的日期索引
date_index = self.date_index_dict.get(ts_code)
if date_index is None or current_date not in date_index:
continue
idx = date_index[current_date]
row = df.iloc[idx]
pos = self.positions[ts_code]
close = float(row["close"])
high = float(row["high"]) if "high" in df.columns else None
low = float(row["low"]) if "low" in df.columns else None
# 计算ATR如果需要
atr = None
if self.stop_loss is not None and self.stop_loss.method == "atr":
from risk.stop_loss import calculate_atr
atr_series = calculate_atr(df, self.stop_loss.atr_period)
if not atr_series.empty:
atr = float(atr_series.iloc[-1])
# 检查止损
if self.stop_loss is not None:
should_exit, reason = self.stop_loss.should_exit(
ts_code=ts_code,
current_price=close,
cost_price=pos.cost,
high=high,
low=low,
atr=atr,
)
if should_exit:
quantity = pos.quantity
self.sell(ts_code, close, quantity)
# logger.info(f"{current_date} 止损卖出 {ts_code} 数量 {quantity} 价格 {close:.2f} 原因: {reason}") # 已移除:止损日志改用进度条
continue # 已卖出,不再检查止盈
# 检查止盈
if self.take_profit is not None:
should_exit, reason = self.take_profit.should_exit(
ts_code=ts_code,
current_price=close,
cost_price=pos.cost,
high=high,
low=low,
atr=atr,
)
if should_exit:
quantity = pos.quantity
self.sell(ts_code, close, quantity)
# logger.info(f"{current_date} 止盈卖出 {ts_code} 数量 {quantity} 价格 {close:.2f} 原因: {reason}") # 已移除:止盈日志改用进度条
def run_backtest(self, data_dict: Dict[str, pd.DataFrame], calendar: List[str]) -> pd.DataFrame:
"""主回测入口。
参数:
- data_dict: {ts_code: DataFrame},每个 df 至少含 ['trade_date', 'open', 'high', 'low', 'close']
- calendar: 统一交易日列表(升序,字符串 YYYYMMDD
返回:
- 资金曲线 DataFrame列为 ['trade_date', 'total_asset', 'cash', 'market_value']。
"""
from tqdm import tqdm
logger.info(f"回测开始,共 {len(calendar)} 个交易日")
# 使用进度条显示回测进度
for current_date in tqdm(calendar, desc="回测进度", unit=""):
# 更新持有天数
self._update_days_held(current_date, data_dict)
# 检查止盈止损(在策略决策前)
self._check_stop_loss_take_profit(current_date, data_dict)
# 策略决策
try:
self.on_bar(current_date, data_dict)
except Exception as e: # noqa: BLE001
logger.error(f"on_bar 异常: {e}")
# 计算当日资产
market_value = self._calc_market_value(current_date, data_dict)
total_asset = self.cash + market_value
self.equity_curve.append(
{
"trade_date": current_date,
"total_asset": total_asset,
"cash": self.cash,
"market_value": market_value,
}
)
logger.info("回测结束")
return pd.DataFrame(self.equity_curve)