"""止盈止损模块。 支持多种止损策略: - fixed_pct: 固定百分比止损/止盈 - atr: ATR倍数止损 - trailing: 跟踪止盈(最高价回撤) 使用方法: stop_loss = StopLoss(method="fixed_pct", stop_pct=0.05) take_profit = StopLoss(method="fixed_pct", stop_pct=0.15) # 在策略的 on_bar 中检查 if stop_loss.should_exit(ts_code, high, low, close, cost_price): # 执行卖出 """ from __future__ import annotations from typing import Dict, Optional import pandas as pd from utils.logger import setup_logger logger = setup_logger(__name__) class StopLoss: """止盈止损管理器。 支持多种方法: - fixed_pct: 固定百分比 - atr: ATR倍数 - trailing: 跟踪止盈 """ def __init__( self, method: str = "fixed_pct", stop_pct: float = 0.05, atr_multiplier: float = 2.0, atr_period: int = 14, trailing_pct: float = 0.10, ): """初始化止损管理器。 参数: method: 止损方法 ("fixed_pct", "atr", "trailing") stop_pct: 固定百分比止损/止盈比例(0.05 表示 5%) atr_multiplier: ATR倍数 atr_period: ATR周期 trailing_pct: 跟踪止盈回撤比例 """ self.method = method self.stop_pct = stop_pct self.atr_multiplier = atr_multiplier self.atr_period = atr_period self.trailing_pct = trailing_pct # 跟踪止盈需要记录每只股票的最高价 self.highest_price: Dict[str, float] = {} def should_exit( self, ts_code: str, current_price: float, cost_price: float, high: Optional[float] = None, low: Optional[float] = None, atr: Optional[float] = None, ) -> tuple[bool, str]: """判断是否应该止损/止盈。 参数: ts_code: 股票代码 current_price: 当前价格(通常是收盘价) cost_price: 成本价 high: 当日最高价(用于跟踪止盈) low: 当日最低价(用于止损) atr: ATR指标值(用于ATR止损) 返回: (bool, str): (是否应该退出, 原因) """ if self.method == "fixed_pct": return self._fixed_pct_exit(ts_code, current_price, cost_price) elif self.method == "atr": return self._atr_exit(ts_code, current_price, cost_price, atr) elif self.method == "trailing": return self._trailing_exit(ts_code, current_price, cost_price, high) else: logger.warning(f"未知的止损方法: {self.method}") return False, "" def _fixed_pct_exit(self, ts_code: str, current_price: float, cost_price: float) -> tuple[bool, str]: """固定百分比止损/止盈。 止损:当前价格 < 成本价 * (1 - stop_pct) 止盈:当前价格 > 成本价 * (1 + stop_pct) """ profit_pct = (current_price - cost_price) / cost_price if profit_pct <= -self.stop_pct: reason = f"固定止损 {profit_pct*100:.2f}%" # logger.info(f"[STOP] {ts_code} 触发{reason},成本={cost_price:.2f}, 当前={current_price:.2f}") # 已移除:止损日志 return True, reason # 注意:这里的 stop_pct 实际用作止盈比例 # 如果需要区分止损和止盈,可以增加一个 take_profit_pct 参数 if profit_pct >= self.stop_pct: reason = f"固定止盈 {profit_pct*100:.2f}%" # logger.info(f"[STOP] {ts_code} 触发{reason},成本={cost_price:.2f}, 当前={current_price:.2f}") # 已移除:止盈日志 return True, reason return False, "" def _atr_exit(self, ts_code: str, current_price: float, cost_price: float, atr: Optional[float]) -> tuple[bool, str]: """ATR倍数止损。 止损:当前价格 < 成本价 - ATR * multiplier """ if atr is None or atr <= 0: logger.warning(f"{ts_code} ATR值无效,跳过ATR止损") return False, "" stop_price = cost_price - atr * self.atr_multiplier if current_price < stop_price: profit_pct = (current_price - cost_price) / cost_price reason = f"ATR止损 {profit_pct*100:.2f}% (ATR={atr:.2f})" # logger.info(f"[STOP] {ts_code} 触发{reason},成本={cost_price:.2f}, 止损价={stop_price:.2f}, 当前={current_price:.2f}") # 已移除 return True, reason return False, "" def _trailing_exit(self, ts_code: str, current_price: float, cost_price: float, high: Optional[float]) -> tuple[bool, str]: """跟踪止盈。 记录持仓期间的最高价,当价格从最高价回撤超过 trailing_pct 时卖出。 """ # 使用当日最高价或当前价更新历史最高价 if high is not None: price_to_track = max(current_price, high) else: price_to_track = current_price if ts_code not in self.highest_price: self.highest_price[ts_code] = price_to_track else: self.highest_price[ts_code] = max(self.highest_price[ts_code], price_to_track) highest = self.highest_price[ts_code] drawdown_from_high = (highest - current_price) / highest # 只有在盈利的情况下才触发跟踪止盈 if current_price > cost_price and drawdown_from_high >= self.trailing_pct: profit_pct = (current_price - cost_price) / cost_price reason = f"跟踪止盈 从最高点{highest:.2f}回撤{drawdown_from_high*100:.2f}%,盈利{profit_pct*100:.2f}%" # logger.info(f"[STOP] {ts_code} 触发{reason},成本={cost_price:.2f}, 当前={current_price:.2f}") # 已移除 # 清除最高价记录 del self.highest_price[ts_code] return True, reason return False, "" def reset_tracking(self, ts_code: str) -> None: """重置某只股票的跟踪状态(例如清仓后)。""" if ts_code in self.highest_price: del self.highest_price[ts_code] def clear_all(self) -> None: """清空所有跟踪状态。""" self.highest_price.clear() def calculate_atr(df: pd.DataFrame, period: int = 14) -> pd.Series: """计算ATR指标(Average True Range)。 参数: df: 包含 ['high', 'low', 'close'] 的DataFrame period: ATR周期 返回: pd.Series: ATR值 """ high = df["high"] low = df["low"] close = df["close"] # True Range = max(high-low, abs(high-prev_close), abs(low-prev_close)) tr1 = high - low tr2 = (high - close.shift(1)).abs() tr3 = (low - close.shift(1)).abs() tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1) # ATR = MA(TR, period) atr = tr.rolling(window=period).mean() return atr