Files
backtrader/strategy/base_strategy.py
2026-01-17 21:21:30 +08:00

223 lines
6.0 KiB
Python

# -*- coding: utf-8 -*-
"""
策略基类
定义所有交易策略需要遵循的接口和通用功能
"""
from abc import ABC, abstractmethod
import pandas as pd
import logging
class BaseStrategy(ABC):
"""
交易策略基类
所有具体策略需要继承此类并实现抽象方法
"""
def __init__(self):
"""
初始化策略
"""
self.name = self.__class__.__name__
self.logger = logging.getLogger(f'strategy.{self.name}')
self.signals = None # 存储生成的交易信号
self.positions = [] # 记录持仓历史
self.current_position = 0 # 当前持仓数量
self.params = {}
self.logger.info(f'策略 {self.name} 初始化完成')
@abstractmethod
def initialize(self, data):
"""
初始化策略,处理数据,计算指标等
Args:
data: 股票数据DataFrame
"""
pass
@abstractmethod
def generate_signals(self, data):
"""
生成交易信号
Args:
data: 股票数据DataFrame
Returns:
pandas.DataFrame: 添加了信号列的DataFrame
"""
pass
@abstractmethod
def on_signal(self, date, signal, price):
"""
处理交易信号
Args:
date: 日期
signal: 信号类型 (1=买入, 0=持有, -1=卖出)
price: 价格
Returns:
dict: 交易指令 {action: 'buy'|'sell'|'hold', quantity: int}
"""
pass
def set_params(self, **kwargs):
"""
设置策略参数
Args:
**kwargs: 策略参数
"""
for key, value in kwargs.items():
self.params[key] = value
self.logger.info(f'策略参数已更新: {self.params}')
def calculate_indicators(self, data):
"""
计算技术指标
Args:
data: 原始数据
Returns:
pandas.DataFrame: 添加了指标列的数据
"""
# 默认实现,子类可以覆盖此方法
return data.copy()
def _calculate_moving_average(self, data, period):
"""
计算移动平均线
Args:
data: 数据
period: 周期
Returns:
pandas.Series: 移动平均线
"""
return data['close'].rolling(window=period).mean()
def _calculate_volume_ma(self, data, period):
"""
计算成交量移动平均线
Args:
data: 数据
period: 周期
Returns:
pandas.Series: 成交量移动平均线
"""
return data['volume'].rolling(window=period).mean()
def _calculate_rsi(self, data, period=14):
"""
计算RSI指标
Args:
data: 数据
period: 周期
Returns:
pandas.Series: RSI值
"""
delta = data['close'].diff()
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
rs = gain / loss
rsi = 100 - (100 / (1 + rs))
return rsi
def _calculate_macd(self, data, fast_period=12, slow_period=26, signal_period=9):
"""
计算MACD指标
Args:
data: 数据
fast_period: 快线周期
slow_period: 慢线周期
signal_period: 信号线周期
Returns:
tuple: (MACD线, 信号线, 柱状图)
"""
exp1 = data['close'].ewm(span=fast_period, adjust=False).mean()
exp2 = data['close'].ewm(span=slow_period, adjust=False).mean()
macd = exp1 - exp2
signal = macd.ewm(span=signal_period, adjust=False).mean()
histogram = macd - signal
return macd, signal, histogram
def record_position(self, date, action, quantity, price):
"""
记录交易和持仓信息
Args:
date: 交易日期
action: 交易动作 ('buy'|'sell'|'hold')
quantity: 交易数量
price: 交易价格
"""
if action == 'buy':
self.current_position += quantity
elif action == 'sell':
self.current_position -= quantity
# 确保持仓不会变成负数
if self.current_position < 0:
self.current_position = 0
trade_record = {
'date': date,
'action': action,
'quantity': quantity,
'price': price,
'position': self.current_position
}
self.positions.append(trade_record)
self.logger.debug(f'记录交易: {trade_record}')
def get_trading_summary(self):
"""
获取交易汇总信息
Returns:
dict: 交易汇总
"""
if not self.positions:
return {
'total_trades': 0,
'buy_trades': 0,
'sell_trades': 0,
'current_position': 0
}
buy_trades = sum(1 for p in self.positions if p['action'] == 'buy')
sell_trades = sum(1 for p in self.positions if p['action'] == 'sell')
return {
'total_trades': buy_trades + sell_trades,
'buy_trades': buy_trades,
'sell_trades': sell_trades,
'current_position': self.current_position,
'first_trade_date': self.positions[0]['date'],
'last_trade_date': self.positions[-1]['date']
}
def reset(self):
"""
重置策略状态
"""
self.signals = None
self.positions = []
self.current_position = 0
self.logger.info(f'策略 {self.name} 已重置')