Files
backtrader/analysis/chart_generator.py
2026-01-17 21:21:30 +08:00

515 lines
22 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from matplotlib.ticker import FuncFormatter
import seaborn as sns
import os
class ChartGenerator:
"""图表生成器,用于创建各种回测分析图表"""
def __init__(self, output_dir='charts', strategy_name=None):
"""初始化图表生成器"""
# 根据策略名称创建子目录结构
self.output_dir = output_dir
if strategy_name:
self.output_dir = os.path.join(output_dir, strategy_name)
os.makedirs(self.output_dir, exist_ok=True)
# 设置图表风格
plt.style.use('ggplot')
plt.rcParams['font.sans-serif'] = ['SimHei'] # 支持中文
plt.rcParams['axes.unicode_minus'] = False # 支持负号
def _save_chart(self, fig, filename, dpi=300):
"""保存图表"""
filepath = os.path.join(self.output_dir, filename)
fig.savefig(filepath, dpi=dpi, bbox_inches='tight')
plt.close(fig)
return filepath
def plot_equity_curve(self, results, strategy_name, benchmark_data=None):
"""1. 策略收益曲线图"""
fig, ax = plt.subplots(figsize=(12, 6))
equity_df = results['equity_curve']
# 检查权益曲线是否为空且包含'equity'列
if equity_df is not None and not equity_df.empty and 'equity' in equity_df.columns:
ax.plot(equity_df.index, equity_df['equity'], label=strategy_name)
# 如果有基准数据,将基准收益曲线叠加到策略收益曲线上
if benchmark_data is not None:
# 计算基准收益率
benchmark_returns = benchmark_data['close'] / benchmark_data['close'].iloc[0] - 1
# 将基准收益率转换为权益曲线(基于初始资金)
initial_capital = equity_df['equity'].iloc[0]
benchmark_equity = initial_capital * (1 + benchmark_returns)
# 确保基准数据与策略数据日期对齐
benchmark_equity = benchmark_equity.reindex(equity_df.index).ffill().bfill()
ax.plot(equity_df.index, benchmark_equity, label='上证指数 (000001.SH)', linestyle='--')
else:
# 如果没有有效的权益曲线数据,显示一条水平直线(初始资金)
ax.plot([], [], label=strategy_name, color='blue')
if benchmark_data is not None:
# 计算基准收益率
benchmark_returns = benchmark_data['close'] / benchmark_data['close'].iloc[0] - 1
# 将基准收益率转换为权益曲线(基于初始资金)
initial_capital = 1000000 # 默认初始资金
benchmark_equity = initial_capital * (1 + benchmark_returns)
ax.plot(benchmark_equity.index, benchmark_equity, label='上证指数 (000001.SH)', linestyle='--')
ax.set_title(f'{strategy_name} 收益曲线与基准对比')
ax.set_xlabel('日期')
ax.set_ylabel('权益 (元)')
ax.grid(True)
ax.legend()
# 格式化日期
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))
fig.autofmt_xdate()
return self._save_chart(fig, f'equity_curve_{strategy_name}.png')
def plot_benchmark_comparison(self, results, benchmark_data, strategy_name):
"""2. 基准收益曲线图"""
fig, ax = plt.subplots(figsize=(12, 6))
# 计算策略收益率
equity_df = results['equity_curve']
# 检查权益曲线是否有效
if equity_df is not None and not equity_df.empty and 'equity' in equity_df.columns:
strategy_returns = equity_df['equity'] / equity_df['equity'].iloc[0] - 1
ax.plot(strategy_returns.index, strategy_returns * 100, label=strategy_name)
else:
# 如果没有有效的权益曲线数据,不绘制策略曲线
ax.plot([], [], label=strategy_name)
# 计算基准收益率
benchmark_returns = benchmark_data['close'] / benchmark_data['close'].iloc[0] - 1
ax.plot(benchmark_returns.index, benchmark_returns * 100, label='基准')
ax.set_title(f'{strategy_name} 与基准收益对比')
ax.set_xlabel('日期')
ax.set_ylabel('累计收益率 (%)')
ax.grid(True)
ax.legend()
# 格式化日期
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))
fig.autofmt_xdate()
return self._save_chart(fig, f'benchmark_comparison_{strategy_name}.png')
def plot_drawdown_curve(self, results, strategy_name):
"""3. 回撤曲线图"""
fig, ax = plt.subplots(figsize=(12, 6))
equity_df = results['equity_curve']
# 检查权益曲线是否有效
if equity_df is not None and not equity_df.empty and 'equity' in equity_df.columns:
equity_df['cummax'] = equity_df['equity'].cummax()
equity_df['drawdown'] = (equity_df['equity'] - equity_df['cummax']) / equity_df['cummax'] * 100
ax.fill_between(equity_df.index, equity_df['drawdown'], 0, alpha=0.5, color='red')
ax.plot(equity_df.index, equity_df['drawdown'], color='red', linewidth=1)
ax.set_ylim(bottom=equity_df['drawdown'].min() * 1.1)
else:
# 如果没有有效的权益曲线数据,显示一个空图表
ax.set_ylim(bottom=-20)
ax.set_title(f'{strategy_name} 回撤曲线')
ax.set_xlabel('日期')
ax.set_ylabel('回撤 (%)')
ax.grid(True)
# 格式化日期
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))
fig.autofmt_xdate()
return self._save_chart(fig, f'drawdown_curve_{strategy_name}.png')
def plot_return_distribution(self, results, strategy_name):
"""4. 收益分布直方图"""
fig, ax = plt.subplots(figsize=(10, 6))
equity_df = results['equity_curve']
# 检查权益曲线是否有效
if equity_df is not None and not equity_df.empty and 'returns' in equity_df.columns:
daily_returns = equity_df['returns'].dropna() * 100
if not daily_returns.empty:
sns.histplot(daily_returns, bins=50, kde=True, ax=ax)
else:
ax.text(0.5, 0.5, '无收益率数据', ha='center', va='center', transform=ax.transAxes)
else:
# 如果没有有效的权益曲线数据,显示提示
ax.text(0.5, 0.5, '无收益率数据', ha='center', va='center', transform=ax.transAxes)
ax.set_title(f'{strategy_name} 日收益率分布')
ax.set_xlabel('日收益率 (%)')
ax.set_ylabel('频率')
ax.grid(True)
return self._save_chart(fig, f'return_distribution_{strategy_name}.png')
def plot_monthly_returns(self, results, strategy_name):
"""5. 月度收益柱状图"""
fig, ax = plt.subplots(figsize=(12, 6))
equity_df = results['equity_curve']
# 检查权益曲线是否有效
if equity_df is not None and not equity_df.empty and 'equity' in equity_df.columns:
monthly_returns = equity_df['equity'].resample('M').last().pct_change() * 100
monthly_returns = monthly_returns.dropna()
if not monthly_returns.empty:
colors = ['green' if ret > 0 else 'red' for ret in monthly_returns]
ax.bar(monthly_returns.index, monthly_returns, color=colors)
else:
ax.text(0.5, 0.5, '无月度收益数据', ha='center', va='center', transform=ax.transAxes)
else:
# 如果没有有效的权益曲线数据,显示提示
ax.text(0.5, 0.5, '无月度收益数据', ha='center', va='center', transform=ax.transAxes)
ax.set_title(f'{strategy_name} 月度收益率')
ax.set_xlabel('月份')
ax.set_ylabel('月度收益率 (%)')
ax.grid(True, axis='y')
# 格式化日期
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))
fig.autofmt_xdate()
return self._save_chart(fig, f'monthly_returns_{strategy_name}.png')
def plot_holding_value(self, results, strategy_name):
"""6. 持仓市值曲线图"""
fig, ax = plt.subplots(figsize=(12, 6))
equity_df = results['equity_curve']
# 检查权益曲线是否有效
if equity_df is not None and not equity_df.empty:
# 绘制持仓市值
if 'positions_value' in equity_df.columns:
ax.plot(equity_df.index, equity_df['positions_value'], label='持仓市值')
else:
ax.plot([], [], label='持仓市值')
# 绘制现金
if 'cash' in equity_df.columns:
ax.plot(equity_df.index, equity_df['cash'], label='现金')
else:
ax.plot([], [], label='现金')
else:
# 如果没有有效的权益曲线数据,绘制空图表
ax.plot([], [], label='持仓市值')
ax.plot([], [], label='现金')
ax.set_title(f'{strategy_name} 持仓市值变化')
ax.set_xlabel('日期')
ax.set_ylabel('金额 (元)')
ax.grid(True)
ax.legend()
# 格式化日期
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))
fig.autofmt_xdate()
return self._save_chart(fig, f'holding_value_{strategy_name}.png')
def plot_turnover(self, results, strategy_name):
"""7. 换手率柱状图"""
# 换手率需要从交易数据中计算
fig, ax = plt.subplots(figsize=(12, 6))
trades_df = results['trades']
equity_df = results['equity_curve']
if trades_df.empty or equity_df is None or equity_df.empty or 'positions_value' not in equity_df.columns:
ax.text(0.5, 0.5, '无换手率数据', ha='center', va='center', transform=ax.transAxes)
ax.set_title(f'{strategy_name} 月度换手率')
ax.set_xlabel('月份')
ax.set_ylabel('换手率 (%)')
ax.grid(True, axis='y')
return self._save_chart(fig, f'turnover_{strategy_name}.png')
# 按月份分组,计算每个月的交易金额
trades_df['month'] = trades_df['date'].dt.to_period('M')
monthly_trades = trades_df.groupby('month')['price'].sum() * trades_df.groupby('month')['shares'].sum()
# 计算平均持仓市值作为分母
monthly_holding = equity_df['positions_value'].resample('M').mean()
# 计算换手率
turnover = monthly_trades / monthly_holding * 100
turnover = turnover.dropna()
if not turnover.empty:
ax.bar(turnover.index.astype(str), turnover)
else:
ax.text(0.5, 0.5, '无换手率数据', ha='center', va='center', transform=ax.transAxes)
ax.set_title(f'{strategy_name} 月度换手率')
ax.set_xlabel('月份')
ax.set_ylabel('换手率 (%)')
ax.grid(True, axis='y')
fig.autofmt_xdate()
return self._save_chart(fig, f'turnover_{strategy_name}.png')
def plot_sector_contribution(self, results, strategy_name):
"""8. 行业/板块收益贡献图"""
# 行业贡献需要股票的行业信息这里假设trades_df包含sector列
if 'sector' not in results['trades'].columns:
return None
fig, ax = plt.subplots(figsize=(10, 6))
trades_df = results['trades']
sector_returns = {} # 这里需要根据实际数据计算行业贡献
# 示例实现,需要根据实际数据调整
if not trades_df.empty:
# 按行业分组,计算每个行业的总盈亏
sector_pnl = trades_df.groupby('sector')['pnl'].sum()
# 绘制饼图
fig, ax = plt.subplots(figsize=(10, 8))
ax.pie(sector_pnl, labels=sector_pnl.index, autopct='%1.1f%%', startangle=90)
ax.axis('equal') # 保证饼图是圆形
ax.set_title(f'{strategy_name} 行业收益贡献')
return self._save_chart(fig, f'sector_contribution_{strategy_name}.png')
return None
def plot_factor_contribution(self, factor_data, strategy_name):
"""9. 因子收益贡献图"""
# 因子贡献需要因子数据
if not factor_data:
return None
fig, ax = plt.subplots(figsize=(10, 6))
# 示例实现,需要根据实际因子数据调整
factors = list(factor_data.keys())
contributions = list(factor_data.values())
colors = ['green' if contrib > 0 else 'red' for contrib in contributions]
ax.bar(factors, contributions, color=colors)
ax.set_title(f'{strategy_name} 因子收益贡献')
ax.set_xlabel('因子')
ax.set_ylabel('贡献度 (%)')
ax.grid(True, axis='y')
return self._save_chart(fig, f'factor_contribution_{strategy_name}.png')
def plot_signal_distribution(self, results, strategy_name):
"""10. 交易信号分布图"""
fig, ax = plt.subplots(figsize=(12, 6))
trades_df = results['trades']
if trades_df.empty:
return None
# 计算每个月的买入和卖出信号数量
trades_df['month'] = trades_df['date'].dt.to_period('M')
signal_counts = trades_df.groupby(['month', 'action']).size().unstack(fill_value=0)
signal_counts.plot(kind='bar', ax=ax)
ax.set_title(f'{strategy_name} 交易信号分布')
ax.set_xlabel('月份')
ax.set_ylabel('信号数量')
ax.grid(True, axis='y')
ax.legend()
fig.autofmt_xdate()
return self._save_chart(fig, f'signal_distribution_{strategy_name}.png')
def plot_all_trades(self, results, strategy_name):
"""11. 所有交易记录表格"""
trades_df = results['trades']
fig, ax = plt.subplots(figsize=(14, max(8, len(trades_df) * 0.3)))
ax.axis('tight')
ax.axis('off')
if not trades_df.empty:
# 选择需要显示的列
display_columns = ['date', 'ts_code', 'action', 'price', 'shares', 'commission',
'pnl', 'pnl_pct', 'reason']
# 确保所有需要的列都存在
available_columns = [col for col in display_columns if col in trades_df.columns]
trades_data = trades_df[available_columns].copy()
# 格式化数据
if 'price' in trades_data.columns:
trades_data['price'] = trades_data['price'].round(2)
if 'commission' in trades_data.columns:
trades_data['commission'] = trades_data['commission'].round(2)
if 'pnl' in trades_data.columns:
trades_data['pnl'] = trades_data['pnl'].round(2)
if 'pnl_pct' in trades_data.columns:
trades_data['pnl_pct'] = (trades_data['pnl_pct'] * 100).round(2)
# 创建表格
table = ax.table(cellText=trades_data.values,
colLabels=trades_data.columns,
cellLoc='center',
loc='center')
# 设置表格样式
table.auto_set_font_size(False)
table.set_fontsize(9)
table.scale(1, 1.5)
# 设置表头背景色
for i, key in enumerate(trades_data.columns):
table[(0, i)].set_facecolor('#40466e')
table[(0, i)].set_text_props(color='white')
table[(0, i)].set_edgecolor('black')
else:
ax.text(0.5, 0.5, '无交易记录', ha='center', va='center', transform=ax.transAxes, fontsize=12)
ax.set_title(f'{strategy_name} 所有交易记录', fontsize=14, pad=20)
return self._save_chart(fig, f'all_trades_{strategy_name}.png')
def plot_daily_trades(self, results, strategy_name):
"""12. 每日买卖记录表格"""
trades_df = results['trades']
fig, ax = plt.subplots(figsize=(14, max(8, len(trades_df) * 0.3)))
ax.axis('tight')
ax.axis('off')
if not trades_df.empty:
# 按日期分组
daily_trades = trades_df.groupby('date')
# 准备数据
daily_data = []
for date, group in daily_trades:
# 分别统计买入和卖出
buys = group[group['action'] == 'BUY']
sells = group[group['action'] == 'SELL']
# 买入信息
buy_count = len(buys)
buy_volume = buys['shares'].sum() if 'shares' in buys.columns else 0
buy_amount = (buys['price'] * buys['shares']).sum() if 'price' in buys.columns and 'shares' in buys.columns else 0
# 卖出信息
sell_count = len(sells)
sell_volume = sells['shares'].sum() if 'shares' in sells.columns else 0
sell_amount = (sells['price'] * sells['shares']).sum() if 'price' in sells.columns and 'shares' in sells.columns else 0
sell_pnl = sells['pnl'].sum() if 'pnl' in sells.columns else 0
daily_data.append({
'日期': date,
'买入次数': buy_count,
'买入股数': buy_volume,
'买入金额': buy_amount,
'卖出次数': sell_count,
'卖出股数': sell_volume,
'卖出金额': sell_amount,
'当日盈亏': sell_pnl
})
# 创建DataFrame
daily_df = pd.DataFrame(daily_data)
# 格式化数据
daily_df['买入金额'] = daily_df['买入金额'].round(2)
daily_df['卖出金额'] = daily_df['卖出金额'].round(2)
daily_df['当日盈亏'] = daily_df['当日盈亏'].round(2)
# 创建表格
table = ax.table(cellText=daily_df.values,
colLabels=daily_df.columns,
cellLoc='center',
loc='center')
# 设置表格样式
table.auto_set_font_size(False)
table.set_fontsize(9)
table.scale(1, 1.5)
# 设置表头背景色
for i, key in enumerate(daily_df.columns):
table[(0, i)].set_facecolor('#40466e')
table[(0, i)].set_text_props(color='white')
table[(0, i)].set_edgecolor('black')
else:
ax.text(0.5, 0.5, '无交易记录', ha='center', va='center', transform=ax.transAxes, fontsize=12)
ax.set_title(f'{strategy_name} 每日买卖记录', fontsize=14, pad=20)
return self._save_chart(fig, f'daily_trades_{strategy_name}.png')
def generate_all_charts(self, results, strategy_name, benchmark_data=None, factor_data=None):
"""生成所有图表"""
charts = []
# 生成策略收益曲线图(包含基准对比)
charts.append(('收益曲线', self.plot_equity_curve(results, strategy_name, benchmark_data)))
# 生成基准对比图(如果有基准数据)- 保持独立的基准对比图
if benchmark_data is not None:
charts.append(('基准对比', self.plot_benchmark_comparison(results, benchmark_data, strategy_name)))
# 生成回撤曲线图
charts.append(('回撤曲线', self.plot_drawdown_curve(results, strategy_name)))
# 生成收益分布直方图
charts.append(('收益分布', self.plot_return_distribution(results, strategy_name)))
# 生成月度收益柱状图
charts.append(('月度收益', self.plot_monthly_returns(results, strategy_name)))
# 生成持仓市值曲线图
charts.append(('持仓市值', self.plot_holding_value(results, strategy_name)))
# 生成换手率柱状图
turnover_chart = self.plot_turnover(results, strategy_name)
if turnover_chart:
charts.append(('换手率', turnover_chart))
# 生成行业贡献图
sector_chart = self.plot_sector_contribution(results, strategy_name)
if sector_chart:
charts.append(('行业贡献', sector_chart))
# 生成因子贡献图(如果有因子数据)
if factor_data is not None:
factor_chart = self.plot_factor_contribution(factor_data, strategy_name)
if factor_chart:
charts.append(('因子贡献', factor_chart))
# 生成交易信号分布图
signal_chart = self.plot_signal_distribution(results, strategy_name)
if signal_chart:
charts.append(('信号分布', signal_chart))
# 生成所有交易记录图表
all_trades_chart = self.plot_all_trades(results, strategy_name)
charts.append(('所有交易记录', all_trades_chart))
# 生成每日买卖记录图表
daily_trades_chart = self.plot_daily_trades(results, strategy_name)
charts.append(('每日买卖记录', daily_trades_chart))
return charts