Files
strategy_backtest/benchmark/benchmark_loader.py

193 lines
6.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""基准指数加载与收益计算模块。
专门负责基准指数(如上证指数 000001.SH的行情读取和收益率计算。
与策略完全解耦,仅用于最后图表对比。
"""
from __future__ import annotations
from pathlib import Path
from typing import Optional, Tuple
import pandas as pd
from utils.logger import setup_logger
logger = setup_logger(__name__)
def load_benchmark(
benchmark_file: Path,
start_date: str,
end_date: str,
) -> pd.DataFrame:
"""加载基准指数行情数据。
参数:
benchmark_file: 基准指数数据文件路径
start_date: 开始日期YYYYMMDD
end_date: 结束日期YYYYMMDD
返回:
pd.DataFrame: 包含 ['trade_date', 'close'] 的数据框
"""
if not benchmark_file.exists():
logger.warning(f"基准指数文件不存在: {benchmark_file}")
return pd.DataFrame()
try:
# 读取基准指数数据
df = pd.read_csv(
benchmark_file,
sep="\t",
encoding="utf-8",
dtype=str,
)
# 检查必要列
if "trade_date" not in df.columns or "close" not in df.columns:
logger.error(f"基准指数文件缺少必要列: {benchmark_file}")
return pd.DataFrame()
# 转换数据类型
df["trade_date"] = df["trade_date"].astype(str)
df["close"] = pd.to_numeric(df["close"], errors="coerce")
# 过滤日期区间
if start_date:
df = df[df["trade_date"] >= start_date]
if end_date:
df = df[df["trade_date"] <= end_date]
# 排序并重置索引
df = df.sort_values("trade_date").reset_index(drop=True)
# 只保留需要的列
df = df[["trade_date", "close"]]
logger.info(f"基准指数数据加载成功: {len(df)} 条记录")
return df
except Exception as e:
logger.error(f"加载基准指数数据失败: {e}")
return pd.DataFrame()
def calc_benchmark_return(
benchmark_df: pd.DataFrame,
calendar: list = None,
) -> Tuple[pd.DataFrame, dict]:
"""计算基准指数的收益率曲线。
参数:
benchmark_df: 基准指数数据(包含 trade_date, close
calendar: 交易日历(可选,用于对齐)
返回:
(equity_curve, stats):
- equity_curve: 归一化收益曲线 DataFrame ['trade_date', 'benchmark_value']
- stats: 统计信息字典
"""
if benchmark_df.empty:
logger.warning("基准指数数据为空,无法计算收益")
return pd.DataFrame(), {}
df = benchmark_df.copy()
# 如果提供了交易日历,只保留日历中的日期
if calendar:
df = df[df["trade_date"].isin(calendar)]
if df.empty:
logger.warning("基准指数数据与交易日历无交集")
return pd.DataFrame(), {}
# 计算归一化收益以初始值为1
initial_close = df["close"].iloc[0]
df["benchmark_value"] = df["close"] / initial_close
# 计算累计收益率
final_value = df["benchmark_value"].iloc[-1]
cum_return = final_value - 1.0
# 计算年化收益率
n_days = len(df)
years = n_days / 252.0
if years > 0:
ann_return = (final_value ** (1.0 / years)) - 1.0
else:
ann_return = 0.0
# 计算最大回撤
cummax = df["benchmark_value"].cummax()
drawdown = df["benchmark_value"] / cummax - 1.0
max_drawdown = float(drawdown.min())
# 计算波动率(年化)
df["daily_return"] = df["close"].pct_change()
volatility = df["daily_return"].std() * (252 ** 0.5)
# 计算夏普比率假设无风险利率为0
avg_return = df["daily_return"].mean()
std_return = df["daily_return"].std()
if std_return > 0:
sharpe = (avg_return * 252) / (std_return * (252 ** 0.5))
else:
sharpe = 0.0
stats = {
"cum_return": cum_return,
"ann_return": ann_return,
"max_drawdown": max_drawdown,
"volatility": volatility,
"sharpe": sharpe,
"trading_days": n_days,
"years": years,
}
# 只返回需要的列
result_df = df[["trade_date", "benchmark_value"]].copy()
logger.info("基准指数收益计算完成:")
logger.info(f" 累计收益: {cum_return*100:+.2f}%")
logger.info(f" 年化收益: {ann_return*100:+.2f}%")
logger.info(f" 最大回撤: {max_drawdown*100:.2f}%")
logger.info(f" 夏普比率: {sharpe:.4f}")
return result_df, stats
def merge_with_strategy(
equity_df: pd.DataFrame,
benchmark_df: pd.DataFrame,
initial_cash: float = 1_000_000,
) -> pd.DataFrame:
"""将策略收益与基准收益合并到一个DataFrame。
参数:
equity_df: 策略资金曲线(包含 trade_date, total_asset
benchmark_df: 基准收益曲线(包含 trade_date, benchmark_value
initial_cash: 初始资金(用于归一化)
返回:
pd.DataFrame: 合并后的数据框 ['trade_date', 'strategy_value', 'benchmark_value']
"""
if equity_df.empty or benchmark_df.empty:
logger.warning("策略或基准数据为空,无法合并")
return pd.DataFrame()
# 归一化策略收益以初始资金为1
strategy_df = equity_df.copy()
strategy_df["strategy_value"] = strategy_df["total_asset"] / initial_cash
# 合并两个数据框
merged = pd.merge(
strategy_df[["trade_date", "strategy_value"]],
benchmark_df[["trade_date", "benchmark_value"]],
on="trade_date",
how="inner",
)
logger.info(f"策略与基准数据合并完成,共 {len(merged)} 个交易日")
return merged