import pandas as pd import numpy as np import logging from typing import Dict, List, Tuple, Optional from .base_strategy import BaseStrategy from .market_filter import MarketFilter class YiYangStrategy(BaseStrategy): """一阳穿三线策略实现""" def __init__(self, config: Dict): super().__init__() self.market_filter = MarketFilter(config) # 使用主日志记录器,确保日志能被正确记录到文件 self.logger = logging.getLogger('backtrader') # 策略参数 self.min_pct_change = config.get('min_pct_change', 3.0) self.min_volume_ratio = config.get('min_volume_ratio', 1.5) self.max_ma_gap_pct = config.get('max_ma_gap_pct', 5.0) self.min_entity_ratio = config.get('min_entity_ratio', 0.6) self.confirm_days = config.get('confirm_days', 3) # 均线组合设置 self.ma_combinations = [ ('ma5', 'ma10', 'ma20'), # 短线组合 ('ma10', 'ma20', 'ma60') # 中线组合 ] def initialize(self, data): """ 初始化策略,处理数据,计算指标等 Args: data: 股票数据DataFrame """ # 初始化信号数据 self.signals = None # 初始化市场过滤器 start_date = data.index.min().strftime('%Y%m%d') if hasattr(data.index.min(), 'strftime') else str(data.index.min()) end_date = data.index.max().strftime('%Y%m%d') if hasattr(data.index.max(), 'strftime') else str(data.index.max()) self.market_filter.initialize(start_date=start_date, end_date=end_date) def generate_signals(self, data): """ 生成交易信号(向量化优化版本) Args: data: 股票数据DataFrame Returns: pandas.DataFrame: 添加了信号列的DataFrame """ if data.empty or len(data) < 60: return pd.DataFrame() # 确保数据按日期从旧到新排序 data = data.sort_index() # 使用向量化方法批量检查条件 signals = self._vectorized_signal_generation(data) # 保存信号 self.signals = signals return signals def _vectorized_signal_generation(self, data: pd.DataFrame) -> pd.DataFrame: """向量化信号生成(核心优化)""" n = len(data) signals = pd.DataFrame(index=data.index) signals['signal'] = 0 signals['strength'] = 0.0 signals['reason'] = '' if n < 60: return signals # === 向量化基础条件检查 === # 1. 涨幅条件 pct_check = data['pct_change'].fillna(0) >= self.min_pct_change # 2. 量能条件 vol_check = data['volume_ratio'].fillna(0) >= self.min_volume_ratio # 3. 实体比例条件 entity_check = data['entity_ratio'].fillna(0) >= self.min_entity_ratio # 4. 阳线条件 yang_check = data['close'] > data['open'] # 组合基础条件 basic_cond = pct_check & vol_check & entity_check & yang_check # === 向量化均线穿越检查 === # 检查两组均线 ma_cross_1 = self._vectorized_ma_cross(data, ('ma5', 'ma10', 'ma20')) ma_cross_2 = self._vectorized_ma_cross(data, ('ma10', 'ma20', 'ma60')) ma_cross = ma_cross_1 | ma_cross_2 # === 向量化均线粘合检查 === ma_conv = self._vectorized_ma_convergence(data) # 综合条件(暂不检查后续确认) candidate_signals = basic_cond & ma_cross & ma_conv # 对候选信号逐个进行后续确认(这部分难以向量化) for idx in np.where(candidate_signals)[0]: if idx < 60: # 跳过前60天 continue # 后续确认检查 if self._check_confirmation(data, idx): signals.iloc[idx, 0] = 1 # signal signals.iloc[idx, 1] = self._calculate_signal_strength_vectorized(data, idx) # strength signals.iloc[idx, 2] = '一阳穿三线' # reason return signals def _vectorized_ma_cross(self, data: pd.DataFrame, ma_combo: Tuple) -> pd.Series: """向量化均线穿越检查""" ma1, ma2, ma3 = ma_combo # 开盘价低于至少两条均线 open_below_1 = data['open'] < data[ma1] open_below_2 = data['open'] < data[ma2] open_below_3 = data['open'] < data[ma3] open_below_count = open_below_1.astype(int) + open_below_2.astype(int) + open_below_3.astype(int) # 收盘价高于至少两条均线 close_above_1 = data['close'] > data[ma1] close_above_2 = data['close'] > data[ma2] close_above_3 = data['close'] > data[ma3] close_above_count = close_above_1.astype(int) + close_above_2.astype(int) + close_above_3.astype(int) return (open_below_count >= 2) & (close_above_count >= 2) def _vectorized_ma_convergence(self, data: pd.DataFrame) -> pd.Series: """向量化均线粘合度检查""" # 计算三条均线的最大最小值 ma5 = data['ma5'].fillna(0) ma10 = data['ma10'].fillna(0) ma20 = data['ma20'].fillna(0) # 三个均线的最大最小值 max_ma = pd.concat([ma5, ma10, ma20], axis=1).max(axis=1) min_ma = pd.concat([ma5, ma10, ma20], axis=1).min(axis=1) # 计算间距百分比 gap_pct = (max_ma - min_ma) / min_ma.replace(0, np.nan) * 100 return (gap_pct <= self.max_ma_gap_pct) & (min_ma > 0) def _calculate_signal_strength_vectorized(self, data: pd.DataFrame, idx: int) -> float: """向量化计算信号强度""" row = data.iloc[idx] strength = 0.0 # 1. 涨幅贡献 strength += min(row.get('pct_change', 0) / 10.0, 0.3) # 2. 量能贡献 strength += min((row.get('volume_ratio', 1) - 1) / 3.0, 0.3) # 3. 均线粘合贡献 mas = [row.get('ma5', 0), row.get('ma10', 0), row.get('ma20', 0)] if all(ma > 0 for ma in mas): gap_pct = (max(mas) - min(mas)) / min(mas) * 100 strength += max(0, 0.2 * (1 - gap_pct / self.max_ma_gap_pct)) # 4. 位置贡献 ma60 = row.get('ma60', 0) if ma60 > 0: ratio = row['close'] / ma60 if ratio < 1.1: strength += 0.2 elif ratio < 1.3: strength += 0.1 return min(strength, 1.0) def on_signal(self, date, signal, price): """ 处理交易信号 Args: date: 日期 signal: 信号类型 (1=买入, 0=持有, -1=卖出) price: 价格 Returns: dict: 交易指令 {action: 'buy'|'sell'|'hold', quantity: int} """ if signal == 1: # 买入信号 return {'action': 'buy', 'quantity': 100} # 示例:买入100股 elif signal == -1: # 卖出信号 return {'action': 'sell', 'quantity': 100} # 示例:卖出100股 else: # 持有 return {'action': 'hold', 'quantity': 0} def _check_yiyang_signal(self, data: pd.DataFrame, idx: int) -> bool: """检查是否符合一阳穿三线条件""" # 检查基本条件 if not self._check_basic_conditions(data, idx): return False # 检查是否穿越至少一组均线 signal_found = False for ma_combo in self.ma_combinations: if self._check_ma_cross(data, idx, ma_combo): signal_found = True break if not signal_found: return False # 检查均线粘合度 if not self._check_ma_convergence(data, idx): return False # 暂时跳过大盘环境检查,以便测试策略核心逻辑 # if not self.market_filter.is_good_market(): # return False # 检查后续确认 if not self._check_confirmation(data, idx): return False return True def _check_basic_conditions(self, data: pd.DataFrame, idx: int) -> bool: """检查基础条件""" row = data.iloc[idx] prev_row = data.iloc[idx-1] # 1. 涨幅条件 - 使用安全的get方法获取值 if row.get('pct_change', 0) < self.min_pct_change: return False # 2. 量能条件 - 使用安全的get方法获取值 if row.get('volume_ratio', 0) < self.min_volume_ratio: return False # 3. 阳线实体条件 - 使用安全的get方法获取值 if row.get('entity_ratio', 0) < self.min_entity_ratio: return False # 4. 必须是阳线 if row['close'] <= row['open']: return False return True def _check_ma_cross(self, data: pd.DataFrame, idx: int, ma_combo: Tuple) -> bool: """检查是否穿越指定均线组合""" ma1, ma2, ma3 = ma_combo row = data.iloc[idx] prev_row = data.iloc[idx-1] # 放宽条件:开盘价低于至少两条均线 open_below_count = 0 if row['open'] < row[ma1]: open_below_count += 1 if row['open'] < row[ma2]: open_below_count += 1 if row['open'] < row[ma3]: open_below_count += 1 # 放宽条件:收盘价高于至少两条均线 close_above_count = 0 if row['close'] > row[ma1]: close_above_count += 1 if row['close'] > row[ma2]: close_above_count += 1 if row['close'] > row[ma3]: close_above_count += 1 return open_below_count >= 2 and close_above_count >= 2 def _check_ma_convergence(self, data: pd.DataFrame, idx: int) -> bool: """检查均线粘合度""" row = data.iloc[idx] # 计算三条短期均线的最大值和最小值 - 使用安全的get方法获取值 mas = [row.get('ma5', 0), row.get('ma10', 0), row.get('ma20', 0)] # 确保ma值有效 if all(ma > 0 for ma in mas): max_ma = max(mas) min_ma = min(mas) # 计算间距百分比 gap_pct = (max_ma - min_ma) / min_ma * 100 return gap_pct <= self.max_ma_gap_pct return False # 如果均线值无效,返回False def _check_confirmation(self, data: pd.DataFrame, idx: int) -> bool: """检查后续确认""" if idx + self.confirm_days >= len(data): return True # 获取基准均线值 - 使用安全的get方法获取值 base_ma20 = data.iloc[idx].get('ma20', 0) if base_ma20 <= 0: return False # 如果ma20无效,无法进行确认检查 # 检查后续self.confirm_days天是否维持在均线之上 for i in range(1, min(self.confirm_days + 1, len(data) - idx)): check_row = data.iloc[idx + i] if check_row['close'] < base_ma20: return False return True def _calculate_signal_strength(self, data: pd.DataFrame, idx: int) -> float: """计算信号强度(0-1)""" strength = 0.0 row = data.iloc[idx] # 1. 涨幅贡献 (0-0.3分) - 使用安全的get方法获取值 pct_strength = min(row.get('pct_change', 0) / 10.0, 0.3) strength += pct_strength # 2. 量能贡献 (0-0.3分) - 使用安全的get方法获取值 volume_strength = min((row.get('volume_ratio', 1) - 1) / 3.0, 0.3) strength += volume_strength # 3. 均线粘合贡献 (0-0.2分) - 使用安全的get方法获取值 mas = [row.get('ma5', 0), row.get('ma10', 0), row.get('ma20', 0)] # 确保ma值有效 if all(ma > 0 for ma in mas): max_ma = max(mas) min_ma = min(mas) gap_pct = (max_ma - min_ma) / min_ma * 100 ma_strength = max(0, 0.2 * (1 - gap_pct / self.max_ma_gap_pct)) strength += ma_strength # 4. 位置贡献 (0-0.2分) # 计算相对于60日均线的位置 - 使用安全的get方法获取值 ma60 = row.get('ma60', 0) if ma60 > 0: price_ma60_ratio = row['close'] / ma60 if price_ma60_ratio < 1.1: # 在60日线附近 position_strength = 0.2 elif price_ma60_ratio < 1.3: # 略高于60日线 position_strength = 0.1 else: position_strength = 0.0 strength += position_strength return min(strength, 1.0)