249 lines
10 KiB
Python
249 lines
10 KiB
Python
"""
|
||
行情数据管理模块 - 支持多数据源(Tushare/AKShare)
|
||
"""
|
||
import pandas as pd
|
||
import tushare as ts
|
||
import akshare as ak
|
||
from typing import List, Dict, Optional # 添加 Optional 导入
|
||
import logging
|
||
import time
|
||
from enum import Enum, auto
|
||
import threading
|
||
|
||
# 添加 Tushare Token 设置
|
||
ts.set_token('9343e641869058684afeadfcfe7fd6684160852e52e85332a7734c8d')
|
||
|
||
# 通用股票排除规则
|
||
STOCK_EXCLUSION_RULES = {
|
||
'exclude_st': True, # 排除ST/*ST股票
|
||
'exclude_b_share': True, # 排除B股
|
||
'exclude_star_market': True, # 排除科创板(688开头)
|
||
'exclude_gem': False, # 排除创业板(300开头) - 默认不排除
|
||
'exclude_bj': True, # 排除北交所股票
|
||
'custom_exclusions': [] # 自定义排除列表
|
||
}
|
||
|
||
class DataSource(Enum):
|
||
TUSHARE = "tushare"
|
||
AKSHARE = "akshare"
|
||
LOCAL = "local"
|
||
|
||
class QuoteManager:
|
||
_instance = None
|
||
_lock = threading.Lock()
|
||
|
||
def __new__(cls):
|
||
with cls._lock:
|
||
if cls._instance is None:
|
||
cls._instance = super().__new__(cls)
|
||
cls._instance._init_manager()
|
||
return cls._instance
|
||
|
||
def _init_manager(self):
|
||
self._cache = {}
|
||
self._cache_ttl = 60
|
||
self._data_source = DataSource.TUSHARE
|
||
self.max_retries = 3 # 默认最大重试次数
|
||
self.retry_interval = 2 # 默认重试间隔(秒)
|
||
|
||
def set_retry_policy(self, max_retries: int, retry_interval: float = 2):
|
||
"""设置重试策略"""
|
||
self.max_retries = max_retries
|
||
self.retry_interval = retry_interval
|
||
|
||
def set_data_source(self, source: DataSource):
|
||
"""设置数据源"""
|
||
self._data_source = source
|
||
|
||
def get_realtime_quotes(self, codes: List[str]) -> Dict[str, pd.DataFrame]:
|
||
"""获取实时行情(带重试机制)"""
|
||
last_error = None
|
||
for attempt in range(self.max_retries):
|
||
try:
|
||
if self._data_source == DataSource.TUSHARE:
|
||
return self._get_tushare_quotes(codes)
|
||
elif self._data_source == DataSource.AKSHARE:
|
||
return self._get_akshare_quotes(codes)
|
||
else:
|
||
raise ValueError("不支持的数据源")
|
||
except Exception as e:
|
||
last_error = e
|
||
if attempt < self.max_retries - 1: # 不是最后一次尝试
|
||
time.sleep(self.retry_interval)
|
||
continue
|
||
raise Exception(f"获取行情失败(尝试{self.max_retries}次): {str(last_error)}")
|
||
|
||
def _get_tushare_quotes(self, codes: List[str], max_retries: int = 3, retry_interval: float = 2) -> Dict[str, pd.DataFrame]:
|
||
"""使用 Tushare 获取实时行情(带重试机制)"""
|
||
for attempt in range(max_retries):
|
||
try:
|
||
df = ts.realtime_quote(ts_code=','.join(codes))
|
||
if df is None or df.empty:
|
||
raise Exception("返回数据为空")
|
||
return {row['TS_CODE']: row for _, row in df.iterrows()}
|
||
except Exception as e:
|
||
if attempt < max_retries - 1:
|
||
time.sleep(retry_interval)
|
||
continue
|
||
raise Exception(f"Tushare 行情获取失败(尝试{max_retries}次): {str(e)}")
|
||
|
||
def _get_akshare_quotes(self, codes: List[str]) -> Dict[str, pd.DataFrame]:
|
||
"""使用 AKShare 获取实时行情"""
|
||
# 这里需要实现 AKShare 的获取逻辑
|
||
raise NotImplementedError("AKShare 实现待完成")
|
||
|
||
def _convert_akshare_format(self, row) -> Dict:
|
||
"""将AKShare数据格式转换为统一格式"""
|
||
return {
|
||
'TS_CODE': row['代码'],
|
||
'PRICE': row['最新价'],
|
||
'OPEN': row['今开'],
|
||
'PRE_CLOSE': row['昨收'],
|
||
'HIGH': row['最高'],
|
||
'LOW': row['最低'],
|
||
'VOLUME': row['成交量']
|
||
}
|
||
|
||
# ... existing code ...
|
||
def _get_tushare_all_stocks(self) -> List[str]:
|
||
"""使用Tushare获取所有A股股票列表"""
|
||
try:
|
||
pro = ts.pro_api() # 获取Tushare专业版API接口
|
||
# 获取所有股票列表
|
||
stock_basic = pro.stock_basic(exchange='', list_status='L',
|
||
fields='ts_code,symbol,name,area,industry,list_date')
|
||
# 根据通用规则过滤股票
|
||
filtered_stocks = self._filter_stocks(stock_basic['ts_code'].tolist())
|
||
return filtered_stocks
|
||
except Exception as e:
|
||
logging.error(f"获取股票列表失败: {str(e)}")
|
||
return []
|
||
|
||
def _filter_stocks(self, stock_list: List[str]) -> List[str]:
|
||
"""
|
||
根据通用规则过滤股票列表
|
||
:param stock_list: 原始股票列表
|
||
:return: 过滤后的股票列表
|
||
"""
|
||
filtered_stocks = []
|
||
|
||
for stock in stock_list:
|
||
exclude = False
|
||
|
||
# 检查是否在自定义排除列表中
|
||
if stock in STOCK_EXCLUSION_RULES.get('custom_exclusions', []):
|
||
exclude = True
|
||
|
||
# 检查是否排除ST/*ST股票
|
||
if STOCK_EXCLUSION_RULES.get('exclude_st', True):
|
||
# 注意:这里需要获取股票名称来判断是否为ST股票
|
||
# 在实际应用中,您可能需要通过其他方式获取股票名称
|
||
pass # ST股票的判断需要额外的数据支持
|
||
|
||
# 检查是否排除B股
|
||
if STOCK_EXCLUSION_RULES.get('exclude_b_share', True) and stock.endswith('.BJ'):
|
||
exclude = True
|
||
|
||
# 检查是否排除科创板股票(688开头)
|
||
if STOCK_EXCLUSION_RULES.get('exclude_star_market', True) and stock.startswith('688'):
|
||
exclude = True
|
||
|
||
# 检查是否排除创业板股票(300开头)
|
||
if STOCK_EXCLUSION_RULES.get('exclude_gem', False) and stock.startswith('300'):
|
||
exclude = True
|
||
|
||
# 检查是否排除北交所股票
|
||
if STOCK_EXCLUSION_RULES.get('exclude_bj', True) and stock.endswith('.BJ'):
|
||
exclude = True
|
||
|
||
# 如果没有被排除,则添加到结果列表中
|
||
if not exclude:
|
||
filtered_stocks.append(stock)
|
||
|
||
return filtered_stocks
|
||
|
||
def set_exclusion_rules(self, rules: Dict):
|
||
"""
|
||
设置股票排除规则
|
||
:param rules: 排除规则字典
|
||
"""
|
||
global STOCK_EXCLUSION_RULES
|
||
STOCK_EXCLUSION_RULES.update(rules)
|
||
|
||
def get_exclusion_rules(self) -> Dict:
|
||
"""
|
||
获取当前的股票排除规则
|
||
:return: 排除规则字典
|
||
"""
|
||
global STOCK_EXCLUSION_RULES
|
||
return STOCK_EXCLUSION_RULES.copy()
|
||
|
||
def get_quote(self, code: str) -> Optional[Dict]:
|
||
"""获取单个股票行情(兼容旧接口)"""
|
||
try:
|
||
quotes = self.get_realtime_quotes([code])
|
||
if not quotes or code not in quotes:
|
||
return None
|
||
row = quotes[code]
|
||
return {
|
||
'price': row['PRICE'],
|
||
'avg_price': (row['OPEN'] + row['PRE_CLOSE']) / 2,
|
||
'volume': row['VOLUME']
|
||
}
|
||
except Exception as e:
|
||
logging.error(f"获取股票{code}行情失败: {str(e)}")
|
||
return None
|
||
|
||
def get_daily_data(self, codes: List[str], start_date: str = None, end_date: str = None,
|
||
max_retries: Optional[int] = None, retry_interval: Optional[float] = None) -> Dict[
|
||
str, pd.DataFrame]:
|
||
"""获取股票日线数据"""
|
||
# 如果没有传入参数,则使用实例的默认值
|
||
max_retries = max_retries if max_retries is not None else self.max_retries
|
||
retry_interval = retry_interval if retry_interval is not None else self.retry_interval
|
||
|
||
try:
|
||
if self._data_source == DataSource.TUSHARE:
|
||
return self._get_tushare_daily_data(codes, start_date, end_date, max_retries, retry_interval)
|
||
elif self._data_source == DataSource.AKSHARE:
|
||
return self._get_akshare_daily_data(codes, start_date, end_date, max_retries, retry_interval)
|
||
else:
|
||
raise ValueError("不支持的数据源")
|
||
except Exception as e:
|
||
logging.error(f"获取日线数据失败: {str(e)}")
|
||
return {}
|
||
|
||
def _get_tushare_daily_data(self, codes: List[str], start_date: str = None, end_date: str = None,
|
||
max_retries: int = 3, retry_interval: float = 2) -> Dict[str, pd.DataFrame]:
|
||
"""使用Tushare获取日线数据"""
|
||
daily_data = {}
|
||
pro = ts.pro_api() # 获取Tushare专业版API接口
|
||
for code in codes:
|
||
for attempt in range(max_retries):
|
||
try:
|
||
# 如果没有指定日期范围,默认获取最近30天的数据
|
||
if not end_date:
|
||
end_date = pd.Timestamp.now().strftime('%Y%m%d')
|
||
if not start_date:
|
||
start_date = (pd.Timestamp.now() - pd.Timedelta(days=30)).strftime('%Y%m%d')
|
||
|
||
df = pro.daily(ts_code=code, start_date=start_date, end_date=end_date)
|
||
if df is not None and not df.empty:
|
||
df = df.sort_values('trade_date')
|
||
daily_data[code] = df
|
||
break # 成功获取数据,跳出重试循环
|
||
except Exception as e:
|
||
if attempt < max_retries - 1:
|
||
time.sleep(retry_interval)
|
||
continue
|
||
logging.error(f"获取{code}日线数据失败: {str(e)}")
|
||
break
|
||
return daily_data
|
||
|
||
def _get_akshare_daily_data(self, codes: List[str], start_date: str = None, end_date: str = None,
|
||
max_retries: int = 3, retry_interval: float = 2) -> Dict[str, pd.DataFrame]:
|
||
"""使用AKShare获取日线数据"""
|
||
# AKShare实现待完成
|
||
raise NotImplementedError("AKShare日线数据获取待实现")
|
||
|