Files
backtrader/data/data_fetcher.py
2026-01-17 21:21:30 +08:00

145 lines
4.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""
数据获取基类
定义所有数据源实现需要遵循的接口
"""
from abc import ABC, abstractmethod
import pandas as pd
class BaseDataFetcher(ABC):
"""
数据获取基类
所有数据源实现需要继承此类并实现抽象方法
"""
def __init__(self):
"""初始化数据获取器"""
self.data_cache = {}
@abstractmethod
def get_stock_data(self, stock_code, start_date, end_date):
"""
获取股票历史数据
Args:
stock_code: 股票代码
start_date: 开始日期 (格式: YYYYMMDD)
end_date: 结束日期 (格式: YYYYMMDD)
Returns:
pandas.DataFrame: 包含股票数据的DataFrame列包括
- date: 交易日期
- open: 开盘价
- high: 最高价
- low: 最低价
- close: 收盘价
- volume: 成交量
- amount: 成交额
"""
pass
@abstractmethod
def get_index_data(self, index_code, start_date, end_date):
"""
获取指数历史数据
Args:
index_code: 指数代码
start_date: 开始日期 (格式: YYYYMMDD)
end_date: 结束日期 (格式: YYYYMMDD)
Returns:
pandas.DataFrame: 包含指数数据的DataFrame列包括
- date: 交易日期
- open: 开盘价
- high: 最高价
- low: 最低价
- close: 收盘价
- volume: 成交量
- amount: 成交额
"""
pass
@abstractmethod
def get_stock_basic_info(self, stock_code):
"""
获取股票基本信息
Args:
stock_code: 股票代码
Returns:
dict: 股票基本信息
"""
pass
def _format_data(self, df):
"""
格式化数据为标准格式
Args:
df: 原始数据DataFrame
Returns:
pandas.DataFrame: 格式化后的数据
"""
# 转换日期列为日期类型
if 'date' in df.columns:
df['date'] = pd.to_datetime(df['date'])
elif 'trade_date' in df.columns:
# 尝试多种日期格式解析
if df['trade_date'].dtype == 'object':
# 首先尝试YYYYMMDD格式
try:
df['date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d')
except ValueError:
# 如果失败,尝试其他格式
df['date'] = pd.to_datetime(df['trade_date'])
else:
# 处理数值类型的trade_date先转为字符串避免被解析为时间戳
trade_date_str = df['trade_date'].astype(str)
try:
df['date'] = pd.to_datetime(trade_date_str, format='%Y%m%d')
except ValueError:
# 如果失败,尝试其他格式
df['date'] = pd.to_datetime(trade_date_str)
df = df.drop('trade_date', axis=1)
# 确保列名标准化
if 'vol' in df.columns:
df = df.rename(columns={'vol': 'volume'})
# 按日期排序
if 'date' in df.columns:
df = df.sort_values('date')
# 设置日期为索引
df = df.set_index('date')
return df
def _cache_data(self, key, data):
"""
缓存获取的数据
Args:
key: 缓存键
data: 要缓存的数据
"""
self.data_cache[key] = data
def _get_cached_data(self, key):
"""
获取缓存的数据
Args:
key: 缓存键
Returns:
缓存的数据如果不存在返回None
"""
return self.data_cache.get(key)