# -*- coding: utf-8 -*- """ 策略基类 定义所有交易策略需要遵循的接口和通用功能 """ from abc import ABC, abstractmethod import pandas as pd import logging class BaseStrategy(ABC): """ 交易策略基类 所有具体策略需要继承此类并实现抽象方法 """ def __init__(self): """ 初始化策略 """ self.name = self.__class__.__name__ self.logger = logging.getLogger(f'strategy.{self.name}') self.signals = None # 存储生成的交易信号 self.positions = [] # 记录持仓历史 self.current_position = 0 # 当前持仓数量 self.params = {} self.logger.info(f'策略 {self.name} 初始化完成') @abstractmethod def initialize(self, data): """ 初始化策略,处理数据,计算指标等 Args: data: 股票数据DataFrame """ pass @abstractmethod def generate_signals(self, data): """ 生成交易信号 Args: data: 股票数据DataFrame Returns: pandas.DataFrame: 添加了信号列的DataFrame """ pass @abstractmethod def on_signal(self, date, signal, price): """ 处理交易信号 Args: date: 日期 signal: 信号类型 (1=买入, 0=持有, -1=卖出) price: 价格 Returns: dict: 交易指令 {action: 'buy'|'sell'|'hold', quantity: int} """ pass def set_params(self, **kwargs): """ 设置策略参数 Args: **kwargs: 策略参数 """ for key, value in kwargs.items(): self.params[key] = value self.logger.info(f'策略参数已更新: {self.params}') def calculate_indicators(self, data): """ 计算技术指标 Args: data: 原始数据 Returns: pandas.DataFrame: 添加了指标列的数据 """ # 默认实现,子类可以覆盖此方法 return data.copy() def _calculate_moving_average(self, data, period): """ 计算移动平均线 Args: data: 数据 period: 周期 Returns: pandas.Series: 移动平均线 """ return data['close'].rolling(window=period).mean() def _calculate_volume_ma(self, data, period): """ 计算成交量移动平均线 Args: data: 数据 period: 周期 Returns: pandas.Series: 成交量移动平均线 """ return data['volume'].rolling(window=period).mean() def _calculate_rsi(self, data, period=14): """ 计算RSI指标 Args: data: 数据 period: 周期 Returns: pandas.Series: RSI值 """ delta = data['close'].diff() gain = (delta.where(delta > 0, 0)).rolling(window=period).mean() loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean() rs = gain / loss rsi = 100 - (100 / (1 + rs)) return rsi def _calculate_macd(self, data, fast_period=12, slow_period=26, signal_period=9): """ 计算MACD指标 Args: data: 数据 fast_period: 快线周期 slow_period: 慢线周期 signal_period: 信号线周期 Returns: tuple: (MACD线, 信号线, 柱状图) """ exp1 = data['close'].ewm(span=fast_period, adjust=False).mean() exp2 = data['close'].ewm(span=slow_period, adjust=False).mean() macd = exp1 - exp2 signal = macd.ewm(span=signal_period, adjust=False).mean() histogram = macd - signal return macd, signal, histogram def record_position(self, date, action, quantity, price): """ 记录交易和持仓信息 Args: date: 交易日期 action: 交易动作 ('buy'|'sell'|'hold') quantity: 交易数量 price: 交易价格 """ if action == 'buy': self.current_position += quantity elif action == 'sell': self.current_position -= quantity # 确保持仓不会变成负数 if self.current_position < 0: self.current_position = 0 trade_record = { 'date': date, 'action': action, 'quantity': quantity, 'price': price, 'position': self.current_position } self.positions.append(trade_record) self.logger.debug(f'记录交易: {trade_record}') def get_trading_summary(self): """ 获取交易汇总信息 Returns: dict: 交易汇总 """ if not self.positions: return { 'total_trades': 0, 'buy_trades': 0, 'sell_trades': 0, 'current_position': 0 } buy_trades = sum(1 for p in self.positions if p['action'] == 'buy') sell_trades = sum(1 for p in self.positions if p['action'] == 'sell') return { 'total_trades': buy_trades + sell_trades, 'buy_trades': buy_trades, 'sell_trades': sell_trades, 'current_position': self.current_position, 'first_trade_date': self.positions[0]['date'], 'last_trade_date': self.positions[-1]['date'] } def reset(self): """ 重置策略状态 """ self.signals = None self.positions = [] self.current_position = 0 self.logger.info(f'策略 {self.name} 已重置')