|
|
|
|
|
|
|
|
|
|
|
import asyncio |
|
import traceback |
|
import aiohttp |
|
import json |
|
import time |
|
import re |
|
|
|
import numpy as np |
|
from typing import List, Any, AsyncGenerator |
|
from dataclasses import dataclass, field |
|
from pydantic import BaseModel |
|
|
|
from ten import ( |
|
AudioFrame, |
|
VideoFrame, |
|
AsyncTenEnv, |
|
Cmd, |
|
StatusCode, |
|
CmdResult, |
|
Data, |
|
) |
|
|
|
from ten_ai_base.config import BaseConfig |
|
from ten_ai_base.chat_memory import ( |
|
ChatMemory, |
|
EVENT_MEMORY_APPENDED, |
|
) |
|
from ten_ai_base.usage import ( |
|
LLMUsage, |
|
LLMCompletionTokensDetails, |
|
LLMPromptTokensDetails, |
|
) |
|
from ten_ai_base.types import ( |
|
LLMChatCompletionUserMessageParam, |
|
LLMToolResult, |
|
LLMCallCompletionArgs, |
|
LLMDataCompletionArgs, |
|
LLMToolMetadata, |
|
) |
|
from ten_ai_base.llm import ( |
|
AsyncLLMBaseExtension, |
|
) |
|
|
|
CMD_IN_FLUSH = "flush" |
|
CMD_IN_ON_USER_JOINED = "on_user_joined" |
|
CMD_IN_ON_USER_LEFT = "on_user_left" |
|
CMD_OUT_FLUSH = "flush" |
|
CMD_OUT_TOOL_CALL = "tool_call" |
|
|
|
DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL = "is_final" |
|
DATA_IN_TEXT_DATA_PROPERTY_TEXT = "text" |
|
|
|
DATA_OUT_TEXT_DATA_PROPERTY_TEXT = "text" |
|
DATA_OUT_TEXT_DATA_PROPERTY_END_OF_SEGMENT = "end_of_segment" |
|
|
|
CMD_PROPERTY_RESULT = "tool_result" |
|
|
|
|
|
def is_punctuation(char): |
|
if char in [",", ",", ".", "。", "?", "?", "!", "!"]: |
|
return True |
|
return False |
|
|
|
|
|
def parse_sentences(sentence_fragment, content): |
|
sentences = [] |
|
current_sentence = sentence_fragment |
|
for char in content: |
|
current_sentence += char |
|
if is_punctuation(char): |
|
stripped_sentence = current_sentence |
|
if any(c.isalnum() for c in stripped_sentence): |
|
sentences.append(stripped_sentence) |
|
current_sentence = "" |
|
|
|
remain = current_sentence |
|
return sentences, remain |
|
|
|
|
|
class ToolCallFunction(BaseModel): |
|
name: str | None = None |
|
arguments: str | None = None |
|
|
|
|
|
class ToolCall(BaseModel): |
|
index: int |
|
type: str = "function" |
|
id: str | None = None |
|
function: ToolCallFunction |
|
|
|
|
|
class ToolCallResponse(BaseModel): |
|
id: str |
|
response: LLMToolResult |
|
error: str | None = None |
|
|
|
|
|
class Delta(BaseModel): |
|
content: str | None = None |
|
tool_calls: List[ToolCall] = None |
|
|
|
|
|
class Choice(BaseModel): |
|
delta: Delta = None |
|
index: int |
|
finish_reason: str | None |
|
|
|
|
|
class ResponseChunk(BaseModel): |
|
choices: List[Choice] |
|
usage: LLMUsage | None = None |
|
|
|
|
|
@dataclass |
|
class GlueConfig(BaseConfig): |
|
api_url: str = "http://localhost:8000/chat/completions" |
|
token: str = "" |
|
prompt: str = "" |
|
max_history: int = 10 |
|
greeting: str = "" |
|
failure_info: str = "" |
|
modalities: List[str] = field(default_factory=lambda: ["text"]) |
|
rtm_enabled: bool = True |
|
ssml_enabled: bool = False |
|
context_enabled: bool = False |
|
extra_context: dict = field(default_factory=dict) |
|
enable_storage: bool = False |
|
|
|
|
|
class AsyncGlueExtension(AsyncLLMBaseExtension): |
|
def __init__(self, name): |
|
super().__init__(name) |
|
|
|
self.config: GlueConfig = None |
|
self.ten_env: AsyncTenEnv = None |
|
self.loop: asyncio.AbstractEventLoop = None |
|
self.stopped: bool = False |
|
self.memory: ChatMemory = None |
|
self.total_usage: LLMUsage = LLMUsage() |
|
self.users_count = 0 |
|
|
|
self.completion_times = [] |
|
self.connect_times = [] |
|
self.first_token_times = [] |
|
|
|
self.remote_stream_id: int = 999 |
|
|
|
async def on_init(self, ten_env: AsyncTenEnv) -> None: |
|
await super().on_init(ten_env) |
|
ten_env.log_debug("on_init") |
|
|
|
async def on_start(self, ten_env: AsyncTenEnv) -> None: |
|
await super().on_start(ten_env) |
|
ten_env.log_debug("on_start") |
|
|
|
self.loop = asyncio.get_event_loop() |
|
|
|
self.config = await GlueConfig.create_async(ten_env=ten_env) |
|
ten_env.log_info(f"config: {self.config}") |
|
|
|
self.memory = ChatMemory(self.config.max_history) |
|
|
|
if self.config.enable_storage: |
|
[result, _] = await ten_env.send_cmd(Cmd.create("retrieve")) |
|
if result.get_status_code() == StatusCode.OK: |
|
try: |
|
history = json.loads(result.get_property_string("response")) |
|
for i in history: |
|
self.memory.put(i) |
|
ten_env.log_info(f"on retrieve context {history}") |
|
except Exception: |
|
ten_env.log_error("Failed to handle retrieve result {e}") |
|
else: |
|
ten_env.log_warn("Failed to retrieve content") |
|
|
|
self.memory.on(EVENT_MEMORY_APPENDED, self._on_memory_appended) |
|
|
|
self.ten_env = ten_env |
|
|
|
async def on_stop(self, ten_env: AsyncTenEnv) -> None: |
|
await super().on_stop(ten_env) |
|
ten_env.log_debug("on_stop") |
|
|
|
self.stopped = True |
|
await self.queue.put(None) |
|
|
|
async def on_deinit(self, ten_env: AsyncTenEnv) -> None: |
|
await super().on_deinit(ten_env) |
|
ten_env.log_debug("on_deinit") |
|
|
|
async def on_cmd(self, ten_env: AsyncTenEnv, cmd: Cmd) -> None: |
|
cmd_name = cmd.get_name() |
|
ten_env.log_debug("on_cmd name {}".format(cmd_name)) |
|
|
|
status = StatusCode.OK |
|
detail = "success" |
|
|
|
if cmd_name == CMD_IN_FLUSH: |
|
await self.flush_input_items(ten_env) |
|
await ten_env.send_cmd(Cmd.create(CMD_OUT_FLUSH)) |
|
ten_env.log_info("on flush") |
|
elif cmd_name == CMD_IN_ON_USER_JOINED: |
|
self.users_count += 1 |
|
|
|
if self.config.greeting and self.users_count == 1: |
|
self.send_text_output(ten_env, self.config.greeting, True) |
|
elif cmd_name == CMD_IN_ON_USER_LEFT: |
|
self.users_count -= 1 |
|
else: |
|
await super().on_cmd(ten_env, cmd) |
|
return |
|
|
|
cmd_result = CmdResult.create(status) |
|
cmd_result.set_property_string("detail", detail) |
|
await ten_env.return_result(cmd_result, cmd) |
|
|
|
async def on_call_chat_completion( |
|
self, ten_env: AsyncTenEnv, **kargs: LLMCallCompletionArgs |
|
) -> any: |
|
raise RuntimeError("Not implemented") |
|
|
|
async def on_data_chat_completion( |
|
self, ten_env: AsyncTenEnv, **kargs: LLMDataCompletionArgs |
|
) -> None: |
|
input_messages: LLMChatCompletionUserMessageParam = kargs.get("messages", []) |
|
|
|
messages = [] |
|
if self.config.prompt: |
|
messages.append({"role": "system", "content": self.config.prompt}) |
|
|
|
history = self.memory.get() |
|
while history: |
|
if history[0].get("role") == "tool": |
|
history = history[1:] |
|
continue |
|
if history[0].get("role") == "assistant" and history[0].get("tool_calls"): |
|
history = history[1:] |
|
continue |
|
|
|
|
|
break |
|
|
|
messages.extend(history) |
|
|
|
if not input_messages: |
|
ten_env.log_warn("No message in data") |
|
else: |
|
messages.extend(input_messages) |
|
for i in input_messages: |
|
self.memory.put(i) |
|
|
|
def tool_dict(tool: LLMToolMetadata): |
|
json_dict = { |
|
"type": "function", |
|
"function": { |
|
"name": tool.name, |
|
"description": tool.description, |
|
"parameters": { |
|
"type": "object", |
|
"properties": {}, |
|
"required": [], |
|
"additionalProperties": False, |
|
}, |
|
}, |
|
"strict": True, |
|
} |
|
|
|
for param in tool.parameters: |
|
json_dict["function"]["parameters"]["properties"][param.name] = { |
|
"type": param.type, |
|
"description": param.description, |
|
} |
|
if param.required: |
|
json_dict["function"]["parameters"]["required"].append(param.name) |
|
|
|
return json_dict |
|
|
|
def trim_xml(input_string): |
|
return re.sub(r"<[^>]+>", "", input_string).strip() |
|
|
|
tools = [] |
|
for tool in self.available_tools: |
|
tools.append(tool_dict(tool)) |
|
|
|
total_output = "" |
|
sentence_fragment = "" |
|
calls = {} |
|
|
|
sentences = [] |
|
start_time = time.time() |
|
first_token_time = None |
|
response = self._stream_chat(messages=messages, tools=tools) |
|
async for message in response: |
|
self.ten_env.log_debug(f"content: {message}") |
|
try: |
|
c = ResponseChunk(**message) |
|
if c.choices: |
|
if c.choices[0].delta.content: |
|
if first_token_time is None: |
|
first_token_time = time.time() |
|
self.first_token_times.append(first_token_time - start_time) |
|
|
|
content = c.choices[0].delta.content |
|
if self.config.ssml_enabled and content.startswith("<speak>"): |
|
content = trim_xml(content) |
|
total_output += content |
|
sentences, sentence_fragment = parse_sentences( |
|
sentence_fragment, content |
|
) |
|
for s in sentences: |
|
await self._send_text(s) |
|
if c.choices[0].delta.tool_calls: |
|
self.ten_env.log_info( |
|
f"tool_calls: {c.choices[0].delta.tool_calls}" |
|
) |
|
for call in c.choices[0].delta.tool_calls: |
|
if call.index not in calls: |
|
calls[call.index] = ToolCall( |
|
id=call.id, |
|
index=call.index, |
|
function=ToolCallFunction(name="", arguments=""), |
|
) |
|
if call.function.name: |
|
calls[call.index].function.name += call.function.name |
|
if call.function.arguments: |
|
calls[ |
|
call.index |
|
].function.arguments += call.function.arguments |
|
if c.usage: |
|
self.ten_env.log_info(f"usage: {c.usage}") |
|
await self._update_usage(c.usage) |
|
except Exception as e: |
|
self.ten_env.log_error(f"Failed to parse response: {message} {e}") |
|
traceback.print_exc() |
|
if sentence_fragment: |
|
await self._send_text(sentence_fragment) |
|
end_time = time.time() |
|
self.completion_times.append(end_time - start_time) |
|
|
|
if total_output: |
|
self.memory.put({"role": "assistant", "content": total_output}) |
|
|
|
if calls: |
|
tasks = [] |
|
tool_calls = [] |
|
for _, call in calls.items(): |
|
self.ten_env.log_info(f"tool call: {call}") |
|
tool_calls.append(call.model_dump()) |
|
tasks.append(self.handle_tool_call(call)) |
|
self.memory.put({"role": "assistant", "tool_calls": tool_calls}) |
|
responses = await asyncio.gather(*tasks) |
|
for r in responses: |
|
content = r.response["content"] |
|
self.ten_env.log_info(f"tool call response: {content} {r.id}") |
|
self.memory.put( |
|
{ |
|
"role": "tool", |
|
"content": json.dumps(content), |
|
"tool_call_id": r.id, |
|
} |
|
) |
|
|
|
|
|
await self.on_data_chat_completion(ten_env) |
|
|
|
self.ten_env.log_info(f"total_output: {total_output} {calls}") |
|
|
|
async def on_tools_update( |
|
self, ten_env: AsyncTenEnv, tool: LLMToolMetadata |
|
) -> None: |
|
|
|
return await super().on_tools_update(ten_env, tool) |
|
|
|
async def handle_tool_call(self, call: ToolCall) -> ToolCallResponse: |
|
cmd: Cmd = Cmd.create(CMD_OUT_TOOL_CALL) |
|
cmd.set_property_string("name", call.function.name) |
|
cmd.set_property_from_json("arguments", call.function.arguments) |
|
|
|
|
|
[result, _] = await self.ten_env.send_cmd(cmd) |
|
if result.get_status_code() == StatusCode.OK: |
|
tool_result: LLMToolResult = json.loads( |
|
result.get_property_to_json(CMD_PROPERTY_RESULT) |
|
) |
|
|
|
self.ten_env.log_info(f"tool_result: {call} {tool_result}") |
|
return ToolCallResponse(id=call.id, response=tool_result) |
|
else: |
|
self.ten_env.log_error("Tool call failed") |
|
return ToolCallResponse( |
|
id=call.id, |
|
error=f"Tool call failed with status code {result.get_status_code()}", |
|
) |
|
|
|
async def on_data(self, ten_env: AsyncTenEnv, data: Data) -> None: |
|
data_name = data.get_name() |
|
ten_env.log_info(f"on_data name {data_name}") |
|
|
|
is_final = False |
|
input_text = "" |
|
try: |
|
is_final = data.get_property_bool(DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL) |
|
except Exception as err: |
|
ten_env.log_info( |
|
f"GetProperty optional {DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL} failed, err: {err}" |
|
) |
|
|
|
try: |
|
input_text = data.get_property_string(DATA_IN_TEXT_DATA_PROPERTY_TEXT) |
|
except Exception as err: |
|
ten_env.log_info( |
|
f"GetProperty optional {DATA_IN_TEXT_DATA_PROPERTY_TEXT} failed, err: {err}" |
|
) |
|
|
|
if not is_final: |
|
ten_env.log_info("ignore non-final input") |
|
return |
|
if not input_text: |
|
ten_env.log_info("ignore empty text") |
|
return |
|
|
|
ten_env.log_info(f"OnData input text: [{input_text}]") |
|
|
|
|
|
message = LLMChatCompletionUserMessageParam(role="user", content=input_text) |
|
await self.queue_input_item(False, messages=[message]) |
|
|
|
async def on_audio_frame( |
|
self, ten_env: AsyncTenEnv, audio_frame: AudioFrame |
|
) -> None: |
|
pass |
|
|
|
async def on_video_frame( |
|
self, ten_env: AsyncTenEnv, video_frame: VideoFrame |
|
) -> None: |
|
pass |
|
|
|
async def _send_text(self, text: str) -> None: |
|
data = Data.create("text_data") |
|
data.set_property_string(DATA_OUT_TEXT_DATA_PROPERTY_TEXT, text) |
|
data.set_property_bool(DATA_OUT_TEXT_DATA_PROPERTY_END_OF_SEGMENT, True) |
|
asyncio.create_task(self.ten_env.send_data(data)) |
|
|
|
async def _stream_chat( |
|
self, messages: List[Any], tools: List[Any] |
|
) -> AsyncGenerator[dict, None]: |
|
async with aiohttp.ClientSession() as session: |
|
try: |
|
payload = { |
|
"messages": messages, |
|
"tools": tools, |
|
"tools_choice": "auto" if tools else "none", |
|
"model": "gpt-3.5-turbo", |
|
"stream": True, |
|
"stream_options": {"include_usage": True}, |
|
"ssml_enabled": self.config.ssml_enabled, |
|
} |
|
if self.config.context_enabled: |
|
payload["context"] = {**self.config.extra_context} |
|
self.ten_env.log_info(f"payload before sending: {json.dumps(payload)}") |
|
headers = { |
|
"Authorization": f"Bearer {self.config.token}", |
|
"Content-Type": "application/json", |
|
} |
|
|
|
start_time = time.time() |
|
async with session.post( |
|
self.config.api_url, json=payload, headers=headers |
|
) as response: |
|
if response.status != 200: |
|
r = await response.json() |
|
self.ten_env.log_error( |
|
f"Received unexpected status {r} from the server." |
|
) |
|
if self.config.failure_info: |
|
await self._send_text(self.config.failure_info) |
|
return |
|
end_time = time.time() |
|
self.connect_times.append(end_time - start_time) |
|
|
|
async for line in response.content: |
|
if line: |
|
l = line.decode("utf-8").strip() |
|
if l.startswith("data:"): |
|
content = l[5:].strip() |
|
if content == "[DONE]": |
|
break |
|
self.ten_env.log_debug(f"content: {content}") |
|
yield json.loads(content) |
|
except Exception as e: |
|
traceback.print_exc() |
|
self.ten_env.log_error(f"Failed to handle {e}") |
|
finally: |
|
await session.close() |
|
session = None |
|
|
|
async def _update_usage(self, usage: LLMUsage) -> None: |
|
if not self.config.rtm_enabled: |
|
return |
|
|
|
self.total_usage.completion_tokens += usage.completion_tokens |
|
self.total_usage.prompt_tokens += usage.prompt_tokens |
|
self.total_usage.total_tokens += usage.total_tokens |
|
|
|
if self.total_usage.completion_tokens_details is None: |
|
self.total_usage.completion_tokens_details = LLMCompletionTokensDetails() |
|
if self.total_usage.prompt_tokens_details is None: |
|
self.total_usage.prompt_tokens_details = LLMPromptTokensDetails() |
|
|
|
if usage.completion_tokens_details: |
|
self.total_usage.completion_tokens_details.accepted_prediction_tokens += ( |
|
usage.completion_tokens_details.accepted_prediction_tokens |
|
) |
|
self.total_usage.completion_tokens_details.audio_tokens += ( |
|
usage.completion_tokens_details.audio_tokens |
|
) |
|
self.total_usage.completion_tokens_details.reasoning_tokens += ( |
|
usage.completion_tokens_details.reasoning_tokens |
|
) |
|
self.total_usage.completion_tokens_details.rejected_prediction_tokens += ( |
|
usage.completion_tokens_details.rejected_prediction_tokens |
|
) |
|
|
|
if usage.prompt_tokens_details: |
|
self.total_usage.prompt_tokens_details.audio_tokens += ( |
|
usage.prompt_tokens_details.audio_tokens |
|
) |
|
self.total_usage.prompt_tokens_details.cached_tokens += ( |
|
usage.prompt_tokens_details.cached_tokens |
|
) |
|
|
|
self.ten_env.log_info(f"total usage: {self.total_usage}") |
|
|
|
data = Data.create("llm_stat") |
|
data.set_property_from_json("usage", json.dumps(self.total_usage.model_dump())) |
|
if self.connect_times and self.completion_times and self.first_token_times: |
|
data.set_property_from_json( |
|
"latency", |
|
json.dumps( |
|
{ |
|
"connection_latency_95": np.percentile(self.connect_times, 95), |
|
"completion_latency_95": np.percentile( |
|
self.completion_times, 95 |
|
), |
|
"first_token_latency_95": np.percentile( |
|
self.first_token_times, 95 |
|
), |
|
"connection_latency_99": np.percentile(self.connect_times, 99), |
|
"completion_latency_99": np.percentile( |
|
self.completion_times, 99 |
|
), |
|
"first_token_latency_99": np.percentile( |
|
self.first_token_times, 99 |
|
), |
|
} |
|
), |
|
) |
|
asyncio.create_task(self.ten_env.send_data(data)) |
|
|
|
async def _on_memory_appended(self, message: dict) -> None: |
|
self.ten_env.log_info(f"Memory appended: {message}") |
|
if not self.config.enable_storage: |
|
return |
|
|
|
role = message.get("role") |
|
stream_id = self.remote_stream_id if role == "user" else 0 |
|
try: |
|
d = Data.create("append") |
|
d.set_property_string("text", message.get("content")) |
|
d.set_property_string("role", role) |
|
d.set_property_int("stream_id", stream_id) |
|
asyncio.create_task(self.ten_env.send_data(d)) |
|
except Exception as e: |
|
self.ten_env.log_error(f"Error send append_context data {message} {e}") |
|
|