# # # Agora Real Time Engagement # Created by Wei Hu in 2024-08. # Copyright (c) 2024 Agora IO. All rights reserved. # # import asyncio import base64 import json from enum import Enum import traceback import time import numpy as np from datetime import datetime from typing import Iterable from ten import ( AudioFrame, AsyncTenEnv, Cmd, StatusCode, CmdResult, Data, ) from ten.audio_frame import AudioFrameDataFmt from ten_ai_base.const import CMD_PROPERTY_RESULT, CMD_TOOL_CALL from dataclasses import dataclass from ten_ai_base.config import BaseConfig from ten_ai_base.chat_memory import ( ChatMemory, EVENT_MEMORY_EXPIRED, EVENT_MEMORY_APPENDED, ) from ten_ai_base.usage import ( LLMUsage, LLMCompletionTokensDetails, LLMPromptTokensDetails, ) from ten_ai_base.types import ( LLMToolMetadata, LLMToolResult, LLMChatCompletionContentPartParam, ) from ten_ai_base.llm import AsyncLLMBaseExtension from .realtime.connection import RealtimeApiConnection from .realtime.struct import ( ItemCreate, SessionCreated, ItemCreated, UserMessageItemParam, AssistantMessageItemParam, ItemInputAudioTranscriptionCompleted, ItemInputAudioTranscriptionFailed, ResponseCreated, ResponseDone, ResponseAudioTranscriptDelta, ResponseTextDelta, ResponseAudioTranscriptDone, ResponseTextDone, ResponseOutputItemDone, ResponseOutputItemAdded, ResponseAudioDelta, ResponseAudioDone, InputAudioBufferSpeechStarted, InputAudioBufferSpeechStopped, ResponseFunctionCallArgumentsDone, ErrorMessage, ItemDelete, ItemTruncate, SessionUpdate, SessionUpdateParams, InputAudioTranscription, ContentType, FunctionCallOutputItemParam, ResponseCreate, ) CMD_IN_FLUSH = "flush" CMD_IN_ON_USER_JOINED = "on_user_joined" CMD_IN_ON_USER_LEFT = "on_user_left" CMD_OUT_FLUSH = "flush" class Role(str, Enum): User = "user" Assistant = "assistant" @dataclass class OpenAIRealtimeConfig(BaseConfig): base_uri: str = "wss://api.openai.com" api_key: str = "" path: str = "/v1/realtime" model: str = "gpt-4o-realtime-preview" language: str = "en-US" prompt: str = "" temperature: float = 0.5 max_tokens: int = 1024 voice: str = "alloy" server_vad: bool = True audio_out: bool = True input_transcript: bool = True sample_rate: int = 24000 vendor: str = "" stream_id: int = 0 dump: bool = False greeting: str = "" max_history: int = 20 enable_storage: bool = False def build_ctx(self) -> dict: return { "language": self.language, "model": self.model, } class OpenAIRealtimeExtension(AsyncLLMBaseExtension): def __init__(self, name: str): super().__init__(name) self.ten_env: AsyncTenEnv = None self.conn = None self.session = None self.session_id = None self.config: OpenAIRealtimeConfig = None self.stopped: bool = False self.connected: bool = False self.buffer: bytearray = b"" self.memory: ChatMemory = None self.total_usage: LLMUsage = LLMUsage() self.users_count = 0 self.stream_id: int = 0 self.remote_stream_id: int = 0 self.channel_name: str = "" self.audio_len_threshold: int = 5120 self.completion_times = [] self.connect_times = [] self.first_token_times = [] self.buff: bytearray = b"" self.transcript: str = "" self.ctx: dict = {} self.input_end = time.time() async def on_init(self, ten_env: AsyncTenEnv) -> None: await super().on_init(ten_env) ten_env.log_debug("on_init") async def on_start(self, ten_env: AsyncTenEnv) -> None: await super().on_start(ten_env) ten_env.log_debug("on_start") self.ten_env = ten_env self.loop = asyncio.get_event_loop() self.config = await OpenAIRealtimeConfig.create_async(ten_env=ten_env) ten_env.log_info(f"config: {self.config}") if not self.config.api_key: ten_env.log_error("api_key is required") return try: self.memory = ChatMemory(self.config.max_history) if self.config.enable_storage: [result, _] = await ten_env.send_cmd(Cmd.create("retrieve")) if result.get_status_code() == StatusCode.OK: try: history = json.loads(result.get_property_string("response")) for i in history: self.memory.put(i) ten_env.log_info(f"on retrieve context {history}") except Exception as e: ten_env.log_error(f"Failed to handle retrieve result {e}") else: ten_env.log_warn("Failed to retrieve content") self.memory.on(EVENT_MEMORY_EXPIRED, self._on_memory_expired) self.memory.on(EVENT_MEMORY_APPENDED, self._on_memory_appended) self.ctx = self.config.build_ctx() self.ctx["greeting"] = self.config.greeting self.conn = RealtimeApiConnection( ten_env=ten_env, base_uri=self.config.base_uri, path=self.config.path, api_key=self.config.api_key, model=self.config.model, vendor=self.config.vendor, ) ten_env.log_info("Finish init client") self.loop.create_task(self._loop()) except Exception as e: traceback.print_exc() self.ten_env.log_error(f"Failed to init client {e}") async def on_stop(self, ten_env: AsyncTenEnv) -> None: await super().on_stop(ten_env) ten_env.log_info("on_stop") self.stopped = True async def on_audio_frame(self, _: AsyncTenEnv, audio_frame: AudioFrame) -> None: try: stream_id = audio_frame.get_property_int("stream_id") if self.channel_name == "": self.channel_name = audio_frame.get_property_string("channel") if self.remote_stream_id == 0: self.remote_stream_id = stream_id frame_buf = audio_frame.get_buf() self._dump_audio_if_need(frame_buf, Role.User) await self._on_audio(frame_buf) if not self.config.server_vad: self.input_end = time.time() except Exception as e: traceback.print_exc() self.ten_env.log_error(f"OpenAIV2VExtension on audio frame failed {e}") async def on_cmd(self, ten_env: AsyncTenEnv, cmd: Cmd) -> None: cmd_name = cmd.get_name() ten_env.log_debug("on_cmd name {}".format(cmd_name)) status = StatusCode.OK detail = "success" if cmd_name == CMD_IN_FLUSH: # Will only flush if it is client side vad await self._flush() await ten_env.send_cmd(Cmd.create(CMD_OUT_FLUSH)) ten_env.log_info("on flush") elif cmd_name == CMD_IN_ON_USER_JOINED: self.users_count += 1 # Send greeting when first user joined if self.users_count == 1: await self._greeting() elif cmd_name == CMD_IN_ON_USER_LEFT: self.users_count -= 1 else: # Register tool await super().on_cmd(ten_env, cmd) return cmd_result = CmdResult.create(status) cmd_result.set_property_string("detail", detail) await ten_env.return_result(cmd_result, cmd) # Not support for now async def on_data(self, ten_env: AsyncTenEnv, data: Data) -> None: pass async def _loop(self): def get_time_ms() -> int: current_time = datetime.now() return current_time.microsecond // 1000 try: start_time = time.time() await self.conn.connect() self.connect_times.append(time.time() - start_time) item_id = "" # For truncate response_id = "" content_index = 0 relative_start_ms = get_time_ms() flushed = set() self.ten_env.log_info("Client loop started") async for message in self.conn.listen(): try: # self.ten_env.log_info(f"Received message: {message.type}") match message: case SessionCreated(): self.ten_env.log_info( f"Session is created: {message.session}" ) self.session_id = message.session.id self.session = message.session await self._update_session() history = self.memory.get() for h in history: if h["role"] == "user": await self.conn.send_request( ItemCreate( item=UserMessageItemParam( content=[ { "type": ContentType.InputText, "text": h["content"], } ] ) ) ) elif h["role"] == "assistant": await self.conn.send_request( ItemCreate( item=AssistantMessageItemParam( content=[ { "type": ContentType.InputText, "text": h["content"], } ] ) ) ) self.ten_env.log_info(f"Finish send history {history}") self.memory.clear() if not self.connected: self.connected = True await self._greeting() case ItemInputAudioTranscriptionCompleted(): self.ten_env.log_info( f"On request transcript {message.transcript}" ) self._send_transcript(message.transcript, Role.User, True) self.memory.put( { "role": "user", "content": message.transcript, "id": message.item_id, } ) case ItemInputAudioTranscriptionFailed(): self.ten_env.log_warn( f"On request transcript failed {message.item_id} {message.error}" ) case ItemCreated(): self.ten_env.log_info(f"On item created {message.item}") case ResponseCreated(): response_id = message.response.id self.ten_env.log_info(f"On response created {response_id}") case ResponseDone(): msg_resp_id = message.response.id status = message.response.status if msg_resp_id == response_id: response_id = "" self.ten_env.log_info( f"On response done {msg_resp_id} {status} {message.response.usage}" ) if message.response.usage: pass # await self._update_usage(message.response.usage) case ResponseAudioTranscriptDelta(): self.ten_env.log_info( f"On response transcript delta {message.response_id} {message.output_index} {message.content_index} {message.delta}" ) if message.response_id in flushed: self.ten_env.log_warn( f"On flushed transcript delta {message.response_id} {message.output_index} {message.content_index} {message.delta}" ) continue self._send_transcript(message.delta, Role.Assistant, False) case ResponseTextDelta(): self.ten_env.log_info( f"On response text delta {message.response_id} {message.output_index} {message.content_index} {message.delta}" ) if message.response_id in flushed: self.ten_env.log_warn( f"On flushed text delta {message.response_id} {message.output_index} {message.content_index} {message.delta}" ) continue if item_id != message.item_id: item_id = message.item_id self.first_token_times.append( time.time() - self.input_end ) self._send_transcript(message.delta, Role.Assistant, False) case ResponseAudioTranscriptDone(): self.ten_env.log_info( f"On response transcript done {message.output_index} {message.content_index} {message.transcript}" ) if message.response_id in flushed: self.ten_env.log_warn( f"On flushed transcript done {message.response_id}" ) continue self.memory.put( { "role": "assistant", "content": message.transcript, "id": message.item_id, } ) self.transcript = "" self._send_transcript("", Role.Assistant, True) case ResponseTextDone(): self.ten_env.log_info( f"On response text done {message.output_index} {message.content_index} {message.text}" ) if message.response_id in flushed: self.ten_env.log_warn( f"On flushed text done {message.response_id}" ) continue self.completion_times.append(time.time() - self.input_end) self.transcript = "" self._send_transcript("", Role.Assistant, True) case ResponseOutputItemDone(): self.ten_env.log_info(f"Output item done {message.item}") case ResponseOutputItemAdded(): self.ten_env.log_info( f"Output item added {message.output_index} {message.item}" ) case ResponseAudioDelta(): if message.response_id in flushed: self.ten_env.log_warn( f"On flushed audio delta {message.response_id} {message.item_id} {message.content_index}" ) continue if item_id != message.item_id: item_id = message.item_id self.first_token_times.append( time.time() - self.input_end ) content_index = message.content_index await self._on_audio_delta(message.delta) case ResponseAudioDone(): self.completion_times.append(time.time() - self.input_end) case InputAudioBufferSpeechStarted(): self.ten_env.log_info( f"On server listening, in response {response_id}, last item {item_id}" ) # Tuncate the on-going audio stream end_ms = get_time_ms() - relative_start_ms if item_id: truncate = ItemTruncate( item_id=item_id, content_index=content_index, audio_end_ms=end_ms, ) await self.conn.send_request(truncate) if self.config.server_vad: await self._flush() if response_id and self.transcript: transcript = self.transcript + "[interrupted]" self._send_transcript(transcript, Role.Assistant, True) self.transcript = "" # memory leak, change to lru later flushed.add(response_id) item_id = "" case InputAudioBufferSpeechStopped(): # Only for server vad self.input_end = time.time() relative_start_ms = get_time_ms() - message.audio_end_ms self.ten_env.log_info( f"On server stop listening, {message.audio_end_ms}, relative {relative_start_ms}" ) case ResponseFunctionCallArgumentsDone(): tool_call_id = message.call_id name = message.name arguments = message.arguments self.ten_env.log_info(f"need to call func {name}") self.loop.create_task( self._handle_tool_call(tool_call_id, name, arguments) ) case ErrorMessage(): self.ten_env.log_error( f"Error message received: {message.error}" ) case _: self.ten_env.log_debug(f"Not handled message {message}") except Exception as e: traceback.print_exc() self.ten_env.log_error(f"Error processing message: {message} {e}") self.ten_env.log_info("Client loop finished") except Exception as e: traceback.print_exc() self.ten_env.log_error(f"Failed to handle loop {e}") # clear so that new session can be triggered self.connected = False self.remote_stream_id = 0 if not self.stopped: await self.conn.close() await asyncio.sleep(0.5) self.ten_env.log_info("Reconnect") self.conn = RealtimeApiConnection( ten_env=self.ten_env, base_uri=self.config.base_uri, path=self.config.path, api_key=self.config.api_key, model=self.config.model, vendor=self.config.vendor, ) self.loop.create_task(self._loop()) async def _on_memory_expired(self, message: dict) -> None: self.ten_env.log_info(f"Memory expired: {message}") item_id = message.get("item_id") if item_id: await self.conn.send_request(ItemDelete(item_id=item_id)) async def _on_memory_appended(self, message: dict) -> None: self.ten_env.log_info(f"Memory appended: {message}") if not self.config.enable_storage: return role = message.get("role") stream_id = self.remote_stream_id if role == Role.User else 0 try: d = Data.create("append") d.set_property_string("text", message.get("content")) d.set_property_string("role", role) d.set_property_int("stream_id", stream_id) asyncio.create_task(self.ten_env.send_data(d)) except Exception as e: self.ten_env.log_error(f"Error send append_context data {message} {e}") # Direction: IN async def _on_audio(self, buff: bytearray): self.buff += buff # Buffer audio if self.connected and len(self.buff) >= self.audio_len_threshold: await self.conn.send_audio_data(self.buff) self.buff = b"" async def _update_session(self) -> None: tools = [] def tool_dict(tool: LLMToolMetadata): t = { "type": "function", "name": tool.name, "description": tool.description, "parameters": { "type": "object", "properties": {}, "required": [], "additionalProperties": False, }, } for param in tool.parameters: t["parameters"]["properties"][param.name] = { "type": param.type, "description": param.description, } if param.required: t["parameters"]["required"].append(param.name) return t if self.available_tools: tool_prompt = "You have several tools that you can get help from:\n" for t in self.available_tools: tool_prompt += f"- ***{t.name}***: {t.description}" self.ctx["tools"] = tool_prompt tools = [tool_dict(t) for t in self.available_tools] prompt = self._replace(self.config.prompt) self.ten_env.log_info(f"update session {prompt} {tools}") su = SessionUpdate( session=SessionUpdateParams( instructions=prompt, model=self.config.model, tool_choice="auto" if self.available_tools else "none", tools=tools, ) ) if self.config.audio_out: su.session.voice = self.config.voice else: su.session.modalities = ["text"] if self.config.input_transcript: su.session.input_audio_transcription = InputAudioTranscription( model="whisper-1" ) await self.conn.send_request(su) async def on_tools_update(self, _: AsyncTenEnv, tool: LLMToolMetadata) -> None: """Called when a new tool is registered. Implement this method to process the new tool.""" self.ten_env.log_info(f"on tools update {tool}") # await self._update_session() def _replace(self, prompt: str) -> str: result = prompt for token, value in self.ctx.items(): result = result.replace("{" + token + "}", value) return result # Direction: OUT async def _on_audio_delta(self, delta: bytes) -> None: audio_data = base64.b64decode(delta) self.ten_env.log_debug( f"on_audio_delta audio_data len {len(audio_data)} samples {len(audio_data) // 2}" ) self._dump_audio_if_need(audio_data, Role.Assistant) f = AudioFrame.create("pcm_frame") f.set_sample_rate(self.config.sample_rate) f.set_bytes_per_sample(2) f.set_number_of_channels(1) f.set_data_fmt(AudioFrameDataFmt.INTERLEAVE) f.set_samples_per_channel(len(audio_data) // 2) f.alloc_buf(len(audio_data)) buff = f.lock_buf() buff[:] = audio_data f.unlock_buf(buff) await self.ten_env.send_audio_frame(f) def _send_transcript(self, content: str, role: Role, is_final: bool) -> None: def is_punctuation(char): if char in [",", ",", ".", "。", "?", "?", "!", "!"]: return True return False def parse_sentences(sentence_fragment, content): sentences = [] current_sentence = sentence_fragment for char in content: current_sentence += char if is_punctuation(char): # Check if the current sentence contains non-punctuation characters stripped_sentence = current_sentence if any(c.isalnum() for c in stripped_sentence): sentences.append(stripped_sentence) current_sentence = "" # Reset for the next sentence remain = current_sentence # Any remaining characters form the incomplete sentence return sentences, remain def send_data( ten_env: AsyncTenEnv, sentence: str, stream_id: int, role: str, is_final: bool, ): try: d = Data.create("text_data") d.set_property_string("text", sentence) d.set_property_bool("end_of_segment", is_final) d.set_property_string("role", role) d.set_property_int("stream_id", stream_id) ten_env.log_info( f"send transcript text [{sentence}] stream_id {stream_id} is_final {is_final} end_of_segment {is_final} role {role}" ) asyncio.create_task(ten_env.send_data(d)) except Exception as e: ten_env.log_error( f"Error send text data {role}: {sentence} {is_final} {e}" ) stream_id = self.remote_stream_id if role == Role.User else 0 try: if role == Role.Assistant and not is_final: sentences, self.transcript = parse_sentences(self.transcript, content) for s in sentences: send_data(self.ten_env, s, stream_id, role, is_final) else: send_data(self.ten_env, content, stream_id, role, is_final) except Exception as e: self.ten_env.log_error( f"Error send text data {role}: {content} {is_final} {e}" ) def _dump_audio_if_need(self, buf: bytearray, role: Role) -> None: if not self.config.dump: return with open("{}_{}.pcm".format(role, self.channel_name), "ab") as dump_file: dump_file.write(buf) async def _handle_tool_call( self, tool_call_id: str, name: str, arguments: str ) -> None: self.ten_env.log_info(f"_handle_tool_call {tool_call_id} {name} {arguments}") cmd: Cmd = Cmd.create(CMD_TOOL_CALL) cmd.set_property_string("name", name) cmd.set_property_from_json("arguments", arguments) [result, _] = await self.ten_env.send_cmd(cmd) tool_response = ItemCreate( item=FunctionCallOutputItemParam( call_id=tool_call_id, output='{"success":false}', ) ) if result.get_status_code() == StatusCode.OK: tool_result: LLMToolResult = json.loads( result.get_property_to_json(CMD_PROPERTY_RESULT) ) result_content = tool_result["content"] tool_response.item.output = json.dumps( self._convert_to_content_parts(result_content) ) self.ten_env.log_info(f"tool_result: {tool_call_id} {tool_result}") else: self.ten_env.log_error("Tool call failed") await self.conn.send_request(tool_response) await self.conn.send_request(ResponseCreate()) self.ten_env.log_info(f"_remote_tool_call finish {name} {arguments}") def _greeting_text(self) -> str: text = "Hi, there." if self.config.language == "zh-CN": text = "你好。" elif self.config.language == "ja-JP": text = "こんにちは" elif self.config.language == "ko-KR": text = "안녕하세요" return text def _convert_tool_params_to_dict(self, tool: LLMToolMetadata): json_dict = {"type": "object", "properties": {}, "required": []} for param in tool.parameters: json_dict["properties"][param.name] = { "type": param.type, "description": param.description, } if param.required: json_dict["required"].append(param.name) return json_dict def _convert_to_content_parts( self, content: Iterable[LLMChatCompletionContentPartParam] ): content_parts = [] if isinstance(content, str): content_parts.append({"type": "text", "text": content}) else: for part in content: # Only text content is supported currently for v2v model if part["type"] == "text": content_parts.append(part) return content_parts async def _greeting(self) -> None: if self.connected and self.users_count == 1: text = self._greeting_text() if self.config.greeting: text = "Say '" + self.config.greeting + "' to me." self.ten_env.log_info(f"send greeting {text}") await self.conn.send_request( ItemCreate( item=UserMessageItemParam( content=[{"type": ContentType.InputText, "text": text}] ) ) ) await self.conn.send_request(ResponseCreate()) async def _flush(self) -> None: try: c = Cmd.create("flush") await self.ten_env.send_cmd(c) except Exception: self.ten_env.log_error("Error flush") async def _update_usage(self, usage: dict) -> None: self.total_usage.completion_tokens += usage.get("output_tokens") or 0 self.total_usage.prompt_tokens += usage.get("input_tokens") or 0 self.total_usage.total_tokens += usage.get("total_tokens") or 0 if not self.total_usage.completion_tokens_details: self.total_usage.completion_tokens_details = LLMCompletionTokensDetails() if not self.total_usage.prompt_tokens_details: self.total_usage.prompt_tokens_details = LLMPromptTokensDetails() if usage.get("output_token_details"): self.total_usage.completion_tokens_details.accepted_prediction_tokens += ( usage["output_token_details"].get("text_tokens") ) self.total_usage.completion_tokens_details.audio_tokens += usage[ "output_token_details" ].get("audio_tokens") if usage.get("input_token_details:"): self.total_usage.prompt_tokens_details.audio_tokens += usage[ "input_token_details" ].get("audio_tokens") self.total_usage.prompt_tokens_details.cached_tokens += usage[ "input_token_details" ].get("cached_tokens") self.total_usage.prompt_tokens_details.text_tokens += usage[ "input_token_details" ].get("text_tokens") self.ten_env.log_info(f"total usage: {self.total_usage}") data = Data.create("llm_stat") data.set_property_from_json("usage", json.dumps(self.total_usage.model_dump())) if self.connect_times and self.completion_times and self.first_token_times: data.set_property_from_json( "latency", json.dumps( { "connection_latency_95": np.percentile(self.connect_times, 95), "completion_latency_95": np.percentile( self.completion_times, 95 ), "first_token_latency_95": np.percentile( self.first_token_times, 95 ), "connection_latency_99": np.percentile(self.connect_times, 99), "completion_latency_99": np.percentile( self.completion_times, 99 ), "first_token_latency_99": np.percentile( self.first_token_times, 99 ), } ), ) asyncio.create_task(self.ten_env.send_data(data)) async def on_call_chat_completion(self, async_ten_env, **kargs): raise NotImplementedError async def on_data_chat_completion(self, async_ten_env, **kargs): raise NotImplementedError