3v324v23's picture
Зафиксирована рабочая версия TEN-Agent для HuggingFace Space
87337b1
#
# This file is part of TEN Framework, an open source project.
# Licensed under the Apache License, Version 2.0.
# See the LICENSE file for more information.
#
import asyncio
import traceback
import aiohttp
import json
import time
import re
import numpy as np
from typing import List, Any, AsyncGenerator
from dataclasses import dataclass, field
from pydantic import BaseModel
from ten import (
AudioFrame,
VideoFrame,
AsyncTenEnv,
Cmd,
StatusCode,
CmdResult,
Data,
)
from ten_ai_base.config import BaseConfig
from ten_ai_base.chat_memory import (
ChatMemory,
EVENT_MEMORY_APPENDED,
)
from ten_ai_base.usage import (
LLMUsage,
LLMCompletionTokensDetails,
LLMPromptTokensDetails,
)
from ten_ai_base.types import (
LLMChatCompletionUserMessageParam,
LLMToolResult,
LLMCallCompletionArgs,
LLMDataCompletionArgs,
LLMToolMetadata,
)
from ten_ai_base.llm import (
AsyncLLMBaseExtension,
)
CMD_IN_FLUSH = "flush"
CMD_IN_ON_USER_JOINED = "on_user_joined"
CMD_IN_ON_USER_LEFT = "on_user_left"
CMD_OUT_FLUSH = "flush"
CMD_OUT_TOOL_CALL = "tool_call"
DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL = "is_final"
DATA_IN_TEXT_DATA_PROPERTY_TEXT = "text"
DATA_OUT_TEXT_DATA_PROPERTY_TEXT = "text"
DATA_OUT_TEXT_DATA_PROPERTY_END_OF_SEGMENT = "end_of_segment"
CMD_PROPERTY_RESULT = "tool_result"
def is_punctuation(char):
if char in [",", ",", ".", "。", "?", "?", "!", "!"]:
return True
return False
def parse_sentences(sentence_fragment, content):
sentences = []
current_sentence = sentence_fragment
for char in content:
current_sentence += char
if is_punctuation(char):
stripped_sentence = current_sentence
if any(c.isalnum() for c in stripped_sentence):
sentences.append(stripped_sentence)
current_sentence = ""
remain = current_sentence
return sentences, remain
class ToolCallFunction(BaseModel):
name: str | None = None
arguments: str | None = None
class ToolCall(BaseModel):
index: int
type: str = "function"
id: str | None = None
function: ToolCallFunction
class ToolCallResponse(BaseModel):
id: str
response: LLMToolResult
error: str | None = None
class Delta(BaseModel):
content: str | None = None
tool_calls: List[ToolCall] = None
class Choice(BaseModel):
delta: Delta = None
index: int
finish_reason: str | None
class ResponseChunk(BaseModel):
choices: List[Choice]
usage: LLMUsage | None = None
@dataclass
class GlueConfig(BaseConfig):
api_url: str = "http://localhost:8000/chat/completions"
token: str = ""
prompt: str = ""
max_history: int = 10
greeting: str = ""
failure_info: str = ""
modalities: List[str] = field(default_factory=lambda: ["text"])
rtm_enabled: bool = True
ssml_enabled: bool = False
context_enabled: bool = False
extra_context: dict = field(default_factory=dict)
enable_storage: bool = False
class AsyncGlueExtension(AsyncLLMBaseExtension):
def __init__(self, name):
super().__init__(name)
self.config: GlueConfig = None
self.ten_env: AsyncTenEnv = None
self.loop: asyncio.AbstractEventLoop = None
self.stopped: bool = False
self.memory: ChatMemory = None
self.total_usage: LLMUsage = LLMUsage()
self.users_count = 0
self.completion_times = []
self.connect_times = []
self.first_token_times = []
self.remote_stream_id: int = 999
async def on_init(self, ten_env: AsyncTenEnv) -> None:
await super().on_init(ten_env)
ten_env.log_debug("on_init")
async def on_start(self, ten_env: AsyncTenEnv) -> None:
await super().on_start(ten_env)
ten_env.log_debug("on_start")
self.loop = asyncio.get_event_loop()
self.config = await GlueConfig.create_async(ten_env=ten_env)
ten_env.log_info(f"config: {self.config}")
self.memory = ChatMemory(self.config.max_history)
if self.config.enable_storage:
[result, _] = await ten_env.send_cmd(Cmd.create("retrieve"))
if result.get_status_code() == StatusCode.OK:
try:
history = json.loads(result.get_property_string("response"))
for i in history:
self.memory.put(i)
ten_env.log_info(f"on retrieve context {history}")
except Exception:
ten_env.log_error("Failed to handle retrieve result {e}")
else:
ten_env.log_warn("Failed to retrieve content")
self.memory.on(EVENT_MEMORY_APPENDED, self._on_memory_appended)
self.ten_env = ten_env
async def on_stop(self, ten_env: AsyncTenEnv) -> None:
await super().on_stop(ten_env)
ten_env.log_debug("on_stop")
self.stopped = True
await self.queue.put(None)
async def on_deinit(self, ten_env: AsyncTenEnv) -> None:
await super().on_deinit(ten_env)
ten_env.log_debug("on_deinit")
async def on_cmd(self, ten_env: AsyncTenEnv, cmd: Cmd) -> None:
cmd_name = cmd.get_name()
ten_env.log_debug("on_cmd name {}".format(cmd_name))
status = StatusCode.OK
detail = "success"
if cmd_name == CMD_IN_FLUSH:
await self.flush_input_items(ten_env)
await ten_env.send_cmd(Cmd.create(CMD_OUT_FLUSH))
ten_env.log_info("on flush")
elif cmd_name == CMD_IN_ON_USER_JOINED:
self.users_count += 1
# Send greeting when first user joined
if self.config.greeting and self.users_count == 1:
self.send_text_output(ten_env, self.config.greeting, True)
elif cmd_name == CMD_IN_ON_USER_LEFT:
self.users_count -= 1
else:
await super().on_cmd(ten_env, cmd)
return
cmd_result = CmdResult.create(status)
cmd_result.set_property_string("detail", detail)
await ten_env.return_result(cmd_result, cmd)
async def on_call_chat_completion(
self, ten_env: AsyncTenEnv, **kargs: LLMCallCompletionArgs
) -> any:
raise RuntimeError("Not implemented")
async def on_data_chat_completion(
self, ten_env: AsyncTenEnv, **kargs: LLMDataCompletionArgs
) -> None:
input_messages: LLMChatCompletionUserMessageParam = kargs.get("messages", [])
messages = []
if self.config.prompt:
messages.append({"role": "system", "content": self.config.prompt})
history = self.memory.get()
while history:
if history[0].get("role") == "tool":
history = history[1:]
continue
if history[0].get("role") == "assistant" and history[0].get("tool_calls"):
history = history[1:]
continue
# Skip the first tool role
break
messages.extend(history)
if not input_messages:
ten_env.log_warn("No message in data")
else:
messages.extend(input_messages)
for i in input_messages:
self.memory.put(i)
def tool_dict(tool: LLMToolMetadata):
json_dict = {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": {
"type": "object",
"properties": {},
"required": [],
"additionalProperties": False,
},
},
"strict": True,
}
for param in tool.parameters:
json_dict["function"]["parameters"]["properties"][param.name] = {
"type": param.type,
"description": param.description,
}
if param.required:
json_dict["function"]["parameters"]["required"].append(param.name)
return json_dict
def trim_xml(input_string):
return re.sub(r"<[^>]+>", "", input_string).strip()
tools = []
for tool in self.available_tools:
tools.append(tool_dict(tool))
total_output = ""
sentence_fragment = ""
calls = {}
sentences = []
start_time = time.time()
first_token_time = None
response = self._stream_chat(messages=messages, tools=tools)
async for message in response:
self.ten_env.log_debug(f"content: {message}")
try:
c = ResponseChunk(**message)
if c.choices:
if c.choices[0].delta.content:
if first_token_time is None:
first_token_time = time.time()
self.first_token_times.append(first_token_time - start_time)
content = c.choices[0].delta.content
if self.config.ssml_enabled and content.startswith("<speak>"):
content = trim_xml(content)
total_output += content
sentences, sentence_fragment = parse_sentences(
sentence_fragment, content
)
for s in sentences:
await self._send_text(s)
if c.choices[0].delta.tool_calls:
self.ten_env.log_info(
f"tool_calls: {c.choices[0].delta.tool_calls}"
)
for call in c.choices[0].delta.tool_calls:
if call.index not in calls:
calls[call.index] = ToolCall(
id=call.id,
index=call.index,
function=ToolCallFunction(name="", arguments=""),
)
if call.function.name:
calls[call.index].function.name += call.function.name
if call.function.arguments:
calls[
call.index
].function.arguments += call.function.arguments
if c.usage:
self.ten_env.log_info(f"usage: {c.usage}")
await self._update_usage(c.usage)
except Exception as e:
self.ten_env.log_error(f"Failed to parse response: {message} {e}")
traceback.print_exc()
if sentence_fragment:
await self._send_text(sentence_fragment)
end_time = time.time()
self.completion_times.append(end_time - start_time)
if total_output:
self.memory.put({"role": "assistant", "content": total_output})
if calls:
tasks = []
tool_calls = []
for _, call in calls.items():
self.ten_env.log_info(f"tool call: {call}")
tool_calls.append(call.model_dump())
tasks.append(self.handle_tool_call(call))
self.memory.put({"role": "assistant", "tool_calls": tool_calls})
responses = await asyncio.gather(*tasks)
for r in responses:
content = r.response["content"]
self.ten_env.log_info(f"tool call response: {content} {r.id}")
self.memory.put(
{
"role": "tool",
"content": json.dumps(content),
"tool_call_id": r.id,
}
)
# request again to let the model know the tool call results
await self.on_data_chat_completion(ten_env)
self.ten_env.log_info(f"total_output: {total_output} {calls}")
async def on_tools_update(
self, ten_env: AsyncTenEnv, tool: LLMToolMetadata
) -> None:
# Implement the logic for tool updates
return await super().on_tools_update(ten_env, tool)
async def handle_tool_call(self, call: ToolCall) -> ToolCallResponse:
cmd: Cmd = Cmd.create(CMD_OUT_TOOL_CALL)
cmd.set_property_string("name", call.function.name)
cmd.set_property_from_json("arguments", call.function.arguments)
# Send the command and handle the result through the future
[result, _] = await self.ten_env.send_cmd(cmd)
if result.get_status_code() == StatusCode.OK:
tool_result: LLMToolResult = json.loads(
result.get_property_to_json(CMD_PROPERTY_RESULT)
)
self.ten_env.log_info(f"tool_result: {call} {tool_result}")
return ToolCallResponse(id=call.id, response=tool_result)
else:
self.ten_env.log_error("Tool call failed")
return ToolCallResponse(
id=call.id,
error=f"Tool call failed with status code {result.get_status_code()}",
)
async def on_data(self, ten_env: AsyncTenEnv, data: Data) -> None:
data_name = data.get_name()
ten_env.log_info(f"on_data name {data_name}")
is_final = False
input_text = ""
try:
is_final = data.get_property_bool(DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL)
except Exception as err:
ten_env.log_info(
f"GetProperty optional {DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL} failed, err: {err}"
)
try:
input_text = data.get_property_string(DATA_IN_TEXT_DATA_PROPERTY_TEXT)
except Exception as err:
ten_env.log_info(
f"GetProperty optional {DATA_IN_TEXT_DATA_PROPERTY_TEXT} failed, err: {err}"
)
if not is_final:
ten_env.log_info("ignore non-final input")
return
if not input_text:
ten_env.log_info("ignore empty text")
return
ten_env.log_info(f"OnData input text: [{input_text}]")
# Start an asynchronous task for handling chat completion
message = LLMChatCompletionUserMessageParam(role="user", content=input_text)
await self.queue_input_item(False, messages=[message])
async def on_audio_frame(
self, ten_env: AsyncTenEnv, audio_frame: AudioFrame
) -> None:
pass
async def on_video_frame(
self, ten_env: AsyncTenEnv, video_frame: VideoFrame
) -> None:
pass
async def _send_text(self, text: str) -> None:
data = Data.create("text_data")
data.set_property_string(DATA_OUT_TEXT_DATA_PROPERTY_TEXT, text)
data.set_property_bool(DATA_OUT_TEXT_DATA_PROPERTY_END_OF_SEGMENT, True)
asyncio.create_task(self.ten_env.send_data(data))
async def _stream_chat(
self, messages: List[Any], tools: List[Any]
) -> AsyncGenerator[dict, None]:
async with aiohttp.ClientSession() as session:
try:
payload = {
"messages": messages,
"tools": tools,
"tools_choice": "auto" if tools else "none",
"model": "gpt-3.5-turbo",
"stream": True,
"stream_options": {"include_usage": True},
"ssml_enabled": self.config.ssml_enabled,
}
if self.config.context_enabled:
payload["context"] = {**self.config.extra_context}
self.ten_env.log_info(f"payload before sending: {json.dumps(payload)}")
headers = {
"Authorization": f"Bearer {self.config.token}",
"Content-Type": "application/json",
}
start_time = time.time()
async with session.post(
self.config.api_url, json=payload, headers=headers
) as response:
if response.status != 200:
r = await response.json()
self.ten_env.log_error(
f"Received unexpected status {r} from the server."
)
if self.config.failure_info:
await self._send_text(self.config.failure_info)
return
end_time = time.time()
self.connect_times.append(end_time - start_time)
async for line in response.content:
if line:
l = line.decode("utf-8").strip()
if l.startswith("data:"):
content = l[5:].strip()
if content == "[DONE]":
break
self.ten_env.log_debug(f"content: {content}")
yield json.loads(content)
except Exception as e:
traceback.print_exc()
self.ten_env.log_error(f"Failed to handle {e}")
finally:
await session.close()
session = None
async def _update_usage(self, usage: LLMUsage) -> None:
if not self.config.rtm_enabled:
return
self.total_usage.completion_tokens += usage.completion_tokens
self.total_usage.prompt_tokens += usage.prompt_tokens
self.total_usage.total_tokens += usage.total_tokens
if self.total_usage.completion_tokens_details is None:
self.total_usage.completion_tokens_details = LLMCompletionTokensDetails()
if self.total_usage.prompt_tokens_details is None:
self.total_usage.prompt_tokens_details = LLMPromptTokensDetails()
if usage.completion_tokens_details:
self.total_usage.completion_tokens_details.accepted_prediction_tokens += (
usage.completion_tokens_details.accepted_prediction_tokens
)
self.total_usage.completion_tokens_details.audio_tokens += (
usage.completion_tokens_details.audio_tokens
)
self.total_usage.completion_tokens_details.reasoning_tokens += (
usage.completion_tokens_details.reasoning_tokens
)
self.total_usage.completion_tokens_details.rejected_prediction_tokens += (
usage.completion_tokens_details.rejected_prediction_tokens
)
if usage.prompt_tokens_details:
self.total_usage.prompt_tokens_details.audio_tokens += (
usage.prompt_tokens_details.audio_tokens
)
self.total_usage.prompt_tokens_details.cached_tokens += (
usage.prompt_tokens_details.cached_tokens
)
self.ten_env.log_info(f"total usage: {self.total_usage}")
data = Data.create("llm_stat")
data.set_property_from_json("usage", json.dumps(self.total_usage.model_dump()))
if self.connect_times and self.completion_times and self.first_token_times:
data.set_property_from_json(
"latency",
json.dumps(
{
"connection_latency_95": np.percentile(self.connect_times, 95),
"completion_latency_95": np.percentile(
self.completion_times, 95
),
"first_token_latency_95": np.percentile(
self.first_token_times, 95
),
"connection_latency_99": np.percentile(self.connect_times, 99),
"completion_latency_99": np.percentile(
self.completion_times, 99
),
"first_token_latency_99": np.percentile(
self.first_token_times, 99
),
}
),
)
asyncio.create_task(self.ten_env.send_data(data))
async def _on_memory_appended(self, message: dict) -> None:
self.ten_env.log_info(f"Memory appended: {message}")
if not self.config.enable_storage:
return
role = message.get("role")
stream_id = self.remote_stream_id if role == "user" else 0
try:
d = Data.create("append")
d.set_property_string("text", message.get("content"))
d.set_property_string("role", role)
d.set_property_int("stream_id", stream_id)
asyncio.create_task(self.ten_env.send_data(d))
except Exception as e:
self.ten_env.log_error(f"Error send append_context data {message} {e}")