# -*- coding: utf-8 -*- """ 本地数据源实现 从本地文件系统读取股票数据 """ import os import pandas as pd import logging from datetime import datetime from .data_fetcher import BaseDataFetcher class LocalDataFetcher(BaseDataFetcher): """ 本地数据源实现 从本地CSV或Excel文件读取股票历史数据 """ def __init__(self, data_dir='./data'): """ 初始化本地数据获取器 Args: data_dir: 数据文件存储目录 """ super().__init__() self.data_dir = data_dir self.logger = logging.getLogger('local_data') # 确保数据目录存在 if not os.path.exists(self.data_dir): os.makedirs(self.data_dir) self.logger.info(f'创建数据目录: {self.data_dir}') self.logger.info(f'本地数据获取器初始化完成,数据目录: {self.data_dir}') def get_stock_data(self, stock_code, start_date, end_date): """ 从本地文件读取股票历史数据 Args: stock_code: 股票代码 start_date: 开始日期 (格式: YYYYMMDD) end_date: 结束日期 (格式: YYYYMMDD) Returns: pandas.DataFrame: 股票数据 """ # 生成缓存键 cache_key = f"stock_{stock_code}_{start_date}_{end_date}" # 检查缓存 cached_data = self._get_cached_data(cache_key) if cached_data is not None: self.logger.info(f'从缓存获取 {stock_code} 数据') return cached_data try: # 构建文件路径 file_path = self._get_file_path(stock_code) if not os.path.exists(file_path): self.logger.warning(f'未找到 {stock_code} 的本地数据文件: {file_path}') # 如果没有本地文件,可以返回一个示例数据用于测试 return self._get_sample_data(stock_code, start_date, end_date) # 读取数据文件 if file_path.endswith('.csv'): df = pd.read_csv(file_path) elif file_path.endswith('.xlsx') or file_path.endswith('.xls'): df = pd.read_excel(file_path) elif file_path.endswith('.txt'): # 读取制表符分隔的txt文件 df = pd.read_csv( file_path, delimiter='\t', encoding='utf-8', skiprows=1 # 跳过表头 ) # 设置列名(与文件实际格式匹配) df.columns = ['ts_code', 'trade_date', 'open', 'high', 'low', 'close', 'pre_close', 'change', 'pct_chg', 'vol', 'amount'] # 重命名vol列为volume,与其他数据源保持一致 df = df.rename(columns={'vol': 'volume'}) else: raise ValueError(f'不支持的文件格式: {file_path}') # 格式化数据 df = self._format_data(df) # 过滤日期范围 start_dt = pd.to_datetime(start_date, format='%Y%m%d') end_dt = pd.to_datetime(end_date, format='%Y%m%d') df = df[(df.index >= start_dt) & (df.index <= end_dt)] # 计算技术指标 df = self._calculate_technical_indicators(df) # 缓存数据 self._cache_data(cache_key, df) self.logger.info(f'成功读取 {stock_code} 数据,{start_date} 至 {end_date},共 {len(df)} 条记录') return df except Exception as e: self.logger.error(f'获取 {stock_code} 数据失败: {str(e)}') return self._get_sample_data(stock_code, start_date, end_date) def get_index_data(self, index_code, start_date, end_date): """ 从本地文件读取指数历史数据 Args: index_code: 指数代码 start_date: 开始日期 (格式: YYYYMMDD) end_date: 结束日期 (格式: YYYYMMDD) Returns: pandas.DataFrame: 指数数据 """ # 生成缓存键 cache_key = f"index_{index_code}_{start_date}_{end_date}" # 检查缓存 cached_data = self._get_cached_data(cache_key) if cached_data is not None: self.logger.info(f'从缓存获取 {index_code} 数据') return cached_data try: # 构建文件路径 - 指数文件放在index子目录 file_path = self._get_file_path(index_code, is_index=True) if not os.path.exists(file_path): self.logger.warning(f'未找到 {index_code} 的本地数据文件: {file_path}') # 返回示例数据 return self._get_sample_data(index_code, start_date, end_date, is_index=True) self.logger.info(f'从本地文件读取 {index_code} 数据: {file_path}') # 读取数据文件 if file_path.endswith('.csv'): df = pd.read_csv(file_path) elif file_path.endswith('.xlsx') or file_path.endswith('.xls'): df = pd.read_excel(file_path) elif file_path.endswith('.txt'): # 读取制表符分隔的txt文件 df = pd.read_csv( file_path, delimiter='\t', encoding='utf-8', skiprows=1 # 跳过表头 ) # 设置列名(与文件实际格式匹配) df.columns = ['ts_code', 'trade_date', 'open', 'high', 'low', 'close', 'pre_close', 'change', 'pct_chg', 'vol', 'amount'] # 重命名vol列为volume,与其他数据源保持一致 df = df.rename(columns={'vol': 'volume'}) else: raise ValueError(f'不支持的文件格式: {file_path}') # 格式化数据 df = self._format_data(df) # 过滤日期范围 start_dt = pd.to_datetime(start_date) end_dt = pd.to_datetime(end_date) df = df[(df.index >= start_dt) & (df.index <= end_dt)] # 缓存数据 self._cache_data(cache_key, df) return df except Exception as e: self.logger.error(f'读取 {index_code} 数据时出错: {str(e)}') # 返回示例数据 return self._get_sample_data(index_code, start_date, end_date, is_index=True) def get_stock_basic_info(self, stock_code): """ 获取股票基本信息 Args: stock_code: 股票代码 Returns: dict: 股票基本信息 """ # 返回基本的示例信息 return { 'ts_code': stock_code, 'symbol': stock_code.split('.')[0], 'name': f'股票{stock_code.split(".")[0]}', 'industry': '未知', 'market': stock_code.split('.')[1] if '.' in stock_code else 'SH', 'list_date': '20100101' } def get_stock_basic(self): """ 获取全市场股票基本信息列表 Returns: pandas.DataFrame: 全市场股票基本信息 """ # 由于是本地数据源,我们可以从本地目录中读取所有股票文件, # 然后生成股票基本信息列表 try: # 获取目录中的所有数据文件 files = os.listdir(self.data_dir) stock_files = [f for f in files if f.endswith('_daily_data.txt')] # 提取股票代码 stock_codes = [] for file in stock_files: # 文件名格式:000002.SZ_daily_data.txt ts_code = file.replace('_daily_data.txt', '') stock_codes.append(ts_code) # 如果没有找到股票文件,返回示例数据 if not stock_codes: self.logger.warning('未找到本地股票数据文件,返回示例股票列表') return self._get_sample_stock_basic() # 生成股票基本信息DataFrame stock_info_list = [] for ts_code in stock_codes: stock_info = self.get_stock_basic_info(ts_code) stock_info_list.append(stock_info) df = pd.DataFrame(stock_info_list) self.logger.info(f'从本地目录读取到 {len(df)} 只股票的基本信息') return df except Exception as e: self.logger.error(f'获取全市场股票信息时出错: {str(e)}') # 返回示例数据 return self._get_sample_stock_basic() def _get_sample_stock_basic(self): """ 生成示例股票基本信息(用于测试) Returns: pandas.DataFrame: 示例股票基本信息 """ sample_stocks = [ {'ts_code': '000001.SH', 'symbol': '000001', 'name': '平安银行', 'industry': '银行', 'market': 'SH', 'list_date': '19910403'}, {'ts_code': '000002.SZ', 'symbol': '000002', 'name': '万科A', 'industry': '房地产', 'market': 'SZ', 'list_date': '19910129'}, {'ts_code': '000004.SZ', 'symbol': '000004', 'name': '国华网安', 'industry': '计算机', 'market': 'SZ', 'list_date': '19910114'}, {'ts_code': '000005.SZ', 'symbol': '000005', 'name': '世纪星源', 'industry': '房地产', 'market': 'SZ', 'list_date': '19901210'}, {'ts_code': '000006.SZ', 'symbol': '000006', 'name': '深振业A', 'industry': '房地产', 'market': 'SZ', 'list_date': '19920427'}, {'ts_code': '600000.SH', 'symbol': '600000', 'name': '浦发银行', 'industry': '银行', 'market': 'SH', 'list_date': '19991110'}, {'ts_code': '600004.SH', 'symbol': '600004', 'name': '白云机场', 'industry': '交通运输', 'market': 'SH', 'list_date': '20030428'}, {'ts_code': '600005.SH', 'symbol': '600005', 'name': '武钢股份', 'industry': '钢铁', 'market': 'SH', 'list_date': '19990803'}, {'ts_code': '600006.SH', 'symbol': '600006', 'name': '东风汽车', 'industry': '汽车', 'market': 'SH', 'list_date': '19990727'}, {'ts_code': '600007.SH', 'symbol': '600007', 'name': '中国国贸', 'industry': '房地产', 'market': 'SH', 'list_date': '19990312'} ] df = pd.DataFrame(sample_stocks) return df def get_daily_data(self, ts_code, start_date, end_date, adjust=None): """ 获取日线数据(与TushareDataFetcher兼容) Args: ts_code: 股票代码 start_date: 开始日期 (格式: YYYYMMDD) end_date: 结束日期 (格式: YYYYMMDD) adjust: 复权类型(已废弃,默认使用除权数据) Returns: pandas.DataFrame: 股票数据 """ return self.get_stock_data(ts_code, start_date, end_date) def _calculate_technical_indicators(self, df): """ 计算技术指标(与TushareDataFetcher兼容) Args: df: 原始数据DataFrame Returns: pandas.DataFrame: 包含技术指标的数据 """ if df.empty: return df # 计算均线 df['ma5'] = df['close'].rolling(5).mean() df['ma10'] = df['close'].rolling(10).mean() df['ma20'] = df['close'].rolling(20).mean() df['ma60'] = df['close'].rolling(60).mean() # 计算成交量均线 df['volume_ma5'] = df['volume'].rolling(5).mean() df['volume_ratio'] = df['volume'] / df['volume_ma5'] # 计算涨跌幅 df['pct_change'] = df['close'].pct_change() * 100 # 计算K线实体比例 df['entity_ratio'] = (df['close'] - df['open']).abs() / (df['high'] - df['low']) return df def _save_to_cache(self, ts_code, df): """ 保存数据到本地缓存(与TushareDataFetcher兼容) Args: ts_code: 股票代码 df: 数据DataFrame """ # 本地数据已经是缓存,不需要再保存 pass def _get_file_path(self, code, is_index=False): """ 获取数据文件路径 Args: code: 股票或指数代码 is_index: 是否为指数 Returns: str: 文件路径 """ # 根据是否为指数选择不同的存储路径和文件命名格式 if is_index: # 指数数据存储在D:\gp_data\index目录 data_dir = "D:\gp_data\index" # 用户的指数数据文件格式为:000001.SH.txt file_name = f'{code}.txt' else: # 股票数据存储在配置的DATA_DIR目录 data_dir = self.data_dir # 用户的股票数据文件格式为:000002.SZ_daily_data.txt file_name = f'{code}_daily_data.txt' return os.path.join(data_dir, file_name) def _get_sample_data(self, code, start_date, end_date, is_index=False): """ 生成示例数据(用于测试) Args: code: 股票或指数代码 start_date: 开始日期 end_date: 结束日期 is_index: 是否为指数 Returns: pandas.DataFrame: 示例数据 """ self.logger.info(f'生成 {code} 的示例数据') # 生成日期范围 start_dt = pd.to_datetime(start_date) end_dt = pd.to_datetime(end_date) date_range = pd.date_range(start=start_dt, end=end_dt) # 生成随机价格数据 import numpy as np np.random.seed(42) # 设置随机种子以保证可重复性 # 初始价格 base_price = 1000 if is_index else 10 # 生成价格序列 returns = np.random.normal(0, 0.02, len(date_range)) prices = base_price * np.exp(np.cumsum(returns)) # 生成开盘价、最高价、最低价 open_prices = prices * np.random.uniform(0.99, 1.01, len(date_range)) high_prices = np.maximum(prices, open_prices) * np.random.uniform(1.0, 1.02, len(date_range)) low_prices = np.minimum(prices, open_prices) * np.random.uniform(0.98, 1.0, len(date_range)) # 生成成交量 volumes = np.random.randint(1000000, 10000000, len(date_range)) amounts = volumes * prices # 创建DataFrame df = pd.DataFrame({ 'date': date_range, 'open': open_prices, 'high': high_prices, 'low': low_prices, 'close': prices, 'volume': volumes, 'amount': amounts }) # 格式化数据 df = self._format_data(df) return df def save_data_to_local(self, code, df, is_index=False): """ 将数据保存到本地文件 Args: code: 股票或指数代码 df: 数据DataFrame is_index: 是否为指数 """ try: file_path = self._get_file_path(code, is_index) df.to_csv(file_path) self.logger.info(f'数据已保存到: {file_path}') return True except Exception as e: self.logger.error(f'保存数据到本地时出错: {str(e)}') return False