Files
backtrader/data/local_data.py
2026-01-17 21:21:30 +08:00

420 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""
本地数据源实现
从本地文件系统读取股票数据
"""
import os
import pandas as pd
import logging
from datetime import datetime
from .data_fetcher import BaseDataFetcher
class LocalDataFetcher(BaseDataFetcher):
"""
本地数据源实现
从本地CSV或Excel文件读取股票历史数据
"""
def __init__(self, data_dir='./data'):
"""
初始化本地数据获取器
Args:
data_dir: 数据文件存储目录
"""
super().__init__()
self.data_dir = data_dir
self.logger = logging.getLogger('local_data')
# 确保数据目录存在
if not os.path.exists(self.data_dir):
os.makedirs(self.data_dir)
self.logger.info(f'创建数据目录: {self.data_dir}')
self.logger.info(f'本地数据获取器初始化完成,数据目录: {self.data_dir}')
def get_stock_data(self, stock_code, start_date, end_date):
"""
从本地文件读取股票历史数据
Args:
stock_code: 股票代码
start_date: 开始日期 (格式: YYYYMMDD)
end_date: 结束日期 (格式: YYYYMMDD)
Returns:
pandas.DataFrame: 股票数据
"""
# 生成缓存键
cache_key = f"stock_{stock_code}_{start_date}_{end_date}"
# 检查缓存
cached_data = self._get_cached_data(cache_key)
if cached_data is not None:
self.logger.info(f'从缓存获取 {stock_code} 数据')
return cached_data
try:
# 构建文件路径
file_path = self._get_file_path(stock_code)
if not os.path.exists(file_path):
self.logger.warning(f'未找到 {stock_code} 的本地数据文件: {file_path}')
# 如果没有本地文件,可以返回一个示例数据用于测试
return self._get_sample_data(stock_code, start_date, end_date)
# 读取数据文件
if file_path.endswith('.csv'):
df = pd.read_csv(file_path)
elif file_path.endswith('.xlsx') or file_path.endswith('.xls'):
df = pd.read_excel(file_path)
elif file_path.endswith('.txt'):
# 读取制表符分隔的txt文件
df = pd.read_csv(
file_path,
delimiter='\t',
encoding='utf-8',
skiprows=1 # 跳过表头
)
# 设置列名(与文件实际格式匹配)
df.columns = ['ts_code', 'trade_date', 'open', 'high', 'low', 'close', 'pre_close', 'change', 'pct_chg', 'vol', 'amount']
# 重命名vol列为volume与其他数据源保持一致
df = df.rename(columns={'vol': 'volume'})
else:
raise ValueError(f'不支持的文件格式: {file_path}')
# 格式化数据
df = self._format_data(df)
# 过滤日期范围
start_dt = pd.to_datetime(start_date, format='%Y%m%d')
end_dt = pd.to_datetime(end_date, format='%Y%m%d')
df = df[(df.index >= start_dt) & (df.index <= end_dt)]
# 计算技术指标
df = self._calculate_technical_indicators(df)
# 缓存数据
self._cache_data(cache_key, df)
self.logger.info(f'成功读取 {stock_code} 数据,{start_date}{end_date},共 {len(df)} 条记录')
return df
except Exception as e:
self.logger.error(f'获取 {stock_code} 数据失败: {str(e)}')
return self._get_sample_data(stock_code, start_date, end_date)
def get_index_data(self, index_code, start_date, end_date):
"""
从本地文件读取指数历史数据
Args:
index_code: 指数代码
start_date: 开始日期 (格式: YYYYMMDD)
end_date: 结束日期 (格式: YYYYMMDD)
Returns:
pandas.DataFrame: 指数数据
"""
# 生成缓存键
cache_key = f"index_{index_code}_{start_date}_{end_date}"
# 检查缓存
cached_data = self._get_cached_data(cache_key)
if cached_data is not None:
self.logger.info(f'从缓存获取 {index_code} 数据')
return cached_data
try:
# 构建文件路径 - 指数文件放在index子目录
file_path = self._get_file_path(index_code, is_index=True)
if not os.path.exists(file_path):
self.logger.warning(f'未找到 {index_code} 的本地数据文件: {file_path}')
# 返回示例数据
return self._get_sample_data(index_code, start_date, end_date, is_index=True)
self.logger.info(f'从本地文件读取 {index_code} 数据: {file_path}')
# 读取数据文件
if file_path.endswith('.csv'):
df = pd.read_csv(file_path)
elif file_path.endswith('.xlsx') or file_path.endswith('.xls'):
df = pd.read_excel(file_path)
elif file_path.endswith('.txt'):
# 读取制表符分隔的txt文件
df = pd.read_csv(
file_path,
delimiter='\t',
encoding='utf-8',
skiprows=1 # 跳过表头
)
# 设置列名(与文件实际格式匹配)
df.columns = ['ts_code', 'trade_date', 'open', 'high', 'low', 'close', 'pre_close', 'change', 'pct_chg', 'vol', 'amount']
# 重命名vol列为volume与其他数据源保持一致
df = df.rename(columns={'vol': 'volume'})
else:
raise ValueError(f'不支持的文件格式: {file_path}')
# 格式化数据
df = self._format_data(df)
# 过滤日期范围
start_dt = pd.to_datetime(start_date)
end_dt = pd.to_datetime(end_date)
df = df[(df.index >= start_dt) & (df.index <= end_dt)]
# 缓存数据
self._cache_data(cache_key, df)
return df
except Exception as e:
self.logger.error(f'读取 {index_code} 数据时出错: {str(e)}')
# 返回示例数据
return self._get_sample_data(index_code, start_date, end_date, is_index=True)
def get_stock_basic_info(self, stock_code):
"""
获取股票基本信息
Args:
stock_code: 股票代码
Returns:
dict: 股票基本信息
"""
# 返回基本的示例信息
return {
'ts_code': stock_code,
'symbol': stock_code.split('.')[0],
'name': f'股票{stock_code.split(".")[0]}',
'industry': '未知',
'market': stock_code.split('.')[1] if '.' in stock_code else 'SH',
'list_date': '20100101'
}
def get_stock_basic(self):
"""
获取全市场股票基本信息列表
Returns:
pandas.DataFrame: 全市场股票基本信息
"""
# 由于是本地数据源,我们可以从本地目录中读取所有股票文件,
# 然后生成股票基本信息列表
try:
# 获取目录中的所有数据文件
files = os.listdir(self.data_dir)
stock_files = [f for f in files if f.endswith('_daily_data.txt')]
# 提取股票代码
stock_codes = []
for file in stock_files:
# 文件名格式000002.SZ_daily_data.txt
ts_code = file.replace('_daily_data.txt', '')
stock_codes.append(ts_code)
# 如果没有找到股票文件,返回示例数据
if not stock_codes:
self.logger.warning('未找到本地股票数据文件,返回示例股票列表')
return self._get_sample_stock_basic()
# 生成股票基本信息DataFrame
stock_info_list = []
for ts_code in stock_codes:
stock_info = self.get_stock_basic_info(ts_code)
stock_info_list.append(stock_info)
df = pd.DataFrame(stock_info_list)
self.logger.info(f'从本地目录读取到 {len(df)} 只股票的基本信息')
return df
except Exception as e:
self.logger.error(f'获取全市场股票信息时出错: {str(e)}')
# 返回示例数据
return self._get_sample_stock_basic()
def _get_sample_stock_basic(self):
"""
生成示例股票基本信息(用于测试)
Returns:
pandas.DataFrame: 示例股票基本信息
"""
sample_stocks = [
{'ts_code': '000001.SH', 'symbol': '000001', 'name': '平安银行', 'industry': '银行', 'market': 'SH', 'list_date': '19910403'},
{'ts_code': '000002.SZ', 'symbol': '000002', 'name': '万科A', 'industry': '房地产', 'market': 'SZ', 'list_date': '19910129'},
{'ts_code': '000004.SZ', 'symbol': '000004', 'name': '国华网安', 'industry': '计算机', 'market': 'SZ', 'list_date': '19910114'},
{'ts_code': '000005.SZ', 'symbol': '000005', 'name': '世纪星源', 'industry': '房地产', 'market': 'SZ', 'list_date': '19901210'},
{'ts_code': '000006.SZ', 'symbol': '000006', 'name': '深振业A', 'industry': '房地产', 'market': 'SZ', 'list_date': '19920427'},
{'ts_code': '600000.SH', 'symbol': '600000', 'name': '浦发银行', 'industry': '银行', 'market': 'SH', 'list_date': '19991110'},
{'ts_code': '600004.SH', 'symbol': '600004', 'name': '白云机场', 'industry': '交通运输', 'market': 'SH', 'list_date': '20030428'},
{'ts_code': '600005.SH', 'symbol': '600005', 'name': '武钢股份', 'industry': '钢铁', 'market': 'SH', 'list_date': '19990803'},
{'ts_code': '600006.SH', 'symbol': '600006', 'name': '东风汽车', 'industry': '汽车', 'market': 'SH', 'list_date': '19990727'},
{'ts_code': '600007.SH', 'symbol': '600007', 'name': '中国国贸', 'industry': '房地产', 'market': 'SH', 'list_date': '19990312'}
]
df = pd.DataFrame(sample_stocks)
return df
def get_daily_data(self, ts_code, start_date, end_date, adjust=None):
"""
获取日线数据与TushareDataFetcher兼容
Args:
ts_code: 股票代码
start_date: 开始日期 (格式: YYYYMMDD)
end_date: 结束日期 (格式: YYYYMMDD)
adjust: 复权类型(已废弃,默认使用除权数据)
Returns:
pandas.DataFrame: 股票数据
"""
return self.get_stock_data(ts_code, start_date, end_date)
def _calculate_technical_indicators(self, df):
"""
计算技术指标与TushareDataFetcher兼容
Args:
df: 原始数据DataFrame
Returns:
pandas.DataFrame: 包含技术指标的数据
"""
if df.empty:
return df
# 计算均线
df['ma5'] = df['close'].rolling(5).mean()
df['ma10'] = df['close'].rolling(10).mean()
df['ma20'] = df['close'].rolling(20).mean()
df['ma60'] = df['close'].rolling(60).mean()
# 计算成交量均线
df['volume_ma5'] = df['volume'].rolling(5).mean()
df['volume_ratio'] = df['volume'] / df['volume_ma5']
# 计算涨跌幅
df['pct_change'] = df['close'].pct_change() * 100
# 计算K线实体比例
df['entity_ratio'] = (df['close'] - df['open']).abs() / (df['high'] - df['low'])
return df
def _save_to_cache(self, ts_code, df):
"""
保存数据到本地缓存与TushareDataFetcher兼容
Args:
ts_code: 股票代码
df: 数据DataFrame
"""
# 本地数据已经是缓存,不需要再保存
pass
def _get_file_path(self, code, is_index=False):
"""
获取数据文件路径
Args:
code: 股票或指数代码
is_index: 是否为指数
Returns:
str: 文件路径
"""
# 根据是否为指数选择不同的存储路径和文件命名格式
if is_index:
# 指数数据存储在D:\gp_data\index目录
data_dir = "D:\gp_data\index"
# 用户的指数数据文件格式为000001.SH.txt
file_name = f'{code}.txt'
else:
# 股票数据存储在配置的DATA_DIR目录
data_dir = self.data_dir
# 用户的股票数据文件格式为000002.SZ_daily_data.txt
file_name = f'{code}_daily_data.txt'
return os.path.join(data_dir, file_name)
def _get_sample_data(self, code, start_date, end_date, is_index=False):
"""
生成示例数据(用于测试)
Args:
code: 股票或指数代码
start_date: 开始日期
end_date: 结束日期
is_index: 是否为指数
Returns:
pandas.DataFrame: 示例数据
"""
self.logger.info(f'生成 {code} 的示例数据')
# 生成日期范围
start_dt = pd.to_datetime(start_date)
end_dt = pd.to_datetime(end_date)
date_range = pd.date_range(start=start_dt, end=end_dt)
# 生成随机价格数据
import numpy as np
np.random.seed(42) # 设置随机种子以保证可重复性
# 初始价格
base_price = 1000 if is_index else 10
# 生成价格序列
returns = np.random.normal(0, 0.02, len(date_range))
prices = base_price * np.exp(np.cumsum(returns))
# 生成开盘价、最高价、最低价
open_prices = prices * np.random.uniform(0.99, 1.01, len(date_range))
high_prices = np.maximum(prices, open_prices) * np.random.uniform(1.0, 1.02, len(date_range))
low_prices = np.minimum(prices, open_prices) * np.random.uniform(0.98, 1.0, len(date_range))
# 生成成交量
volumes = np.random.randint(1000000, 10000000, len(date_range))
amounts = volumes * prices
# 创建DataFrame
df = pd.DataFrame({
'date': date_range,
'open': open_prices,
'high': high_prices,
'low': low_prices,
'close': prices,
'volume': volumes,
'amount': amounts
})
# 格式化数据
df = self._format_data(df)
return df
def save_data_to_local(self, code, df, is_index=False):
"""
将数据保存到本地文件
Args:
code: 股票或指数代码
df: 数据DataFrame
is_index: 是否为指数
"""
try:
file_path = self._get_file_path(code, is_index)
df.to_csv(file_path)
self.logger.info(f'数据已保存到: {file_path}')
return True
except Exception as e:
self.logger.error(f'保存数据到本地时出错: {str(e)}')
return False