Files
backtrader/strategy/dual_moving_average_strategy.py
2026-01-17 21:21:30 +08:00

141 lines
5.1 KiB
Python

# -*- coding: utf-8 -*-
"""
双均线策略实现
当短期均线上穿长期均线时买入,当短期均线下穿长期均线时卖出
"""
import pandas as pd
import numpy as np
import logging
from typing import Dict, List, Tuple, Optional
from .base_strategy import BaseStrategy
class DualMovingAverageStrategy(BaseStrategy):
"""双均线策略实现"""
def __init__(self, config: Dict):
super().__init__()
# 使用主日志记录器,确保日志能被正确记录到文件
self.logger = logging.getLogger('backtrader')
# 策略参数
self.short_period = config.get('short_period', 5) # 短期均线周期
self.long_period = config.get('long_period', 20) # 长期均线周期
self.trade_quantity = config.get('trade_quantity', 100) # 每次交易数量
# 保存生成的信号
self.signals = None
self.logger.info(f'双均线策略初始化完成: 短期周期={self.short_period}, 长期周期={self.long_period}')
def initialize(self, data):
"""
初始化策略,处理数据,计算指标等
Args:
data: 股票数据DataFrame
"""
# 初始化信号数据
self.signals = None
# 计算移动平均线
self._calculate_moving_averages(data)
def generate_signals(self, data):
"""
生成交易信号
Args:
data: 股票数据DataFrame
Returns:
pandas.DataFrame: 添加了信号列的DataFrame
"""
if data.empty or len(data) < self.long_period:
return pd.DataFrame()
# 确保数据按日期从旧到新排序
data = data.sort_index()
# 计算移动平均线
self._calculate_moving_averages(data)
signals = pd.DataFrame(index=data.index)
signals['signal'] = 0 # 0:无信号, 1:买入信号, -1:卖出信号
signals['reason'] = '' # 信号原因
signals['short_ma'] = data[f'sma{self.short_period}']
signals['long_ma'] = data[f'sma{self.long_period}']
signals['strength'] = 0.0 # 信号强度
# 生成交易信号
for i in range(1, len(data)):
# 获取当前和前一天的均线值
short_ma_current = data[f'sma{self.short_period}'].iloc[i]
long_ma_current = data[f'sma{self.long_period}'].iloc[i]
short_ma_prev = data[f'sma{self.short_period}'].iloc[i-1]
long_ma_prev = data[f'sma{self.long_period}'].iloc[i-1]
signal = 0
strength = 0.0
reason = ''
# 短期均线上穿长期均线 - 买入信号
if short_ma_prev <= long_ma_prev and short_ma_current > long_ma_current:
signal = 1
strength = abs((short_ma_current - long_ma_current) / long_ma_current * 100)
reason = f'短期均线上穿长期均线 ({short_ma_current:.2f} > {long_ma_current:.2f})'
# 短期均线下穿长期均线 - 卖出信号
elif short_ma_prev >= long_ma_prev and short_ma_current < long_ma_current:
signal = -1
strength = abs((long_ma_current - short_ma_current) / long_ma_current * 100)
reason = f'短期均线下穿长期均线 ({short_ma_current:.2f} < {long_ma_current:.2f})'
# 记录信号
signals.loc[data.index[i], 'signal'] = signal
signals.loc[data.index[i], 'strength'] = strength
signals.loc[data.index[i], 'reason'] = reason
# 保存信号
self.signals = signals
return signals
def on_signal(self, date, signal, price):
"""
处理交易信号
Args:
date: 日期
signal: 信号类型 (1=买入, 0=持有, -1=卖出)
price: 价格
Returns:
dict: 交易指令 {action: 'buy'|'sell'|'hold', quantity: int}
"""
if signal == 1:
# 买入信号
self.logger.info(f'{date} 买入信号: 价格={price:.2f}, 数量={self.trade_quantity}')
return {'action': 'buy', 'quantity': self.trade_quantity}
elif signal == -1:
# 卖出信号
self.logger.info(f'{date} 卖出信号: 价格={price:.2f}, 数量={self.trade_quantity}')
return {'action': 'sell', 'quantity': self.trade_quantity}
else:
# 持有
return {'action': 'hold', 'quantity': 0}
def _calculate_moving_averages(self, data):
"""
计算移动平均线
Args:
data: 股票数据DataFrame
"""
# 计算短期移动平均线
data[f'sma{self.short_period}'] = data['close'].rolling(window=self.short_period).mean()
# 计算长期移动平均线
data[f'sma{self.long_period}'] = data['close'].rolling(window=self.long_period).mean()
self.logger.debug(f'已计算移动平均线: SMA{self.short_period} 和 SMA{self.long_period}')