新建回测系统,并提交
This commit is contained in:
5
utils/__init__.py
Normal file
5
utils/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""工具包初始化文件。"""
|
||||
|
||||
from utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
BIN
utils/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
utils/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/data_loader.cpython-310.pyc
Normal file
BIN
utils/__pycache__/data_loader.cpython-310.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/logger.cpython-310.pyc
Normal file
BIN
utils/__pycache__/logger.cpython-310.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/performance.cpython-310.pyc
Normal file
BIN
utils/__pycache__/performance.cpython-310.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/plotter.cpython-310.pyc
Normal file
BIN
utils/__pycache__/plotter.cpython-310.pyc
Normal file
Binary file not shown.
82
utils/data_loader.py
Normal file
82
utils/data_loader.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""行情数据加载模块。
|
||||
|
||||
唯一数据入口,负责从 data/day 目录加载 TXT 日线行情,并返回 pandas.DataFrame。
|
||||
|
||||
约束:
|
||||
- 文件不存在或为空时,返回空 DataFrame,不抛异常,并输出 warning 日志;
|
||||
- 解析失败时返回空 DataFrame,并输出 error 日志;
|
||||
- 日期统一升序排列;
|
||||
- 首行必须包含 trade_date/open/high/low/close/vol 列,否则视为格式错误。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
REQUIRED_COLUMNS = ["trade_date", "open", "high", "low", "close", "vol"]
|
||||
|
||||
|
||||
def load_single_stock(
|
||||
data_dir: str,
|
||||
ts_code: str,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""加载单只股票日线数据。
|
||||
|
||||
参数:
|
||||
- data_dir: 行情数据根目录,如 "data/day";
|
||||
- ts_code: 股票代码(带后缀),如 "000001.SZ";
|
||||
- start_date, end_date: 过滤区间,格式 YYYYMMDD,可为 None。
|
||||
|
||||
返回:
|
||||
- 若文件存在且格式正确,返回包含必要列的 DataFrame,日期升序;
|
||||
- 若文件不存在、为空或格式错误,返回空 DataFrame,不抛异常。
|
||||
"""
|
||||
filename = f"{ts_code}_daily_data.txt"
|
||||
path = os.path.join(data_dir, filename)
|
||||
|
||||
if not os.path.exists(path) or os.path.getsize(path) == 0:
|
||||
logger.warning(f"{filename} 不存在或为空,跳过")
|
||||
return pd.DataFrame()
|
||||
|
||||
try:
|
||||
# DAK 文件以制表符分隔,UTF-8 编码
|
||||
df = pd.read_csv(path, sep="\t", encoding="utf-8")
|
||||
if df.empty:
|
||||
logger.warning(f"{filename} 不存在或为空,跳过")
|
||||
return pd.DataFrame()
|
||||
|
||||
# 列名统一为小写
|
||||
df.columns = [c.strip().lower() for c in df.columns]
|
||||
|
||||
# 检查必需列
|
||||
for col in REQUIRED_COLUMNS:
|
||||
if col not in df.columns:
|
||||
logger.error(f"{ts_code} 数据缺少必要列: {col}")
|
||||
return pd.DataFrame()
|
||||
|
||||
# 确保 trade_date 列为字符串类型(源数据可能是 int 或 str)
|
||||
df["trade_date"] = df["trade_date"].astype(str)
|
||||
|
||||
# 按需过滤日期(统一为字符串比较,格式 YYYYMMDD)
|
||||
if start_date is not None:
|
||||
df = df[df["trade_date"] >= start_date]
|
||||
if end_date is not None:
|
||||
df = df[df["trade_date"] <= end_date]
|
||||
|
||||
# 日期升序
|
||||
df = df.sort_values("trade_date").reset_index(drop=True)
|
||||
|
||||
# 只在 DEBUG 级别记录成功日志,避免 log 文件过大
|
||||
logger.debug(f"加载 {ts_code} 成功,共 {len(df)} 条记录")
|
||||
return df
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f"{ts_code} 数据解析失败: {e}")
|
||||
return pd.DataFrame()
|
||||
63
utils/logger.py
Normal file
63
utils/logger.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""统一日志模块。
|
||||
|
||||
提供 setup_logger 函数,所有模块共用统一格式和输出目标:
|
||||
- 格式:2025-12-20 18:05:30 [INFO] data_loader.py:42 - 消息内容
|
||||
- 输出:同时打印到 stdout 和写入 results/logs/app.log
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
_LOGGER_CACHE: dict[str, logging.Logger] = {}
|
||||
|
||||
|
||||
class DataLoaderFilter(logging.Filter):
|
||||
"""过滤 data_loader 的 WARNING 和 ERROR 日志,不在控制台显示。"""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
# 如果是 data_loader 模块的警告或错误,不在控制台显示
|
||||
if "data_loader" in record.filename and record.levelno >= logging.WARNING:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def setup_logger(name: str) -> logging.Logger:
|
||||
"""创建或获取指定名称的 Logger 实例。
|
||||
|
||||
所有 logger:
|
||||
- 级别:INFO
|
||||
- 格式:时间 + 级别 + 文件名 + 行号 + 消息
|
||||
- 输出:stdout + results/logs/app.log
|
||||
"""
|
||||
if name in _LOGGER_CACHE:
|
||||
return _LOGGER_CACHE[name]
|
||||
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.propagate = False
|
||||
|
||||
if not logger.handlers:
|
||||
# 日志目录位于 strategy_backtest/results/logs
|
||||
log_dir = os.path.join("results", "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_path = os.path.join(log_dir, "app.log")
|
||||
|
||||
fmt = "%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(message)s"
|
||||
datefmt = "%Y-%m-%d %H:%M:%S"
|
||||
formatter = logging.Formatter(fmt=fmt, datefmt=datefmt)
|
||||
|
||||
# 控制台输出(过滤 data_loader 的警告和错误)
|
||||
ch = logging.StreamHandler()
|
||||
ch.setLevel(logging.INFO)
|
||||
ch.setFormatter(formatter)
|
||||
ch.addFilter(DataLoaderFilter()) # 添加过滤器
|
||||
logger.addHandler(ch)
|
||||
|
||||
# 文件输出(带滚动)
|
||||
fh = RotatingFileHandler(log_path, maxBytes=10 * 1024 * 1024, backupCount=5, encoding="utf-8")
|
||||
fh.setLevel(logging.INFO)
|
||||
fh.setFormatter(formatter)
|
||||
logger.addHandler(fh)
|
||||
|
||||
_LOGGER_CACHE[name] = logger
|
||||
return logger
|
||||
145
utils/performance.py
Normal file
145
utils/performance.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""绩效计算模块。
|
||||
|
||||
根据资金曲线计算收益、年化收益、夏普比率、最大回撤等指标。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
def calc_performance(
|
||||
equity_df: pd.DataFrame,
|
||||
trade_count: int = 0,
|
||||
trade_history: list = None,
|
||||
trading_days_per_year: int = 252,
|
||||
) -> dict:
|
||||
"""根据资金曲线计算常用绩效指标。
|
||||
|
||||
参数:
|
||||
- equity_df: 包含列 ['trade_date', 'total_asset', 'cash', 'market_value'] 的 DataFrame;
|
||||
- trade_count: 总交易次数(买入+卖出);
|
||||
- trade_history: 交易历史记录(用于计算胜率和盈亏比);
|
||||
- trading_days_per_year: 年化使用的交易日数,默认 252。
|
||||
|
||||
返回:
|
||||
- dict,包含累积收益、年化收益、夏普比率、最大回撤、资金利用率、胜率、盈亏比等。
|
||||
"""
|
||||
if equity_df.empty:
|
||||
logger.warning("资金曲线为空,无法计算绩效")
|
||||
return {}
|
||||
|
||||
df = equity_df.copy()
|
||||
df = df.sort_values("trade_date").reset_index(drop=True)
|
||||
df["ret"] = df["total_asset"].pct_change().fillna(0.0)
|
||||
|
||||
# 累积收益
|
||||
cum_return = df["total_asset"].iloc[-1] / df["total_asset"].iloc[0] - 1
|
||||
|
||||
# 年化收益
|
||||
n = len(df)
|
||||
if n <= 1:
|
||||
ann_return = 0.0
|
||||
years = 0.0
|
||||
else:
|
||||
ann_return = (1 + cum_return) ** (trading_days_per_year / n) - 1
|
||||
years = n / trading_days_per_year
|
||||
|
||||
# 夏普比率(假设无无风险利率)
|
||||
ret_mean = df["ret"].mean()
|
||||
ret_std = df["ret"].std(ddof=1)
|
||||
if ret_std == 0:
|
||||
sharpe = 0.0
|
||||
else:
|
||||
sharpe = (ret_mean * trading_days_per_year) / (ret_std * (trading_days_per_year**0.5))
|
||||
|
||||
# 最大回撤
|
||||
cummax = df["total_asset"].cummax()
|
||||
drawdown = df["total_asset"] / cummax - 1
|
||||
max_drawdown = float(drawdown.min())
|
||||
|
||||
# 资金利用率统计(每日持仓市值 / 总资产)
|
||||
if "market_value" in df.columns and "total_asset" in df.columns:
|
||||
df["capital_utilization"] = df["market_value"] / df["total_asset"]
|
||||
avg_capital_utilization = df["capital_utilization"].mean()
|
||||
else:
|
||||
avg_capital_utilization = 0.0
|
||||
|
||||
# 交易次数统计
|
||||
total_trades = trade_count
|
||||
if years > 0:
|
||||
avg_trades_per_year = total_trades / years
|
||||
else:
|
||||
avg_trades_per_year = 0.0
|
||||
|
||||
# 计算胜率和盈亏比(从交易历史中获取)
|
||||
win_rate = 0.0
|
||||
profit_loss_ratio = 0.0
|
||||
win_count = 0
|
||||
loss_count = 0
|
||||
total_win_pct = 0.0
|
||||
total_loss_pct = 0.0
|
||||
|
||||
if trade_history and len(trade_history) > 0:
|
||||
for trade in trade_history:
|
||||
if trade.get("is_win", False):
|
||||
win_count += 1
|
||||
total_win_pct += trade.get("profit_pct", 0.0)
|
||||
else:
|
||||
loss_count += 1
|
||||
total_loss_pct += abs(trade.get("profit_pct", 0.0))
|
||||
|
||||
total_complete_trades = win_count + loss_count
|
||||
if total_complete_trades > 0:
|
||||
win_rate = win_count / total_complete_trades
|
||||
|
||||
# 计算平均盈亏比:平均盈利 / 平均亏损
|
||||
avg_win = total_win_pct / win_count if win_count > 0 else 0.0
|
||||
avg_loss = total_loss_pct / loss_count if loss_count > 0 else 0.0
|
||||
|
||||
if avg_loss > 0:
|
||||
profit_loss_ratio = avg_win / avg_loss
|
||||
else:
|
||||
profit_loss_ratio = 0.0 if avg_win == 0 else float('inf')
|
||||
|
||||
res = {
|
||||
"cum_return": float(cum_return),
|
||||
"ann_return": float(ann_return),
|
||||
"sharpe": float(sharpe),
|
||||
"max_drawdown": max_drawdown,
|
||||
"avg_capital_utilization": float(avg_capital_utilization),
|
||||
"total_trades": int(total_trades),
|
||||
"avg_trades_per_year": float(avg_trades_per_year),
|
||||
"backtest_years": float(years),
|
||||
"win_rate": float(win_rate),
|
||||
"profit_loss_ratio": float(profit_loss_ratio),
|
||||
"win_count": int(win_count),
|
||||
"loss_count": int(loss_count),
|
||||
}
|
||||
|
||||
# 格式化输出绩效指标(中文、分行、百分比)
|
||||
logger.info("=" * 60)
|
||||
logger.info("回测绩效指标汇总")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"回测年数: {years:.2f} 年")
|
||||
logger.info(f"总交易次数: {total_trades} 次")
|
||||
logger.info(f"年平均交易次数: {avg_trades_per_year:.2f} 次/年")
|
||||
logger.info("-" * 60)
|
||||
logger.info(f"累计收益率: {cum_return * 100:+.2f}%")
|
||||
logger.info(f"年化收益率: {ann_return * 100:+.2f}%")
|
||||
logger.info(f"夏普比率: {sharpe:.4f}")
|
||||
logger.info(f"最大回撤: {max_drawdown * 100:.2f}%")
|
||||
logger.info(f"平均资金利用率: {avg_capital_utilization * 100:.2f}%")
|
||||
logger.info("-" * 60)
|
||||
logger.info(f"胜率: {win_rate * 100:.2f}% ({win_count}胜 / {loss_count}败)")
|
||||
if profit_loss_ratio == float('inf'):
|
||||
logger.info(f"平均盈亏比: ∞ (无亏损交易)")
|
||||
else:
|
||||
logger.info(f"平均盈亏比: {profit_loss_ratio:.2f}")
|
||||
logger.info("=" * 60)
|
||||
|
||||
return res
|
||||
99
utils/plotter.py
Normal file
99
utils/plotter.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""资金曲线绘图模块。
|
||||
|
||||
从资金曲线 CSV 生成 PNG 图表,保存到 results/plots 目录。
|
||||
支持策略与基准指数的对比图。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
|
||||
from utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
def plot_equity_curve(
|
||||
csv_path: str,
|
||||
output_path: str,
|
||||
benchmark_df: Optional[pd.DataFrame] = None,
|
||||
benchmark_name: str = "基准指数",
|
||||
) -> None:
|
||||
"""根据资金曲线 CSV 绘制资金曲线图。
|
||||
|
||||
参数:
|
||||
- csv_path: 资金曲线 CSV 路径,包含列 ['trade_date', 'total_asset'];
|
||||
- output_path: 输出 PNG 文件路径。
|
||||
- benchmark_df: 基准指数数据(可选),包含 ['trade_date', 'benchmark_value']
|
||||
- benchmark_name: 基准指数名称
|
||||
"""
|
||||
if not os.path.exists(csv_path):
|
||||
logger.warning(f"资金曲线文件不存在: {csv_path}")
|
||||
return
|
||||
|
||||
df = pd.read_csv(csv_path)
|
||||
if df.empty:
|
||||
logger.warning("资金曲线文件为空,无法绘图")
|
||||
return
|
||||
|
||||
df = df.sort_values("trade_date")
|
||||
|
||||
# 归一化策略收益(以初始资产为1)
|
||||
initial_asset = df["total_asset"].iloc[0]
|
||||
df["strategy_value"] = df["total_asset"] / initial_asset
|
||||
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
|
||||
# 设置中文字体
|
||||
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'Arial Unicode MS']
|
||||
plt.rcParams['axes.unicode_minus'] = False
|
||||
|
||||
plt.figure(figsize=(14, 7))
|
||||
|
||||
# 绘制策略收益曲线
|
||||
plt.plot(range(len(df)), df["strategy_value"], label="策略收益", linewidth=2, color='blue')
|
||||
|
||||
# 如果有基准数据,绘制基准收益曲线
|
||||
if benchmark_df is not None and not benchmark_df.empty:
|
||||
# 确保 trade_date 类型一致(都转为字符串)
|
||||
df["trade_date"] = df["trade_date"].astype(str)
|
||||
benchmark_df_copy = benchmark_df.copy()
|
||||
benchmark_df_copy["trade_date"] = benchmark_df_copy["trade_date"].astype(str)
|
||||
|
||||
# 合并数据(按交易日)
|
||||
merged = pd.merge(
|
||||
df[["trade_date", "strategy_value"]],
|
||||
benchmark_df_copy[["trade_date", "benchmark_value"]],
|
||||
on="trade_date",
|
||||
how="inner",
|
||||
)
|
||||
|
||||
if not merged.empty:
|
||||
plt.plot(
|
||||
range(len(merged)),
|
||||
merged["benchmark_value"],
|
||||
label=benchmark_name,
|
||||
linewidth=2,
|
||||
color='red',
|
||||
alpha=0.7,
|
||||
)
|
||||
|
||||
# 计算超额收益
|
||||
merged["excess_return"] = merged["strategy_value"] - merged["benchmark_value"]
|
||||
final_excess = merged["excess_return"].iloc[-1]
|
||||
|
||||
logger.info(f"策略相对于{benchmark_name}的超额收益: {final_excess*100:+.2f}%")
|
||||
|
||||
plt.xlabel("交易日", fontsize=12)
|
||||
plt.ylabel("收益率(归一化)", fontsize=12)
|
||||
plt.title("策略收益曲线 vs 基准指数", fontsize=14, fontweight='bold')
|
||||
plt.grid(True, alpha=0.3)
|
||||
plt.legend(fontsize=11, loc='upper left')
|
||||
plt.tight_layout()
|
||||
|
||||
plt.savefig(output_path, dpi=150)
|
||||
plt.close()
|
||||
logger.info(f"资金曲线图保存到 {output_path}")
|
||||
Reference in New Issue
Block a user