83 lines
2.9 KiB
Python
83 lines
2.9 KiB
Python
"""行情数据加载模块。
|
||
|
||
唯一数据入口,负责从 data/day 目录加载 TXT 日线行情,并返回 pandas.DataFrame。
|
||
|
||
约束:
|
||
- 文件不存在或为空时,返回空 DataFrame,不抛异常,并输出 warning 日志;
|
||
- 解析失败时返回空 DataFrame,并输出 error 日志;
|
||
- 日期统一升序排列;
|
||
- 首行必须包含 trade_date/open/high/low/close/vol 列,否则视为格式错误。
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import os
|
||
from typing import Optional
|
||
|
||
import pandas as pd
|
||
|
||
from utils.logger import setup_logger
|
||
|
||
logger = setup_logger(__name__)
|
||
|
||
REQUIRED_COLUMNS = ["trade_date", "open", "high", "low", "close", "vol"]
|
||
|
||
|
||
def load_single_stock(
|
||
data_dir: str,
|
||
ts_code: str,
|
||
start_date: Optional[str] = None,
|
||
end_date: Optional[str] = None,
|
||
) -> pd.DataFrame:
|
||
"""加载单只股票日线数据。
|
||
|
||
参数:
|
||
- data_dir: 行情数据根目录,如 "data/day";
|
||
- ts_code: 股票代码(带后缀),如 "000001.SZ";
|
||
- start_date, end_date: 过滤区间,格式 YYYYMMDD,可为 None。
|
||
|
||
返回:
|
||
- 若文件存在且格式正确,返回包含必要列的 DataFrame,日期升序;
|
||
- 若文件不存在、为空或格式错误,返回空 DataFrame,不抛异常。
|
||
"""
|
||
filename = f"{ts_code}_daily_data.txt"
|
||
path = os.path.join(data_dir, filename)
|
||
|
||
if not os.path.exists(path) or os.path.getsize(path) == 0:
|
||
logger.warning(f"{filename} 不存在或为空,跳过")
|
||
return pd.DataFrame()
|
||
|
||
try:
|
||
# DAK 文件以制表符分隔,UTF-8 编码
|
||
df = pd.read_csv(path, sep="\t", encoding="utf-8")
|
||
if df.empty:
|
||
logger.warning(f"{filename} 不存在或为空,跳过")
|
||
return pd.DataFrame()
|
||
|
||
# 列名统一为小写
|
||
df.columns = [c.strip().lower() for c in df.columns]
|
||
|
||
# 检查必需列
|
||
for col in REQUIRED_COLUMNS:
|
||
if col not in df.columns:
|
||
logger.error(f"{ts_code} 数据缺少必要列: {col}")
|
||
return pd.DataFrame()
|
||
|
||
# 确保 trade_date 列为字符串类型(源数据可能是 int 或 str)
|
||
df["trade_date"] = df["trade_date"].astype(str)
|
||
|
||
# 按需过滤日期(统一为字符串比较,格式 YYYYMMDD)
|
||
if start_date is not None:
|
||
df = df[df["trade_date"] >= start_date]
|
||
if end_date is not None:
|
||
df = df[df["trade_date"] <= end_date]
|
||
|
||
# 日期升序
|
||
df = df.sort_values("trade_date").reset_index(drop=True)
|
||
|
||
# 只在 DEBUG 级别记录成功日志,避免 log 文件过大
|
||
logger.debug(f"加载 {ts_code} 成功,共 {len(df)} 条记录")
|
||
return df
|
||
except Exception as e: # noqa: BLE001
|
||
logger.error(f"{ts_code} 数据解析失败: {e}")
|
||
return pd.DataFrame()
|