420 lines
16 KiB
Python
420 lines
16 KiB
Python
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
本地数据源实现
|
||
从本地文件系统读取股票数据
|
||
"""
|
||
|
||
import os
|
||
import pandas as pd
|
||
import logging
|
||
from datetime import datetime
|
||
|
||
from .data_fetcher import BaseDataFetcher
|
||
|
||
|
||
class LocalDataFetcher(BaseDataFetcher):
|
||
"""
|
||
本地数据源实现
|
||
从本地CSV或Excel文件读取股票历史数据
|
||
"""
|
||
|
||
def __init__(self, data_dir='./data'):
|
||
"""
|
||
初始化本地数据获取器
|
||
|
||
Args:
|
||
data_dir: 数据文件存储目录
|
||
"""
|
||
super().__init__()
|
||
self.data_dir = data_dir
|
||
self.logger = logging.getLogger('local_data')
|
||
|
||
# 确保数据目录存在
|
||
if not os.path.exists(self.data_dir):
|
||
os.makedirs(self.data_dir)
|
||
self.logger.info(f'创建数据目录: {self.data_dir}')
|
||
|
||
self.logger.info(f'本地数据获取器初始化完成,数据目录: {self.data_dir}')
|
||
|
||
def get_stock_data(self, stock_code, start_date, end_date):
|
||
"""
|
||
从本地文件读取股票历史数据
|
||
|
||
Args:
|
||
stock_code: 股票代码
|
||
start_date: 开始日期 (格式: YYYYMMDD)
|
||
end_date: 结束日期 (格式: YYYYMMDD)
|
||
|
||
Returns:
|
||
pandas.DataFrame: 股票数据
|
||
"""
|
||
# 生成缓存键
|
||
cache_key = f"stock_{stock_code}_{start_date}_{end_date}"
|
||
|
||
# 检查缓存
|
||
cached_data = self._get_cached_data(cache_key)
|
||
if cached_data is not None:
|
||
self.logger.info(f'从缓存获取 {stock_code} 数据')
|
||
return cached_data
|
||
|
||
try:
|
||
# 构建文件路径
|
||
file_path = self._get_file_path(stock_code)
|
||
|
||
if not os.path.exists(file_path):
|
||
self.logger.warning(f'未找到 {stock_code} 的本地数据文件: {file_path}')
|
||
# 如果没有本地文件,可以返回一个示例数据用于测试
|
||
return self._get_sample_data(stock_code, start_date, end_date)
|
||
|
||
# 读取数据文件
|
||
if file_path.endswith('.csv'):
|
||
df = pd.read_csv(file_path)
|
||
elif file_path.endswith('.xlsx') or file_path.endswith('.xls'):
|
||
df = pd.read_excel(file_path)
|
||
elif file_path.endswith('.txt'):
|
||
# 读取制表符分隔的txt文件
|
||
df = pd.read_csv(
|
||
file_path,
|
||
delimiter='\t',
|
||
encoding='utf-8',
|
||
skiprows=1 # 跳过表头
|
||
)
|
||
# 设置列名(与文件实际格式匹配)
|
||
df.columns = ['ts_code', 'trade_date', 'open', 'high', 'low', 'close', 'pre_close', 'change', 'pct_chg', 'vol', 'amount']
|
||
# 重命名vol列为volume,与其他数据源保持一致
|
||
df = df.rename(columns={'vol': 'volume'})
|
||
else:
|
||
raise ValueError(f'不支持的文件格式: {file_path}')
|
||
|
||
# 格式化数据
|
||
df = self._format_data(df)
|
||
|
||
# 过滤日期范围
|
||
start_dt = pd.to_datetime(start_date, format='%Y%m%d')
|
||
end_dt = pd.to_datetime(end_date, format='%Y%m%d')
|
||
df = df[(df.index >= start_dt) & (df.index <= end_dt)]
|
||
|
||
# 计算技术指标
|
||
df = self._calculate_technical_indicators(df)
|
||
|
||
# 缓存数据
|
||
self._cache_data(cache_key, df)
|
||
|
||
self.logger.info(f'成功读取 {stock_code} 数据,{start_date} 至 {end_date},共 {len(df)} 条记录')
|
||
return df
|
||
|
||
except Exception as e:
|
||
self.logger.error(f'获取 {stock_code} 数据失败: {str(e)}')
|
||
return self._get_sample_data(stock_code, start_date, end_date)
|
||
|
||
def get_index_data(self, index_code, start_date, end_date):
|
||
"""
|
||
从本地文件读取指数历史数据
|
||
|
||
Args:
|
||
index_code: 指数代码
|
||
start_date: 开始日期 (格式: YYYYMMDD)
|
||
end_date: 结束日期 (格式: YYYYMMDD)
|
||
|
||
Returns:
|
||
pandas.DataFrame: 指数数据
|
||
"""
|
||
# 生成缓存键
|
||
cache_key = f"index_{index_code}_{start_date}_{end_date}"
|
||
|
||
# 检查缓存
|
||
cached_data = self._get_cached_data(cache_key)
|
||
if cached_data is not None:
|
||
self.logger.info(f'从缓存获取 {index_code} 数据')
|
||
return cached_data
|
||
|
||
try:
|
||
# 构建文件路径 - 指数文件放在index子目录
|
||
file_path = self._get_file_path(index_code, is_index=True)
|
||
|
||
if not os.path.exists(file_path):
|
||
self.logger.warning(f'未找到 {index_code} 的本地数据文件: {file_path}')
|
||
# 返回示例数据
|
||
return self._get_sample_data(index_code, start_date, end_date, is_index=True)
|
||
|
||
self.logger.info(f'从本地文件读取 {index_code} 数据: {file_path}')
|
||
|
||
# 读取数据文件
|
||
if file_path.endswith('.csv'):
|
||
df = pd.read_csv(file_path)
|
||
elif file_path.endswith('.xlsx') or file_path.endswith('.xls'):
|
||
df = pd.read_excel(file_path)
|
||
elif file_path.endswith('.txt'):
|
||
# 读取制表符分隔的txt文件
|
||
df = pd.read_csv(
|
||
file_path,
|
||
delimiter='\t',
|
||
encoding='utf-8',
|
||
skiprows=1 # 跳过表头
|
||
)
|
||
# 设置列名(与文件实际格式匹配)
|
||
df.columns = ['ts_code', 'trade_date', 'open', 'high', 'low', 'close', 'pre_close', 'change', 'pct_chg', 'vol', 'amount']
|
||
# 重命名vol列为volume,与其他数据源保持一致
|
||
df = df.rename(columns={'vol': 'volume'})
|
||
else:
|
||
raise ValueError(f'不支持的文件格式: {file_path}')
|
||
|
||
# 格式化数据
|
||
df = self._format_data(df)
|
||
|
||
# 过滤日期范围
|
||
start_dt = pd.to_datetime(start_date)
|
||
end_dt = pd.to_datetime(end_date)
|
||
df = df[(df.index >= start_dt) & (df.index <= end_dt)]
|
||
|
||
# 缓存数据
|
||
self._cache_data(cache_key, df)
|
||
|
||
return df
|
||
|
||
except Exception as e:
|
||
self.logger.error(f'读取 {index_code} 数据时出错: {str(e)}')
|
||
# 返回示例数据
|
||
return self._get_sample_data(index_code, start_date, end_date, is_index=True)
|
||
|
||
def get_stock_basic_info(self, stock_code):
|
||
"""
|
||
获取股票基本信息
|
||
|
||
Args:
|
||
stock_code: 股票代码
|
||
|
||
Returns:
|
||
dict: 股票基本信息
|
||
"""
|
||
# 返回基本的示例信息
|
||
return {
|
||
'ts_code': stock_code,
|
||
'symbol': stock_code.split('.')[0],
|
||
'name': f'股票{stock_code.split(".")[0]}',
|
||
'industry': '未知',
|
||
'market': stock_code.split('.')[1] if '.' in stock_code else 'SH',
|
||
'list_date': '20100101'
|
||
}
|
||
|
||
def get_stock_basic(self):
|
||
"""
|
||
获取全市场股票基本信息列表
|
||
|
||
Returns:
|
||
pandas.DataFrame: 全市场股票基本信息
|
||
"""
|
||
# 由于是本地数据源,我们可以从本地目录中读取所有股票文件,
|
||
# 然后生成股票基本信息列表
|
||
try:
|
||
# 获取目录中的所有数据文件
|
||
files = os.listdir(self.data_dir)
|
||
stock_files = [f for f in files if f.endswith('_daily_data.txt')]
|
||
|
||
# 提取股票代码
|
||
stock_codes = []
|
||
for file in stock_files:
|
||
# 文件名格式:000002.SZ_daily_data.txt
|
||
ts_code = file.replace('_daily_data.txt', '')
|
||
stock_codes.append(ts_code)
|
||
|
||
# 如果没有找到股票文件,返回示例数据
|
||
if not stock_codes:
|
||
self.logger.warning('未找到本地股票数据文件,返回示例股票列表')
|
||
return self._get_sample_stock_basic()
|
||
|
||
# 生成股票基本信息DataFrame
|
||
stock_info_list = []
|
||
for ts_code in stock_codes:
|
||
stock_info = self.get_stock_basic_info(ts_code)
|
||
stock_info_list.append(stock_info)
|
||
|
||
df = pd.DataFrame(stock_info_list)
|
||
self.logger.info(f'从本地目录读取到 {len(df)} 只股票的基本信息')
|
||
return df
|
||
|
||
except Exception as e:
|
||
self.logger.error(f'获取全市场股票信息时出错: {str(e)}')
|
||
# 返回示例数据
|
||
return self._get_sample_stock_basic()
|
||
|
||
def _get_sample_stock_basic(self):
|
||
"""
|
||
生成示例股票基本信息(用于测试)
|
||
|
||
Returns:
|
||
pandas.DataFrame: 示例股票基本信息
|
||
"""
|
||
sample_stocks = [
|
||
{'ts_code': '000001.SH', 'symbol': '000001', 'name': '平安银行', 'industry': '银行', 'market': 'SH', 'list_date': '19910403'},
|
||
{'ts_code': '000002.SZ', 'symbol': '000002', 'name': '万科A', 'industry': '房地产', 'market': 'SZ', 'list_date': '19910129'},
|
||
{'ts_code': '000004.SZ', 'symbol': '000004', 'name': '国华网安', 'industry': '计算机', 'market': 'SZ', 'list_date': '19910114'},
|
||
{'ts_code': '000005.SZ', 'symbol': '000005', 'name': '世纪星源', 'industry': '房地产', 'market': 'SZ', 'list_date': '19901210'},
|
||
{'ts_code': '000006.SZ', 'symbol': '000006', 'name': '深振业A', 'industry': '房地产', 'market': 'SZ', 'list_date': '19920427'},
|
||
{'ts_code': '600000.SH', 'symbol': '600000', 'name': '浦发银行', 'industry': '银行', 'market': 'SH', 'list_date': '19991110'},
|
||
{'ts_code': '600004.SH', 'symbol': '600004', 'name': '白云机场', 'industry': '交通运输', 'market': 'SH', 'list_date': '20030428'},
|
||
{'ts_code': '600005.SH', 'symbol': '600005', 'name': '武钢股份', 'industry': '钢铁', 'market': 'SH', 'list_date': '19990803'},
|
||
{'ts_code': '600006.SH', 'symbol': '600006', 'name': '东风汽车', 'industry': '汽车', 'market': 'SH', 'list_date': '19990727'},
|
||
{'ts_code': '600007.SH', 'symbol': '600007', 'name': '中国国贸', 'industry': '房地产', 'market': 'SH', 'list_date': '19990312'}
|
||
]
|
||
|
||
df = pd.DataFrame(sample_stocks)
|
||
return df
|
||
|
||
def get_daily_data(self, ts_code, start_date, end_date, adjust=None):
|
||
"""
|
||
获取日线数据(与TushareDataFetcher兼容)
|
||
|
||
Args:
|
||
ts_code: 股票代码
|
||
start_date: 开始日期 (格式: YYYYMMDD)
|
||
end_date: 结束日期 (格式: YYYYMMDD)
|
||
adjust: 复权类型(已废弃,默认使用除权数据)
|
||
|
||
Returns:
|
||
pandas.DataFrame: 股票数据
|
||
"""
|
||
return self.get_stock_data(ts_code, start_date, end_date)
|
||
|
||
def _calculate_technical_indicators(self, df):
|
||
"""
|
||
计算技术指标(与TushareDataFetcher兼容)
|
||
|
||
Args:
|
||
df: 原始数据DataFrame
|
||
|
||
Returns:
|
||
pandas.DataFrame: 包含技术指标的数据
|
||
"""
|
||
if df.empty:
|
||
return df
|
||
|
||
# 计算均线
|
||
df['ma5'] = df['close'].rolling(5).mean()
|
||
df['ma10'] = df['close'].rolling(10).mean()
|
||
df['ma20'] = df['close'].rolling(20).mean()
|
||
df['ma60'] = df['close'].rolling(60).mean()
|
||
|
||
# 计算成交量均线
|
||
df['volume_ma5'] = df['volume'].rolling(5).mean()
|
||
df['volume_ratio'] = df['volume'] / df['volume_ma5']
|
||
|
||
# 计算涨跌幅
|
||
df['pct_change'] = df['close'].pct_change() * 100
|
||
|
||
# 计算K线实体比例
|
||
df['entity_ratio'] = (df['close'] - df['open']).abs() / (df['high'] - df['low'])
|
||
|
||
return df
|
||
|
||
def _save_to_cache(self, ts_code, df):
|
||
"""
|
||
保存数据到本地缓存(与TushareDataFetcher兼容)
|
||
|
||
Args:
|
||
ts_code: 股票代码
|
||
df: 数据DataFrame
|
||
"""
|
||
# 本地数据已经是缓存,不需要再保存
|
||
pass
|
||
|
||
def _get_file_path(self, code, is_index=False):
|
||
"""
|
||
获取数据文件路径
|
||
|
||
Args:
|
||
code: 股票或指数代码
|
||
is_index: 是否为指数
|
||
|
||
Returns:
|
||
str: 文件路径
|
||
"""
|
||
# 根据是否为指数选择不同的存储路径和文件命名格式
|
||
if is_index:
|
||
# 指数数据存储在D:\gp_data\index目录
|
||
data_dir = "D:\gp_data\index"
|
||
# 用户的指数数据文件格式为:000001.SH.txt
|
||
file_name = f'{code}.txt'
|
||
else:
|
||
# 股票数据存储在配置的DATA_DIR目录
|
||
data_dir = self.data_dir
|
||
# 用户的股票数据文件格式为:000002.SZ_daily_data.txt
|
||
file_name = f'{code}_daily_data.txt'
|
||
|
||
return os.path.join(data_dir, file_name)
|
||
|
||
def _get_sample_data(self, code, start_date, end_date, is_index=False):
|
||
"""
|
||
生成示例数据(用于测试)
|
||
|
||
Args:
|
||
code: 股票或指数代码
|
||
start_date: 开始日期
|
||
end_date: 结束日期
|
||
is_index: 是否为指数
|
||
|
||
Returns:
|
||
pandas.DataFrame: 示例数据
|
||
"""
|
||
self.logger.info(f'生成 {code} 的示例数据')
|
||
|
||
# 生成日期范围
|
||
start_dt = pd.to_datetime(start_date)
|
||
end_dt = pd.to_datetime(end_date)
|
||
date_range = pd.date_range(start=start_dt, end=end_dt)
|
||
|
||
# 生成随机价格数据
|
||
import numpy as np
|
||
np.random.seed(42) # 设置随机种子以保证可重复性
|
||
|
||
# 初始价格
|
||
base_price = 1000 if is_index else 10
|
||
|
||
# 生成价格序列
|
||
returns = np.random.normal(0, 0.02, len(date_range))
|
||
prices = base_price * np.exp(np.cumsum(returns))
|
||
|
||
# 生成开盘价、最高价、最低价
|
||
open_prices = prices * np.random.uniform(0.99, 1.01, len(date_range))
|
||
high_prices = np.maximum(prices, open_prices) * np.random.uniform(1.0, 1.02, len(date_range))
|
||
low_prices = np.minimum(prices, open_prices) * np.random.uniform(0.98, 1.0, len(date_range))
|
||
|
||
# 生成成交量
|
||
volumes = np.random.randint(1000000, 10000000, len(date_range))
|
||
amounts = volumes * prices
|
||
|
||
# 创建DataFrame
|
||
df = pd.DataFrame({
|
||
'date': date_range,
|
||
'open': open_prices,
|
||
'high': high_prices,
|
||
'low': low_prices,
|
||
'close': prices,
|
||
'volume': volumes,
|
||
'amount': amounts
|
||
})
|
||
|
||
# 格式化数据
|
||
df = self._format_data(df)
|
||
|
||
return df
|
||
|
||
def save_data_to_local(self, code, df, is_index=False):
|
||
"""
|
||
将数据保存到本地文件
|
||
|
||
Args:
|
||
code: 股票或指数代码
|
||
df: 数据DataFrame
|
||
is_index: 是否为指数
|
||
"""
|
||
try:
|
||
file_path = self._get_file_path(code, is_index)
|
||
df.to_csv(file_path)
|
||
self.logger.info(f'数据已保存到: {file_path}')
|
||
return True
|
||
except Exception as e:
|
||
self.logger.error(f'保存数据到本地时出错: {str(e)}')
|
||
return False
|