Files
backtrader/data/tushare_data.py
2026-01-17 21:21:30 +08:00

145 lines
4.7 KiB
Python

import tushare as ts
import pandas as pd
import numpy as np
import time
import os
from datetime import datetime, timedelta
from utils.logger import setup_logger
logger = setup_logger()
class TushareDataFetcher:
"""TuShare Pro 数据获取器,带频率控制"""
def __init__(self, token, call_interval=0.1):
"""
初始化TuShare Pro接口
Args:
token: TuShare Pro token
call_interval: 调用间隔(秒)
"""
ts.set_token(token)
self.pro = ts.pro_api()
self.last_call_time = 0
self.call_interval = call_interval
self._check_connection()
def _rate_limit(self):
"""频率控制"""
current_time = time.time()
elapsed = current_time - self.last_call_time
if elapsed < self.call_interval:
time.sleep(self.call_interval - elapsed)
self.last_call_time = time.time()
def _check_connection(self):
"""测试连接"""
try:
self._rate_limit()
df = self.pro.trade_cal(exchange='SSE', start_date='20230101', end_date='20230105')
logger.info("TuShare Pro连接成功")
except Exception as e:
logger.error(f"TuShare Pro连接失败: {e}")
raise
def get_daily_data(self, ts_code, start_date, end_date, adjust=None):
"""
获取除权日线数据
Args:
ts_code: 股票代码
start_date: 开始日期
end_date: 结束日期
adjust: 复权类型(已废弃,默认使用除权数据)
Returns:
DataFrame
"""
try:
self._rate_limit()
# 获取基础日线数据(使用原始除权数据)
df = self.pro.daily(ts_code=ts_code,
start_date=start_date,
end_date=end_date)
if df.empty:
return pd.DataFrame()
# 转换交易日期格式并设置为索引
df['trade_date'] = pd.to_datetime(df['trade_date'])
df.set_index('trade_date', inplace=True)
# 计算技术指标
df = self._calculate_technical_indicators(df)
# 保存到本地缓存
self._save_to_cache(ts_code, df)
return df
except Exception as e:
logger.error(f"获取{ts_code}数据失败: {e}")
return pd.DataFrame()
def _calculate_technical_indicators(self, df):
"""计算技术指标"""
if df.empty:
return df
# 计算均线
df['ma5'] = df['close'].rolling(5).mean()
df['ma10'] = df['close'].rolling(10).mean()
df['ma20'] = df['close'].rolling(20).mean()
df['ma60'] = df['close'].rolling(60).mean()
# 计算成交量均线
df['volume_ma5'] = df['vol'].rolling(5).mean()
df['volume_ratio'] = df['vol'] / df['volume_ma5']
# 计算涨跌幅
df['pct_change'] = df['close'].pct_change() * 100
# 计算K线实体比例
df['entity_ratio'] = (df['close'] - df['open']).abs() / (df['high'] - df['low'])
return df
def _save_to_cache(self, ts_code, df):
"""保存数据到本地缓存"""
if df.empty:
return
cache_dir = "./data/cache"
os.makedirs(cache_dir, exist_ok=True)
cache_file = f"{cache_dir}/{ts_code}.pkl"
df.to_pickle(cache_file)
logger.debug(f"已缓存{ts_code}数据到{cache_file}")
def get_index_data(self, index_code='000001.SH', start_date=None, end_date=None):
"""获取大盘指数数据"""
try:
self._rate_limit()
# 获取指数日线数据
df = self.pro.index_daily(ts_code=index_code,
start_date=start_date,
end_date=end_date)
if df.empty:
return pd.DataFrame()
# 转换交易日期格式并设置为索引
df['trade_date'] = pd.to_datetime(df['trade_date'])
df.set_index('trade_date', inplace=True)
# 保存到本地缓存
self._save_to_cache(index_code, df)
return df
except Exception as e:
logger.error(f"获取{index_code}指数数据失败: {e}")
return pd.DataFrame()
def get_stock_basic(self):
"""获取股票基本信息"""
self._rate_limit()
return self.pro.stock_basic(exchange='', list_status='L')