Files
strategy_backtest/utils/data_loader.py

83 lines
2.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""行情数据加载模块。
唯一数据入口,负责从 data/day 目录加载 TXT 日线行情,并返回 pandas.DataFrame。
约束:
- 文件不存在或为空时,返回空 DataFrame不抛异常并输出 warning 日志;
- 解析失败时返回空 DataFrame并输出 error 日志;
- 日期统一升序排列;
- 首行必须包含 trade_date/open/high/low/close/vol 列,否则视为格式错误。
"""
from __future__ import annotations
import os
from typing import Optional
import pandas as pd
from utils.logger import setup_logger
logger = setup_logger(__name__)
REQUIRED_COLUMNS = ["trade_date", "open", "high", "low", "close", "vol"]
def load_single_stock(
data_dir: str,
ts_code: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
) -> pd.DataFrame:
"""加载单只股票日线数据。
参数:
- data_dir: 行情数据根目录,如 "data/day"
- ts_code: 股票代码(带后缀),如 "000001.SZ"
- start_date, end_date: 过滤区间,格式 YYYYMMDD可为 None。
返回:
- 若文件存在且格式正确,返回包含必要列的 DataFrame日期升序
- 若文件不存在、为空或格式错误,返回空 DataFrame不抛异常。
"""
filename = f"{ts_code}_daily_data.txt"
path = os.path.join(data_dir, filename)
if not os.path.exists(path) or os.path.getsize(path) == 0:
logger.warning(f"{filename} 不存在或为空,跳过")
return pd.DataFrame()
try:
# DAK 文件以制表符分隔UTF-8 编码
df = pd.read_csv(path, sep="\t", encoding="utf-8")
if df.empty:
logger.warning(f"{filename} 不存在或为空,跳过")
return pd.DataFrame()
# 列名统一为小写
df.columns = [c.strip().lower() for c in df.columns]
# 检查必需列
for col in REQUIRED_COLUMNS:
if col not in df.columns:
logger.error(f"{ts_code} 数据缺少必要列: {col}")
return pd.DataFrame()
# 确保 trade_date 列为字符串类型(源数据可能是 int 或 str
df["trade_date"] = df["trade_date"].astype(str)
# 按需过滤日期(统一为字符串比较,格式 YYYYMMDD
if start_date is not None:
df = df[df["trade_date"] >= start_date]
if end_date is not None:
df = df[df["trade_date"] <= end_date]
# 日期升序
df = df.sort_values("trade_date").reset_index(drop=True)
# 只在 DEBUG 级别记录成功日志,避免 log 文件过大
logger.debug(f"加载 {ts_code} 成功,共 {len(df)} 条记录")
return df
except Exception as e: # noqa: BLE001
logger.error(f"{ts_code} 数据解析失败: {e}")
return pd.DataFrame()