145 lines
4.0 KiB
Python
145 lines
4.0 KiB
Python
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
数据获取基类
|
||
定义所有数据源实现需要遵循的接口
|
||
"""
|
||
|
||
from abc import ABC, abstractmethod
|
||
import pandas as pd
|
||
|
||
|
||
class BaseDataFetcher(ABC):
|
||
"""
|
||
数据获取基类
|
||
所有数据源实现需要继承此类并实现抽象方法
|
||
"""
|
||
|
||
def __init__(self):
|
||
"""初始化数据获取器"""
|
||
self.data_cache = {}
|
||
|
||
@abstractmethod
|
||
def get_stock_data(self, stock_code, start_date, end_date):
|
||
"""
|
||
获取股票历史数据
|
||
|
||
Args:
|
||
stock_code: 股票代码
|
||
start_date: 开始日期 (格式: YYYYMMDD)
|
||
end_date: 结束日期 (格式: YYYYMMDD)
|
||
|
||
Returns:
|
||
pandas.DataFrame: 包含股票数据的DataFrame,列包括:
|
||
- date: 交易日期
|
||
- open: 开盘价
|
||
- high: 最高价
|
||
- low: 最低价
|
||
- close: 收盘价
|
||
- volume: 成交量
|
||
- amount: 成交额
|
||
"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def get_index_data(self, index_code, start_date, end_date):
|
||
"""
|
||
获取指数历史数据
|
||
|
||
Args:
|
||
index_code: 指数代码
|
||
start_date: 开始日期 (格式: YYYYMMDD)
|
||
end_date: 结束日期 (格式: YYYYMMDD)
|
||
|
||
Returns:
|
||
pandas.DataFrame: 包含指数数据的DataFrame,列包括:
|
||
- date: 交易日期
|
||
- open: 开盘价
|
||
- high: 最高价
|
||
- low: 最低价
|
||
- close: 收盘价
|
||
- volume: 成交量
|
||
- amount: 成交额
|
||
"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def get_stock_basic_info(self, stock_code):
|
||
"""
|
||
获取股票基本信息
|
||
|
||
Args:
|
||
stock_code: 股票代码
|
||
|
||
Returns:
|
||
dict: 股票基本信息
|
||
"""
|
||
pass
|
||
|
||
def _format_data(self, df):
|
||
"""
|
||
格式化数据为标准格式
|
||
|
||
Args:
|
||
df: 原始数据DataFrame
|
||
|
||
Returns:
|
||
pandas.DataFrame: 格式化后的数据
|
||
"""
|
||
# 转换日期列为日期类型
|
||
if 'date' in df.columns:
|
||
df['date'] = pd.to_datetime(df['date'])
|
||
elif 'trade_date' in df.columns:
|
||
# 尝试多种日期格式解析
|
||
if df['trade_date'].dtype == 'object':
|
||
# 首先尝试YYYYMMDD格式
|
||
try:
|
||
df['date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d')
|
||
except ValueError:
|
||
# 如果失败,尝试其他格式
|
||
df['date'] = pd.to_datetime(df['trade_date'])
|
||
else:
|
||
# 处理数值类型的trade_date,先转为字符串避免被解析为时间戳
|
||
trade_date_str = df['trade_date'].astype(str)
|
||
try:
|
||
df['date'] = pd.to_datetime(trade_date_str, format='%Y%m%d')
|
||
except ValueError:
|
||
# 如果失败,尝试其他格式
|
||
df['date'] = pd.to_datetime(trade_date_str)
|
||
df = df.drop('trade_date', axis=1)
|
||
|
||
# 确保列名标准化
|
||
if 'vol' in df.columns:
|
||
df = df.rename(columns={'vol': 'volume'})
|
||
|
||
# 按日期排序
|
||
if 'date' in df.columns:
|
||
df = df.sort_values('date')
|
||
|
||
# 设置日期为索引
|
||
df = df.set_index('date')
|
||
|
||
return df
|
||
|
||
def _cache_data(self, key, data):
|
||
"""
|
||
缓存获取的数据
|
||
|
||
Args:
|
||
key: 缓存键
|
||
data: 要缓存的数据
|
||
"""
|
||
self.data_cache[key] = data
|
||
|
||
def _get_cached_data(self, key):
|
||
"""
|
||
获取缓存的数据
|
||
|
||
Args:
|
||
key: 缓存键
|
||
|
||
Returns:
|
||
缓存的数据,如果不存在返回None
|
||
"""
|
||
return self.data_cache.get(key)
|