356 lines
13 KiB
Python
356 lines
13 KiB
Python
import pandas as pd
|
||
import numpy as np
|
||
import logging
|
||
from typing import Dict, List, Tuple, Optional
|
||
from .base_strategy import BaseStrategy
|
||
from .market_filter import MarketFilter
|
||
|
||
class YiYangStrategy(BaseStrategy):
|
||
"""一阳穿三线策略实现"""
|
||
|
||
def __init__(self, config: Dict):
|
||
super().__init__()
|
||
self.market_filter = MarketFilter(config)
|
||
# 使用主日志记录器,确保日志能被正确记录到文件
|
||
self.logger = logging.getLogger('backtrader')
|
||
|
||
# 策略参数
|
||
self.min_pct_change = config.get('min_pct_change', 3.0)
|
||
self.min_volume_ratio = config.get('min_volume_ratio', 1.5)
|
||
self.max_ma_gap_pct = config.get('max_ma_gap_pct', 5.0)
|
||
self.min_entity_ratio = config.get('min_entity_ratio', 0.6)
|
||
self.confirm_days = config.get('confirm_days', 3)
|
||
|
||
# 均线组合设置
|
||
self.ma_combinations = [
|
||
('ma5', 'ma10', 'ma20'), # 短线组合
|
||
('ma10', 'ma20', 'ma60') # 中线组合
|
||
]
|
||
|
||
def initialize(self, data):
|
||
"""
|
||
初始化策略,处理数据,计算指标等
|
||
|
||
Args:
|
||
data: 股票数据DataFrame
|
||
"""
|
||
# 初始化信号数据
|
||
self.signals = None
|
||
|
||
# 初始化市场过滤器
|
||
start_date = data.index.min().strftime('%Y%m%d') if hasattr(data.index.min(), 'strftime') else str(data.index.min())
|
||
end_date = data.index.max().strftime('%Y%m%d') if hasattr(data.index.max(), 'strftime') else str(data.index.max())
|
||
self.market_filter.initialize(start_date=start_date, end_date=end_date)
|
||
|
||
def generate_signals(self, data):
|
||
"""
|
||
生成交易信号(向量化优化版本)
|
||
Args:
|
||
data: 股票数据DataFrame
|
||
|
||
Returns:
|
||
pandas.DataFrame: 添加了信号列的DataFrame
|
||
"""
|
||
if data.empty or len(data) < 60:
|
||
return pd.DataFrame()
|
||
|
||
# 确保数据按日期从旧到新排序
|
||
data = data.sort_index()
|
||
|
||
# 使用向量化方法批量检查条件
|
||
signals = self._vectorized_signal_generation(data)
|
||
|
||
# 保存信号
|
||
self.signals = signals
|
||
return signals
|
||
|
||
def _vectorized_signal_generation(self, data: pd.DataFrame) -> pd.DataFrame:
|
||
"""向量化信号生成(核心优化)"""
|
||
n = len(data)
|
||
signals = pd.DataFrame(index=data.index)
|
||
signals['signal'] = 0
|
||
signals['strength'] = 0.0
|
||
signals['reason'] = ''
|
||
|
||
if n < 60:
|
||
return signals
|
||
|
||
# === 向量化基础条件检查 ===
|
||
# 1. 涨幅条件
|
||
pct_check = data['pct_change'].fillna(0) >= self.min_pct_change
|
||
|
||
# 2. 量能条件
|
||
vol_check = data['volume_ratio'].fillna(0) >= self.min_volume_ratio
|
||
|
||
# 3. 实体比例条件
|
||
entity_check = data['entity_ratio'].fillna(0) >= self.min_entity_ratio
|
||
|
||
# 4. 阳线条件
|
||
yang_check = data['close'] > data['open']
|
||
|
||
# 组合基础条件
|
||
basic_cond = pct_check & vol_check & entity_check & yang_check
|
||
|
||
# === 向量化均线穿越检查 ===
|
||
# 检查两组均线
|
||
ma_cross_1 = self._vectorized_ma_cross(data, ('ma5', 'ma10', 'ma20'))
|
||
ma_cross_2 = self._vectorized_ma_cross(data, ('ma10', 'ma20', 'ma60'))
|
||
ma_cross = ma_cross_1 | ma_cross_2
|
||
|
||
# === 向量化均线粘合检查 ===
|
||
ma_conv = self._vectorized_ma_convergence(data)
|
||
|
||
# 综合条件(暂不检查后续确认)
|
||
candidate_signals = basic_cond & ma_cross & ma_conv
|
||
|
||
# 对候选信号逐个进行后续确认(这部分难以向量化)
|
||
for idx in np.where(candidate_signals)[0]:
|
||
if idx < 60: # 跳过前60天
|
||
continue
|
||
|
||
# 后续确认检查
|
||
if self._check_confirmation(data, idx):
|
||
signals.iloc[idx, 0] = 1 # signal
|
||
signals.iloc[idx, 1] = self._calculate_signal_strength_vectorized(data, idx) # strength
|
||
signals.iloc[idx, 2] = '一阳穿三线' # reason
|
||
|
||
return signals
|
||
|
||
def _vectorized_ma_cross(self, data: pd.DataFrame, ma_combo: Tuple) -> pd.Series:
|
||
"""向量化均线穿越检查"""
|
||
ma1, ma2, ma3 = ma_combo
|
||
|
||
# 开盘价低于至少两条均线
|
||
open_below_1 = data['open'] < data[ma1]
|
||
open_below_2 = data['open'] < data[ma2]
|
||
open_below_3 = data['open'] < data[ma3]
|
||
open_below_count = open_below_1.astype(int) + open_below_2.astype(int) + open_below_3.astype(int)
|
||
|
||
# 收盘价高于至少两条均线
|
||
close_above_1 = data['close'] > data[ma1]
|
||
close_above_2 = data['close'] > data[ma2]
|
||
close_above_3 = data['close'] > data[ma3]
|
||
close_above_count = close_above_1.astype(int) + close_above_2.astype(int) + close_above_3.astype(int)
|
||
|
||
return (open_below_count >= 2) & (close_above_count >= 2)
|
||
|
||
def _vectorized_ma_convergence(self, data: pd.DataFrame) -> pd.Series:
|
||
"""向量化均线粘合度检查"""
|
||
# 计算三条均线的最大最小值
|
||
ma5 = data['ma5'].fillna(0)
|
||
ma10 = data['ma10'].fillna(0)
|
||
ma20 = data['ma20'].fillna(0)
|
||
|
||
# 三个均线的最大最小值
|
||
max_ma = pd.concat([ma5, ma10, ma20], axis=1).max(axis=1)
|
||
min_ma = pd.concat([ma5, ma10, ma20], axis=1).min(axis=1)
|
||
|
||
# 计算间距百分比
|
||
gap_pct = (max_ma - min_ma) / min_ma.replace(0, np.nan) * 100
|
||
|
||
return (gap_pct <= self.max_ma_gap_pct) & (min_ma > 0)
|
||
|
||
def _calculate_signal_strength_vectorized(self, data: pd.DataFrame, idx: int) -> float:
|
||
"""向量化计算信号强度"""
|
||
row = data.iloc[idx]
|
||
strength = 0.0
|
||
|
||
# 1. 涨幅贡献
|
||
strength += min(row.get('pct_change', 0) / 10.0, 0.3)
|
||
|
||
# 2. 量能贡献
|
||
strength += min((row.get('volume_ratio', 1) - 1) / 3.0, 0.3)
|
||
|
||
# 3. 均线粘合贡献
|
||
mas = [row.get('ma5', 0), row.get('ma10', 0), row.get('ma20', 0)]
|
||
if all(ma > 0 for ma in mas):
|
||
gap_pct = (max(mas) - min(mas)) / min(mas) * 100
|
||
strength += max(0, 0.2 * (1 - gap_pct / self.max_ma_gap_pct))
|
||
|
||
# 4. 位置贡献
|
||
ma60 = row.get('ma60', 0)
|
||
if ma60 > 0:
|
||
ratio = row['close'] / ma60
|
||
if ratio < 1.1:
|
||
strength += 0.2
|
||
elif ratio < 1.3:
|
||
strength += 0.1
|
||
|
||
return min(strength, 1.0)
|
||
|
||
def on_signal(self, date, signal, price):
|
||
"""
|
||
处理交易信号
|
||
|
||
Args:
|
||
date: 日期
|
||
signal: 信号类型 (1=买入, 0=持有, -1=卖出)
|
||
price: 价格
|
||
|
||
Returns:
|
||
dict: 交易指令 {action: 'buy'|'sell'|'hold', quantity: int}
|
||
"""
|
||
if signal == 1:
|
||
# 买入信号
|
||
return {'action': 'buy', 'quantity': 100} # 示例:买入100股
|
||
elif signal == -1:
|
||
# 卖出信号
|
||
return {'action': 'sell', 'quantity': 100} # 示例:卖出100股
|
||
else:
|
||
# 持有
|
||
return {'action': 'hold', 'quantity': 0}
|
||
|
||
def _check_yiyang_signal(self, data: pd.DataFrame, idx: int) -> bool:
|
||
"""检查是否符合一阳穿三线条件"""
|
||
# 检查基本条件
|
||
if not self._check_basic_conditions(data, idx):
|
||
return False
|
||
|
||
# 检查是否穿越至少一组均线
|
||
signal_found = False
|
||
for ma_combo in self.ma_combinations:
|
||
if self._check_ma_cross(data, idx, ma_combo):
|
||
signal_found = True
|
||
break
|
||
|
||
if not signal_found:
|
||
return False
|
||
|
||
# 检查均线粘合度
|
||
if not self._check_ma_convergence(data, idx):
|
||
return False
|
||
|
||
# 暂时跳过大盘环境检查,以便测试策略核心逻辑
|
||
# if not self.market_filter.is_good_market():
|
||
# return False
|
||
|
||
# 检查后续确认
|
||
if not self._check_confirmation(data, idx):
|
||
return False
|
||
|
||
return True
|
||
|
||
def _check_basic_conditions(self, data: pd.DataFrame, idx: int) -> bool:
|
||
"""检查基础条件"""
|
||
row = data.iloc[idx]
|
||
prev_row = data.iloc[idx-1]
|
||
|
||
# 1. 涨幅条件 - 使用安全的get方法获取值
|
||
if row.get('pct_change', 0) < self.min_pct_change:
|
||
return False
|
||
|
||
# 2. 量能条件 - 使用安全的get方法获取值
|
||
if row.get('volume_ratio', 0) < self.min_volume_ratio:
|
||
return False
|
||
|
||
# 3. 阳线实体条件 - 使用安全的get方法获取值
|
||
if row.get('entity_ratio', 0) < self.min_entity_ratio:
|
||
return False
|
||
|
||
# 4. 必须是阳线
|
||
if row['close'] <= row['open']:
|
||
return False
|
||
|
||
return True
|
||
|
||
def _check_ma_cross(self, data: pd.DataFrame, idx: int, ma_combo: Tuple) -> bool:
|
||
"""检查是否穿越指定均线组合"""
|
||
ma1, ma2, ma3 = ma_combo
|
||
|
||
row = data.iloc[idx]
|
||
prev_row = data.iloc[idx-1]
|
||
|
||
# 放宽条件:开盘价低于至少两条均线
|
||
open_below_count = 0
|
||
if row['open'] < row[ma1]:
|
||
open_below_count += 1
|
||
if row['open'] < row[ma2]:
|
||
open_below_count += 1
|
||
if row['open'] < row[ma3]:
|
||
open_below_count += 1
|
||
|
||
# 放宽条件:收盘价高于至少两条均线
|
||
close_above_count = 0
|
||
if row['close'] > row[ma1]:
|
||
close_above_count += 1
|
||
if row['close'] > row[ma2]:
|
||
close_above_count += 1
|
||
if row['close'] > row[ma3]:
|
||
close_above_count += 1
|
||
|
||
return open_below_count >= 2 and close_above_count >= 2
|
||
|
||
def _check_ma_convergence(self, data: pd.DataFrame, idx: int) -> bool:
|
||
"""检查均线粘合度"""
|
||
row = data.iloc[idx]
|
||
|
||
# 计算三条短期均线的最大值和最小值 - 使用安全的get方法获取值
|
||
mas = [row.get('ma5', 0), row.get('ma10', 0), row.get('ma20', 0)]
|
||
|
||
# 确保ma值有效
|
||
if all(ma > 0 for ma in mas):
|
||
max_ma = max(mas)
|
||
min_ma = min(mas)
|
||
|
||
# 计算间距百分比
|
||
gap_pct = (max_ma - min_ma) / min_ma * 100
|
||
|
||
return gap_pct <= self.max_ma_gap_pct
|
||
|
||
return False # 如果均线值无效,返回False
|
||
|
||
def _check_confirmation(self, data: pd.DataFrame, idx: int) -> bool:
|
||
"""检查后续确认"""
|
||
if idx + self.confirm_days >= len(data):
|
||
return True
|
||
|
||
# 获取基准均线值 - 使用安全的get方法获取值
|
||
base_ma20 = data.iloc[idx].get('ma20', 0)
|
||
if base_ma20 <= 0:
|
||
return False # 如果ma20无效,无法进行确认检查
|
||
|
||
# 检查后续self.confirm_days天是否维持在均线之上
|
||
for i in range(1, min(self.confirm_days + 1, len(data) - idx)):
|
||
check_row = data.iloc[idx + i]
|
||
if check_row['close'] < base_ma20:
|
||
return False
|
||
|
||
return True
|
||
|
||
def _calculate_signal_strength(self, data: pd.DataFrame, idx: int) -> float:
|
||
"""计算信号强度(0-1)"""
|
||
strength = 0.0
|
||
row = data.iloc[idx]
|
||
|
||
# 1. 涨幅贡献 (0-0.3分) - 使用安全的get方法获取值
|
||
pct_strength = min(row.get('pct_change', 0) / 10.0, 0.3)
|
||
strength += pct_strength
|
||
|
||
# 2. 量能贡献 (0-0.3分) - 使用安全的get方法获取值
|
||
volume_strength = min((row.get('volume_ratio', 1) - 1) / 3.0, 0.3)
|
||
strength += volume_strength
|
||
|
||
# 3. 均线粘合贡献 (0-0.2分) - 使用安全的get方法获取值
|
||
mas = [row.get('ma5', 0), row.get('ma10', 0), row.get('ma20', 0)]
|
||
# 确保ma值有效
|
||
if all(ma > 0 for ma in mas):
|
||
max_ma = max(mas)
|
||
min_ma = min(mas)
|
||
gap_pct = (max_ma - min_ma) / min_ma * 100
|
||
ma_strength = max(0, 0.2 * (1 - gap_pct / self.max_ma_gap_pct))
|
||
strength += ma_strength
|
||
|
||
# 4. 位置贡献 (0-0.2分)
|
||
# 计算相对于60日均线的位置 - 使用安全的get方法获取值
|
||
ma60 = row.get('ma60', 0)
|
||
if ma60 > 0:
|
||
price_ma60_ratio = row['close'] / ma60
|
||
if price_ma60_ratio < 1.1: # 在60日线附近
|
||
position_strength = 0.2
|
||
elif price_ma60_ratio < 1.3: # 略高于60日线
|
||
position_strength = 0.1
|
||
else:
|
||
position_strength = 0.0
|
||
strength += position_strength
|
||
|
||
return min(strength, 1.0) |