Files
strategy_backtest/risk/stop_loss.py

201 lines
7.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""止盈止损模块。
支持多种止损策略:
- fixed_pct: 固定百分比止损/止盈
- atr: ATR倍数止损
- trailing: 跟踪止盈(最高价回撤)
使用方法:
stop_loss = StopLoss(method="fixed_pct", stop_pct=0.05)
take_profit = StopLoss(method="fixed_pct", stop_pct=0.15)
# 在策略的 on_bar 中检查
if stop_loss.should_exit(ts_code, high, low, close, cost_price):
# 执行卖出
"""
from __future__ import annotations
from typing import Dict, Optional
import pandas as pd
from utils.logger import setup_logger
logger = setup_logger(__name__)
class StopLoss:
"""止盈止损管理器。
支持多种方法:
- fixed_pct: 固定百分比
- atr: ATR倍数
- trailing: 跟踪止盈
"""
def __init__(
self,
method: str = "fixed_pct",
stop_pct: float = 0.05,
atr_multiplier: float = 2.0,
atr_period: int = 14,
trailing_pct: float = 0.10,
):
"""初始化止损管理器。
参数:
method: 止损方法 ("fixed_pct", "atr", "trailing")
stop_pct: 固定百分比止损/止盈比例0.05 表示 5%
atr_multiplier: ATR倍数
atr_period: ATR周期
trailing_pct: 跟踪止盈回撤比例
"""
self.method = method
self.stop_pct = stop_pct
self.atr_multiplier = atr_multiplier
self.atr_period = atr_period
self.trailing_pct = trailing_pct
# 跟踪止盈需要记录每只股票的最高价
self.highest_price: Dict[str, float] = {}
def should_exit(
self,
ts_code: str,
current_price: float,
cost_price: float,
high: Optional[float] = None,
low: Optional[float] = None,
atr: Optional[float] = None,
) -> tuple[bool, str]:
"""判断是否应该止损/止盈。
参数:
ts_code: 股票代码
current_price: 当前价格(通常是收盘价)
cost_price: 成本价
high: 当日最高价(用于跟踪止盈)
low: 当日最低价(用于止损)
atr: ATR指标值用于ATR止损
返回:
(bool, str): (是否应该退出, 原因)
"""
if self.method == "fixed_pct":
return self._fixed_pct_exit(ts_code, current_price, cost_price)
elif self.method == "atr":
return self._atr_exit(ts_code, current_price, cost_price, atr)
elif self.method == "trailing":
return self._trailing_exit(ts_code, current_price, cost_price, high)
else:
logger.warning(f"未知的止损方法: {self.method}")
return False, ""
def _fixed_pct_exit(self, ts_code: str, current_price: float, cost_price: float) -> tuple[bool, str]:
"""固定百分比止损/止盈。
止损:当前价格 < 成本价 * (1 - stop_pct)
止盈:当前价格 > 成本价 * (1 + stop_pct)
"""
profit_pct = (current_price - cost_price) / cost_price
if profit_pct <= -self.stop_pct:
reason = f"固定止损 {profit_pct*100:.2f}%"
# logger.info(f"[STOP] {ts_code} 触发{reason},成本={cost_price:.2f}, 当前={current_price:.2f}") # 已移除:止损日志
return True, reason
# 注意:这里的 stop_pct 实际用作止盈比例
# 如果需要区分止损和止盈,可以增加一个 take_profit_pct 参数
if profit_pct >= self.stop_pct:
reason = f"固定止盈 {profit_pct*100:.2f}%"
# logger.info(f"[STOP] {ts_code} 触发{reason},成本={cost_price:.2f}, 当前={current_price:.2f}") # 已移除:止盈日志
return True, reason
return False, ""
def _atr_exit(self, ts_code: str, current_price: float, cost_price: float, atr: Optional[float]) -> tuple[bool, str]:
"""ATR倍数止损。
止损:当前价格 < 成本价 - ATR * multiplier
"""
if atr is None or atr <= 0:
logger.warning(f"{ts_code} ATR值无效跳过ATR止损")
return False, ""
stop_price = cost_price - atr * self.atr_multiplier
if current_price < stop_price:
profit_pct = (current_price - cost_price) / cost_price
reason = f"ATR止损 {profit_pct*100:.2f}% (ATR={atr:.2f})"
# logger.info(f"[STOP] {ts_code} 触发{reason},成本={cost_price:.2f}, 止损价={stop_price:.2f}, 当前={current_price:.2f}") # 已移除
return True, reason
return False, ""
def _trailing_exit(self, ts_code: str, current_price: float, cost_price: float, high: Optional[float]) -> tuple[bool, str]:
"""跟踪止盈。
记录持仓期间的最高价,当价格从最高价回撤超过 trailing_pct 时卖出。
"""
# 使用当日最高价或当前价更新历史最高价
if high is not None:
price_to_track = max(current_price, high)
else:
price_to_track = current_price
if ts_code not in self.highest_price:
self.highest_price[ts_code] = price_to_track
else:
self.highest_price[ts_code] = max(self.highest_price[ts_code], price_to_track)
highest = self.highest_price[ts_code]
drawdown_from_high = (highest - current_price) / highest
# 只有在盈利的情况下才触发跟踪止盈
if current_price > cost_price and drawdown_from_high >= self.trailing_pct:
profit_pct = (current_price - cost_price) / cost_price
reason = f"跟踪止盈 从最高点{highest:.2f}回撤{drawdown_from_high*100:.2f}%,盈利{profit_pct*100:.2f}%"
# logger.info(f"[STOP] {ts_code} 触发{reason},成本={cost_price:.2f}, 当前={current_price:.2f}") # 已移除
# 清除最高价记录
del self.highest_price[ts_code]
return True, reason
return False, ""
def reset_tracking(self, ts_code: str) -> None:
"""重置某只股票的跟踪状态(例如清仓后)。"""
if ts_code in self.highest_price:
del self.highest_price[ts_code]
def clear_all(self) -> None:
"""清空所有跟踪状态。"""
self.highest_price.clear()
def calculate_atr(df: pd.DataFrame, period: int = 14) -> pd.Series:
"""计算ATR指标Average True Range
参数:
df: 包含 ['high', 'low', 'close'] 的DataFrame
period: ATR周期
返回:
pd.Series: ATR值
"""
high = df["high"]
low = df["low"]
close = df["close"]
# True Range = max(high-low, abs(high-prev_close), abs(low-prev_close))
tr1 = high - low
tr2 = (high - close.shift(1)).abs()
tr3 = (low - close.shift(1)).abs()
tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
# ATR = MA(TR, period)
atr = tr.rolling(window=period).mean()
return atr