"""资金曲线绘图模块。 从资金曲线 CSV 生成 PNG 图表,保存到 results/plots 目录。 支持策略与基准指数的对比图。 """ from __future__ import annotations import os from typing import Optional import matplotlib.pyplot as plt import pandas as pd from utils.logger import setup_logger logger = setup_logger(__name__) def plot_equity_curve( csv_path: str, output_path: str, benchmark_df: Optional[pd.DataFrame] = None, benchmark_name: str = "基准指数", ) -> None: """根据资金曲线 CSV 绘制资金曲线图。 参数: - csv_path: 资金曲线 CSV 路径,包含列 ['trade_date', 'total_asset']; - output_path: 输出 PNG 文件路径。 - benchmark_df: 基准指数数据(可选),包含 ['trade_date', 'benchmark_value'] - benchmark_name: 基准指数名称 """ if not os.path.exists(csv_path): logger.warning(f"资金曲线文件不存在: {csv_path}") return df = pd.read_csv(csv_path) if df.empty: logger.warning("资金曲线文件为空,无法绘图") return df = df.sort_values("trade_date") # 归一化策略收益(以初始资产为1) initial_asset = df["total_asset"].iloc[0] df["strategy_value"] = df["total_asset"] / initial_asset os.makedirs(os.path.dirname(output_path), exist_ok=True) # 设置中文字体 plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'Arial Unicode MS'] plt.rcParams['axes.unicode_minus'] = False plt.figure(figsize=(14, 7)) # 绘制策略收益曲线 plt.plot(range(len(df)), df["strategy_value"], label="策略收益", linewidth=2, color='blue') # 如果有基准数据,绘制基准收益曲线 if benchmark_df is not None and not benchmark_df.empty: # 确保 trade_date 类型一致(都转为字符串) df["trade_date"] = df["trade_date"].astype(str) benchmark_df_copy = benchmark_df.copy() benchmark_df_copy["trade_date"] = benchmark_df_copy["trade_date"].astype(str) # 合并数据(按交易日) merged = pd.merge( df[["trade_date", "strategy_value"]], benchmark_df_copy[["trade_date", "benchmark_value"]], on="trade_date", how="inner", ) if not merged.empty: plt.plot( range(len(merged)), merged["benchmark_value"], label=benchmark_name, linewidth=2, color='red', alpha=0.7, ) # 计算超额收益 merged["excess_return"] = merged["strategy_value"] - merged["benchmark_value"] final_excess = merged["excess_return"].iloc[-1] logger.info(f"策略相对于{benchmark_name}的超额收益: {final_excess*100:+.2f}%") plt.xlabel("交易日", fontsize=12) plt.ylabel("收益率(归一化)", fontsize=12) plt.title("策略收益曲线 vs 基准指数", fontsize=14, fontweight='bold') plt.grid(True, alpha=0.3) plt.legend(fontsize=11, loc='upper left') plt.tight_layout() plt.savefig(output_path, dpi=150) plt.close() logger.info(f"资金曲线图保存到 {output_path}")