201 lines
7.3 KiB
Python
201 lines
7.3 KiB
Python
"""止盈止损模块。
|
||
|
||
支持多种止损策略:
|
||
- fixed_pct: 固定百分比止损/止盈
|
||
- atr: ATR倍数止损
|
||
- trailing: 跟踪止盈(最高价回撤)
|
||
|
||
使用方法:
|
||
stop_loss = StopLoss(method="fixed_pct", stop_pct=0.05)
|
||
take_profit = StopLoss(method="fixed_pct", stop_pct=0.15)
|
||
|
||
# 在策略的 on_bar 中检查
|
||
if stop_loss.should_exit(ts_code, high, low, close, cost_price):
|
||
# 执行卖出
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
from typing import Dict, Optional
|
||
|
||
import pandas as pd
|
||
|
||
from utils.logger import setup_logger
|
||
|
||
logger = setup_logger(__name__)
|
||
|
||
|
||
class StopLoss:
|
||
"""止盈止损管理器。
|
||
|
||
支持多种方法:
|
||
- fixed_pct: 固定百分比
|
||
- atr: ATR倍数
|
||
- trailing: 跟踪止盈
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
method: str = "fixed_pct",
|
||
stop_pct: float = 0.05,
|
||
atr_multiplier: float = 2.0,
|
||
atr_period: int = 14,
|
||
trailing_pct: float = 0.10,
|
||
):
|
||
"""初始化止损管理器。
|
||
|
||
参数:
|
||
method: 止损方法 ("fixed_pct", "atr", "trailing")
|
||
stop_pct: 固定百分比止损/止盈比例(0.05 表示 5%)
|
||
atr_multiplier: ATR倍数
|
||
atr_period: ATR周期
|
||
trailing_pct: 跟踪止盈回撤比例
|
||
"""
|
||
self.method = method
|
||
self.stop_pct = stop_pct
|
||
self.atr_multiplier = atr_multiplier
|
||
self.atr_period = atr_period
|
||
self.trailing_pct = trailing_pct
|
||
|
||
# 跟踪止盈需要记录每只股票的最高价
|
||
self.highest_price: Dict[str, float] = {}
|
||
|
||
def should_exit(
|
||
self,
|
||
ts_code: str,
|
||
current_price: float,
|
||
cost_price: float,
|
||
high: Optional[float] = None,
|
||
low: Optional[float] = None,
|
||
atr: Optional[float] = None,
|
||
) -> tuple[bool, str]:
|
||
"""判断是否应该止损/止盈。
|
||
|
||
参数:
|
||
ts_code: 股票代码
|
||
current_price: 当前价格(通常是收盘价)
|
||
cost_price: 成本价
|
||
high: 当日最高价(用于跟踪止盈)
|
||
low: 当日最低价(用于止损)
|
||
atr: ATR指标值(用于ATR止损)
|
||
|
||
返回:
|
||
(bool, str): (是否应该退出, 原因)
|
||
"""
|
||
if self.method == "fixed_pct":
|
||
return self._fixed_pct_exit(ts_code, current_price, cost_price)
|
||
elif self.method == "atr":
|
||
return self._atr_exit(ts_code, current_price, cost_price, atr)
|
||
elif self.method == "trailing":
|
||
return self._trailing_exit(ts_code, current_price, cost_price, high)
|
||
else:
|
||
logger.warning(f"未知的止损方法: {self.method}")
|
||
return False, ""
|
||
|
||
def _fixed_pct_exit(self, ts_code: str, current_price: float, cost_price: float) -> tuple[bool, str]:
|
||
"""固定百分比止损/止盈。
|
||
|
||
止损:当前价格 < 成本价 * (1 - stop_pct)
|
||
止盈:当前价格 > 成本价 * (1 + stop_pct)
|
||
"""
|
||
profit_pct = (current_price - cost_price) / cost_price
|
||
|
||
if profit_pct <= -self.stop_pct:
|
||
reason = f"固定止损 {profit_pct*100:.2f}%"
|
||
# logger.info(f"[STOP] {ts_code} 触发{reason},成本={cost_price:.2f}, 当前={current_price:.2f}") # 已移除:止损日志
|
||
return True, reason
|
||
|
||
# 注意:这里的 stop_pct 实际用作止盈比例
|
||
# 如果需要区分止损和止盈,可以增加一个 take_profit_pct 参数
|
||
if profit_pct >= self.stop_pct:
|
||
reason = f"固定止盈 {profit_pct*100:.2f}%"
|
||
# logger.info(f"[STOP] {ts_code} 触发{reason},成本={cost_price:.2f}, 当前={current_price:.2f}") # 已移除:止盈日志
|
||
return True, reason
|
||
|
||
return False, ""
|
||
|
||
def _atr_exit(self, ts_code: str, current_price: float, cost_price: float, atr: Optional[float]) -> tuple[bool, str]:
|
||
"""ATR倍数止损。
|
||
|
||
止损:当前价格 < 成本价 - ATR * multiplier
|
||
"""
|
||
if atr is None or atr <= 0:
|
||
logger.warning(f"{ts_code} ATR值无效,跳过ATR止损")
|
||
return False, ""
|
||
|
||
stop_price = cost_price - atr * self.atr_multiplier
|
||
|
||
if current_price < stop_price:
|
||
profit_pct = (current_price - cost_price) / cost_price
|
||
reason = f"ATR止损 {profit_pct*100:.2f}% (ATR={atr:.2f})"
|
||
# logger.info(f"[STOP] {ts_code} 触发{reason},成本={cost_price:.2f}, 止损价={stop_price:.2f}, 当前={current_price:.2f}") # 已移除
|
||
return True, reason
|
||
|
||
return False, ""
|
||
|
||
def _trailing_exit(self, ts_code: str, current_price: float, cost_price: float, high: Optional[float]) -> tuple[bool, str]:
|
||
"""跟踪止盈。
|
||
|
||
记录持仓期间的最高价,当价格从最高价回撤超过 trailing_pct 时卖出。
|
||
"""
|
||
# 使用当日最高价或当前价更新历史最高价
|
||
if high is not None:
|
||
price_to_track = max(current_price, high)
|
||
else:
|
||
price_to_track = current_price
|
||
|
||
if ts_code not in self.highest_price:
|
||
self.highest_price[ts_code] = price_to_track
|
||
else:
|
||
self.highest_price[ts_code] = max(self.highest_price[ts_code], price_to_track)
|
||
|
||
highest = self.highest_price[ts_code]
|
||
drawdown_from_high = (highest - current_price) / highest
|
||
|
||
# 只有在盈利的情况下才触发跟踪止盈
|
||
if current_price > cost_price and drawdown_from_high >= self.trailing_pct:
|
||
profit_pct = (current_price - cost_price) / cost_price
|
||
reason = f"跟踪止盈 从最高点{highest:.2f}回撤{drawdown_from_high*100:.2f}%,盈利{profit_pct*100:.2f}%"
|
||
# logger.info(f"[STOP] {ts_code} 触发{reason},成本={cost_price:.2f}, 当前={current_price:.2f}") # 已移除
|
||
|
||
# 清除最高价记录
|
||
del self.highest_price[ts_code]
|
||
return True, reason
|
||
|
||
return False, ""
|
||
|
||
def reset_tracking(self, ts_code: str) -> None:
|
||
"""重置某只股票的跟踪状态(例如清仓后)。"""
|
||
if ts_code in self.highest_price:
|
||
del self.highest_price[ts_code]
|
||
|
||
def clear_all(self) -> None:
|
||
"""清空所有跟踪状态。"""
|
||
self.highest_price.clear()
|
||
|
||
|
||
def calculate_atr(df: pd.DataFrame, period: int = 14) -> pd.Series:
|
||
"""计算ATR指标(Average True Range)。
|
||
|
||
参数:
|
||
df: 包含 ['high', 'low', 'close'] 的DataFrame
|
||
period: ATR周期
|
||
|
||
返回:
|
||
pd.Series: ATR值
|
||
"""
|
||
high = df["high"]
|
||
low = df["low"]
|
||
close = df["close"]
|
||
|
||
# True Range = max(high-low, abs(high-prev_close), abs(low-prev_close))
|
||
tr1 = high - low
|
||
tr2 = (high - close.shift(1)).abs()
|
||
tr3 = (low - close.shift(1)).abs()
|
||
|
||
tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
|
||
|
||
# ATR = MA(TR, period)
|
||
atr = tr.rolling(window=period).mean()
|
||
|
||
return atr
|