Files
strategy_backtest/strategies/ma_cross.py

194 lines
8.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""示例策略:均线交叉(买入持有 5 天,最多 2 只)。
策略规则(可通过 config.settings 配置参数):
- 使用短期均线 ma_short 和长期均线 ma_long
- 当短期均线向上突破长期均线(金叉)时,若当前仓位未满且无该股持仓,则买入;
- 持有达到 hold_days 天,或出现均线死叉(短期下穿长期)时卖出;
- 最多持有 max_positions 只股票,按可用资金等权分配买入;
- 支持可选的止盈止损和仓位管理(通过 BaseStrategy 集成)。
"""
from __future__ import annotations
from typing import Dict, List, Optional
import pandas as pd
from strategies.base_strategy import BaseStrategy
from utils.logger import setup_logger
logger = setup_logger(__name__)
class MaCrossStrategy(BaseStrategy):
"""均线交叉示例策略。"""
def __init__(
self,
initial_cash: float,
ma_short: int,
ma_long: int,
hold_days: int = 5,
max_positions: int = 2,
position_pct_per_stock: float = 0.2,
stop_loss: Optional[object] = None,
take_profit: Optional[object] = None,
position_sizer: Optional[object] = None,
date_index_dict: Optional[Dict[str, Dict[str, int]]] = None,
buy_signal_index: Optional[Dict[str, List[str]]] = None,
):
"""初始化均线交叉策略。
参数:
initial_cash: 初始资金
ma_short: 短期均线周期
ma_long: 长期均线周期
hold_days: 持有天数
max_positions: 最大持仓数
position_pct_per_stock: 每只个股占总资金的比例0.2 = 20%
stop_loss: 止损管理器(可选)
take_profit: 止盈管理器(可选)
position_sizer: 仓位管理器(可选)
date_index_dict: 日期索引字典 {ts_code: {date: idx}}
buy_signal_index: 买入信号索引 {date: [ts_code1, ts_code2, ...]}
"""
super().__init__(
initial_cash=initial_cash,
max_positions=max_positions,
stop_loss=stop_loss,
take_profit=take_profit,
position_sizer=position_sizer,
date_index_dict=date_index_dict,
buy_signal_index=buy_signal_index,
)
self.ma_short = int(ma_short)
self.ma_long = int(ma_long)
self.hold_days = int(hold_days)
self.position_pct_per_stock = float(position_pct_per_stock)
# 输出策略实例化参数(验证参数读取)
logger.info(f"MaCrossStrategy 初始化参数: ma_short={self.ma_short}, ma_long={self.ma_long}, "
f"hold_days={self.hold_days}, max_positions={max_positions}, "
f"position_pct_per_stock={self.position_pct_per_stock}")
def on_bar(self, current_date: str, data_dict: Dict[str, pd.DataFrame]) -> None:
"""每个交易日的交易决策逻辑。
性能优化:直接读取预计算的信号列,无需每天重复计算均线和成交量变化。
"""
# 1. 先处理卖出逻辑(持有到期或均线死叉)
for ts_code in list(self.positions.keys()):
df = data_dict.get(ts_code)
if df is None or df.empty:
continue
# 性能优化:使用预计算的日期索引
date_index = self.date_index_dict.get(ts_code)
if date_index is None or current_date not in date_index:
continue
idx = date_index[current_date]
row = df.iloc[idx]
close = float(row["close"])
pos = self.positions[ts_code]
# 性能优化:直接读取预计算的死叉信号
death_cross = bool(row["death_cross"]) if "death_cross" in df.columns else False
# 满足持有天数或死叉则卖出
if pos.days_held >= self.hold_days or death_cross:
quantity = pos.quantity
if quantity > 0:
self.sell(ts_code, close, quantity)
# logger.info(f"{current_date} 卖出 {ts_code} 数量 {quantity} 价格 {close}") # 已移除:交易日志改用进度条
# 2. 再处理买入逻辑(金叉 + 放量按成交量排序选择前N只
if len(self.positions) >= self.max_positions:
return
# 优化如果现金过少不足初始资金1%),直接返回
if self.cash < self.initial_cash * 0.01:
return
# 性能优化核心:直接从买入信号索引获取今天有信号的股票列表
signal_stocks = self.buy_signal_index.get(current_date, [])
if not signal_stocks:
return # 今天没有任何买入信号
candidates = [] # 候选股票列表:[(ts_code, volume, price), ...]
# 只需遍历有买入信号的股票从3776只减少到99%以上)
for ts_code in signal_stocks:
if ts_code in self.positions:
continue
df = data_dict.get(ts_code)
if df is None or df.empty:
continue
# 性能优化:使用预计算的日期索引
date_index = self.date_index_dict.get(ts_code)
if date_index is None or current_date not in date_index:
continue
idx = date_index[current_date]
row = data_dict[ts_code].iloc[idx]
# 满足条件,加入候选列表
price = float(row["close"])
volume = float(row["vol"])
candidates.append((ts_code, volume, price))
# 第二步按成交量倒序排序选择成交量最大的前N只
if not candidates:
return # 没有满足条件的股票
# 按成交量降序排序
candidates.sort(key=lambda x: x[1], reverse=True)
# 计算还能买入几只
remain_slots = self.max_positions - len(self.positions)
selected = candidates[:remain_slots] # 选择前N只
# 第三步:依次买入选中的股票
for ts_code, volume, price in selected:
# 使用仓位管理器计算买入数量(如果有的话)
if self.position_sizer is not None:
df = data_dict.get(ts_code)
quantity = self.position_sizer.calc_shares(
ts_code=ts_code,
cash=self.cash,
price=price,
remain_slots=remain_slots,
df=df,
)
else:
# 默认:按配置的比例分配仓位
# 目标:使用初始资金的 position_pct_per_stock 比例例如50%
target_cash = self.initial_cash * self.position_pct_per_stock
# 实际:如果当前现金不足目标金额,就用剩余的所有现金
actual_cash = min(target_cash, self.cash)
# 计算能买多少股向下取整到100的整数倍
quantity = int(actual_cash // price)
quantity = (quantity // 100) * 100
if quantity <= 0:
continue
# 买入前记录当前现金和仓位
before_cash = self.cash
before_positions = len(self.positions)
self.buy(ts_code, price, quantity)
# 交易日志已移除,改用进度条显示
# if self.cash < before_cash and len(self.positions) > before_positions:
# actual_quantity = self.positions[ts_code].quantity
# actual_cost = before_cash - self.cash
# logger.info(
# f"{current_date} 买入 {ts_code} 数量 {actual_quantity} "
# f"价格 {price:.2f} 成本 {actual_cost:.2f} 成交量 {volume:.0f} "
# f"剩余现金 {self.cash:.2f} 当前持仓数 {len(self.positions)}/{self.max_positions}"
# )