3v324v23's picture
Зафиксирована рабочая версия TEN-Agent для HuggingFace Space
87337b1
#
#
# Agora Real Time Engagement
# Created by Wei Hu in 2024-08.
# Copyright (c) 2024 Agora IO. All rights reserved.
#
#
import asyncio
import json
import time
import traceback
from typing import Iterable
import uuid
from ten.async_ten_env import AsyncTenEnv
from ten_ai_base.const import CMD_PROPERTY_RESULT, CMD_TOOL_CALL, CONTENT_DATA_OUT_NAME, DATA_OUT_PROPERTY_END_OF_SEGMENT, DATA_OUT_PROPERTY_TEXT
from ten_ai_base.helper import (
AsyncEventEmitter,
get_property_bool,
get_property_string,
)
from ten_ai_base.types import (
LLMCallCompletionArgs,
LLMChatCompletionContentPartParam,
LLMChatCompletionUserMessageParam,
LLMChatCompletionMessageParam,
LLMDataCompletionArgs,
LLMToolMetadata,
LLMToolResult,
)
from ten_ai_base.llm import AsyncLLMBaseExtension
from .helper import parse_sentences
from .openai import OpenAIChatGPT, OpenAIChatGPTConfig
from ten import (
Cmd,
StatusCode,
CmdResult,
Data,
)
CMD_IN_FLUSH = "flush"
CMD_IN_ON_USER_JOINED = "on_user_joined"
CMD_IN_ON_USER_LEFT = "on_user_left"
CMD_OUT_FLUSH = "flush"
DATA_IN_TEXT_DATA_PROPERTY_TEXT = "text"
DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL = "is_final"
DATA_OUT_TEXT_DATA_PROPERTY_TEXT = "text"
DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT = "end_of_segment"
class OpenAIChatGPTExtension(AsyncLLMBaseExtension):
def __init__(self, name: str):
super().__init__(name)
self.memory = []
self.memory_cache = []
self.config = None
self.client = None
self.sentence_fragment = ""
self.tool_task_future: asyncio.Future | None = None
self.users_count = 0
self.last_reasoning_ts = 0
async def on_init(self, async_ten_env: AsyncTenEnv) -> None:
async_ten_env.log_info("on_init")
await super().on_init(async_ten_env)
async def on_start(self, async_ten_env: AsyncTenEnv) -> None:
async_ten_env.log_info("on_start")
await super().on_start(async_ten_env)
self.config = await OpenAIChatGPTConfig.create_async(ten_env=async_ten_env)
# Mandatory properties
if not self.config.api_key:
async_ten_env.log_info("API key is missing, exiting on_start")
return
# Create instance
try:
self.client = OpenAIChatGPT(async_ten_env, self.config)
async_ten_env.log_info(
f"initialized with max_tokens: {self.config.max_tokens}, model: {self.config.model}, vendor: {self.config.vendor}"
)
except Exception as err:
async_ten_env.log_info(f"Failed to initialize OpenAIChatGPT: {err}")
async def on_stop(self, async_ten_env: AsyncTenEnv) -> None:
async_ten_env.log_info("on_stop")
await super().on_stop(async_ten_env)
async def on_deinit(self, async_ten_env: AsyncTenEnv) -> None:
async_ten_env.log_info("on_deinit")
await super().on_deinit(async_ten_env)
async def on_cmd(self, async_ten_env: AsyncTenEnv, cmd: Cmd) -> None:
cmd_name = cmd.get_name()
async_ten_env.log_info(f"on_cmd name: {cmd_name}")
if cmd_name == CMD_IN_FLUSH:
await self.flush_input_items(async_ten_env)
await async_ten_env.send_cmd(Cmd.create(CMD_OUT_FLUSH))
async_ten_env.log_info("on_cmd sent flush")
status_code, detail = StatusCode.OK, "success"
cmd_result = CmdResult.create(status_code)
cmd_result.set_property_string("detail", detail)
await async_ten_env.return_result(cmd_result, cmd)
elif cmd_name == CMD_IN_ON_USER_JOINED:
self.users_count += 1
# Send greeting when first user joined
if self.config.greeting and self.users_count == 1:
self.send_text_output(async_ten_env, self.config.greeting, True)
status_code, detail = StatusCode.OK, "success"
cmd_result = CmdResult.create(status_code)
cmd_result.set_property_string("detail", detail)
await async_ten_env.return_result(cmd_result, cmd)
elif cmd_name == CMD_IN_ON_USER_LEFT:
self.users_count -= 1
status_code, detail = StatusCode.OK, "success"
cmd_result = CmdResult.create(status_code)
cmd_result.set_property_string("detail", detail)
await async_ten_env.return_result(cmd_result, cmd)
else:
await super().on_cmd(async_ten_env, cmd)
async def on_data(self, async_ten_env: AsyncTenEnv, data: Data) -> None:
data_name = data.get_name()
async_ten_env.log_debug("on_data name {}".format(data_name))
# Get the necessary properties
is_final = get_property_bool(data, "is_final")
input_text = get_property_string(data, "text")
if not is_final:
async_ten_env.log_debug("ignore non-final input")
return
if not input_text:
async_ten_env.log_warn("ignore empty text")
return
async_ten_env.log_info(f"OnData input text: [{input_text}]")
# Start an asynchronous task for handling chat completion
message = LLMChatCompletionUserMessageParam(role="user", content=input_text)
await self.queue_input_item(False, messages=[message])
async def on_tools_update(
self, async_ten_env: AsyncTenEnv, tool: LLMToolMetadata
) -> None:
return await super().on_tools_update(async_ten_env, tool)
async def on_call_chat_completion(
self, async_ten_env: AsyncTenEnv, **kargs: LLMCallCompletionArgs
) -> any:
kmessages: LLMChatCompletionUserMessageParam = kargs.get("messages", [])
async_ten_env.log_info(f"on_call_chat_completion: {kmessages}")
response = await self.client.get_chat_completions(kmessages, None)
return response.to_json()
async def on_data_chat_completion(
self, async_ten_env: AsyncTenEnv, **kargs: LLMDataCompletionArgs
) -> None:
"""Run the chatflow asynchronously."""
kmessages: Iterable[LLMChatCompletionUserMessageParam] = kargs.get(
"messages", []
)
if len(kmessages) == 0:
async_ten_env.log_error("No message in data")
return
messages = []
for message in kmessages:
messages = messages + [self.message_to_dict(message)]
self.memory_cache = []
memory = self.memory
try:
async_ten_env.log_info(f"for input text: [{messages}] memory: {memory}")
tools = None
no_tool = kargs.get("no_tool", False)
for message in messages:
if (
not isinstance(message.get("content"), str)
and message.get("role") == "user"
):
non_artifact_content = [
item
for item in message.get("content", [])
if item.get("type") == "text"
]
non_artifact_message = {
"role": message.get("role"),
"content": non_artifact_content,
}
self.memory_cache = self.memory_cache + [
non_artifact_message,
]
else:
self.memory_cache = self.memory_cache + [
message,
]
self.memory_cache = self.memory_cache + [{"role": "assistant", "content": ""}]
tools = None
if not no_tool and len(self.available_tools) > 0:
tools = []
for tool in self.available_tools:
tools.append(self._convert_tools_to_dict(tool))
async_ten_env.log_info(f"tool: {tool}")
self.sentence_fragment = ""
# Create an asyncio.Event to signal when content is finished
content_finished_event = asyncio.Event()
# Create a future to track the single tool call task
self.tool_task_future = None
message_id = str(uuid.uuid4())[:8]
self.last_reasoning_ts = int(time.time() * 1000)
# Create an async listener to handle tool calls and content updates
async def handle_tool_call(tool_call):
self.tool_task_future = asyncio.get_event_loop().create_future()
async_ten_env.log_info(f"tool_call: {tool_call}")
for tool in self.available_tools:
if tool_call["function"]["name"] == tool.name:
cmd: Cmd = Cmd.create(CMD_TOOL_CALL)
cmd.set_property_string("name", tool.name)
cmd.set_property_from_json(
"arguments", tool_call["function"]["arguments"]
)
# cmd.set_property_from_json("arguments", json.dumps([]))
# Send the command and handle the result through the future
[result, _] = await async_ten_env.send_cmd(cmd)
if result.get_status_code() == StatusCode.OK:
tool_result: LLMToolResult = json.loads(
result.get_property_to_json(CMD_PROPERTY_RESULT)
)
async_ten_env.log_info(f"tool_result: {tool_result}")
if tool_result["type"] == "llmresult":
result_content = tool_result["content"]
if isinstance(result_content, str):
tool_message = {
"role": "assistant",
"tool_calls": [tool_call],
}
new_message = {
"role": "tool",
"content": result_content,
"tool_call_id": tool_call["id"],
}
await self.queue_input_item(
True, messages=[tool_message, new_message], no_tool=True
)
else:
async_ten_env.log_error(
f"Unknown tool result content: {result_content}"
)
elif tool_result["type"] == "requery":
# self.memory_cache = []
self.memory_cache.pop()
result_content = tool_result["content"]
nonlocal message
new_message = {
"role": "user",
"content": self._convert_to_content_parts(
message["content"]
),
}
new_message["content"] = new_message[
"content"
] + self._convert_to_content_parts(result_content)
await self.queue_input_item(
True, messages=[new_message], no_tool=True
)
else:
async_ten_env.log_error(
f"Unknown tool result type: {tool_result}"
)
else:
async_ten_env.log_error("Tool call failed")
self.tool_task_future.set_result(None)
async def handle_content_update(content: str):
# Append the content to the last assistant message
for item in reversed(self.memory_cache):
if item.get("role") == "assistant":
item["content"] = item["content"] + content
break
sentences, self.sentence_fragment = parse_sentences(
self.sentence_fragment, content
)
for s in sentences:
self.send_text_output(async_ten_env, s, False)
async def handle_reasoning_update(think: str):
ts = int(time.time() * 1000)
if ts - self.last_reasoning_ts >= 200:
self.last_reasoning_ts = ts
self.send_reasoning_text_output(async_ten_env, message_id, think, False)
async def handle_reasoning_update_finish(think: str):
self.last_reasoning_ts = int(time.time() * 1000)
self.send_reasoning_text_output(async_ten_env, message_id, think, True)
async def handle_content_finished(_: str):
# Wait for the single tool task to complete (if any)
if self.tool_task_future:
await self.tool_task_future
content_finished_event.set()
listener = AsyncEventEmitter()
listener.on("tool_call", handle_tool_call)
listener.on("content_update", handle_content_update)
listener.on("reasoning_update", handle_reasoning_update)
listener.on("reasoning_update_finish", handle_reasoning_update_finish)
listener.on("content_finished", handle_content_finished)
# Make an async API call to get chat completions
await self.client.get_chat_completions_stream(
memory + messages, tools, listener
)
# Wait for the content to be finished
await content_finished_event.wait()
async_ten_env.log_info(
f"Chat completion finished for input text: {messages}"
)
except asyncio.CancelledError:
async_ten_env.log_info(f"Task cancelled: {messages}")
except Exception:
async_ten_env.log_error(
f"Error in chat_completion: {traceback.format_exc()} for input text: {messages}"
)
finally:
self.send_text_output(async_ten_env, "", True)
# always append the memory
for m in self.memory_cache:
self._append_memory(m)
def _convert_to_content_parts(
self, content: Iterable[LLMChatCompletionContentPartParam]
):
content_parts = []
if isinstance(content, str):
content_parts.append({"type": "text", "text": content})
else:
for part in content:
content_parts.append(part)
return content_parts
def _convert_tools_to_dict(self, tool: LLMToolMetadata):
json_dict = {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": {
"type": "object",
"properties": {},
"required": [],
"additionalProperties": False,
},
},
"strict": True,
}
for param in tool.parameters:
json_dict["function"]["parameters"]["properties"][param.name] = {
"type": param.type,
"description": param.description,
}
if param.required:
json_dict["function"]["parameters"]["required"].append(param.name)
return json_dict
def message_to_dict(self, message: LLMChatCompletionMessageParam):
if message.get("content") is not None:
if isinstance(message["content"], str):
message["content"] = str(message["content"])
else:
message["content"] = list(message["content"])
return message
def _append_memory(self, message: str):
if len(self.memory) > self.config.max_memory_length:
removed_item = self.memory.pop(0)
# Remove tool calls from memory
if removed_item.get("tool_calls") and self.memory[0].get("role") == "tool":
self.memory.pop(0)
self.memory.append(message)
def send_reasoning_text_output(
self, async_ten_env: AsyncTenEnv, msg_id:str, sentence: str, end_of_segment: bool
):
try:
output_data = Data.create(CONTENT_DATA_OUT_NAME)
output_data.set_property_string(DATA_OUT_PROPERTY_TEXT, json.dumps({
"id":msg_id,
"data": {
"text": sentence
},
"type": "reasoning"
}))
output_data.set_property_bool(
DATA_OUT_PROPERTY_END_OF_SEGMENT, end_of_segment
)
asyncio.create_task(async_ten_env.send_data(output_data))
# async_ten_env.log_info(
# f"{'end of segment ' if end_of_segment else ''}sent sentence [{sentence}]"
# )
except Exception:
async_ten_env.log_warn(
f"send sentence [{sentence}] failed, err: {traceback.format_exc()}")