dingtalk / toolbox /agent_x /question_answer.py
qgyd2021's picture
update
d33b446
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import json
import os
import random
import sys
import time
from typing import List
pwd = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.join(pwd, '../../'))
import requests
from project_settings import environment
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--api_key",
default=environment.get("agent_x_api_key", default=None),
type=str
)
args = parser.parse_args()
return args
class AgentX(object):
def __init__(self,
api_key: str,
agent_name: str = "NXLink智能助手",
url_host: str = "https://api.agentx.so"
):
self.api_key = api_key
self.agent_name = agent_name
self.url_host = url_host
self.agent_id = self.get_agent_id()
def __str__(self):
result = "<{}; agent_name: {}; agent_id: {}; api_key: {}>".format(
self.__class__.__name__, self.agent_name, self.agent_id, self.api_key)
return result
def get_agent_id(self):
url = "{}/api/v1/access/agents".format(self.url_host)
headers = {
"accept": "*/*",
"x-api-key": self.api_key
}
resp = requests.request(
"GET",
url=url,
headers=headers,
)
if resp.status_code != 200:
print(resp.status_code)
print(resp.text)
exit(0)
js = resp.json()
result = None
for e in js:
if e["name"] == self.agent_name:
result = e["_id"]
if result is None:
raise AssertionError("agent not found")
return result
def get_agent_config(self):
url = "{}/api/v1/access/agents/{}".format(self.url_host, self.agent_id)
headers = {
"accept": "*/*",
"x-api-key": self.api_key
}
resp = requests.request(
"GET",
url=url,
headers=headers,
)
js = resp.json()
return js
def get_conversation_list(self):
url = "{}/api/v1/access/agents/{}/conversations".format(self.url_host, self.agent_id)
headers = {
"accept": "*/*",
"x-api-key": self.api_key
}
resp = requests.request(
"GET",
url=url,
headers=headers,
)
js = resp.json()
return js
def post_message(self, message: str, conversation_id: str, context: int = 0):
url = "{}/api/v1/access/conversations/{}/message".format(self.url_host, conversation_id)
headers = {
"accept": "*/*",
"Content-type": "application/json",
"x-api-key": self.api_key
}
data = {
"message": message,
"context": context,
}
resp = requests.request(
"POST",
url=url,
headers=headers,
data=json.dumps(data)
)
if resp.status_code != 200:
print(resp.status_code)
print(resp.text)
exit(0)
js = resp.json()
return js
def post_message_by_sse(self, message: str, conversation_id: str, context: int = 0):
url = "{}/api/v1/access/conversations/{}/messagesse".format(self.url_host, conversation_id)
headers = {
"accept": "*/*",
"Content-type": "application/json",
"x-api-key": self.api_key
}
data = {
"message": message,
"context": context,
}
resp = requests.request(
"POST",
url=url,
headers=headers,
data=json.dumps(data),
stream=True
)
# print(resp.headers)
trace_id = resp.headers["x-trace-id"]
if resp.status_code == 200:
def generator():
result = ""
buf = b""
for chunk in resp.iter_content():
buf += chunk
try:
chunk = buf.decode("utf-8")
except UnicodeDecodeError:
continue
result += chunk
buf = b""
yield chunk
return generator(), trace_id
else:
print(resp.status_code)
print(resp.headers["Content-Type"])
raise AssertionError
def get_trace_by_message_id(self, message_id: str):
url = "{}/api/v1/access/messages/{}/trace".format(self.url_host, message_id)
headers = {
"accept": "*/*",
"x-api-key": self.api_key
}
resp = requests.request(
"GET",
url=url,
headers=headers,
)
js = resp.json()
return js
def get_trace_by_trace_id(self, trace_id: str):
url = "{}/api/v1/access/traces/{}".format(self.url_host, trace_id)
headers = {
"accept": "*/*",
"x-api-key": self.api_key
}
resp = requests.request(
"GET",
url=url,
headers=headers,
)
js = resp.json()
return js
def post_new_conversation_id(self):
url = "{}/api/v1/access/agents/{}/conversations/new".format(self.url_host, self.agent_id)
headers = {
"accept": "*/*",
"x-api-key": self.api_key
}
resp = requests.request(
"POST",
url=url,
headers=headers,
)
js = resp.json()
conversation_id = js["_id"]
return conversation_id
def delete_conversation(self, conversation_id: str):
url = "{}/api/v1/access/conversations/{}".format(self.url_host, conversation_id)
headers = {
"accept": "*/*",
"Content-type": "application/json",
"x-api-key": self.api_key
}
resp = requests.request(
"DELETE",
url=url,
headers=headers,
)
js = resp.json()
return js
def update_context(self, messages: List[dict], conversation_id: str):
url = "{}/api/v1/access/conversations/{}/update-context".format(self.url_host, conversation_id)
headers = {
"accept": "*/*",
"Content-type": "application/json",
"x-api-key": self.api_key
}
data = {
"messages": messages,
}
resp = requests.request(
"PUT",
url=url,
headers=headers,
data=json.dumps(data),
)
js = resp.json()
return js
def question_answer(self, question: str, conversation_id: str = None, context: List[dict] = None, streaming: bool = False):
if conversation_id is None:
conversation_id = self.post_new_conversation_id()
if context is not None:
self.update_context(context, conversation_id)
result = {
"answer": None,
"reference": None
}
try:
if streaming:
resp_stream, trace_id = self.post_message_by_sse(question, conversation_id,
context=0 if context is None else 1)
answer = ""
for chunk in resp_stream:
print(chunk, end="")
answer += chunk
print("\n")
result["answer"] = answer
# print(answer)
# exit(0)
# [{"title": "", "source": ""}, ...]
trace = self.get_trace_by_trace_id(trace_id)
if trace == "No trace":
reference = "No trace"
else:
reference = list()
for t in trace:
reference.append((t["title"], t["source"]))
result["reference"] = reference
else:
js = self.post_message(question, conversation_id,
context=0 if context is None else 1)
answer = js["text"]
result["answer"] = answer
message_id = js["_id"]
trace = self.get_trace_by_message_id(message_id)
# print(trace)
if trace == "No trace":
reference = "No trace"
else:
reference = list()
for t in trace:
reference.append((t["title"], t["source"]))
result["reference"] = reference
finally:
self.delete_conversation(conversation_id)
return result
def main():
args = get_args()
agent = AgentX(
api_key=args.api_key,
agent_name="Yutong Bus",
)
print(agent)
context = [
{
"user": "你好"
},
{
"assistant": "你好,我们是宇通客车公司,有什么可以帮到您的吗?"
},
{
"user": "需要一辆55座客车。"
},
{
"assistant": "Which country will the bus be used in?"
},
{
"user": "你可以说中文吗。"
},
{
"assistant": "可以的,请问您需要在哪个国家使用客车?"
},
]
question = "你好"
time_begin = time.time()
response = agent.question_answer(question, context=context, streaming=True)
time_cost = time.time() - time_begin
print(response)
print("time cost: {}".format(time_cost))
return
if __name__ == '__main__':
main()