import pandas as pd import numpy as np import matplotlib.pyplot as plt import matplotlib.dates as mdates from matplotlib.ticker import FuncFormatter import seaborn as sns import os class ChartGenerator: """图表生成器,用于创建各种回测分析图表""" def __init__(self, output_dir='charts', strategy_name=None): """初始化图表生成器""" # 根据策略名称创建子目录结构 self.output_dir = output_dir if strategy_name: self.output_dir = os.path.join(output_dir, strategy_name) os.makedirs(self.output_dir, exist_ok=True) # 设置图表风格 plt.style.use('ggplot') plt.rcParams['font.sans-serif'] = ['SimHei'] # 支持中文 plt.rcParams['axes.unicode_minus'] = False # 支持负号 def _save_chart(self, fig, filename, dpi=300): """保存图表""" filepath = os.path.join(self.output_dir, filename) fig.savefig(filepath, dpi=dpi, bbox_inches='tight') plt.close(fig) return filepath def plot_equity_curve(self, results, strategy_name, benchmark_data=None): """1. 策略收益曲线图""" fig, ax = plt.subplots(figsize=(12, 6)) equity_df = results['equity_curve'] # 检查权益曲线是否为空且包含'equity'列 if equity_df is not None and not equity_df.empty and 'equity' in equity_df.columns: ax.plot(equity_df.index, equity_df['equity'], label=strategy_name) # 如果有基准数据,将基准收益曲线叠加到策略收益曲线上 if benchmark_data is not None: # 计算基准收益率 benchmark_returns = benchmark_data['close'] / benchmark_data['close'].iloc[0] - 1 # 将基准收益率转换为权益曲线(基于初始资金) initial_capital = equity_df['equity'].iloc[0] benchmark_equity = initial_capital * (1 + benchmark_returns) # 确保基准数据与策略数据日期对齐 benchmark_equity = benchmark_equity.reindex(equity_df.index).ffill().bfill() ax.plot(equity_df.index, benchmark_equity, label='上证指数 (000001.SH)', linestyle='--') else: # 如果没有有效的权益曲线数据,显示一条水平直线(初始资金) ax.plot([], [], label=strategy_name, color='blue') if benchmark_data is not None: # 计算基准收益率 benchmark_returns = benchmark_data['close'] / benchmark_data['close'].iloc[0] - 1 # 将基准收益率转换为权益曲线(基于初始资金) initial_capital = 1000000 # 默认初始资金 benchmark_equity = initial_capital * (1 + benchmark_returns) ax.plot(benchmark_equity.index, benchmark_equity, label='上证指数 (000001.SH)', linestyle='--') ax.set_title(f'{strategy_name} 收益曲线与基准对比') ax.set_xlabel('日期') ax.set_ylabel('权益 (元)') ax.grid(True) ax.legend() # 格式化日期 ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m')) fig.autofmt_xdate() return self._save_chart(fig, f'equity_curve_{strategy_name}.png') def plot_benchmark_comparison(self, results, benchmark_data, strategy_name): """2. 基准收益曲线图""" fig, ax = plt.subplots(figsize=(12, 6)) # 计算策略收益率 equity_df = results['equity_curve'] # 检查权益曲线是否有效 if equity_df is not None and not equity_df.empty and 'equity' in equity_df.columns: strategy_returns = equity_df['equity'] / equity_df['equity'].iloc[0] - 1 ax.plot(strategy_returns.index, strategy_returns * 100, label=strategy_name) else: # 如果没有有效的权益曲线数据,不绘制策略曲线 ax.plot([], [], label=strategy_name) # 计算基准收益率 benchmark_returns = benchmark_data['close'] / benchmark_data['close'].iloc[0] - 1 ax.plot(benchmark_returns.index, benchmark_returns * 100, label='基准') ax.set_title(f'{strategy_name} 与基准收益对比') ax.set_xlabel('日期') ax.set_ylabel('累计收益率 (%)') ax.grid(True) ax.legend() # 格式化日期 ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m')) fig.autofmt_xdate() return self._save_chart(fig, f'benchmark_comparison_{strategy_name}.png') def plot_drawdown_curve(self, results, strategy_name): """3. 回撤曲线图""" fig, ax = plt.subplots(figsize=(12, 6)) equity_df = results['equity_curve'] # 检查权益曲线是否有效 if equity_df is not None and not equity_df.empty and 'equity' in equity_df.columns: equity_df['cummax'] = equity_df['equity'].cummax() equity_df['drawdown'] = (equity_df['equity'] - equity_df['cummax']) / equity_df['cummax'] * 100 ax.fill_between(equity_df.index, equity_df['drawdown'], 0, alpha=0.5, color='red') ax.plot(equity_df.index, equity_df['drawdown'], color='red', linewidth=1) ax.set_ylim(bottom=equity_df['drawdown'].min() * 1.1) else: # 如果没有有效的权益曲线数据,显示一个空图表 ax.set_ylim(bottom=-20) ax.set_title(f'{strategy_name} 回撤曲线') ax.set_xlabel('日期') ax.set_ylabel('回撤 (%)') ax.grid(True) # 格式化日期 ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m')) fig.autofmt_xdate() return self._save_chart(fig, f'drawdown_curve_{strategy_name}.png') def plot_return_distribution(self, results, strategy_name): """4. 收益分布直方图""" fig, ax = plt.subplots(figsize=(10, 6)) equity_df = results['equity_curve'] # 检查权益曲线是否有效 if equity_df is not None and not equity_df.empty and 'returns' in equity_df.columns: daily_returns = equity_df['returns'].dropna() * 100 if not daily_returns.empty: sns.histplot(daily_returns, bins=50, kde=True, ax=ax) else: ax.text(0.5, 0.5, '无收益率数据', ha='center', va='center', transform=ax.transAxes) else: # 如果没有有效的权益曲线数据,显示提示 ax.text(0.5, 0.5, '无收益率数据', ha='center', va='center', transform=ax.transAxes) ax.set_title(f'{strategy_name} 日收益率分布') ax.set_xlabel('日收益率 (%)') ax.set_ylabel('频率') ax.grid(True) return self._save_chart(fig, f'return_distribution_{strategy_name}.png') def plot_monthly_returns(self, results, strategy_name): """5. 月度收益柱状图""" fig, ax = plt.subplots(figsize=(12, 6)) equity_df = results['equity_curve'] # 检查权益曲线是否有效 if equity_df is not None and not equity_df.empty and 'equity' in equity_df.columns: monthly_returns = equity_df['equity'].resample('M').last().pct_change() * 100 monthly_returns = monthly_returns.dropna() if not monthly_returns.empty: colors = ['green' if ret > 0 else 'red' for ret in monthly_returns] ax.bar(monthly_returns.index, monthly_returns, color=colors) else: ax.text(0.5, 0.5, '无月度收益数据', ha='center', va='center', transform=ax.transAxes) else: # 如果没有有效的权益曲线数据,显示提示 ax.text(0.5, 0.5, '无月度收益数据', ha='center', va='center', transform=ax.transAxes) ax.set_title(f'{strategy_name} 月度收益率') ax.set_xlabel('月份') ax.set_ylabel('月度收益率 (%)') ax.grid(True, axis='y') # 格式化日期 ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m')) fig.autofmt_xdate() return self._save_chart(fig, f'monthly_returns_{strategy_name}.png') def plot_holding_value(self, results, strategy_name): """6. 持仓市值曲线图""" fig, ax = plt.subplots(figsize=(12, 6)) equity_df = results['equity_curve'] # 检查权益曲线是否有效 if equity_df is not None and not equity_df.empty: # 绘制持仓市值 if 'positions_value' in equity_df.columns: ax.plot(equity_df.index, equity_df['positions_value'], label='持仓市值') else: ax.plot([], [], label='持仓市值') # 绘制现金 if 'cash' in equity_df.columns: ax.plot(equity_df.index, equity_df['cash'], label='现金') else: ax.plot([], [], label='现金') else: # 如果没有有效的权益曲线数据,绘制空图表 ax.plot([], [], label='持仓市值') ax.plot([], [], label='现金') ax.set_title(f'{strategy_name} 持仓市值变化') ax.set_xlabel('日期') ax.set_ylabel('金额 (元)') ax.grid(True) ax.legend() # 格式化日期 ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m')) fig.autofmt_xdate() return self._save_chart(fig, f'holding_value_{strategy_name}.png') def plot_turnover(self, results, strategy_name): """7. 换手率柱状图""" # 换手率需要从交易数据中计算 fig, ax = plt.subplots(figsize=(12, 6)) trades_df = results['trades'] equity_df = results['equity_curve'] if trades_df.empty or equity_df is None or equity_df.empty or 'positions_value' not in equity_df.columns: ax.text(0.5, 0.5, '无换手率数据', ha='center', va='center', transform=ax.transAxes) ax.set_title(f'{strategy_name} 月度换手率') ax.set_xlabel('月份') ax.set_ylabel('换手率 (%)') ax.grid(True, axis='y') return self._save_chart(fig, f'turnover_{strategy_name}.png') # 按月份分组,计算每个月的交易金额 trades_df['month'] = trades_df['date'].dt.to_period('M') monthly_trades = trades_df.groupby('month')['price'].sum() * trades_df.groupby('month')['shares'].sum() # 计算平均持仓市值作为分母 monthly_holding = equity_df['positions_value'].resample('M').mean() # 计算换手率 turnover = monthly_trades / monthly_holding * 100 turnover = turnover.dropna() if not turnover.empty: ax.bar(turnover.index.astype(str), turnover) else: ax.text(0.5, 0.5, '无换手率数据', ha='center', va='center', transform=ax.transAxes) ax.set_title(f'{strategy_name} 月度换手率') ax.set_xlabel('月份') ax.set_ylabel('换手率 (%)') ax.grid(True, axis='y') fig.autofmt_xdate() return self._save_chart(fig, f'turnover_{strategy_name}.png') def plot_sector_contribution(self, results, strategy_name): """8. 行业/板块收益贡献图""" # 行业贡献需要股票的行业信息,这里假设trades_df包含sector列 if 'sector' not in results['trades'].columns: return None fig, ax = plt.subplots(figsize=(10, 6)) trades_df = results['trades'] sector_returns = {} # 这里需要根据实际数据计算行业贡献 # 示例实现,需要根据实际数据调整 if not trades_df.empty: # 按行业分组,计算每个行业的总盈亏 sector_pnl = trades_df.groupby('sector')['pnl'].sum() # 绘制饼图 fig, ax = plt.subplots(figsize=(10, 8)) ax.pie(sector_pnl, labels=sector_pnl.index, autopct='%1.1f%%', startangle=90) ax.axis('equal') # 保证饼图是圆形 ax.set_title(f'{strategy_name} 行业收益贡献') return self._save_chart(fig, f'sector_contribution_{strategy_name}.png') return None def plot_factor_contribution(self, factor_data, strategy_name): """9. 因子收益贡献图""" # 因子贡献需要因子数据 if not factor_data: return None fig, ax = plt.subplots(figsize=(10, 6)) # 示例实现,需要根据实际因子数据调整 factors = list(factor_data.keys()) contributions = list(factor_data.values()) colors = ['green' if contrib > 0 else 'red' for contrib in contributions] ax.bar(factors, contributions, color=colors) ax.set_title(f'{strategy_name} 因子收益贡献') ax.set_xlabel('因子') ax.set_ylabel('贡献度 (%)') ax.grid(True, axis='y') return self._save_chart(fig, f'factor_contribution_{strategy_name}.png') def plot_signal_distribution(self, results, strategy_name): """10. 交易信号分布图""" fig, ax = plt.subplots(figsize=(12, 6)) trades_df = results['trades'] if trades_df.empty: return None # 计算每个月的买入和卖出信号数量 trades_df['month'] = trades_df['date'].dt.to_period('M') signal_counts = trades_df.groupby(['month', 'action']).size().unstack(fill_value=0) signal_counts.plot(kind='bar', ax=ax) ax.set_title(f'{strategy_name} 交易信号分布') ax.set_xlabel('月份') ax.set_ylabel('信号数量') ax.grid(True, axis='y') ax.legend() fig.autofmt_xdate() return self._save_chart(fig, f'signal_distribution_{strategy_name}.png') def plot_all_trades(self, results, strategy_name): """11. 所有交易记录表格""" trades_df = results['trades'] fig, ax = plt.subplots(figsize=(14, max(8, len(trades_df) * 0.3))) ax.axis('tight') ax.axis('off') if not trades_df.empty: # 选择需要显示的列 display_columns = ['date', 'ts_code', 'action', 'price', 'shares', 'commission', 'pnl', 'pnl_pct', 'reason'] # 确保所有需要的列都存在 available_columns = [col for col in display_columns if col in trades_df.columns] trades_data = trades_df[available_columns].copy() # 格式化数据 if 'price' in trades_data.columns: trades_data['price'] = trades_data['price'].round(2) if 'commission' in trades_data.columns: trades_data['commission'] = trades_data['commission'].round(2) if 'pnl' in trades_data.columns: trades_data['pnl'] = trades_data['pnl'].round(2) if 'pnl_pct' in trades_data.columns: trades_data['pnl_pct'] = (trades_data['pnl_pct'] * 100).round(2) # 创建表格 table = ax.table(cellText=trades_data.values, colLabels=trades_data.columns, cellLoc='center', loc='center') # 设置表格样式 table.auto_set_font_size(False) table.set_fontsize(9) table.scale(1, 1.5) # 设置表头背景色 for i, key in enumerate(trades_data.columns): table[(0, i)].set_facecolor('#40466e') table[(0, i)].set_text_props(color='white') table[(0, i)].set_edgecolor('black') else: ax.text(0.5, 0.5, '无交易记录', ha='center', va='center', transform=ax.transAxes, fontsize=12) ax.set_title(f'{strategy_name} 所有交易记录', fontsize=14, pad=20) return self._save_chart(fig, f'all_trades_{strategy_name}.png') def plot_daily_trades(self, results, strategy_name): """12. 每日买卖记录表格""" trades_df = results['trades'] fig, ax = plt.subplots(figsize=(14, max(8, len(trades_df) * 0.3))) ax.axis('tight') ax.axis('off') if not trades_df.empty: # 按日期分组 daily_trades = trades_df.groupby('date') # 准备数据 daily_data = [] for date, group in daily_trades: # 分别统计买入和卖出 buys = group[group['action'] == 'BUY'] sells = group[group['action'] == 'SELL'] # 买入信息 buy_count = len(buys) buy_volume = buys['shares'].sum() if 'shares' in buys.columns else 0 buy_amount = (buys['price'] * buys['shares']).sum() if 'price' in buys.columns and 'shares' in buys.columns else 0 # 卖出信息 sell_count = len(sells) sell_volume = sells['shares'].sum() if 'shares' in sells.columns else 0 sell_amount = (sells['price'] * sells['shares']).sum() if 'price' in sells.columns and 'shares' in sells.columns else 0 sell_pnl = sells['pnl'].sum() if 'pnl' in sells.columns else 0 daily_data.append({ '日期': date, '买入次数': buy_count, '买入股数': buy_volume, '买入金额': buy_amount, '卖出次数': sell_count, '卖出股数': sell_volume, '卖出金额': sell_amount, '当日盈亏': sell_pnl }) # 创建DataFrame daily_df = pd.DataFrame(daily_data) # 格式化数据 daily_df['买入金额'] = daily_df['买入金额'].round(2) daily_df['卖出金额'] = daily_df['卖出金额'].round(2) daily_df['当日盈亏'] = daily_df['当日盈亏'].round(2) # 创建表格 table = ax.table(cellText=daily_df.values, colLabels=daily_df.columns, cellLoc='center', loc='center') # 设置表格样式 table.auto_set_font_size(False) table.set_fontsize(9) table.scale(1, 1.5) # 设置表头背景色 for i, key in enumerate(daily_df.columns): table[(0, i)].set_facecolor('#40466e') table[(0, i)].set_text_props(color='white') table[(0, i)].set_edgecolor('black') else: ax.text(0.5, 0.5, '无交易记录', ha='center', va='center', transform=ax.transAxes, fontsize=12) ax.set_title(f'{strategy_name} 每日买卖记录', fontsize=14, pad=20) return self._save_chart(fig, f'daily_trades_{strategy_name}.png') def generate_all_charts(self, results, strategy_name, benchmark_data=None, factor_data=None): """生成所有图表""" charts = [] # 生成策略收益曲线图(包含基准对比) charts.append(('收益曲线', self.plot_equity_curve(results, strategy_name, benchmark_data))) # 生成基准对比图(如果有基准数据)- 保持独立的基准对比图 if benchmark_data is not None: charts.append(('基准对比', self.plot_benchmark_comparison(results, benchmark_data, strategy_name))) # 生成回撤曲线图 charts.append(('回撤曲线', self.plot_drawdown_curve(results, strategy_name))) # 生成收益分布直方图 charts.append(('收益分布', self.plot_return_distribution(results, strategy_name))) # 生成月度收益柱状图 charts.append(('月度收益', self.plot_monthly_returns(results, strategy_name))) # 生成持仓市值曲线图 charts.append(('持仓市值', self.plot_holding_value(results, strategy_name))) # 生成换手率柱状图 turnover_chart = self.plot_turnover(results, strategy_name) if turnover_chart: charts.append(('换手率', turnover_chart)) # 生成行业贡献图 sector_chart = self.plot_sector_contribution(results, strategy_name) if sector_chart: charts.append(('行业贡献', sector_chart)) # 生成因子贡献图(如果有因子数据) if factor_data is not None: factor_chart = self.plot_factor_contribution(factor_data, strategy_name) if factor_chart: charts.append(('因子贡献', factor_chart)) # 生成交易信号分布图 signal_chart = self.plot_signal_distribution(results, strategy_name) if signal_chart: charts.append(('信号分布', signal_chart)) # 生成所有交易记录图表 all_trades_chart = self.plot_all_trades(results, strategy_name) charts.append(('所有交易记录', all_trades_chart)) # 生成每日买卖记录图表 daily_trades_chart = self.plot_daily_trades(results, strategy_name) charts.append(('每日买卖记录', daily_trades_chart)) return charts