515 lines
22 KiB
Python
515 lines
22 KiB
Python
import pandas as pd
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
import matplotlib.dates as mdates
|
||
from matplotlib.ticker import FuncFormatter
|
||
import seaborn as sns
|
||
import os
|
||
|
||
class ChartGenerator:
|
||
"""图表生成器,用于创建各种回测分析图表"""
|
||
|
||
def __init__(self, output_dir='charts', strategy_name=None):
|
||
"""初始化图表生成器"""
|
||
# 根据策略名称创建子目录结构
|
||
self.output_dir = output_dir
|
||
if strategy_name:
|
||
self.output_dir = os.path.join(output_dir, strategy_name)
|
||
os.makedirs(self.output_dir, exist_ok=True)
|
||
|
||
# 设置图表风格
|
||
plt.style.use('ggplot')
|
||
plt.rcParams['font.sans-serif'] = ['SimHei'] # 支持中文
|
||
plt.rcParams['axes.unicode_minus'] = False # 支持负号
|
||
|
||
def _save_chart(self, fig, filename, dpi=300):
|
||
"""保存图表"""
|
||
filepath = os.path.join(self.output_dir, filename)
|
||
fig.savefig(filepath, dpi=dpi, bbox_inches='tight')
|
||
plt.close(fig)
|
||
return filepath
|
||
|
||
def plot_equity_curve(self, results, strategy_name, benchmark_data=None):
|
||
"""1. 策略收益曲线图"""
|
||
fig, ax = plt.subplots(figsize=(12, 6))
|
||
|
||
equity_df = results['equity_curve']
|
||
|
||
# 检查权益曲线是否为空且包含'equity'列
|
||
if equity_df is not None and not equity_df.empty and 'equity' in equity_df.columns:
|
||
ax.plot(equity_df.index, equity_df['equity'], label=strategy_name)
|
||
|
||
# 如果有基准数据,将基准收益曲线叠加到策略收益曲线上
|
||
if benchmark_data is not None:
|
||
# 计算基准收益率
|
||
benchmark_returns = benchmark_data['close'] / benchmark_data['close'].iloc[0] - 1
|
||
# 将基准收益率转换为权益曲线(基于初始资金)
|
||
initial_capital = equity_df['equity'].iloc[0]
|
||
benchmark_equity = initial_capital * (1 + benchmark_returns)
|
||
# 确保基准数据与策略数据日期对齐
|
||
benchmark_equity = benchmark_equity.reindex(equity_df.index).ffill().bfill()
|
||
ax.plot(equity_df.index, benchmark_equity, label='上证指数 (000001.SH)', linestyle='--')
|
||
else:
|
||
# 如果没有有效的权益曲线数据,显示一条水平直线(初始资金)
|
||
ax.plot([], [], label=strategy_name, color='blue')
|
||
if benchmark_data is not None:
|
||
# 计算基准收益率
|
||
benchmark_returns = benchmark_data['close'] / benchmark_data['close'].iloc[0] - 1
|
||
# 将基准收益率转换为权益曲线(基于初始资金)
|
||
initial_capital = 1000000 # 默认初始资金
|
||
benchmark_equity = initial_capital * (1 + benchmark_returns)
|
||
ax.plot(benchmark_equity.index, benchmark_equity, label='上证指数 (000001.SH)', linestyle='--')
|
||
|
||
ax.set_title(f'{strategy_name} 收益曲线与基准对比')
|
||
ax.set_xlabel('日期')
|
||
ax.set_ylabel('权益 (元)')
|
||
ax.grid(True)
|
||
ax.legend()
|
||
|
||
# 格式化日期
|
||
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))
|
||
fig.autofmt_xdate()
|
||
|
||
return self._save_chart(fig, f'equity_curve_{strategy_name}.png')
|
||
|
||
def plot_benchmark_comparison(self, results, benchmark_data, strategy_name):
|
||
"""2. 基准收益曲线图"""
|
||
fig, ax = plt.subplots(figsize=(12, 6))
|
||
|
||
# 计算策略收益率
|
||
equity_df = results['equity_curve']
|
||
|
||
# 检查权益曲线是否有效
|
||
if equity_df is not None and not equity_df.empty and 'equity' in equity_df.columns:
|
||
strategy_returns = equity_df['equity'] / equity_df['equity'].iloc[0] - 1
|
||
ax.plot(strategy_returns.index, strategy_returns * 100, label=strategy_name)
|
||
else:
|
||
# 如果没有有效的权益曲线数据,不绘制策略曲线
|
||
ax.plot([], [], label=strategy_name)
|
||
|
||
# 计算基准收益率
|
||
benchmark_returns = benchmark_data['close'] / benchmark_data['close'].iloc[0] - 1
|
||
ax.plot(benchmark_returns.index, benchmark_returns * 100, label='基准')
|
||
|
||
ax.set_title(f'{strategy_name} 与基准收益对比')
|
||
ax.set_xlabel('日期')
|
||
ax.set_ylabel('累计收益率 (%)')
|
||
ax.grid(True)
|
||
ax.legend()
|
||
|
||
# 格式化日期
|
||
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))
|
||
fig.autofmt_xdate()
|
||
|
||
return self._save_chart(fig, f'benchmark_comparison_{strategy_name}.png')
|
||
|
||
def plot_drawdown_curve(self, results, strategy_name):
|
||
"""3. 回撤曲线图"""
|
||
fig, ax = plt.subplots(figsize=(12, 6))
|
||
|
||
equity_df = results['equity_curve']
|
||
|
||
# 检查权益曲线是否有效
|
||
if equity_df is not None and not equity_df.empty and 'equity' in equity_df.columns:
|
||
equity_df['cummax'] = equity_df['equity'].cummax()
|
||
equity_df['drawdown'] = (equity_df['equity'] - equity_df['cummax']) / equity_df['cummax'] * 100
|
||
|
||
ax.fill_between(equity_df.index, equity_df['drawdown'], 0, alpha=0.5, color='red')
|
||
ax.plot(equity_df.index, equity_df['drawdown'], color='red', linewidth=1)
|
||
|
||
ax.set_ylim(bottom=equity_df['drawdown'].min() * 1.1)
|
||
else:
|
||
# 如果没有有效的权益曲线数据,显示一个空图表
|
||
ax.set_ylim(bottom=-20)
|
||
|
||
ax.set_title(f'{strategy_name} 回撤曲线')
|
||
ax.set_xlabel('日期')
|
||
ax.set_ylabel('回撤 (%)')
|
||
ax.grid(True)
|
||
|
||
# 格式化日期
|
||
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))
|
||
fig.autofmt_xdate()
|
||
|
||
return self._save_chart(fig, f'drawdown_curve_{strategy_name}.png')
|
||
|
||
def plot_return_distribution(self, results, strategy_name):
|
||
"""4. 收益分布直方图"""
|
||
fig, ax = plt.subplots(figsize=(10, 6))
|
||
|
||
equity_df = results['equity_curve']
|
||
|
||
# 检查权益曲线是否有效
|
||
if equity_df is not None and not equity_df.empty and 'returns' in equity_df.columns:
|
||
daily_returns = equity_df['returns'].dropna() * 100
|
||
if not daily_returns.empty:
|
||
sns.histplot(daily_returns, bins=50, kde=True, ax=ax)
|
||
else:
|
||
ax.text(0.5, 0.5, '无收益率数据', ha='center', va='center', transform=ax.transAxes)
|
||
else:
|
||
# 如果没有有效的权益曲线数据,显示提示
|
||
ax.text(0.5, 0.5, '无收益率数据', ha='center', va='center', transform=ax.transAxes)
|
||
|
||
ax.set_title(f'{strategy_name} 日收益率分布')
|
||
ax.set_xlabel('日收益率 (%)')
|
||
ax.set_ylabel('频率')
|
||
ax.grid(True)
|
||
|
||
return self._save_chart(fig, f'return_distribution_{strategy_name}.png')
|
||
|
||
def plot_monthly_returns(self, results, strategy_name):
|
||
"""5. 月度收益柱状图"""
|
||
fig, ax = plt.subplots(figsize=(12, 6))
|
||
|
||
equity_df = results['equity_curve']
|
||
|
||
# 检查权益曲线是否有效
|
||
if equity_df is not None and not equity_df.empty and 'equity' in equity_df.columns:
|
||
monthly_returns = equity_df['equity'].resample('M').last().pct_change() * 100
|
||
monthly_returns = monthly_returns.dropna()
|
||
|
||
if not monthly_returns.empty:
|
||
colors = ['green' if ret > 0 else 'red' for ret in monthly_returns]
|
||
ax.bar(monthly_returns.index, monthly_returns, color=colors)
|
||
else:
|
||
ax.text(0.5, 0.5, '无月度收益数据', ha='center', va='center', transform=ax.transAxes)
|
||
else:
|
||
# 如果没有有效的权益曲线数据,显示提示
|
||
ax.text(0.5, 0.5, '无月度收益数据', ha='center', va='center', transform=ax.transAxes)
|
||
|
||
ax.set_title(f'{strategy_name} 月度收益率')
|
||
ax.set_xlabel('月份')
|
||
ax.set_ylabel('月度收益率 (%)')
|
||
ax.grid(True, axis='y')
|
||
|
||
# 格式化日期
|
||
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))
|
||
fig.autofmt_xdate()
|
||
|
||
return self._save_chart(fig, f'monthly_returns_{strategy_name}.png')
|
||
|
||
def plot_holding_value(self, results, strategy_name):
|
||
"""6. 持仓市值曲线图"""
|
||
fig, ax = plt.subplots(figsize=(12, 6))
|
||
|
||
equity_df = results['equity_curve']
|
||
|
||
# 检查权益曲线是否有效
|
||
if equity_df is not None and not equity_df.empty:
|
||
# 绘制持仓市值
|
||
if 'positions_value' in equity_df.columns:
|
||
ax.plot(equity_df.index, equity_df['positions_value'], label='持仓市值')
|
||
else:
|
||
ax.plot([], [], label='持仓市值')
|
||
|
||
# 绘制现金
|
||
if 'cash' in equity_df.columns:
|
||
ax.plot(equity_df.index, equity_df['cash'], label='现金')
|
||
else:
|
||
ax.plot([], [], label='现金')
|
||
else:
|
||
# 如果没有有效的权益曲线数据,绘制空图表
|
||
ax.plot([], [], label='持仓市值')
|
||
ax.plot([], [], label='现金')
|
||
|
||
ax.set_title(f'{strategy_name} 持仓市值变化')
|
||
ax.set_xlabel('日期')
|
||
ax.set_ylabel('金额 (元)')
|
||
ax.grid(True)
|
||
ax.legend()
|
||
|
||
# 格式化日期
|
||
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))
|
||
fig.autofmt_xdate()
|
||
|
||
return self._save_chart(fig, f'holding_value_{strategy_name}.png')
|
||
|
||
def plot_turnover(self, results, strategy_name):
|
||
"""7. 换手率柱状图"""
|
||
# 换手率需要从交易数据中计算
|
||
fig, ax = plt.subplots(figsize=(12, 6))
|
||
|
||
trades_df = results['trades']
|
||
equity_df = results['equity_curve']
|
||
|
||
if trades_df.empty or equity_df is None or equity_df.empty or 'positions_value' not in equity_df.columns:
|
||
ax.text(0.5, 0.5, '无换手率数据', ha='center', va='center', transform=ax.transAxes)
|
||
ax.set_title(f'{strategy_name} 月度换手率')
|
||
ax.set_xlabel('月份')
|
||
ax.set_ylabel('换手率 (%)')
|
||
ax.grid(True, axis='y')
|
||
return self._save_chart(fig, f'turnover_{strategy_name}.png')
|
||
|
||
# 按月份分组,计算每个月的交易金额
|
||
trades_df['month'] = trades_df['date'].dt.to_period('M')
|
||
monthly_trades = trades_df.groupby('month')['price'].sum() * trades_df.groupby('month')['shares'].sum()
|
||
|
||
# 计算平均持仓市值作为分母
|
||
monthly_holding = equity_df['positions_value'].resample('M').mean()
|
||
|
||
# 计算换手率
|
||
turnover = monthly_trades / monthly_holding * 100
|
||
turnover = turnover.dropna()
|
||
|
||
if not turnover.empty:
|
||
ax.bar(turnover.index.astype(str), turnover)
|
||
else:
|
||
ax.text(0.5, 0.5, '无换手率数据', ha='center', va='center', transform=ax.transAxes)
|
||
|
||
ax.set_title(f'{strategy_name} 月度换手率')
|
||
ax.set_xlabel('月份')
|
||
ax.set_ylabel('换手率 (%)')
|
||
ax.grid(True, axis='y')
|
||
|
||
fig.autofmt_xdate()
|
||
|
||
return self._save_chart(fig, f'turnover_{strategy_name}.png')
|
||
|
||
def plot_sector_contribution(self, results, strategy_name):
|
||
"""8. 行业/板块收益贡献图"""
|
||
# 行业贡献需要股票的行业信息,这里假设trades_df包含sector列
|
||
if 'sector' not in results['trades'].columns:
|
||
return None
|
||
|
||
fig, ax = plt.subplots(figsize=(10, 6))
|
||
|
||
trades_df = results['trades']
|
||
sector_returns = {} # 这里需要根据实际数据计算行业贡献
|
||
|
||
# 示例实现,需要根据实际数据调整
|
||
if not trades_df.empty:
|
||
# 按行业分组,计算每个行业的总盈亏
|
||
sector_pnl = trades_df.groupby('sector')['pnl'].sum()
|
||
|
||
# 绘制饼图
|
||
fig, ax = plt.subplots(figsize=(10, 8))
|
||
ax.pie(sector_pnl, labels=sector_pnl.index, autopct='%1.1f%%', startangle=90)
|
||
ax.axis('equal') # 保证饼图是圆形
|
||
|
||
ax.set_title(f'{strategy_name} 行业收益贡献')
|
||
|
||
return self._save_chart(fig, f'sector_contribution_{strategy_name}.png')
|
||
|
||
return None
|
||
|
||
def plot_factor_contribution(self, factor_data, strategy_name):
|
||
"""9. 因子收益贡献图"""
|
||
# 因子贡献需要因子数据
|
||
if not factor_data:
|
||
return None
|
||
|
||
fig, ax = plt.subplots(figsize=(10, 6))
|
||
|
||
# 示例实现,需要根据实际因子数据调整
|
||
factors = list(factor_data.keys())
|
||
contributions = list(factor_data.values())
|
||
|
||
colors = ['green' if contrib > 0 else 'red' for contrib in contributions]
|
||
ax.bar(factors, contributions, color=colors)
|
||
|
||
ax.set_title(f'{strategy_name} 因子收益贡献')
|
||
ax.set_xlabel('因子')
|
||
ax.set_ylabel('贡献度 (%)')
|
||
ax.grid(True, axis='y')
|
||
|
||
return self._save_chart(fig, f'factor_contribution_{strategy_name}.png')
|
||
|
||
def plot_signal_distribution(self, results, strategy_name):
|
||
"""10. 交易信号分布图"""
|
||
fig, ax = plt.subplots(figsize=(12, 6))
|
||
|
||
trades_df = results['trades']
|
||
if trades_df.empty:
|
||
return None
|
||
|
||
# 计算每个月的买入和卖出信号数量
|
||
trades_df['month'] = trades_df['date'].dt.to_period('M')
|
||
signal_counts = trades_df.groupby(['month', 'action']).size().unstack(fill_value=0)
|
||
|
||
signal_counts.plot(kind='bar', ax=ax)
|
||
|
||
ax.set_title(f'{strategy_name} 交易信号分布')
|
||
ax.set_xlabel('月份')
|
||
ax.set_ylabel('信号数量')
|
||
ax.grid(True, axis='y')
|
||
ax.legend()
|
||
|
||
fig.autofmt_xdate()
|
||
|
||
return self._save_chart(fig, f'signal_distribution_{strategy_name}.png')
|
||
|
||
def plot_all_trades(self, results, strategy_name):
|
||
"""11. 所有交易记录表格"""
|
||
trades_df = results['trades']
|
||
|
||
fig, ax = plt.subplots(figsize=(14, max(8, len(trades_df) * 0.3)))
|
||
ax.axis('tight')
|
||
ax.axis('off')
|
||
|
||
if not trades_df.empty:
|
||
# 选择需要显示的列
|
||
display_columns = ['date', 'ts_code', 'action', 'price', 'shares', 'commission',
|
||
'pnl', 'pnl_pct', 'reason']
|
||
|
||
# 确保所有需要的列都存在
|
||
available_columns = [col for col in display_columns if col in trades_df.columns]
|
||
trades_data = trades_df[available_columns].copy()
|
||
|
||
# 格式化数据
|
||
if 'price' in trades_data.columns:
|
||
trades_data['price'] = trades_data['price'].round(2)
|
||
if 'commission' in trades_data.columns:
|
||
trades_data['commission'] = trades_data['commission'].round(2)
|
||
if 'pnl' in trades_data.columns:
|
||
trades_data['pnl'] = trades_data['pnl'].round(2)
|
||
if 'pnl_pct' in trades_data.columns:
|
||
trades_data['pnl_pct'] = (trades_data['pnl_pct'] * 100).round(2)
|
||
|
||
# 创建表格
|
||
table = ax.table(cellText=trades_data.values,
|
||
colLabels=trades_data.columns,
|
||
cellLoc='center',
|
||
loc='center')
|
||
|
||
# 设置表格样式
|
||
table.auto_set_font_size(False)
|
||
table.set_fontsize(9)
|
||
table.scale(1, 1.5)
|
||
|
||
# 设置表头背景色
|
||
for i, key in enumerate(trades_data.columns):
|
||
table[(0, i)].set_facecolor('#40466e')
|
||
table[(0, i)].set_text_props(color='white')
|
||
table[(0, i)].set_edgecolor('black')
|
||
else:
|
||
ax.text(0.5, 0.5, '无交易记录', ha='center', va='center', transform=ax.transAxes, fontsize=12)
|
||
|
||
ax.set_title(f'{strategy_name} 所有交易记录', fontsize=14, pad=20)
|
||
|
||
return self._save_chart(fig, f'all_trades_{strategy_name}.png')
|
||
|
||
def plot_daily_trades(self, results, strategy_name):
|
||
"""12. 每日买卖记录表格"""
|
||
trades_df = results['trades']
|
||
|
||
fig, ax = plt.subplots(figsize=(14, max(8, len(trades_df) * 0.3)))
|
||
ax.axis('tight')
|
||
ax.axis('off')
|
||
|
||
if not trades_df.empty:
|
||
# 按日期分组
|
||
daily_trades = trades_df.groupby('date')
|
||
|
||
# 准备数据
|
||
daily_data = []
|
||
for date, group in daily_trades:
|
||
# 分别统计买入和卖出
|
||
buys = group[group['action'] == 'BUY']
|
||
sells = group[group['action'] == 'SELL']
|
||
|
||
# 买入信息
|
||
buy_count = len(buys)
|
||
buy_volume = buys['shares'].sum() if 'shares' in buys.columns else 0
|
||
buy_amount = (buys['price'] * buys['shares']).sum() if 'price' in buys.columns and 'shares' in buys.columns else 0
|
||
|
||
# 卖出信息
|
||
sell_count = len(sells)
|
||
sell_volume = sells['shares'].sum() if 'shares' in sells.columns else 0
|
||
sell_amount = (sells['price'] * sells['shares']).sum() if 'price' in sells.columns and 'shares' in sells.columns else 0
|
||
sell_pnl = sells['pnl'].sum() if 'pnl' in sells.columns else 0
|
||
|
||
daily_data.append({
|
||
'日期': date,
|
||
'买入次数': buy_count,
|
||
'买入股数': buy_volume,
|
||
'买入金额': buy_amount,
|
||
'卖出次数': sell_count,
|
||
'卖出股数': sell_volume,
|
||
'卖出金额': sell_amount,
|
||
'当日盈亏': sell_pnl
|
||
})
|
||
|
||
# 创建DataFrame
|
||
daily_df = pd.DataFrame(daily_data)
|
||
|
||
# 格式化数据
|
||
daily_df['买入金额'] = daily_df['买入金额'].round(2)
|
||
daily_df['卖出金额'] = daily_df['卖出金额'].round(2)
|
||
daily_df['当日盈亏'] = daily_df['当日盈亏'].round(2)
|
||
|
||
# 创建表格
|
||
table = ax.table(cellText=daily_df.values,
|
||
colLabels=daily_df.columns,
|
||
cellLoc='center',
|
||
loc='center')
|
||
|
||
# 设置表格样式
|
||
table.auto_set_font_size(False)
|
||
table.set_fontsize(9)
|
||
table.scale(1, 1.5)
|
||
|
||
# 设置表头背景色
|
||
for i, key in enumerate(daily_df.columns):
|
||
table[(0, i)].set_facecolor('#40466e')
|
||
table[(0, i)].set_text_props(color='white')
|
||
table[(0, i)].set_edgecolor('black')
|
||
else:
|
||
ax.text(0.5, 0.5, '无交易记录', ha='center', va='center', transform=ax.transAxes, fontsize=12)
|
||
|
||
ax.set_title(f'{strategy_name} 每日买卖记录', fontsize=14, pad=20)
|
||
|
||
return self._save_chart(fig, f'daily_trades_{strategy_name}.png')
|
||
|
||
def generate_all_charts(self, results, strategy_name, benchmark_data=None, factor_data=None):
|
||
"""生成所有图表"""
|
||
charts = []
|
||
|
||
# 生成策略收益曲线图(包含基准对比)
|
||
charts.append(('收益曲线', self.plot_equity_curve(results, strategy_name, benchmark_data)))
|
||
|
||
# 生成基准对比图(如果有基准数据)- 保持独立的基准对比图
|
||
if benchmark_data is not None:
|
||
charts.append(('基准对比', self.plot_benchmark_comparison(results, benchmark_data, strategy_name)))
|
||
|
||
# 生成回撤曲线图
|
||
charts.append(('回撤曲线', self.plot_drawdown_curve(results, strategy_name)))
|
||
|
||
# 生成收益分布直方图
|
||
charts.append(('收益分布', self.plot_return_distribution(results, strategy_name)))
|
||
|
||
# 生成月度收益柱状图
|
||
charts.append(('月度收益', self.plot_monthly_returns(results, strategy_name)))
|
||
|
||
# 生成持仓市值曲线图
|
||
charts.append(('持仓市值', self.plot_holding_value(results, strategy_name)))
|
||
|
||
# 生成换手率柱状图
|
||
turnover_chart = self.plot_turnover(results, strategy_name)
|
||
if turnover_chart:
|
||
charts.append(('换手率', turnover_chart))
|
||
|
||
# 生成行业贡献图
|
||
sector_chart = self.plot_sector_contribution(results, strategy_name)
|
||
if sector_chart:
|
||
charts.append(('行业贡献', sector_chart))
|
||
|
||
# 生成因子贡献图(如果有因子数据)
|
||
if factor_data is not None:
|
||
factor_chart = self.plot_factor_contribution(factor_data, strategy_name)
|
||
if factor_chart:
|
||
charts.append(('因子贡献', factor_chart))
|
||
|
||
# 生成交易信号分布图
|
||
signal_chart = self.plot_signal_distribution(results, strategy_name)
|
||
if signal_chart:
|
||
charts.append(('信号分布', signal_chart))
|
||
|
||
# 生成所有交易记录图表
|
||
all_trades_chart = self.plot_all_trades(results, strategy_name)
|
||
charts.append(('所有交易记录', all_trades_chart))
|
||
|
||
# 生成每日买卖记录图表
|
||
daily_trades_chart = self.plot_daily_trades(results, strategy_name)
|
||
charts.append(('每日买卖记录', daily_trades_chart))
|
||
|
||
return charts |