|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import asyncio |
|
import base64 |
|
import io |
|
import json |
|
from enum import Enum |
|
import traceback |
|
import time |
|
import numpy as np |
|
from datetime import datetime |
|
from typing import Iterable |
|
from pydub import AudioSegment |
|
|
|
from ten import ( |
|
AudioFrame, |
|
AsyncTenEnv, |
|
Cmd, |
|
StatusCode, |
|
CmdResult, |
|
Data, |
|
) |
|
from ten.audio_frame import AudioFrameDataFmt |
|
from ten_ai_base.const import CMD_PROPERTY_RESULT, CMD_TOOL_CALL |
|
from dataclasses import dataclass |
|
from ten_ai_base.config import BaseConfig |
|
from ten_ai_base.chat_memory import ( |
|
ChatMemory, |
|
EVENT_MEMORY_EXPIRED, |
|
EVENT_MEMORY_APPENDED, |
|
) |
|
from ten_ai_base.usage import ( |
|
LLMUsage, |
|
LLMCompletionTokensDetails, |
|
LLMPromptTokensDetails, |
|
) |
|
from ten_ai_base.types import ( |
|
LLMToolMetadata, |
|
LLMToolResult, |
|
LLMChatCompletionContentPartParam, |
|
) |
|
from ten_ai_base.llm import AsyncLLMBaseExtension |
|
from .realtime.connection import RealtimeApiConnection |
|
from .realtime.struct import ( |
|
AudioFormats, |
|
ItemCreate, |
|
SessionCreated, |
|
ItemCreated, |
|
UserMessageItemParam, |
|
AssistantMessageItemParam, |
|
ItemInputAudioTranscriptionCompleted, |
|
ItemInputAudioTranscriptionFailed, |
|
ResponseCreated, |
|
ResponseDone, |
|
ResponseAudioTranscriptDelta, |
|
ResponseTextDelta, |
|
ResponseAudioTranscriptDone, |
|
ResponseTextDone, |
|
ResponseOutputItemDone, |
|
ResponseOutputItemAdded, |
|
ResponseAudioDelta, |
|
ResponseAudioDone, |
|
InputAudioBufferSpeechStarted, |
|
InputAudioBufferSpeechStopped, |
|
ResponseFunctionCallArgumentsDone, |
|
ErrorMessage, |
|
ItemDelete, |
|
SessionUpdate, |
|
SessionUpdateParams, |
|
InputAudioTranscription, |
|
ContentType, |
|
FunctionCallOutputItemParam, |
|
ResponseCreate, |
|
) |
|
|
|
CMD_IN_FLUSH = "flush" |
|
CMD_IN_ON_USER_JOINED = "on_user_joined" |
|
CMD_IN_ON_USER_LEFT = "on_user_left" |
|
CMD_OUT_FLUSH = "flush" |
|
|
|
|
|
class Role(str, Enum): |
|
User = "user" |
|
Assistant = "assistant" |
|
|
|
|
|
@dataclass |
|
class GLMRealtimeConfig(BaseConfig): |
|
base_uri: str = "wss://open.bigmodel.cn" |
|
api_key: str = "" |
|
path: str = "/api/paas/v4/realtime" |
|
prompt: str = "" |
|
temperature: float = 0.5 |
|
max_tokens: int = 1024 |
|
server_vad: bool = True |
|
audio_out: bool = True |
|
input_transcript: bool = True |
|
sample_rate: int = 24000 |
|
|
|
stream_id: int = 0 |
|
dump: bool = False |
|
max_history: int = 20 |
|
enable_storage: bool = False |
|
greeting: str = "" |
|
language: str = "en-US" |
|
|
|
def build_ctx(self) -> dict: |
|
return { |
|
} |
|
|
|
|
|
class GLMRealtimeExtension(AsyncLLMBaseExtension): |
|
|
|
def __init__(self, name: str): |
|
super().__init__(name) |
|
self.ten_env: AsyncTenEnv = None |
|
self.conn = None |
|
self.session = None |
|
self.session_id = None |
|
|
|
self.config: GLMRealtimeConfig = None |
|
self.stopped: bool = False |
|
self.connected: bool = False |
|
self.buffer: bytearray = b"" |
|
self.memory: ChatMemory = None |
|
self.total_usage: LLMUsage = LLMUsage() |
|
self.users_count = 0 |
|
|
|
self.stream_id: int = 0 |
|
self.remote_stream_id: int = 0 |
|
self.channel_name: str = "" |
|
self.audio_len_threshold: int = 5120 |
|
|
|
self.completion_times = [] |
|
self.connect_times = [] |
|
self.first_token_times = [] |
|
|
|
self.transcript: str = "" |
|
self.ctx: dict = {} |
|
self.input_end = time.time() |
|
self.input_audio_queue = asyncio.Queue() |
|
|
|
async def on_init(self, ten_env: AsyncTenEnv) -> None: |
|
await super().on_init(ten_env) |
|
ten_env.log_debug("on_init") |
|
|
|
async def on_start(self, ten_env: AsyncTenEnv) -> None: |
|
await super().on_start(ten_env) |
|
ten_env.log_debug("on_start") |
|
self.ten_env = ten_env |
|
|
|
self.loop = asyncio.get_event_loop() |
|
self.loop.create_task(self._on_process_audio()) |
|
|
|
self.config = await GLMRealtimeConfig.create_async(ten_env=ten_env) |
|
ten_env.log_info(f"config: {self.config}") |
|
|
|
if not self.config.api_key: |
|
ten_env.log_error("api_key is required") |
|
return |
|
|
|
try: |
|
self.memory = ChatMemory(self.config.max_history) |
|
|
|
if self.config.enable_storage: |
|
[result, _] = await ten_env.send_cmd(Cmd.create("retrieve")) |
|
if result.get_status_code() == StatusCode.OK: |
|
try: |
|
history = json.loads(result.get_property_string("response")) |
|
for i in history: |
|
self.memory.put(i) |
|
ten_env.log_info(f"on retrieve context {history}") |
|
except Exception as e: |
|
ten_env.log_error(f"Failed to handle retrieve result {e}") |
|
else: |
|
ten_env.log_warn("Failed to retrieve content") |
|
|
|
self.memory.on(EVENT_MEMORY_EXPIRED, self._on_memory_expired) |
|
self.memory.on(EVENT_MEMORY_APPENDED, self._on_memory_appended) |
|
|
|
self.ctx = self.config.build_ctx() |
|
|
|
self.conn = RealtimeApiConnection( |
|
ten_env=ten_env, |
|
base_uri=self.config.base_uri, |
|
path=self.config.path, |
|
api_key=self.config.api_key, |
|
) |
|
ten_env.log_info("Finish init client") |
|
|
|
self.loop.create_task(self._loop()) |
|
except Exception as e: |
|
traceback.print_exc() |
|
self.ten_env.log_error(f"Failed to init client {e}") |
|
|
|
async def on_stop(self, ten_env: AsyncTenEnv) -> None: |
|
await super().on_stop(ten_env) |
|
ten_env.log_info("on_stop") |
|
|
|
self.input_audio_queue.put_nowait(None) |
|
self.stopped = True |
|
|
|
async def on_audio_frame(self, _: AsyncTenEnv, audio_frame: AudioFrame) -> None: |
|
try: |
|
stream_id = audio_frame.get_property_int("stream_id") |
|
if self.channel_name == "": |
|
self.channel_name = audio_frame.get_property_string("channel") |
|
|
|
if self.remote_stream_id == 0: |
|
self.remote_stream_id = stream_id |
|
|
|
frame_buf = audio_frame.get_buf() |
|
self.input_audio_queue.put_nowait(frame_buf) |
|
|
|
if not self.config.server_vad: |
|
self.input_end = time.time() |
|
except Exception as e: |
|
traceback.print_exc() |
|
self.ten_env.log_error(f"GLMV2VExtension on audio frame failed {e}") |
|
|
|
async def on_cmd(self, ten_env: AsyncTenEnv, cmd: Cmd) -> None: |
|
cmd_name = cmd.get_name() |
|
ten_env.log_debug("on_cmd name {}".format(cmd_name)) |
|
|
|
status = StatusCode.OK |
|
detail = "success" |
|
|
|
if cmd_name == CMD_IN_FLUSH: |
|
|
|
await self._flush() |
|
await ten_env.send_cmd(Cmd.create(CMD_OUT_FLUSH)) |
|
ten_env.log_info("on flush") |
|
elif cmd_name == CMD_IN_ON_USER_JOINED: |
|
self.users_count += 1 |
|
|
|
if self.users_count == 1: |
|
await self._greeting() |
|
elif cmd_name == CMD_IN_ON_USER_LEFT: |
|
self.users_count -= 1 |
|
else: |
|
|
|
await super().on_cmd(ten_env, cmd) |
|
return |
|
|
|
cmd_result = CmdResult.create(status) |
|
cmd_result.set_property_string("detail", detail) |
|
await ten_env.return_result(cmd_result, cmd) |
|
|
|
|
|
async def on_data(self, ten_env: AsyncTenEnv, data: Data) -> None: |
|
pass |
|
|
|
async def _on_process_audio(self) -> None: |
|
while True: |
|
try: |
|
audio_frame = await self.input_audio_queue.get() |
|
|
|
if audio_frame is None: |
|
break |
|
|
|
self._dump_audio_if_need(audio_frame, Role.User) |
|
if self.connected: |
|
wav_buff = self.convert_to_wav_in_memory(audio_frame) |
|
await self.conn.send_audio_data(wav_buff) |
|
except Exception as e: |
|
traceback.print_exc() |
|
self.ten_env.log_error(f"Error processing audio frame {e}") |
|
|
|
async def _loop(self): |
|
def get_time_ms() -> int: |
|
current_time = datetime.now() |
|
return current_time.microsecond // 1000 |
|
|
|
try: |
|
start_time = time.time() |
|
await self.conn.connect() |
|
self.connect_times.append(time.time() - start_time) |
|
item_id = "" |
|
response_id = "" |
|
|
|
relative_start_ms = get_time_ms() |
|
flushed = set() |
|
|
|
self.ten_env.log_info("Client loop started") |
|
async for message in self.conn.listen(): |
|
try: |
|
|
|
match message: |
|
case SessionCreated(): |
|
self.ten_env.log_info( |
|
f"Session is created: {message.session}" |
|
) |
|
self.session_id = message.session.id |
|
self.session = message.session |
|
await self._update_session() |
|
|
|
history = self.memory.get() |
|
for h in history: |
|
if h["role"] == "user": |
|
await self.conn.send_request( |
|
ItemCreate( |
|
item=UserMessageItemParam( |
|
content=[ |
|
{ |
|
"type": ContentType.InputText, |
|
"text": h["content"], |
|
} |
|
] |
|
) |
|
) |
|
) |
|
elif h["role"] == "assistant": |
|
await self.conn.send_request( |
|
ItemCreate( |
|
item=AssistantMessageItemParam( |
|
content=[ |
|
{ |
|
"type": ContentType.InputText, |
|
"text": h["content"], |
|
} |
|
] |
|
) |
|
) |
|
) |
|
self.ten_env.log_info(f"Finish send history {history}") |
|
self.memory.clear() |
|
|
|
if not self.connected: |
|
self.connected = True |
|
await self._greeting() |
|
case ItemInputAudioTranscriptionCompleted(): |
|
self.ten_env.log_info( |
|
f"On request transcript {message.transcript}" |
|
) |
|
self._send_transcript(message.transcript, Role.User, True) |
|
self.memory.put( |
|
{ |
|
"role": "user", |
|
"content": message.transcript, |
|
|
|
} |
|
) |
|
case ItemInputAudioTranscriptionFailed(): |
|
self.ten_env.log_warn( |
|
f"On request transcript failed {message.item_id} {message.error}" |
|
) |
|
case ItemCreated(): |
|
self.ten_env.log_info(f"On item created {message.item}") |
|
case ResponseCreated(): |
|
response_id = message.response.id |
|
self.ten_env.log_info(f"On response created {response_id}") |
|
case ResponseDone(): |
|
msg_resp_id = message.response.id |
|
status = message.response.status |
|
if msg_resp_id == response_id: |
|
response_id = "" |
|
self.ten_env.log_info( |
|
f"On response done {msg_resp_id} {status} {message.response.usage}" |
|
) |
|
|
|
|
|
self.transcript = "" |
|
self._send_transcript("", Role.Assistant, True) |
|
|
|
if message.response.usage: |
|
pass |
|
|
|
case ResponseAudioTranscriptDelta(): |
|
self.ten_env.log_info( |
|
f"On response transcript delta {message.output_index} {message.content_index} {message.delta}" |
|
) |
|
if message.response_id in flushed: |
|
self.ten_env.log_warn( |
|
f"On flushed transcript delta {message.output_index} {message.content_index} {message.delta}" |
|
) |
|
continue |
|
self._send_transcript(message.delta, Role.Assistant, False) |
|
case ResponseTextDelta(): |
|
self.ten_env.log_info( |
|
f"On response text delta {message.output_index} {message.content_index} {message.delta}" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._send_transcript(message.delta, Role.Assistant, False) |
|
case ResponseAudioTranscriptDone(): |
|
|
|
self.ten_env.log_info( |
|
f"On response transcript done {message.output_index} {message.content_index} {message.transcript}" |
|
) |
|
if message.response_id in flushed: |
|
self.ten_env.log_warn( |
|
"On flushed transcript done" |
|
) |
|
continue |
|
self.memory.put( |
|
{ |
|
"role": "assistant", |
|
"content": message.transcript, |
|
|
|
} |
|
) |
|
self.transcript = "" |
|
self._send_transcript("", Role.Assistant, True) |
|
case ResponseTextDone(): |
|
self.ten_env.log_info( |
|
f"On response text done {message.output_index} {message.content_index} {message.text}" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.completion_times.append(time.time() - self.input_end) |
|
self.transcript = "" |
|
self._send_transcript("", Role.Assistant, True) |
|
case ResponseOutputItemDone(): |
|
self.ten_env.log_info(f"Output item done {message.item}") |
|
case ResponseOutputItemAdded(): |
|
self.ten_env.log_info( |
|
f"Output item added {message.output_index} {message.item}" |
|
) |
|
case ResponseAudioDelta(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
await self._on_audio_delta(message.delta) |
|
case ResponseAudioDone(): |
|
self.completion_times.append(time.time() - self.input_end) |
|
case InputAudioBufferSpeechStarted(): |
|
self.ten_env.log_info( |
|
f"On server listening, in response {response_id}, last item {item_id}" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.config.server_vad: |
|
await self._flush() |
|
if response_id and self.transcript: |
|
transcript = self.transcript + "[interrupted]" |
|
self._send_transcript(transcript, Role.Assistant, True) |
|
self.transcript = "" |
|
|
|
flushed.add(response_id) |
|
item_id = "" |
|
case InputAudioBufferSpeechStopped(): |
|
|
|
self.input_end = time.time() |
|
relative_start_ms = get_time_ms() - message.audio_end_ms |
|
self.ten_env.log_info( |
|
f"On server stop listening, {message.audio_end_ms}, relative {relative_start_ms}" |
|
) |
|
case ResponseFunctionCallArgumentsDone(): |
|
|
|
name = message.name |
|
arguments = message.arguments |
|
self.ten_env.log_info(f"need to call func {name}") |
|
self.loop.create_task( |
|
self._handle_tool_call(name, arguments) |
|
) |
|
case ErrorMessage(): |
|
self.ten_env.log_error( |
|
f"Error message received: {message.error}" |
|
) |
|
case _: |
|
self.ten_env.log_debug(f"Not handled message {message}") |
|
except Exception as e: |
|
traceback.print_exc() |
|
self.ten_env.log_error(f"Error processing message: {message} {e}") |
|
|
|
self.ten_env.log_info("Client loop finished") |
|
except Exception as e: |
|
traceback.print_exc() |
|
self.ten_env.log_error(f"Failed to handle loop {e}") |
|
|
|
|
|
self.connected = False |
|
self.remote_stream_id = 0 |
|
|
|
if not self.stopped: |
|
await self.conn.close() |
|
await asyncio.sleep(0.5) |
|
self.ten_env.log_info("Reconnect") |
|
|
|
self.conn = RealtimeApiConnection( |
|
ten_env=self.ten_env, |
|
base_uri=self.config.base_uri, |
|
path=self.config.path, |
|
api_key=self.config.api_key, |
|
) |
|
|
|
self.loop.create_task(self._loop()) |
|
|
|
async def _on_memory_expired(self, message: dict) -> None: |
|
self.ten_env.log_info(f"Memory expired: {message}") |
|
item_id = message.get("item_id") |
|
if item_id: |
|
await self.conn.send_request(ItemDelete(item_id=item_id)) |
|
|
|
async def _on_memory_appended(self, message: dict) -> None: |
|
self.ten_env.log_info(f"Memory appended: {message}") |
|
if not self.config.enable_storage: |
|
return |
|
|
|
role = message.get("role") |
|
stream_id = self.remote_stream_id if role == Role.User else 0 |
|
try: |
|
d = Data.create("append") |
|
d.set_property_string("text", message.get("content")) |
|
d.set_property_string("role", role) |
|
d.set_property_int("stream_id", stream_id) |
|
asyncio.create_task(self.ten_env.send_data(d)) |
|
except Exception as e: |
|
self.ten_env.log_error(f"Error send append_context data {message} {e}") |
|
|
|
|
|
def convert_to_wav_in_memory(self, buff: bytearray) -> bytes: |
|
""" |
|
Converts the accumulated PCM data to WAV format in-memory. |
|
Returns the WAV data as bytes. |
|
""" |
|
|
|
pcm_data = np.frombuffer(buff, dtype=np.int16) |
|
|
|
|
|
audio_segment = AudioSegment( |
|
pcm_data.tobytes(), |
|
frame_rate=24000, |
|
sample_width=2, |
|
channels=1 |
|
) |
|
|
|
|
|
memory_stream = io.BytesIO() |
|
|
|
|
|
audio_segment.export(memory_stream, format="wav") |
|
|
|
|
|
wav_bytes = memory_stream.getvalue() |
|
return wav_bytes |
|
|
|
async def _update_session(self) -> None: |
|
tools = [] |
|
|
|
def tool_dict(tool: LLMToolMetadata): |
|
t = { |
|
"type": "function", |
|
"name": tool.name, |
|
"description": tool.description, |
|
"parameters": { |
|
"type": "object", |
|
"properties": {}, |
|
"required": [], |
|
"additionalProperties": False, |
|
}, |
|
} |
|
|
|
for param in tool.parameters: |
|
t["parameters"]["properties"][param.name] = { |
|
"type": param.type, |
|
"description": param.description, |
|
} |
|
if param.required: |
|
t["parameters"]["required"].append(param.name) |
|
|
|
return t |
|
|
|
if self.available_tools: |
|
tool_prompt = "You have several tools that you can get help from:\n" |
|
for t in self.available_tools: |
|
tool_prompt += f"- ***{t.name}***: {t.description}" |
|
self.ctx["tools"] = tool_prompt |
|
tools = [tool_dict(t) for t in self.available_tools] |
|
prompt = self._replace(self.config.prompt) |
|
|
|
self.ten_env.log_info(f"update session {prompt} {tools}") |
|
su = SessionUpdate( |
|
session=SessionUpdateParams( |
|
instructions=prompt, |
|
input_audio_format=AudioFormats.WAV24, |
|
output_audio_format=AudioFormats.PCM, |
|
tools=tools, |
|
) |
|
) |
|
if self.config.audio_out: |
|
|
|
pass |
|
else: |
|
su.session.modalities = ["text"] |
|
|
|
if self.config.input_transcript: |
|
su.session.input_audio_transcription = InputAudioTranscription( |
|
model="whisper-1" |
|
) |
|
await self.conn.send_request(su) |
|
|
|
async def on_tools_update(self, _: AsyncTenEnv, tool: LLMToolMetadata) -> None: |
|
"""Called when a new tool is registered. Implement this method to process the new tool.""" |
|
self.ten_env.log_info(f"on tools update {tool}") |
|
|
|
|
|
def _replace(self, prompt: str) -> str: |
|
result = prompt |
|
for token, value in self.ctx.items(): |
|
result = result.replace("{" + token + "}", value) |
|
return result |
|
|
|
|
|
async def _on_audio_delta(self, delta: bytes) -> None: |
|
audio_data = base64.b64decode(delta) |
|
self.ten_env.log_debug( |
|
f"on_audio_delta audio_data len {len(audio_data)} samples {len(audio_data) // 2}" |
|
) |
|
self._dump_audio_if_need(audio_data, Role.Assistant) |
|
|
|
f = AudioFrame.create("pcm_frame") |
|
f.set_sample_rate(self.config.sample_rate) |
|
f.set_bytes_per_sample(2) |
|
f.set_number_of_channels(1) |
|
f.set_data_fmt(AudioFrameDataFmt.INTERLEAVE) |
|
f.set_samples_per_channel(len(audio_data) // 2) |
|
f.alloc_buf(len(audio_data)) |
|
buff = f.lock_buf() |
|
buff[:] = audio_data |
|
f.unlock_buf(buff) |
|
await self.ten_env.send_audio_frame(f) |
|
|
|
def _send_transcript(self, content: str, role: Role, is_final: bool) -> None: |
|
def is_punctuation(char): |
|
if char in [",", ",", ".", "。", "?", "?", "!", "!"]: |
|
return True |
|
return False |
|
|
|
def parse_sentences(sentence_fragment, content): |
|
sentences = [] |
|
current_sentence = sentence_fragment |
|
for char in content: |
|
current_sentence += char |
|
if is_punctuation(char): |
|
|
|
stripped_sentence = current_sentence |
|
if any(c.isalnum() for c in stripped_sentence): |
|
sentences.append(stripped_sentence) |
|
current_sentence = "" |
|
|
|
remain = current_sentence |
|
return sentences, remain |
|
|
|
def send_data( |
|
ten_env: AsyncTenEnv, |
|
sentence: str, |
|
stream_id: int, |
|
role: str, |
|
is_final: bool, |
|
): |
|
try: |
|
d = Data.create("text_data") |
|
d.set_property_string("text", sentence) |
|
d.set_property_bool("end_of_segment", is_final) |
|
d.set_property_string("role", role) |
|
d.set_property_int("stream_id", stream_id) |
|
ten_env.log_info( |
|
f"send transcript text [{sentence}] stream_id {stream_id} is_final {is_final} end_of_segment {is_final} role {role}" |
|
) |
|
asyncio.create_task(ten_env.send_data(d)) |
|
except Exception as e: |
|
ten_env.log_error( |
|
f"Error send text data {role}: {sentence} {is_final} {e}" |
|
) |
|
|
|
stream_id = self.remote_stream_id if role == Role.User else 0 |
|
try: |
|
if role == Role.Assistant and not is_final: |
|
sentences, self.transcript = parse_sentences(self.transcript, content) |
|
for s in sentences: |
|
send_data(self.ten_env, s, stream_id, role, is_final) |
|
else: |
|
send_data(self.ten_env, content, stream_id, role, is_final) |
|
except Exception as e: |
|
self.ten_env.log_error( |
|
f"Error send text data {role}: {content} {is_final} {e}" |
|
) |
|
|
|
def _dump_audio_if_need(self, buf: bytearray, role: Role) -> None: |
|
if not self.config.dump: |
|
return |
|
|
|
with open("{}_{}.pcm".format(role, self.channel_name), "ab") as dump_file: |
|
dump_file.write(buf) |
|
|
|
async def _handle_tool_call( |
|
self, name: str, arguments: str |
|
) -> None: |
|
self.ten_env.log_info(f"_handle_tool_call {name} {arguments}") |
|
cmd: Cmd = Cmd.create(CMD_TOOL_CALL) |
|
cmd.set_property_string("name", name) |
|
cmd.set_property_from_json("arguments", arguments) |
|
[result, _] = await self.ten_env.send_cmd(cmd) |
|
|
|
tool_response = ItemCreate( |
|
item=FunctionCallOutputItemParam( |
|
output='{"success":false}', |
|
) |
|
) |
|
if result.get_status_code() == StatusCode.OK: |
|
tool_result: LLMToolResult = json.loads( |
|
result.get_property_to_json(CMD_PROPERTY_RESULT) |
|
) |
|
|
|
result_content = tool_result["content"] |
|
tool_response.item.output = json.dumps( |
|
self._convert_to_content_parts(result_content) |
|
) |
|
self.ten_env.log_info(f"tool_result: {tool_result}") |
|
else: |
|
self.ten_env.log_error("Tool call failed") |
|
|
|
await self.conn.send_request(tool_response) |
|
await self.conn.send_request(ResponseCreate()) |
|
self.ten_env.log_info(f"_remote_tool_call finish {name} {arguments}") |
|
|
|
def _greeting_text(self) -> str: |
|
text = "Hi, there." |
|
if self.config.language == "zh-CN": |
|
text = "你好。" |
|
elif self.config.language == "ja-JP": |
|
text = "こんにちは" |
|
elif self.config.language == "ko-KR": |
|
text = "안녕하세요" |
|
return text |
|
|
|
def _convert_tool_params_to_dict(self, tool: LLMToolMetadata): |
|
json_dict = {"type": "object", "properties": {}, "required": []} |
|
|
|
for param in tool.parameters: |
|
json_dict["properties"][param.name] = { |
|
"type": param.type, |
|
"description": param.description, |
|
} |
|
if param.required: |
|
json_dict["required"].append(param.name) |
|
|
|
return json_dict |
|
|
|
def _convert_to_content_parts( |
|
self, content: Iterable[LLMChatCompletionContentPartParam] |
|
): |
|
content_parts = [] |
|
|
|
if isinstance(content, str): |
|
content_parts.append({"type": "text", "text": content}) |
|
else: |
|
for part in content: |
|
|
|
if part["type"] == "text": |
|
content_parts.append(part) |
|
return content_parts |
|
|
|
async def _greeting(self) -> None: |
|
if self.connected and self.users_count == 1: |
|
|
|
text = self._greeting_text() |
|
if self.config.greeting: |
|
text = "Say '" + self.config.greeting + "' to me." |
|
self.ten_env.log_info(f"send greeting {text}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _flush(self) -> None: |
|
try: |
|
c = Cmd.create("flush") |
|
await self.ten_env.send_cmd(c) |
|
except Exception: |
|
self.ten_env.log_error("Error flush") |
|
|
|
async def _update_usage(self, usage: dict) -> None: |
|
self.total_usage.completion_tokens += usage.get("output_tokens") or 0 |
|
self.total_usage.prompt_tokens += usage.get("input_tokens") or 0 |
|
self.total_usage.total_tokens += usage.get("total_tokens") or 0 |
|
if not self.total_usage.completion_tokens_details: |
|
self.total_usage.completion_tokens_details = LLMCompletionTokensDetails() |
|
if not self.total_usage.prompt_tokens_details: |
|
self.total_usage.prompt_tokens_details = LLMPromptTokensDetails() |
|
|
|
if usage.get("output_token_details"): |
|
self.total_usage.completion_tokens_details.accepted_prediction_tokens += ( |
|
usage["output_token_details"].get("text_tokens") |
|
) |
|
self.total_usage.completion_tokens_details.audio_tokens += usage[ |
|
"output_token_details" |
|
].get("audio_tokens") |
|
|
|
if usage.get("input_token_details:"): |
|
self.total_usage.prompt_tokens_details.audio_tokens += usage[ |
|
"input_token_details" |
|
].get("audio_tokens") |
|
self.total_usage.prompt_tokens_details.cached_tokens += usage[ |
|
"input_token_details" |
|
].get("cached_tokens") |
|
self.total_usage.prompt_tokens_details.text_tokens += usage[ |
|
"input_token_details" |
|
].get("text_tokens") |
|
|
|
self.ten_env.log_info(f"total usage: {self.total_usage}") |
|
|
|
data = Data.create("llm_stat") |
|
data.set_property_from_json("usage", json.dumps(self.total_usage.model_dump())) |
|
if self.connect_times and self.completion_times and self.first_token_times: |
|
data.set_property_from_json( |
|
"latency", |
|
json.dumps( |
|
{ |
|
"connection_latency_95": np.percentile(self.connect_times, 95), |
|
"completion_latency_95": np.percentile( |
|
self.completion_times, 95 |
|
), |
|
"first_token_latency_95": np.percentile( |
|
self.first_token_times, 95 |
|
), |
|
"connection_latency_99": np.percentile(self.connect_times, 99), |
|
"completion_latency_99": np.percentile( |
|
self.completion_times, 99 |
|
), |
|
"first_token_latency_99": np.percentile( |
|
self.first_token_times, 99 |
|
), |
|
} |
|
), |
|
) |
|
asyncio.create_task(self.ten_env.send_data(data)) |
|
|
|
async def on_call_chat_completion(self, async_ten_env, **kargs): |
|
raise NotImplementedError |
|
|
|
async def on_data_chat_completion(self, async_ten_env, **kargs): |
|
raise NotImplementedError |
|
|