100 lines
3.3 KiB
Python
100 lines
3.3 KiB
Python
"""资金曲线绘图模块。
|
||
|
||
从资金曲线 CSV 生成 PNG 图表,保存到 results/plots 目录。
|
||
支持策略与基准指数的对比图。
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import os
|
||
from typing import Optional
|
||
|
||
import matplotlib.pyplot as plt
|
||
import pandas as pd
|
||
|
||
from utils.logger import setup_logger
|
||
|
||
logger = setup_logger(__name__)
|
||
|
||
|
||
def plot_equity_curve(
|
||
csv_path: str,
|
||
output_path: str,
|
||
benchmark_df: Optional[pd.DataFrame] = None,
|
||
benchmark_name: str = "基准指数",
|
||
) -> None:
|
||
"""根据资金曲线 CSV 绘制资金曲线图。
|
||
|
||
参数:
|
||
- csv_path: 资金曲线 CSV 路径,包含列 ['trade_date', 'total_asset'];
|
||
- output_path: 输出 PNG 文件路径。
|
||
- benchmark_df: 基准指数数据(可选),包含 ['trade_date', 'benchmark_value']
|
||
- benchmark_name: 基准指数名称
|
||
"""
|
||
if not os.path.exists(csv_path):
|
||
logger.warning(f"资金曲线文件不存在: {csv_path}")
|
||
return
|
||
|
||
df = pd.read_csv(csv_path)
|
||
if df.empty:
|
||
logger.warning("资金曲线文件为空,无法绘图")
|
||
return
|
||
|
||
df = df.sort_values("trade_date")
|
||
|
||
# 归一化策略收益(以初始资产为1)
|
||
initial_asset = df["total_asset"].iloc[0]
|
||
df["strategy_value"] = df["total_asset"] / initial_asset
|
||
|
||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||
|
||
# 设置中文字体
|
||
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'Arial Unicode MS']
|
||
plt.rcParams['axes.unicode_minus'] = False
|
||
|
||
plt.figure(figsize=(14, 7))
|
||
|
||
# 绘制策略收益曲线
|
||
plt.plot(range(len(df)), df["strategy_value"], label="策略收益", linewidth=2, color='blue')
|
||
|
||
# 如果有基准数据,绘制基准收益曲线
|
||
if benchmark_df is not None and not benchmark_df.empty:
|
||
# 确保 trade_date 类型一致(都转为字符串)
|
||
df["trade_date"] = df["trade_date"].astype(str)
|
||
benchmark_df_copy = benchmark_df.copy()
|
||
benchmark_df_copy["trade_date"] = benchmark_df_copy["trade_date"].astype(str)
|
||
|
||
# 合并数据(按交易日)
|
||
merged = pd.merge(
|
||
df[["trade_date", "strategy_value"]],
|
||
benchmark_df_copy[["trade_date", "benchmark_value"]],
|
||
on="trade_date",
|
||
how="inner",
|
||
)
|
||
|
||
if not merged.empty:
|
||
plt.plot(
|
||
range(len(merged)),
|
||
merged["benchmark_value"],
|
||
label=benchmark_name,
|
||
linewidth=2,
|
||
color='red',
|
||
alpha=0.7,
|
||
)
|
||
|
||
# 计算超额收益
|
||
merged["excess_return"] = merged["strategy_value"] - merged["benchmark_value"]
|
||
final_excess = merged["excess_return"].iloc[-1]
|
||
|
||
logger.info(f"策略相对于{benchmark_name}的超额收益: {final_excess*100:+.2f}%")
|
||
|
||
plt.xlabel("交易日", fontsize=12)
|
||
plt.ylabel("收益率(归一化)", fontsize=12)
|
||
plt.title("策略收益曲线 vs 基准指数", fontsize=14, fontweight='bold')
|
||
plt.grid(True, alpha=0.3)
|
||
plt.legend(fontsize=11, loc='upper left')
|
||
plt.tight_layout()
|
||
|
||
plt.savefig(output_path, dpi=150)
|
||
plt.close()
|
||
logger.info(f"资金曲线图保存到 {output_path}")
|