223 lines
6.0 KiB
Python
223 lines
6.0 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
策略基类
|
|
定义所有交易策略需要遵循的接口和通用功能
|
|
"""
|
|
|
|
from abc import ABC, abstractmethod
|
|
import pandas as pd
|
|
import logging
|
|
|
|
|
|
class BaseStrategy(ABC):
|
|
"""
|
|
交易策略基类
|
|
所有具体策略需要继承此类并实现抽象方法
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""
|
|
初始化策略
|
|
"""
|
|
self.name = self.__class__.__name__
|
|
self.logger = logging.getLogger(f'strategy.{self.name}')
|
|
self.signals = None # 存储生成的交易信号
|
|
self.positions = [] # 记录持仓历史
|
|
self.current_position = 0 # 当前持仓数量
|
|
self.params = {}
|
|
|
|
self.logger.info(f'策略 {self.name} 初始化完成')
|
|
|
|
@abstractmethod
|
|
def initialize(self, data):
|
|
"""
|
|
初始化策略,处理数据,计算指标等
|
|
|
|
Args:
|
|
data: 股票数据DataFrame
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def generate_signals(self, data):
|
|
"""
|
|
生成交易信号
|
|
|
|
Args:
|
|
data: 股票数据DataFrame
|
|
|
|
Returns:
|
|
pandas.DataFrame: 添加了信号列的DataFrame
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def on_signal(self, date, signal, price):
|
|
"""
|
|
处理交易信号
|
|
|
|
Args:
|
|
date: 日期
|
|
signal: 信号类型 (1=买入, 0=持有, -1=卖出)
|
|
price: 价格
|
|
|
|
Returns:
|
|
dict: 交易指令 {action: 'buy'|'sell'|'hold', quantity: int}
|
|
"""
|
|
pass
|
|
|
|
def set_params(self, **kwargs):
|
|
"""
|
|
设置策略参数
|
|
|
|
Args:
|
|
**kwargs: 策略参数
|
|
"""
|
|
for key, value in kwargs.items():
|
|
self.params[key] = value
|
|
self.logger.info(f'策略参数已更新: {self.params}')
|
|
|
|
def calculate_indicators(self, data):
|
|
"""
|
|
计算技术指标
|
|
|
|
Args:
|
|
data: 原始数据
|
|
|
|
Returns:
|
|
pandas.DataFrame: 添加了指标列的数据
|
|
"""
|
|
# 默认实现,子类可以覆盖此方法
|
|
return data.copy()
|
|
|
|
def _calculate_moving_average(self, data, period):
|
|
"""
|
|
计算移动平均线
|
|
|
|
Args:
|
|
data: 数据
|
|
period: 周期
|
|
|
|
Returns:
|
|
pandas.Series: 移动平均线
|
|
"""
|
|
return data['close'].rolling(window=period).mean()
|
|
|
|
def _calculate_volume_ma(self, data, period):
|
|
"""
|
|
计算成交量移动平均线
|
|
|
|
Args:
|
|
data: 数据
|
|
period: 周期
|
|
|
|
Returns:
|
|
pandas.Series: 成交量移动平均线
|
|
"""
|
|
return data['volume'].rolling(window=period).mean()
|
|
|
|
def _calculate_rsi(self, data, period=14):
|
|
"""
|
|
计算RSI指标
|
|
|
|
Args:
|
|
data: 数据
|
|
period: 周期
|
|
|
|
Returns:
|
|
pandas.Series: RSI值
|
|
"""
|
|
delta = data['close'].diff()
|
|
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
|
|
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
|
|
rs = gain / loss
|
|
rsi = 100 - (100 / (1 + rs))
|
|
return rsi
|
|
|
|
def _calculate_macd(self, data, fast_period=12, slow_period=26, signal_period=9):
|
|
"""
|
|
计算MACD指标
|
|
|
|
Args:
|
|
data: 数据
|
|
fast_period: 快线周期
|
|
slow_period: 慢线周期
|
|
signal_period: 信号线周期
|
|
|
|
Returns:
|
|
tuple: (MACD线, 信号线, 柱状图)
|
|
"""
|
|
exp1 = data['close'].ewm(span=fast_period, adjust=False).mean()
|
|
exp2 = data['close'].ewm(span=slow_period, adjust=False).mean()
|
|
macd = exp1 - exp2
|
|
signal = macd.ewm(span=signal_period, adjust=False).mean()
|
|
histogram = macd - signal
|
|
return macd, signal, histogram
|
|
|
|
def record_position(self, date, action, quantity, price):
|
|
"""
|
|
记录交易和持仓信息
|
|
|
|
Args:
|
|
date: 交易日期
|
|
action: 交易动作 ('buy'|'sell'|'hold')
|
|
quantity: 交易数量
|
|
price: 交易价格
|
|
"""
|
|
if action == 'buy':
|
|
self.current_position += quantity
|
|
elif action == 'sell':
|
|
self.current_position -= quantity
|
|
|
|
# 确保持仓不会变成负数
|
|
if self.current_position < 0:
|
|
self.current_position = 0
|
|
|
|
trade_record = {
|
|
'date': date,
|
|
'action': action,
|
|
'quantity': quantity,
|
|
'price': price,
|
|
'position': self.current_position
|
|
}
|
|
|
|
self.positions.append(trade_record)
|
|
self.logger.debug(f'记录交易: {trade_record}')
|
|
|
|
def get_trading_summary(self):
|
|
"""
|
|
获取交易汇总信息
|
|
|
|
Returns:
|
|
dict: 交易汇总
|
|
"""
|
|
if not self.positions:
|
|
return {
|
|
'total_trades': 0,
|
|
'buy_trades': 0,
|
|
'sell_trades': 0,
|
|
'current_position': 0
|
|
}
|
|
|
|
buy_trades = sum(1 for p in self.positions if p['action'] == 'buy')
|
|
sell_trades = sum(1 for p in self.positions if p['action'] == 'sell')
|
|
|
|
return {
|
|
'total_trades': buy_trades + sell_trades,
|
|
'buy_trades': buy_trades,
|
|
'sell_trades': sell_trades,
|
|
'current_position': self.current_position,
|
|
'first_trade_date': self.positions[0]['date'],
|
|
'last_trade_date': self.positions[-1]['date']
|
|
}
|
|
|
|
def reset(self):
|
|
"""
|
|
重置策略状态
|
|
"""
|
|
self.signals = None
|
|
self.positions = []
|
|
self.current_position = 0
|
|
self.logger.info(f'策略 {self.name} 已重置')
|