Files
backtrader/venv/Lib/site-packages/tushare/subs/ts_subs/subscribe.py
2026-01-17 21:21:30 +08:00

213 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding:utf-8 -*-
'''
@group:waditu
@author: DY
'''
import _thread as thread
import fnmatch
import json
import logging, sys
import re
import ssl
import threading
import time
from collections import defaultdict
from functools import wraps
from multiprocessing.context import Process
import websocket
from websocket import WebSocketConnectionClosedException
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
import logging
logger = logging.getLogger(__name__)
class TsSubscribe(object):
def __init__(self, token='', callback_mode='multi-thread', debug=False):
self.url = 'wss://ws.tushare.pro/listening'
self.token = token
self.debug = debug
self.callback_mode = callback_mode
self.topics = defaultdict(lambda: list())
self.callback_funcs = defaultdict(lambda: list())
self.websocket = None
def threading_keepalive_ping(self):
def ping():
time.sleep(30)
req_data = {
"action": "ping"
}
self.websocket.send(json.dumps(req_data))
logger.debug('send ping message')
threading.Thread(target=ping).start()
def on_open(self, *args, **kwargs):
req_data = {
"action": "listening",
"token": self.token,
"data": self.topics
}
self.websocket.send(json.dumps(req_data))
logger.info('application starting...')
self.threading_keepalive_ping()
def on_message(self, *args, **kwargs):
if isinstance(args[0], websocket.WebSocketApp):
message = args[1]
else:
message = args[0]
logger.debug(message)
if isinstance(message, (str, bytes, bytearray)):
resp_data = json.loads(message)
if not resp_data.get('status'):
logger.error(resp_data.get('message'))
return
else:
logger.info(message)
return
data = resp_data.get('data', {})
if not data or not isinstance(data, dict):
return
topic = data.get('topic')
code = data.get('code')
record = data.get('record')
if not topic or not code or not record:
logger.warning('get invalid response-data(%s)' % resp_data)
return
self._do_callback_function(topic, code, record)
def on_error(self, error, *args, **kwargs):
if self.debug:
logging.error(str(error), exc_info=True)
def on_close(self, *args, **kwargs):
logger.error('close')
_type, _value, _traceback = sys.exc_info()
if _type in [WebSocketConnectionClosedException, ConnectionRefusedError]:
time.sleep(1)
self.run()
def run(self):
if not self.topics:
logger.error('no data.')
return
self.websocket = websocket.WebSocketApp(
self.url,
on_message=self.on_message,
on_error=self.on_error,
on_close=self.on_close,
on_open=self.on_open
)
if self.url.startswith('wss:'):
self.websocket.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
else:
self.websocket.run_forever()
def register(self, topic, codes):
codes = set(codes)
def decorator(func):
func.codes = set()
func.pcodes = set()
for code in codes:
if not re.match(r'[\d\w\.\*]+', code):
logger.error('error code')
exit(1)
if '*' in code:
for code1 in func.pcodes:
if fnmatch.fnmatch(code1, code) or fnmatch.fnmatch(code, code1):
logger.error('duplicate code')
exit(1)
for code1 in func.codes:
if fnmatch.fnmatch(code1, code):
logger.error('duplicate code')
exit(1)
func.pcodes.add(code)
else:
for code1 in func.pcodes:
if fnmatch.fnmatch(code, code1):
logger.error('duplicate code')
exit(1)
func.codes.add(code)
self.topics[topic] += codes
self.callback_funcs[topic].append(func)
@wraps(func)
def inner(*args, **kwargs):
""" should receive a message-value parameter """
return func(*args, **kwargs)
return inner
return decorator
def _do_callback_function(self, topic, code, value):
for func in self.callback_funcs[topic]:
checked = False
if code in func.codes:
checked = True
else:
for pcode in func.pcodes:
if fnmatch.fnmatch(code, pcode):
checked = True
break
if not checked:
continue
if self.callback_mode == 'single-thread':
func(value)
elif self.callback_mode == 'multi-thread':
thread.start_new_thread(func, (value,))
elif self.callback_mode == 'multi-process':
p = Process(target=func, args=(value,))
p.start()
def test_min():
app = TsSubscribe("xxx")
# code 可以包含 * (通配符)
@app.register(topic='HQ_STK_MIN', codes=["1MIN:*.SH"])
def print_message(record):
"""
订阅主题topic并指定codes列表在接收到topic的推送消息时符合code条件就会执行回调
:param record:
:return:
"""
print('用户定义业务代码输出 print_message(%s)' % str(record))
app.run()
def test_tick():
app = TsSubscribe(token='xxx')
# code 可以包含 * (通配符)
@app.register(topic='HQ_STK_TICK', codes=["*.SH"])
def print_message(record):
"""
订阅主题topic并指定codes列表在接收到topic的推送消息时符合code条件就会执行回调
:param record:
:return:
"""
print('用户定义业务代码输出 print_message(%s)' % str(record))
app.run()
if __name__ == '__main__':
test_min()