141 lines
5.1 KiB
Python
141 lines
5.1 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
双均线策略实现
|
|
当短期均线上穿长期均线时买入,当短期均线下穿长期均线时卖出
|
|
"""
|
|
|
|
import pandas as pd
|
|
import numpy as np
|
|
import logging
|
|
from typing import Dict, List, Tuple, Optional
|
|
from .base_strategy import BaseStrategy
|
|
|
|
class DualMovingAverageStrategy(BaseStrategy):
|
|
"""双均线策略实现"""
|
|
|
|
def __init__(self, config: Dict):
|
|
super().__init__()
|
|
# 使用主日志记录器,确保日志能被正确记录到文件
|
|
self.logger = logging.getLogger('backtrader')
|
|
|
|
# 策略参数
|
|
self.short_period = config.get('short_period', 5) # 短期均线周期
|
|
self.long_period = config.get('long_period', 20) # 长期均线周期
|
|
self.trade_quantity = config.get('trade_quantity', 100) # 每次交易数量
|
|
|
|
# 保存生成的信号
|
|
self.signals = None
|
|
|
|
self.logger.info(f'双均线策略初始化完成: 短期周期={self.short_period}, 长期周期={self.long_period}')
|
|
|
|
def initialize(self, data):
|
|
"""
|
|
初始化策略,处理数据,计算指标等
|
|
|
|
Args:
|
|
data: 股票数据DataFrame
|
|
"""
|
|
# 初始化信号数据
|
|
self.signals = None
|
|
|
|
# 计算移动平均线
|
|
self._calculate_moving_averages(data)
|
|
|
|
def generate_signals(self, data):
|
|
"""
|
|
生成交易信号
|
|
|
|
Args:
|
|
data: 股票数据DataFrame
|
|
|
|
Returns:
|
|
pandas.DataFrame: 添加了信号列的DataFrame
|
|
"""
|
|
if data.empty or len(data) < self.long_period:
|
|
return pd.DataFrame()
|
|
|
|
# 确保数据按日期从旧到新排序
|
|
data = data.sort_index()
|
|
|
|
# 计算移动平均线
|
|
self._calculate_moving_averages(data)
|
|
|
|
signals = pd.DataFrame(index=data.index)
|
|
signals['signal'] = 0 # 0:无信号, 1:买入信号, -1:卖出信号
|
|
signals['reason'] = '' # 信号原因
|
|
signals['short_ma'] = data[f'sma{self.short_period}']
|
|
signals['long_ma'] = data[f'sma{self.long_period}']
|
|
signals['strength'] = 0.0 # 信号强度
|
|
|
|
# 生成交易信号
|
|
for i in range(1, len(data)):
|
|
# 获取当前和前一天的均线值
|
|
short_ma_current = data[f'sma{self.short_period}'].iloc[i]
|
|
long_ma_current = data[f'sma{self.long_period}'].iloc[i]
|
|
short_ma_prev = data[f'sma{self.short_period}'].iloc[i-1]
|
|
long_ma_prev = data[f'sma{self.long_period}'].iloc[i-1]
|
|
|
|
signal = 0
|
|
strength = 0.0
|
|
reason = ''
|
|
|
|
# 短期均线上穿长期均线 - 买入信号
|
|
if short_ma_prev <= long_ma_prev and short_ma_current > long_ma_current:
|
|
signal = 1
|
|
strength = abs((short_ma_current - long_ma_current) / long_ma_current * 100)
|
|
reason = f'短期均线上穿长期均线 ({short_ma_current:.2f} > {long_ma_current:.2f})'
|
|
|
|
# 短期均线下穿长期均线 - 卖出信号
|
|
elif short_ma_prev >= long_ma_prev and short_ma_current < long_ma_current:
|
|
signal = -1
|
|
strength = abs((long_ma_current - short_ma_current) / long_ma_current * 100)
|
|
reason = f'短期均线下穿长期均线 ({short_ma_current:.2f} < {long_ma_current:.2f})'
|
|
|
|
# 记录信号
|
|
signals.loc[data.index[i], 'signal'] = signal
|
|
signals.loc[data.index[i], 'strength'] = strength
|
|
signals.loc[data.index[i], 'reason'] = reason
|
|
|
|
# 保存信号
|
|
self.signals = signals
|
|
return signals
|
|
|
|
def on_signal(self, date, signal, price):
|
|
"""
|
|
处理交易信号
|
|
|
|
Args:
|
|
date: 日期
|
|
signal: 信号类型 (1=买入, 0=持有, -1=卖出)
|
|
price: 价格
|
|
|
|
Returns:
|
|
dict: 交易指令 {action: 'buy'|'sell'|'hold', quantity: int}
|
|
"""
|
|
if signal == 1:
|
|
# 买入信号
|
|
self.logger.info(f'{date} 买入信号: 价格={price:.2f}, 数量={self.trade_quantity}')
|
|
return {'action': 'buy', 'quantity': self.trade_quantity}
|
|
elif signal == -1:
|
|
# 卖出信号
|
|
self.logger.info(f'{date} 卖出信号: 价格={price:.2f}, 数量={self.trade_quantity}')
|
|
return {'action': 'sell', 'quantity': self.trade_quantity}
|
|
else:
|
|
# 持有
|
|
return {'action': 'hold', 'quantity': 0}
|
|
|
|
def _calculate_moving_averages(self, data):
|
|
"""
|
|
计算移动平均线
|
|
|
|
Args:
|
|
data: 股票数据DataFrame
|
|
"""
|
|
# 计算短期移动平均线
|
|
data[f'sma{self.short_period}'] = data['close'].rolling(window=self.short_period).mean()
|
|
|
|
# 计算长期移动平均线
|
|
data[f'sma{self.long_period}'] = data['close'].rolling(window=self.long_period).mean()
|
|
|
|
self.logger.debug(f'已计算移动平均线: SMA{self.short_period} 和 SMA{self.long_period}')
|