119 lines
3.4 KiB
Python
119 lines
3.4 KiB
Python
import json
|
|
|
|
import requests
|
|
try:
|
|
import sseclient
|
|
except:
|
|
raise ImportError("sseclient not found, please install it using 'pip install sseclient-py'.")
|
|
|
|
|
|
from tushare import get_token
|
|
|
|
BASE_URL = "http://api.waditu.com/dataapi"
|
|
# BASE_URL = "http://10.255.255.205:8083/dataapi"
|
|
API_KEY_PREFIX = "tsgpt-"
|
|
|
|
|
|
class GPTClient:
|
|
def __init__(self, token=None, timetout=120):
|
|
if not token:
|
|
token = get_token()
|
|
self.token = token
|
|
self.timeout = timetout
|
|
|
|
def _request(self, model, messages, temperature=None, max_tokens=None, stream=True, pretty=False) -> requests.Response:
|
|
"""
|
|
model string 模型名称, doubao-pro-128k
|
|
messages list 消息列表
|
|
[
|
|
{
|
|
"role": "user",
|
|
"content": "Hello World"
|
|
}
|
|
]
|
|
pretty bool 是否只返回回答内容文本
|
|
"""
|
|
resp = requests.post(
|
|
f'{BASE_URL}/llm/{model}',
|
|
json={"params": {
|
|
"stream": stream,
|
|
"messages": messages,
|
|
"temperature": temperature,
|
|
"max_tokens": max_tokens
|
|
}},
|
|
headers={"Authorization": f"tstoken-{self.token}"},
|
|
timeout=self.timeout, stream=stream
|
|
)
|
|
if resp.status_code != 200:
|
|
raise Exception(f"请求出现错误,{resp.content}")
|
|
return resp
|
|
|
|
def gpt_query(self, model, messages, temperature=None, max_tokens=None, pretty=False):
|
|
resp = self._request(model, messages, temperature, max_tokens, False, pretty)
|
|
resp_data = resp.json()
|
|
if resp_data.get('code') not in (0, None):
|
|
raise Exception(resp_data.get('msg') or resp_data)
|
|
if pretty:
|
|
return resp_data['choices'][0]["message"]["content"]
|
|
else:
|
|
return resp_data
|
|
|
|
def gpt_stream(self, model, messages, temperature=None, max_tokens=None, pretty=False):
|
|
resp = self._request(model, messages, temperature, max_tokens, True, pretty)
|
|
for e in sseclient.SSEClient(resp).events():
|
|
if '[DONE]' in e.data.upper():
|
|
break
|
|
e_data = json.loads(e.data)
|
|
if pretty:
|
|
yield e_data["choices"][0]["delta"]["content"]
|
|
else:
|
|
yield e_data
|
|
|
|
def gpt(self, model, query) -> str:
|
|
messages = [{
|
|
"role": "user",
|
|
"content": query
|
|
}]
|
|
return self.gpt_query(model, messages, pretty=True)
|
|
|
|
|
|
def test_gpt_query():
|
|
c = GPTClient()
|
|
dd = c.gpt_query("doubao-pro-128k", [{
|
|
"role": "user",
|
|
"content": "你好"
|
|
}])
|
|
print(dd)
|
|
dd = c.gpt_query("doubao-pro-128k", [{
|
|
"role": "user",
|
|
"content": "你好"
|
|
}], pretty=True)
|
|
print(dd)
|
|
|
|
|
|
def test_gpt_stream():
|
|
c = GPTClient()
|
|
dd = c.gpt_stream("doubao-pro-128k", [
|
|
{
|
|
"role": "user",
|
|
"content": "你好"
|
|
}
|
|
])
|
|
for d in dd:
|
|
print(d)
|
|
|
|
dd = c.gpt_stream("doubao-pro-128k", [
|
|
{
|
|
"role": "user",
|
|
"content": "你好"
|
|
}
|
|
], pretty=True)
|
|
for d in dd:
|
|
print(d)
|
|
|
|
|
|
def test_gpt():
|
|
c = GPTClient()
|
|
dd = c.gpt("doubao-pro-128k", "你好")
|
|
print(dd)
|