Files
strategy_backtest/utils/plotter.py

100 lines
3.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""资金曲线绘图模块。
从资金曲线 CSV 生成 PNG 图表,保存到 results/plots 目录。
支持策略与基准指数的对比图。
"""
from __future__ import annotations
import os
from typing import Optional
import matplotlib.pyplot as plt
import pandas as pd
from utils.logger import setup_logger
logger = setup_logger(__name__)
def plot_equity_curve(
csv_path: str,
output_path: str,
benchmark_df: Optional[pd.DataFrame] = None,
benchmark_name: str = "基准指数",
) -> None:
"""根据资金曲线 CSV 绘制资金曲线图。
参数:
- csv_path: 资金曲线 CSV 路径,包含列 ['trade_date', 'total_asset']
- output_path: 输出 PNG 文件路径。
- benchmark_df: 基准指数数据(可选),包含 ['trade_date', 'benchmark_value']
- benchmark_name: 基准指数名称
"""
if not os.path.exists(csv_path):
logger.warning(f"资金曲线文件不存在: {csv_path}")
return
df = pd.read_csv(csv_path)
if df.empty:
logger.warning("资金曲线文件为空,无法绘图")
return
df = df.sort_values("trade_date")
# 归一化策略收益以初始资产为1
initial_asset = df["total_asset"].iloc[0]
df["strategy_value"] = df["total_asset"] / initial_asset
os.makedirs(os.path.dirname(output_path), exist_ok=True)
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'Arial Unicode MS']
plt.rcParams['axes.unicode_minus'] = False
plt.figure(figsize=(14, 7))
# 绘制策略收益曲线
plt.plot(range(len(df)), df["strategy_value"], label="策略收益", linewidth=2, color='blue')
# 如果有基准数据,绘制基准收益曲线
if benchmark_df is not None and not benchmark_df.empty:
# 确保 trade_date 类型一致(都转为字符串)
df["trade_date"] = df["trade_date"].astype(str)
benchmark_df_copy = benchmark_df.copy()
benchmark_df_copy["trade_date"] = benchmark_df_copy["trade_date"].astype(str)
# 合并数据(按交易日)
merged = pd.merge(
df[["trade_date", "strategy_value"]],
benchmark_df_copy[["trade_date", "benchmark_value"]],
on="trade_date",
how="inner",
)
if not merged.empty:
plt.plot(
range(len(merged)),
merged["benchmark_value"],
label=benchmark_name,
linewidth=2,
color='red',
alpha=0.7,
)
# 计算超额收益
merged["excess_return"] = merged["strategy_value"] - merged["benchmark_value"]
final_excess = merged["excess_return"].iloc[-1]
logger.info(f"策略相对于{benchmark_name}的超额收益: {final_excess*100:+.2f}%")
plt.xlabel("交易日", fontsize=12)
plt.ylabel("收益率(归一化)", fontsize=12)
plt.title("策略收益曲线 vs 基准指数", fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)
plt.legend(fontsize=11, loc='upper left')
plt.tight_layout()
plt.savefig(output_path, dpi=150)
plt.close()
logger.info(f"资金曲线图保存到 {output_path}")