Files
strategy_backtest/risk/position_sizing.py

264 lines
9.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""持仓规模优化模块。
支持多种仓位管理方法:
- equal_weight: 等权分配(默认)
- kelly: Kelly公式需要历史胜率和盈亏比
- volatility_target: 波动率目标(根据股票波动率调整仓位)
使用方法:
position_sizer = PositionSizing(method="equal_weight", max_positions=2)
# 在策略中使用
if position_sizer.can_open(ts_code, cash, price):
shares = position_sizer.calc_shares(ts_code, cash, price, df)
"""
from __future__ import annotations
from typing import Dict, Optional
import pandas as pd
from utils.logger import setup_logger
logger = setup_logger(__name__)
class PositionSizing:
"""持仓规模优化管理器。
支持多种方法:
- equal_weight: 等权分配
- kelly: Kelly公式
- volatility_target: 波动率目标
"""
def __init__(
self,
method: str = "equal_weight",
max_positions: int = 2,
kelly_risk_free: float = 0.03,
kelly_max_fraction: float = 0.25,
volatility_target: float = 0.15,
volatility_window: int = 20,
):
"""初始化持仓规模管理器。
参数:
method: 仓位管理方法 ("equal_weight", "kelly", "volatility_target")
max_positions: 最大持仓数
kelly_risk_free: Kelly公式的无风险利率
kelly_max_fraction: Kelly公式的最大仓位比例防止过度杠杆
volatility_target: 目标波动率(年化)
volatility_window: 计算波动率的窗口期
"""
self.method = method
self.max_positions = max_positions
self.kelly_risk_free = kelly_risk_free
self.kelly_max_fraction = kelly_max_fraction
self.volatility_target = volatility_target
self.volatility_window = volatility_window
# 记录每只股票的历史交易统计用于Kelly公式
self.trade_stats: Dict[str, Dict] = {}
def can_open(self, ts_code: str, cash: float, price: float, current_positions: int) -> bool:
"""判断是否可以开新仓位。
参数:
ts_code: 股票代码
cash: 当前可用现金
price: 当前价格
current_positions: 当前持仓数量
返回:
bool: 是否可以开仓
"""
# 检查仓位是否已满
if current_positions >= self.max_positions:
return False
# 检查资金是否足够买入至少 100 股
min_cost = price * 100
if cash < min_cost:
return False
return True
def calc_shares(
self,
ts_code: str,
cash: float,
price: float,
remain_slots: int,
df: Optional[pd.DataFrame] = None,
) -> int:
"""计算应该买入的股数。
参数:
ts_code: 股票代码
cash: 当前可用现金
price: 当前价格
remain_slots: 剩余可用仓位数
df: 股票历史数据(用于计算波动率等指标)
返回:
int: 买入股数已取整到100的整数倍
"""
if self.method == "equal_weight":
return self._equal_weight_shares(cash, price, remain_slots)
elif self.method == "kelly":
return self._kelly_shares(ts_code, cash, price, remain_slots)
elif self.method == "volatility_target":
return self._volatility_target_shares(ts_code, cash, price, remain_slots, df)
else:
logger.warning(f"未知的仓位管理方法: {self.method},使用等权分配")
return self._equal_weight_shares(cash, price, remain_slots)
def _equal_weight_shares(self, cash: float, price: float, remain_slots: int) -> int:
"""等权分配:平均分配现金到剩余仓位。
例如现金100万剩余2个仓位每个仓位分配50万
"""
if remain_slots <= 0:
return 0
cash_per_stock = cash / remain_slots
shares = int(cash_per_stock // price)
# A股规则向下取整到100的整数倍
shares = (shares // 100) * 100
return shares
def _kelly_shares(self, ts_code: str, cash: float, price: float, remain_slots: int) -> int:
"""Kelly公式根据历史胜率和盈亏比计算最优仓位。
Kelly% = (胜率 * 盈亏比 - 败率) / 盈亏比
注意:需要积累一定的交易历史才能准确计算。
如果没有历史数据,回退到等权分配。
"""
stats = self.trade_stats.get(ts_code)
if stats is None or stats.get("total_trades", 0) < 10:
# 交易次数不足,使用等权分配
logger.debug(f"{ts_code} 历史交易不足,使用等权分配")
return self._equal_weight_shares(cash, price, remain_slots)
win_rate = stats.get("win_rate", 0.5)
avg_win = stats.get("avg_win", 0.05)
avg_loss = stats.get("avg_loss", 0.05)
if avg_loss <= 0:
avg_loss = 0.01 # 避免除零
profit_loss_ratio = avg_win / avg_loss
kelly_fraction = (win_rate * profit_loss_ratio - (1 - win_rate)) / profit_loss_ratio
# Kelly公式可能给出负值或过大值需要限制
kelly_fraction = max(0, min(kelly_fraction, self.kelly_max_fraction))
# 考虑剩余仓位数,分配资金
cash_per_stock = cash / remain_slots
kelly_cash = cash_per_stock * kelly_fraction
shares = int(kelly_cash // price)
shares = (shares // 100) * 100
logger.debug(
f"{ts_code} Kelly仓位: {kelly_fraction*100:.2f}%, "
f"胜率={win_rate*100:.1f}%, 盈亏比={profit_loss_ratio:.2f}"
)
return shares
def _volatility_target_shares(
self,
ts_code: str,
cash: float,
price: float,
remain_slots: int,
df: Optional[pd.DataFrame],
) -> int:
"""波动率目标:根据股票波动率调整仓位。
波动率高的股票减小仓位,波动率低的股票增大仓位。
目标:使每个仓位的波动率贡献接近目标波动率。
"""
if df is None or df.empty:
logger.debug(f"{ts_code} 缺少历史数据,使用等权分配")
return self._equal_weight_shares(cash, price, remain_slots)
# 计算历史波动率(日收益率标准差 * sqrt(252)
if "close" not in df.columns or len(df) < self.volatility_window:
logger.debug(f"{ts_code} 数据不足,使用等权分配")
return self._equal_weight_shares(cash, price, remain_slots)
returns = df["close"].pct_change().dropna()
if len(returns) < self.volatility_window:
logger.debug(f"{ts_code} 数据不足,使用等权分配")
return self._equal_weight_shares(cash, price, remain_slots)
# 使用最近 volatility_window 天的数据计算波动率
recent_volatility = returns.tail(self.volatility_window).std() * (252 ** 0.5)
if recent_volatility <= 0:
logger.debug(f"{ts_code} 波动率为0使用等权分配")
return self._equal_weight_shares(cash, price, remain_slots)
# 仓位调整因子 = 目标波动率 / 实际波动率
volatility_factor = self.volatility_target / recent_volatility
volatility_factor = max(0.5, min(volatility_factor, 2.0)) # 限制在 [0.5, 2.0]
# 基础等权分配 * 波动率因子
base_shares = self._equal_weight_shares(cash, price, remain_slots)
adjusted_shares = int(base_shares * volatility_factor)
adjusted_shares = (adjusted_shares // 100) * 100
logger.debug(
f"{ts_code} 波动率调整: 实际={recent_volatility*100:.2f}%, "
f"目标={self.volatility_target*100:.2f}%, 因子={volatility_factor:.2f}"
)
return adjusted_shares
def update_trade_stats(self, ts_code: str, profit_pct: float) -> None:
"""更新交易统计用于Kelly公式
参数:
ts_code: 股票代码
profit_pct: 本次交易盈亏比例(正数为盈利,负数为亏损)
"""
if ts_code not in self.trade_stats:
self.trade_stats[ts_code] = {
"total_trades": 0,
"wins": 0,
"losses": 0,
"total_win": 0.0,
"total_loss": 0.0,
}
stats = self.trade_stats[ts_code]
stats["total_trades"] += 1
if profit_pct > 0:
stats["wins"] += 1
stats["total_win"] += profit_pct
else:
stats["losses"] += 1
stats["total_loss"] += abs(profit_pct)
# 计算平均值
stats["win_rate"] = stats["wins"] / stats["total_trades"]
stats["avg_win"] = stats["total_win"] / stats["wins"] if stats["wins"] > 0 else 0.05
stats["avg_loss"] = stats["total_loss"] / stats["losses"] if stats["losses"] > 0 else 0.05
def get_stats(self, ts_code: str) -> Optional[Dict]:
"""获取某只股票的交易统计。"""
return self.trade_stats.get(ts_code)
def clear_stats(self) -> None:
"""清空所有交易统计。"""
self.trade_stats.clear()