# -*- coding: utf-8 -*- """ 双均线策略实现 当短期均线上穿长期均线时买入,当短期均线下穿长期均线时卖出 """ import pandas as pd import numpy as np import logging from typing import Dict, List, Tuple, Optional from .base_strategy import BaseStrategy class DualMovingAverageStrategy(BaseStrategy): """双均线策略实现""" def __init__(self, config: Dict): super().__init__() # 使用主日志记录器,确保日志能被正确记录到文件 self.logger = logging.getLogger('backtrader') # 策略参数 self.short_period = config.get('short_period', 5) # 短期均线周期 self.long_period = config.get('long_period', 20) # 长期均线周期 self.trade_quantity = config.get('trade_quantity', 100) # 每次交易数量 # 保存生成的信号 self.signals = None self.logger.info(f'双均线策略初始化完成: 短期周期={self.short_period}, 长期周期={self.long_period}') def initialize(self, data): """ 初始化策略,处理数据,计算指标等 Args: data: 股票数据DataFrame """ # 初始化信号数据 self.signals = None # 计算移动平均线 self._calculate_moving_averages(data) def generate_signals(self, data): """ 生成交易信号 Args: data: 股票数据DataFrame Returns: pandas.DataFrame: 添加了信号列的DataFrame """ if data.empty or len(data) < self.long_period: return pd.DataFrame() # 确保数据按日期从旧到新排序 data = data.sort_index() # 计算移动平均线 self._calculate_moving_averages(data) signals = pd.DataFrame(index=data.index) signals['signal'] = 0 # 0:无信号, 1:买入信号, -1:卖出信号 signals['reason'] = '' # 信号原因 signals['short_ma'] = data[f'sma{self.short_period}'] signals['long_ma'] = data[f'sma{self.long_period}'] signals['strength'] = 0.0 # 信号强度 # 生成交易信号 for i in range(1, len(data)): # 获取当前和前一天的均线值 short_ma_current = data[f'sma{self.short_period}'].iloc[i] long_ma_current = data[f'sma{self.long_period}'].iloc[i] short_ma_prev = data[f'sma{self.short_period}'].iloc[i-1] long_ma_prev = data[f'sma{self.long_period}'].iloc[i-1] signal = 0 strength = 0.0 reason = '' # 短期均线上穿长期均线 - 买入信号 if short_ma_prev <= long_ma_prev and short_ma_current > long_ma_current: signal = 1 strength = abs((short_ma_current - long_ma_current) / long_ma_current * 100) reason = f'短期均线上穿长期均线 ({short_ma_current:.2f} > {long_ma_current:.2f})' # 短期均线下穿长期均线 - 卖出信号 elif short_ma_prev >= long_ma_prev and short_ma_current < long_ma_current: signal = -1 strength = abs((long_ma_current - short_ma_current) / long_ma_current * 100) reason = f'短期均线下穿长期均线 ({short_ma_current:.2f} < {long_ma_current:.2f})' # 记录信号 signals.loc[data.index[i], 'signal'] = signal signals.loc[data.index[i], 'strength'] = strength signals.loc[data.index[i], 'reason'] = reason # 保存信号 self.signals = signals return signals def on_signal(self, date, signal, price): """ 处理交易信号 Args: date: 日期 signal: 信号类型 (1=买入, 0=持有, -1=卖出) price: 价格 Returns: dict: 交易指令 {action: 'buy'|'sell'|'hold', quantity: int} """ if signal == 1: # 买入信号 self.logger.info(f'{date} 买入信号: 价格={price:.2f}, 数量={self.trade_quantity}') return {'action': 'buy', 'quantity': self.trade_quantity} elif signal == -1: # 卖出信号 self.logger.info(f'{date} 卖出信号: 价格={price:.2f}, 数量={self.trade_quantity}') return {'action': 'sell', 'quantity': self.trade_quantity} else: # 持有 return {'action': 'hold', 'quantity': 0} def _calculate_moving_averages(self, data): """ 计算移动平均线 Args: data: 股票数据DataFrame """ # 计算短期移动平均线 data[f'sma{self.short_period}'] = data['close'].rolling(window=self.short_period).mean() # 计算长期移动平均线 data[f'sma{self.long_period}'] = data['close'].rolling(window=self.long_period).mean() self.logger.debug(f'已计算移动平均线: SMA{self.short_period} 和 SMA{self.long_period}')