104 lines
3.3 KiB
Python
104 lines
3.3 KiB
Python
"""
|
||
# 作 者:84028
|
||
# 时 间:2024/2/28 20:51
|
||
# tsdpsdk
|
||
"""
|
||
import time
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
from typing import Union
|
||
|
||
import tgw
|
||
from ..model.tick import TsTick, TsTickIdx
|
||
from .convert import convert_tick_index, convert_tick_stock, get_tgw_type_and_code
|
||
from .login import login
|
||
from ... import get_token, set_token
|
||
from ...util.verify_token import verify_token
|
||
|
||
threadpool = ThreadPoolExecutor(5)
|
||
|
||
|
||
class TgwSubscribe(tgw.IPushSpi):
|
||
def __init__(self, username, password, host, port=8600) -> None:
|
||
super().__init__()
|
||
self.callback_map = {}
|
||
self.username = username
|
||
self.password = password
|
||
self.host = host
|
||
self.port = port
|
||
|
||
def process_data(self, data):
|
||
_type = data.get('variety_category')
|
||
item = None
|
||
if _type == tgw.VarietyCategory.kIndex:
|
||
item = convert_tick_index(data)
|
||
elif _type == tgw.VarietyCategory.kStock:
|
||
item = convert_tick_stock(data)
|
||
if item:
|
||
func = self.callback_map.get(item.ts_code)
|
||
if func and callable(func):
|
||
func(item)
|
||
|
||
def OnMDSnapshot(self, data, err):
|
||
if data:
|
||
threadpool.submit(self.process_data, data[0])
|
||
pass
|
||
else:
|
||
print(err)
|
||
|
||
def OnMDIndexSnapshot(self, data, err):
|
||
if data:
|
||
threadpool.submit(self.process_data, data[0])
|
||
else:
|
||
print(err)
|
||
|
||
def register(self, ts_codes):
|
||
def decorator(func):
|
||
for ts_code in ts_codes:
|
||
self.callback_map[ts_code] = func
|
||
return func
|
||
return decorator
|
||
|
||
def run(self):
|
||
token = get_token()
|
||
r = verify_token(token, 'tgw_subscribe', self.username)
|
||
if r.status_code == 200:
|
||
resp_data = r.json()
|
||
if resp_data.get("message") != "success":
|
||
print(f"验证token出错,{r.text}")
|
||
raise PermissionError(f"请先设置Tushare的token,{r['msg']}")
|
||
|
||
login(username=self.username, password=self.password, host=self.host, port=self.port)
|
||
items = []
|
||
for ts_code in self.callback_map:
|
||
t, c = get_tgw_type_and_code(ts_code)
|
||
item = tgw.SubscribeItem()
|
||
item.flag = tgw.SubscribeDataType.kSnapshot
|
||
item.category_type = tgw.VarietyCategory.kStock
|
||
item.security_code = c
|
||
item.market = t
|
||
items.append(item)
|
||
item = tgw.SubscribeItem()
|
||
item.flag = tgw.SubscribeDataType.kIndexSnapshot
|
||
item.category_type = tgw.VarietyCategory.kIndex
|
||
item.security_code = c
|
||
item.market = t
|
||
items.append(item)
|
||
|
||
self.SetDfFormat(False)
|
||
success = tgw.Subscribe(items, self)
|
||
if success != tgw.ErrorCode.kSuccess:
|
||
print(tgw.GetErrorMsg(success))
|
||
while True:
|
||
time.sleep(10)
|
||
|
||
|
||
def demo():
|
||
set_token('xxxxxx')
|
||
sub = TgwSubscribe(username='username', password='password', host='host', port=8600)
|
||
|
||
@sub.register(ts_codes=['600000.SH', '000300.SH'])
|
||
def just_print_index(data: Union[TsTick, TsTickIdx]):
|
||
print(data)
|
||
|
||
sub.run()
|