194 lines
8.2 KiB
Python
194 lines
8.2 KiB
Python
"""示例策略:均线交叉(买入持有 5 天,最多 2 只)。
|
||
|
||
策略规则(可通过 config.settings 配置参数):
|
||
- 使用短期均线 ma_short 和长期均线 ma_long;
|
||
- 当短期均线向上突破长期均线(金叉)时,若当前仓位未满且无该股持仓,则买入;
|
||
- 持有达到 hold_days 天,或出现均线死叉(短期下穿长期)时卖出;
|
||
- 最多持有 max_positions 只股票,按可用资金等权分配买入;
|
||
- 支持可选的止盈止损和仓位管理(通过 BaseStrategy 集成)。
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
from typing import Dict, List, Optional
|
||
|
||
import pandas as pd
|
||
|
||
from strategies.base_strategy import BaseStrategy
|
||
from utils.logger import setup_logger
|
||
|
||
logger = setup_logger(__name__)
|
||
|
||
|
||
class MaCrossStrategy(BaseStrategy):
|
||
"""均线交叉示例策略。"""
|
||
|
||
def __init__(
|
||
self,
|
||
initial_cash: float,
|
||
ma_short: int,
|
||
ma_long: int,
|
||
hold_days: int = 5,
|
||
max_positions: int = 2,
|
||
position_pct_per_stock: float = 0.2,
|
||
stop_loss: Optional[object] = None,
|
||
take_profit: Optional[object] = None,
|
||
position_sizer: Optional[object] = None,
|
||
date_index_dict: Optional[Dict[str, Dict[str, int]]] = None,
|
||
buy_signal_index: Optional[Dict[str, List[str]]] = None,
|
||
):
|
||
"""初始化均线交叉策略。
|
||
|
||
参数:
|
||
initial_cash: 初始资金
|
||
ma_short: 短期均线周期
|
||
ma_long: 长期均线周期
|
||
hold_days: 持有天数
|
||
max_positions: 最大持仓数
|
||
position_pct_per_stock: 每只个股占总资金的比例(0.2 = 20%)
|
||
stop_loss: 止损管理器(可选)
|
||
take_profit: 止盈管理器(可选)
|
||
position_sizer: 仓位管理器(可选)
|
||
date_index_dict: 日期索引字典 {ts_code: {date: idx}}
|
||
buy_signal_index: 买入信号索引 {date: [ts_code1, ts_code2, ...]}
|
||
"""
|
||
super().__init__(
|
||
initial_cash=initial_cash,
|
||
max_positions=max_positions,
|
||
stop_loss=stop_loss,
|
||
take_profit=take_profit,
|
||
position_sizer=position_sizer,
|
||
date_index_dict=date_index_dict,
|
||
buy_signal_index=buy_signal_index,
|
||
)
|
||
self.ma_short = int(ma_short)
|
||
self.ma_long = int(ma_long)
|
||
self.hold_days = int(hold_days)
|
||
self.position_pct_per_stock = float(position_pct_per_stock)
|
||
|
||
# 输出策略实例化参数(验证参数读取)
|
||
logger.info(f"MaCrossStrategy 初始化参数: ma_short={self.ma_short}, ma_long={self.ma_long}, "
|
||
f"hold_days={self.hold_days}, max_positions={max_positions}, "
|
||
f"position_pct_per_stock={self.position_pct_per_stock}")
|
||
|
||
def on_bar(self, current_date: str, data_dict: Dict[str, pd.DataFrame]) -> None:
|
||
"""每个交易日的交易决策逻辑。
|
||
|
||
性能优化:直接读取预计算的信号列,无需每天重复计算均线和成交量变化。
|
||
"""
|
||
# 1. 先处理卖出逻辑(持有到期或均线死叉)
|
||
for ts_code in list(self.positions.keys()):
|
||
df = data_dict.get(ts_code)
|
||
if df is None or df.empty:
|
||
continue
|
||
|
||
# 性能优化:使用预计算的日期索引
|
||
date_index = self.date_index_dict.get(ts_code)
|
||
if date_index is None or current_date not in date_index:
|
||
continue
|
||
|
||
idx = date_index[current_date]
|
||
row = df.iloc[idx]
|
||
close = float(row["close"])
|
||
|
||
pos = self.positions[ts_code]
|
||
|
||
# 性能优化:直接读取预计算的死叉信号
|
||
death_cross = bool(row["death_cross"]) if "death_cross" in df.columns else False
|
||
|
||
# 满足持有天数或死叉则卖出
|
||
if pos.days_held >= self.hold_days or death_cross:
|
||
quantity = pos.quantity
|
||
if quantity > 0:
|
||
self.sell(ts_code, close, quantity)
|
||
# logger.info(f"{current_date} 卖出 {ts_code} 数量 {quantity} 价格 {close}") # 已移除:交易日志改用进度条
|
||
|
||
# 2. 再处理买入逻辑(金叉 + 放量,按成交量排序选择前N只)
|
||
if len(self.positions) >= self.max_positions:
|
||
return
|
||
|
||
# 优化:如果现金过少(不足初始资金1%),直接返回
|
||
if self.cash < self.initial_cash * 0.01:
|
||
return
|
||
|
||
# 性能优化核心:直接从买入信号索引获取今天有信号的股票列表
|
||
signal_stocks = self.buy_signal_index.get(current_date, [])
|
||
if not signal_stocks:
|
||
return # 今天没有任何买入信号
|
||
|
||
candidates = [] # 候选股票列表:[(ts_code, volume, price), ...]
|
||
|
||
# 只需遍历有买入信号的股票(从3776只减少到99%以上)
|
||
for ts_code in signal_stocks:
|
||
if ts_code in self.positions:
|
||
continue
|
||
|
||
df = data_dict.get(ts_code)
|
||
if df is None or df.empty:
|
||
continue
|
||
|
||
# 性能优化:使用预计算的日期索引
|
||
date_index = self.date_index_dict.get(ts_code)
|
||
if date_index is None or current_date not in date_index:
|
||
continue
|
||
|
||
idx = date_index[current_date]
|
||
row = data_dict[ts_code].iloc[idx]
|
||
|
||
# 满足条件,加入候选列表
|
||
price = float(row["close"])
|
||
volume = float(row["vol"])
|
||
candidates.append((ts_code, volume, price))
|
||
|
||
# 第二步:按成交量倒序排序,选择成交量最大的前N只
|
||
if not candidates:
|
||
return # 没有满足条件的股票
|
||
|
||
# 按成交量降序排序
|
||
candidates.sort(key=lambda x: x[1], reverse=True)
|
||
|
||
# 计算还能买入几只
|
||
remain_slots = self.max_positions - len(self.positions)
|
||
selected = candidates[:remain_slots] # 选择前N只
|
||
|
||
# 第三步:依次买入选中的股票
|
||
for ts_code, volume, price in selected:
|
||
# 使用仓位管理器计算买入数量(如果有的话)
|
||
if self.position_sizer is not None:
|
||
df = data_dict.get(ts_code)
|
||
quantity = self.position_sizer.calc_shares(
|
||
ts_code=ts_code,
|
||
cash=self.cash,
|
||
price=price,
|
||
remain_slots=remain_slots,
|
||
df=df,
|
||
)
|
||
else:
|
||
# 默认:按配置的比例分配仓位
|
||
# 目标:使用初始资金的 position_pct_per_stock 比例(例如50%)
|
||
target_cash = self.initial_cash * self.position_pct_per_stock
|
||
# 实际:如果当前现金不足目标金额,就用剩余的所有现金
|
||
actual_cash = min(target_cash, self.cash)
|
||
|
||
# 计算能买多少股(向下取整到100的整数倍)
|
||
quantity = int(actual_cash // price)
|
||
quantity = (quantity // 100) * 100
|
||
|
||
if quantity <= 0:
|
||
continue
|
||
|
||
# 买入前记录当前现金和仓位
|
||
before_cash = self.cash
|
||
before_positions = len(self.positions)
|
||
|
||
self.buy(ts_code, price, quantity)
|
||
|
||
# 交易日志已移除,改用进度条显示
|
||
# if self.cash < before_cash and len(self.positions) > before_positions:
|
||
# actual_quantity = self.positions[ts_code].quantity
|
||
# actual_cost = before_cash - self.cash
|
||
# logger.info(
|
||
# f"{current_date} 买入 {ts_code} 数量 {actual_quantity} "
|
||
# f"价格 {price:.2f} 成本 {actual_cost:.2f} 成交量 {volume:.0f} "
|
||
# f"剩余现金 {self.cash:.2f} 当前持仓数 {len(self.positions)}/{self.max_positions}"
|
||
# )
|