Spaces:
Paused
Paused
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
import argparse | |
import json | |
import os | |
import random | |
import sys | |
import time | |
from typing import List | |
pwd = os.path.abspath(os.path.dirname(__file__)) | |
sys.path.append(os.path.join(pwd, '../../')) | |
import requests | |
from project_settings import environment | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--api_key", | |
default=environment.get("agent_x_api_key", default=None), | |
type=str | |
) | |
args = parser.parse_args() | |
return args | |
class AgentX(object): | |
def __init__(self, | |
api_key: str, | |
agent_name: str = "NXLink智能助手", | |
url_host: str = "https://api.agentx.so" | |
): | |
self.api_key = api_key | |
self.agent_name = agent_name | |
self.url_host = url_host | |
self.agent_id = self.get_agent_id() | |
def __str__(self): | |
result = "<{}; agent_name: {}; agent_id: {}; api_key: {}>".format( | |
self.__class__.__name__, self.agent_name, self.agent_id, self.api_key) | |
return result | |
def get_agent_id(self): | |
url = "{}/api/v1/access/agents".format(self.url_host) | |
headers = { | |
"accept": "*/*", | |
"x-api-key": self.api_key | |
} | |
resp = requests.request( | |
"GET", | |
url=url, | |
headers=headers, | |
) | |
if resp.status_code != 200: | |
print(resp.status_code) | |
print(resp.text) | |
exit(0) | |
js = resp.json() | |
result = None | |
for e in js: | |
if e["name"] == self.agent_name: | |
result = e["_id"] | |
if result is None: | |
raise AssertionError("agent not found") | |
return result | |
def get_agent_config(self): | |
url = "{}/api/v1/access/agents/{}".format(self.url_host, self.agent_id) | |
headers = { | |
"accept": "*/*", | |
"x-api-key": self.api_key | |
} | |
resp = requests.request( | |
"GET", | |
url=url, | |
headers=headers, | |
) | |
js = resp.json() | |
return js | |
def get_conversation_list(self): | |
url = "{}/api/v1/access/agents/{}/conversations".format(self.url_host, self.agent_id) | |
headers = { | |
"accept": "*/*", | |
"x-api-key": self.api_key | |
} | |
resp = requests.request( | |
"GET", | |
url=url, | |
headers=headers, | |
) | |
js = resp.json() | |
return js | |
def post_message(self, message: str, conversation_id: str, context: int = 0): | |
url = "{}/api/v1/access/conversations/{}/message".format(self.url_host, conversation_id) | |
headers = { | |
"accept": "*/*", | |
"Content-type": "application/json", | |
"x-api-key": self.api_key | |
} | |
data = { | |
"message": message, | |
"context": context, | |
} | |
resp = requests.request( | |
"POST", | |
url=url, | |
headers=headers, | |
data=json.dumps(data) | |
) | |
if resp.status_code != 200: | |
print(resp.status_code) | |
print(resp.text) | |
exit(0) | |
js = resp.json() | |
return js | |
def post_message_by_sse(self, message: str, conversation_id: str, context: int = 0): | |
url = "{}/api/v1/access/conversations/{}/messagesse".format(self.url_host, conversation_id) | |
headers = { | |
"accept": "*/*", | |
"Content-type": "application/json", | |
"x-api-key": self.api_key | |
} | |
data = { | |
"message": message, | |
"context": context, | |
} | |
resp = requests.request( | |
"POST", | |
url=url, | |
headers=headers, | |
data=json.dumps(data), | |
stream=True | |
) | |
# print(resp.headers) | |
trace_id = resp.headers["x-trace-id"] | |
if resp.status_code == 200: | |
def generator(): | |
result = "" | |
buf = b"" | |
for chunk in resp.iter_content(): | |
buf += chunk | |
try: | |
chunk = buf.decode("utf-8") | |
except UnicodeDecodeError: | |
continue | |
result += chunk | |
buf = b"" | |
yield chunk | |
return generator(), trace_id | |
else: | |
print(resp.status_code) | |
print(resp.headers["Content-Type"]) | |
raise AssertionError | |
def get_trace_by_message_id(self, message_id: str): | |
url = "{}/api/v1/access/messages/{}/trace".format(self.url_host, message_id) | |
headers = { | |
"accept": "*/*", | |
"x-api-key": self.api_key | |
} | |
resp = requests.request( | |
"GET", | |
url=url, | |
headers=headers, | |
) | |
js = resp.json() | |
return js | |
def get_trace_by_trace_id(self, trace_id: str): | |
url = "{}/api/v1/access/traces/{}".format(self.url_host, trace_id) | |
headers = { | |
"accept": "*/*", | |
"x-api-key": self.api_key | |
} | |
resp = requests.request( | |
"GET", | |
url=url, | |
headers=headers, | |
) | |
js = resp.json() | |
return js | |
def post_new_conversation_id(self): | |
url = "{}/api/v1/access/agents/{}/conversations/new".format(self.url_host, self.agent_id) | |
headers = { | |
"accept": "*/*", | |
"x-api-key": self.api_key | |
} | |
resp = requests.request( | |
"POST", | |
url=url, | |
headers=headers, | |
) | |
js = resp.json() | |
conversation_id = js["_id"] | |
return conversation_id | |
def delete_conversation(self, conversation_id: str): | |
url = "{}/api/v1/access/conversations/{}".format(self.url_host, conversation_id) | |
headers = { | |
"accept": "*/*", | |
"Content-type": "application/json", | |
"x-api-key": self.api_key | |
} | |
resp = requests.request( | |
"DELETE", | |
url=url, | |
headers=headers, | |
) | |
js = resp.json() | |
return js | |
def update_context(self, messages: List[dict], conversation_id: str): | |
url = "{}/api/v1/access/conversations/{}/update-context".format(self.url_host, conversation_id) | |
headers = { | |
"accept": "*/*", | |
"Content-type": "application/json", | |
"x-api-key": self.api_key | |
} | |
data = { | |
"messages": messages, | |
} | |
resp = requests.request( | |
"PUT", | |
url=url, | |
headers=headers, | |
data=json.dumps(data), | |
) | |
js = resp.json() | |
return js | |
def question_answer(self, question: str, conversation_id: str = None, context: List[dict] = None, streaming: bool = False): | |
if conversation_id is None: | |
conversation_id = self.post_new_conversation_id() | |
if context is not None: | |
self.update_context(context, conversation_id) | |
result = { | |
"answer": None, | |
"reference": None | |
} | |
try: | |
if streaming: | |
resp_stream, trace_id = self.post_message_by_sse(question, conversation_id, | |
context=0 if context is None else 1) | |
answer = "" | |
for chunk in resp_stream: | |
print(chunk, end="") | |
answer += chunk | |
print("\n") | |
result["answer"] = answer | |
# print(answer) | |
# exit(0) | |
# [{"title": "", "source": ""}, ...] | |
trace = self.get_trace_by_trace_id(trace_id) | |
if trace == "No trace": | |
reference = "No trace" | |
else: | |
reference = list() | |
for t in trace: | |
reference.append((t["title"], t["source"])) | |
result["reference"] = reference | |
else: | |
js = self.post_message(question, conversation_id, | |
context=0 if context is None else 1) | |
answer = js["text"] | |
result["answer"] = answer | |
message_id = js["_id"] | |
trace = self.get_trace_by_message_id(message_id) | |
# print(trace) | |
if trace == "No trace": | |
reference = "No trace" | |
else: | |
reference = list() | |
for t in trace: | |
reference.append((t["title"], t["source"])) | |
result["reference"] = reference | |
finally: | |
self.delete_conversation(conversation_id) | |
return result | |
def main(): | |
args = get_args() | |
agent = AgentX( | |
api_key=args.api_key, | |
agent_name="Yutong Bus", | |
) | |
print(agent) | |
context = [ | |
{ | |
"user": "你好" | |
}, | |
{ | |
"assistant": "你好,我们是宇通客车公司,有什么可以帮到您的吗?" | |
}, | |
{ | |
"user": "需要一辆55座客车。" | |
}, | |
{ | |
"assistant": "Which country will the bus be used in?" | |
}, | |
{ | |
"user": "你可以说中文吗。" | |
}, | |
{ | |
"assistant": "可以的,请问您需要在哪个国家使用客车?" | |
}, | |
] | |
question = "你好" | |
time_begin = time.time() | |
response = agent.question_answer(question, context=context, streaming=True) | |
time_cost = time.time() - time_begin | |
print(response) | |
print("time cost: {}".format(time_cost)) | |
return | |
if __name__ == '__main__': | |
main() | |