145 lines
4.7 KiB
Python
145 lines
4.7 KiB
Python
import tushare as ts
|
|
import pandas as pd
|
|
import numpy as np
|
|
import time
|
|
import os
|
|
from datetime import datetime, timedelta
|
|
from utils.logger import setup_logger
|
|
|
|
logger = setup_logger()
|
|
|
|
class TushareDataFetcher:
|
|
"""TuShare Pro 数据获取器,带频率控制"""
|
|
|
|
def __init__(self, token, call_interval=0.1):
|
|
"""
|
|
初始化TuShare Pro接口
|
|
Args:
|
|
token: TuShare Pro token
|
|
call_interval: 调用间隔(秒)
|
|
"""
|
|
ts.set_token(token)
|
|
self.pro = ts.pro_api()
|
|
self.last_call_time = 0
|
|
self.call_interval = call_interval
|
|
self._check_connection()
|
|
|
|
def _rate_limit(self):
|
|
"""频率控制"""
|
|
current_time = time.time()
|
|
elapsed = current_time - self.last_call_time
|
|
if elapsed < self.call_interval:
|
|
time.sleep(self.call_interval - elapsed)
|
|
self.last_call_time = time.time()
|
|
|
|
def _check_connection(self):
|
|
"""测试连接"""
|
|
try:
|
|
self._rate_limit()
|
|
df = self.pro.trade_cal(exchange='SSE', start_date='20230101', end_date='20230105')
|
|
logger.info("TuShare Pro连接成功")
|
|
except Exception as e:
|
|
logger.error(f"TuShare Pro连接失败: {e}")
|
|
raise
|
|
|
|
def get_daily_data(self, ts_code, start_date, end_date, adjust=None):
|
|
"""
|
|
获取除权日线数据
|
|
Args:
|
|
ts_code: 股票代码
|
|
start_date: 开始日期
|
|
end_date: 结束日期
|
|
adjust: 复权类型(已废弃,默认使用除权数据)
|
|
Returns:
|
|
DataFrame
|
|
"""
|
|
try:
|
|
self._rate_limit()
|
|
# 获取基础日线数据(使用原始除权数据)
|
|
df = self.pro.daily(ts_code=ts_code,
|
|
start_date=start_date,
|
|
end_date=end_date)
|
|
|
|
if df.empty:
|
|
return pd.DataFrame()
|
|
|
|
# 转换交易日期格式并设置为索引
|
|
df['trade_date'] = pd.to_datetime(df['trade_date'])
|
|
df.set_index('trade_date', inplace=True)
|
|
|
|
# 计算技术指标
|
|
df = self._calculate_technical_indicators(df)
|
|
|
|
# 保存到本地缓存
|
|
self._save_to_cache(ts_code, df)
|
|
|
|
return df
|
|
|
|
except Exception as e:
|
|
logger.error(f"获取{ts_code}数据失败: {e}")
|
|
return pd.DataFrame()
|
|
|
|
def _calculate_technical_indicators(self, df):
|
|
"""计算技术指标"""
|
|
if df.empty:
|
|
return df
|
|
|
|
# 计算均线
|
|
df['ma5'] = df['close'].rolling(5).mean()
|
|
df['ma10'] = df['close'].rolling(10).mean()
|
|
df['ma20'] = df['close'].rolling(20).mean()
|
|
df['ma60'] = df['close'].rolling(60).mean()
|
|
|
|
# 计算成交量均线
|
|
df['volume_ma5'] = df['vol'].rolling(5).mean()
|
|
df['volume_ratio'] = df['vol'] / df['volume_ma5']
|
|
|
|
# 计算涨跌幅
|
|
df['pct_change'] = df['close'].pct_change() * 100
|
|
|
|
# 计算K线实体比例
|
|
df['entity_ratio'] = (df['close'] - df['open']).abs() / (df['high'] - df['low'])
|
|
|
|
return df
|
|
|
|
def _save_to_cache(self, ts_code, df):
|
|
"""保存数据到本地缓存"""
|
|
if df.empty:
|
|
return
|
|
|
|
cache_dir = "./data/cache"
|
|
os.makedirs(cache_dir, exist_ok=True)
|
|
|
|
cache_file = f"{cache_dir}/{ts_code}.pkl"
|
|
df.to_pickle(cache_file)
|
|
logger.debug(f"已缓存{ts_code}数据到{cache_file}")
|
|
|
|
def get_index_data(self, index_code='000001.SH', start_date=None, end_date=None):
|
|
"""获取大盘指数数据"""
|
|
try:
|
|
self._rate_limit()
|
|
# 获取指数日线数据
|
|
df = self.pro.index_daily(ts_code=index_code,
|
|
start_date=start_date,
|
|
end_date=end_date)
|
|
|
|
if df.empty:
|
|
return pd.DataFrame()
|
|
|
|
# 转换交易日期格式并设置为索引
|
|
df['trade_date'] = pd.to_datetime(df['trade_date'])
|
|
df.set_index('trade_date', inplace=True)
|
|
|
|
# 保存到本地缓存
|
|
self._save_to_cache(index_code, df)
|
|
|
|
return df
|
|
|
|
except Exception as e:
|
|
logger.error(f"获取{index_code}指数数据失败: {e}")
|
|
return pd.DataFrame()
|
|
|
|
def get_stock_basic(self):
|
|
"""获取股票基本信息"""
|
|
self._rate_limit()
|
|
return self.pro.stock_basic(exchange='', list_status='L') |