Files
backtrader/strategy/yiyang_strategy.py
2026-01-17 21:21:30 +08:00

356 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import pandas as pd
import numpy as np
import logging
from typing import Dict, List, Tuple, Optional
from .base_strategy import BaseStrategy
from .market_filter import MarketFilter
class YiYangStrategy(BaseStrategy):
"""一阳穿三线策略实现"""
def __init__(self, config: Dict):
super().__init__()
self.market_filter = MarketFilter(config)
# 使用主日志记录器,确保日志能被正确记录到文件
self.logger = logging.getLogger('backtrader')
# 策略参数
self.min_pct_change = config.get('min_pct_change', 3.0)
self.min_volume_ratio = config.get('min_volume_ratio', 1.5)
self.max_ma_gap_pct = config.get('max_ma_gap_pct', 5.0)
self.min_entity_ratio = config.get('min_entity_ratio', 0.6)
self.confirm_days = config.get('confirm_days', 3)
# 均线组合设置
self.ma_combinations = [
('ma5', 'ma10', 'ma20'), # 短线组合
('ma10', 'ma20', 'ma60') # 中线组合
]
def initialize(self, data):
"""
初始化策略,处理数据,计算指标等
Args:
data: 股票数据DataFrame
"""
# 初始化信号数据
self.signals = None
# 初始化市场过滤器
start_date = data.index.min().strftime('%Y%m%d') if hasattr(data.index.min(), 'strftime') else str(data.index.min())
end_date = data.index.max().strftime('%Y%m%d') if hasattr(data.index.max(), 'strftime') else str(data.index.max())
self.market_filter.initialize(start_date=start_date, end_date=end_date)
def generate_signals(self, data):
"""
生成交易信号(向量化优化版本)
Args:
data: 股票数据DataFrame
Returns:
pandas.DataFrame: 添加了信号列的DataFrame
"""
if data.empty or len(data) < 60:
return pd.DataFrame()
# 确保数据按日期从旧到新排序
data = data.sort_index()
# 使用向量化方法批量检查条件
signals = self._vectorized_signal_generation(data)
# 保存信号
self.signals = signals
return signals
def _vectorized_signal_generation(self, data: pd.DataFrame) -> pd.DataFrame:
"""向量化信号生成(核心优化)"""
n = len(data)
signals = pd.DataFrame(index=data.index)
signals['signal'] = 0
signals['strength'] = 0.0
signals['reason'] = ''
if n < 60:
return signals
# === 向量化基础条件检查 ===
# 1. 涨幅条件
pct_check = data['pct_change'].fillna(0) >= self.min_pct_change
# 2. 量能条件
vol_check = data['volume_ratio'].fillna(0) >= self.min_volume_ratio
# 3. 实体比例条件
entity_check = data['entity_ratio'].fillna(0) >= self.min_entity_ratio
# 4. 阳线条件
yang_check = data['close'] > data['open']
# 组合基础条件
basic_cond = pct_check & vol_check & entity_check & yang_check
# === 向量化均线穿越检查 ===
# 检查两组均线
ma_cross_1 = self._vectorized_ma_cross(data, ('ma5', 'ma10', 'ma20'))
ma_cross_2 = self._vectorized_ma_cross(data, ('ma10', 'ma20', 'ma60'))
ma_cross = ma_cross_1 | ma_cross_2
# === 向量化均线粘合检查 ===
ma_conv = self._vectorized_ma_convergence(data)
# 综合条件(暂不检查后续确认)
candidate_signals = basic_cond & ma_cross & ma_conv
# 对候选信号逐个进行后续确认(这部分难以向量化)
for idx in np.where(candidate_signals)[0]:
if idx < 60: # 跳过前60天
continue
# 后续确认检查
if self._check_confirmation(data, idx):
signals.iloc[idx, 0] = 1 # signal
signals.iloc[idx, 1] = self._calculate_signal_strength_vectorized(data, idx) # strength
signals.iloc[idx, 2] = '一阳穿三线' # reason
return signals
def _vectorized_ma_cross(self, data: pd.DataFrame, ma_combo: Tuple) -> pd.Series:
"""向量化均线穿越检查"""
ma1, ma2, ma3 = ma_combo
# 开盘价低于至少两条均线
open_below_1 = data['open'] < data[ma1]
open_below_2 = data['open'] < data[ma2]
open_below_3 = data['open'] < data[ma3]
open_below_count = open_below_1.astype(int) + open_below_2.astype(int) + open_below_3.astype(int)
# 收盘价高于至少两条均线
close_above_1 = data['close'] > data[ma1]
close_above_2 = data['close'] > data[ma2]
close_above_3 = data['close'] > data[ma3]
close_above_count = close_above_1.astype(int) + close_above_2.astype(int) + close_above_3.astype(int)
return (open_below_count >= 2) & (close_above_count >= 2)
def _vectorized_ma_convergence(self, data: pd.DataFrame) -> pd.Series:
"""向量化均线粘合度检查"""
# 计算三条均线的最大最小值
ma5 = data['ma5'].fillna(0)
ma10 = data['ma10'].fillna(0)
ma20 = data['ma20'].fillna(0)
# 三个均线的最大最小值
max_ma = pd.concat([ma5, ma10, ma20], axis=1).max(axis=1)
min_ma = pd.concat([ma5, ma10, ma20], axis=1).min(axis=1)
# 计算间距百分比
gap_pct = (max_ma - min_ma) / min_ma.replace(0, np.nan) * 100
return (gap_pct <= self.max_ma_gap_pct) & (min_ma > 0)
def _calculate_signal_strength_vectorized(self, data: pd.DataFrame, idx: int) -> float:
"""向量化计算信号强度"""
row = data.iloc[idx]
strength = 0.0
# 1. 涨幅贡献
strength += min(row.get('pct_change', 0) / 10.0, 0.3)
# 2. 量能贡献
strength += min((row.get('volume_ratio', 1) - 1) / 3.0, 0.3)
# 3. 均线粘合贡献
mas = [row.get('ma5', 0), row.get('ma10', 0), row.get('ma20', 0)]
if all(ma > 0 for ma in mas):
gap_pct = (max(mas) - min(mas)) / min(mas) * 100
strength += max(0, 0.2 * (1 - gap_pct / self.max_ma_gap_pct))
# 4. 位置贡献
ma60 = row.get('ma60', 0)
if ma60 > 0:
ratio = row['close'] / ma60
if ratio < 1.1:
strength += 0.2
elif ratio < 1.3:
strength += 0.1
return min(strength, 1.0)
def on_signal(self, date, signal, price):
"""
处理交易信号
Args:
date: 日期
signal: 信号类型 (1=买入, 0=持有, -1=卖出)
price: 价格
Returns:
dict: 交易指令 {action: 'buy'|'sell'|'hold', quantity: int}
"""
if signal == 1:
# 买入信号
return {'action': 'buy', 'quantity': 100} # 示例买入100股
elif signal == -1:
# 卖出信号
return {'action': 'sell', 'quantity': 100} # 示例卖出100股
else:
# 持有
return {'action': 'hold', 'quantity': 0}
def _check_yiyang_signal(self, data: pd.DataFrame, idx: int) -> bool:
"""检查是否符合一阳穿三线条件"""
# 检查基本条件
if not self._check_basic_conditions(data, idx):
return False
# 检查是否穿越至少一组均线
signal_found = False
for ma_combo in self.ma_combinations:
if self._check_ma_cross(data, idx, ma_combo):
signal_found = True
break
if not signal_found:
return False
# 检查均线粘合度
if not self._check_ma_convergence(data, idx):
return False
# 暂时跳过大盘环境检查,以便测试策略核心逻辑
# if not self.market_filter.is_good_market():
# return False
# 检查后续确认
if not self._check_confirmation(data, idx):
return False
return True
def _check_basic_conditions(self, data: pd.DataFrame, idx: int) -> bool:
"""检查基础条件"""
row = data.iloc[idx]
prev_row = data.iloc[idx-1]
# 1. 涨幅条件 - 使用安全的get方法获取值
if row.get('pct_change', 0) < self.min_pct_change:
return False
# 2. 量能条件 - 使用安全的get方法获取值
if row.get('volume_ratio', 0) < self.min_volume_ratio:
return False
# 3. 阳线实体条件 - 使用安全的get方法获取值
if row.get('entity_ratio', 0) < self.min_entity_ratio:
return False
# 4. 必须是阳线
if row['close'] <= row['open']:
return False
return True
def _check_ma_cross(self, data: pd.DataFrame, idx: int, ma_combo: Tuple) -> bool:
"""检查是否穿越指定均线组合"""
ma1, ma2, ma3 = ma_combo
row = data.iloc[idx]
prev_row = data.iloc[idx-1]
# 放宽条件:开盘价低于至少两条均线
open_below_count = 0
if row['open'] < row[ma1]:
open_below_count += 1
if row['open'] < row[ma2]:
open_below_count += 1
if row['open'] < row[ma3]:
open_below_count += 1
# 放宽条件:收盘价高于至少两条均线
close_above_count = 0
if row['close'] > row[ma1]:
close_above_count += 1
if row['close'] > row[ma2]:
close_above_count += 1
if row['close'] > row[ma3]:
close_above_count += 1
return open_below_count >= 2 and close_above_count >= 2
def _check_ma_convergence(self, data: pd.DataFrame, idx: int) -> bool:
"""检查均线粘合度"""
row = data.iloc[idx]
# 计算三条短期均线的最大值和最小值 - 使用安全的get方法获取值
mas = [row.get('ma5', 0), row.get('ma10', 0), row.get('ma20', 0)]
# 确保ma值有效
if all(ma > 0 for ma in mas):
max_ma = max(mas)
min_ma = min(mas)
# 计算间距百分比
gap_pct = (max_ma - min_ma) / min_ma * 100
return gap_pct <= self.max_ma_gap_pct
return False # 如果均线值无效返回False
def _check_confirmation(self, data: pd.DataFrame, idx: int) -> bool:
"""检查后续确认"""
if idx + self.confirm_days >= len(data):
return True
# 获取基准均线值 - 使用安全的get方法获取值
base_ma20 = data.iloc[idx].get('ma20', 0)
if base_ma20 <= 0:
return False # 如果ma20无效无法进行确认检查
# 检查后续self.confirm_days天是否维持在均线之上
for i in range(1, min(self.confirm_days + 1, len(data) - idx)):
check_row = data.iloc[idx + i]
if check_row['close'] < base_ma20:
return False
return True
def _calculate_signal_strength(self, data: pd.DataFrame, idx: int) -> float:
"""计算信号强度(0-1)"""
strength = 0.0
row = data.iloc[idx]
# 1. 涨幅贡献 (0-0.3分) - 使用安全的get方法获取值
pct_strength = min(row.get('pct_change', 0) / 10.0, 0.3)
strength += pct_strength
# 2. 量能贡献 (0-0.3分) - 使用安全的get方法获取值
volume_strength = min((row.get('volume_ratio', 1) - 1) / 3.0, 0.3)
strength += volume_strength
# 3. 均线粘合贡献 (0-0.2分) - 使用安全的get方法获取值
mas = [row.get('ma5', 0), row.get('ma10', 0), row.get('ma20', 0)]
# 确保ma值有效
if all(ma > 0 for ma in mas):
max_ma = max(mas)
min_ma = min(mas)
gap_pct = (max_ma - min_ma) / min_ma * 100
ma_strength = max(0, 0.2 * (1 - gap_pct / self.max_ma_gap_pct))
strength += ma_strength
# 4. 位置贡献 (0-0.2分)
# 计算相对于60日均线的位置 - 使用安全的get方法获取值
ma60 = row.get('ma60', 0)
if ma60 > 0:
price_ma60_ratio = row['close'] / ma60
if price_ma60_ratio < 1.1: # 在60日线附近
position_strength = 0.2
elif price_ma60_ratio < 1.3: # 略高于60日线
position_strength = 0.1
else:
position_strength = 0.0
strength += position_strength
return min(strength, 1.0)