3v324v23's picture
Зафиксирована рабочая версия TEN-Agent для HuggingFace Space
87337b1
#
#
# Agora Real Time Engagement
# Created by Wei Hu in 2024-08.
# Copyright (c) 2024 Agora IO. All rights reserved.
#
#
import asyncio
import base64
from enum import Enum
import json
import traceback
import time
from google import genai
import numpy as np
from typing import Iterable, cast
import websockets
from ten import (
AudioFrame,
AsyncTenEnv,
Cmd,
StatusCode,
CmdResult,
Data,
)
from ten.audio_frame import AudioFrameDataFmt
from ten_ai_base.const import CMD_PROPERTY_RESULT, CMD_TOOL_CALL
from dataclasses import dataclass
from ten_ai_base.config import BaseConfig
from ten_ai_base.chat_memory import ChatMemory
from ten_ai_base.usage import (
LLMUsage,
LLMCompletionTokensDetails,
LLMPromptTokensDetails,
)
from ten_ai_base.types import (
LLMToolMetadata,
LLMToolResult,
LLMChatCompletionContentPartParam,
TTSPcmOptions,
)
from ten_ai_base.llm import AsyncLLMBaseExtension
from google.genai.types import (
LiveServerMessage,
LiveConnectConfig,
LiveConnectConfigDict,
GenerationConfig,
Content,
Part,
Tool,
FunctionDeclaration,
Schema,
LiveClientToolResponse,
FunctionCall,
FunctionResponse,
SpeechConfig,
VoiceConfig,
PrebuiltVoiceConfig,
)
from google.genai.live import AsyncSession
from PIL import Image
from io import BytesIO
from base64 import b64encode
import urllib.parse
import google.genai._api_client
google.genai._api_client.urllib = urllib # pylint: disable=protected-access
CMD_IN_FLUSH = "flush"
CMD_IN_ON_USER_JOINED = "on_user_joined"
CMD_IN_ON_USER_LEFT = "on_user_left"
CMD_OUT_FLUSH = "flush"
class Role(str, Enum):
User = "user"
Assistant = "assistant"
def rgb2base64jpeg(rgb_data, width, height):
# Convert the RGB image to a PIL Image
pil_image = Image.frombytes("RGBA", (width, height), bytes(rgb_data))
pil_image = pil_image.convert("RGB")
# Resize the image while maintaining its aspect ratio
pil_image = resize_image_keep_aspect(pil_image, 512)
# Save the image to a BytesIO object in JPEG format
buffered = BytesIO()
pil_image.save(buffered, format="JPEG")
# pil_image.save("test.jpg", format="JPEG")
# Get the byte data of the JPEG image
jpeg_image_data = buffered.getvalue()
# Convert the JPEG byte data to a Base64 encoded string
base64_encoded_image = b64encode(jpeg_image_data).decode("utf-8")
# Create the data URL
# mime_type = "image/jpeg"
return base64_encoded_image
def resize_image_keep_aspect(image, max_size=512):
"""
Resize an image while maintaining its aspect ratio, ensuring the larger dimension is max_size.
If both dimensions are smaller than max_size, the image is not resized.
:param image: A PIL Image object
:param max_size: The maximum size for the larger dimension (width or height)
:return: A PIL Image object (resized or original)
"""
# Get current width and height
width, height = image.size
# If both dimensions are already smaller than max_size, return the original image
if width <= max_size and height <= max_size:
return image
# Calculate the aspect ratio
aspect_ratio = width / height
# Determine the new dimensions
if width > height:
new_width = max_size
new_height = int(max_size / aspect_ratio)
else:
new_height = max_size
new_width = int(max_size * aspect_ratio)
# Resize the image with the new dimensions
resized_image = image.resize((new_width, new_height))
return resized_image
@dataclass
class GeminiRealtimeConfig(BaseConfig):
base_uri: str = "generativelanguage.googleapis.com"
api_key: str = ""
api_version: str = "v1alpha"
model: str = "gemini-2.0-flash-exp"
language: str = "en-US"
prompt: str = ""
temperature: float = 0.5
max_tokens: int = 1024
voice: str = "Puck"
server_vad: bool = True
audio_out: bool = True
input_transcript: bool = True
sample_rate: int = 24000
stream_id: int = 0
dump: bool = False
greeting: str = ""
def build_ctx(self) -> dict:
return {
"language": self.language,
"model": self.model,
}
class GeminiRealtimeExtension(AsyncLLMBaseExtension):
def __init__(self, name):
super().__init__(name)
self.config: GeminiRealtimeConfig = None
self.stopped: bool = False
self.connected: bool = False
self.buffer: bytearray = b""
self.memory: ChatMemory = None
self.total_usage: LLMUsage = LLMUsage()
self.users_count = 0
self.stream_id: int = 0
self.remote_stream_id: int = 0
self.channel_name: str = ""
self.audio_len_threshold: int = 5120
self.completion_times = []
self.connect_times = []
self.first_token_times = []
self.buff: bytearray = b""
self.transcript: str = ""
self.ctx: dict = {}
self.input_end = time.time()
self.client = None
self.session: AsyncSession = None
self.leftover_bytes = b""
self.video_task = None
self.image_queue = asyncio.Queue()
self.video_buff: str = ""
self.loop = None
self.ten_env = None
async def on_init(self, ten_env: AsyncTenEnv) -> None:
await super().on_init(ten_env)
ten_env.log_debug("on_init")
async def on_start(self, ten_env: AsyncTenEnv) -> None:
await super().on_start(ten_env)
self.ten_env = ten_env
ten_env.log_debug("on_start")
self.loop = asyncio.get_event_loop()
self.config = await GeminiRealtimeConfig.create_async(ten_env=ten_env)
ten_env.log_info(f"config: {self.config}")
if not self.config.api_key:
ten_env.log_error("api_key is required")
return
try:
self.ctx = self.config.build_ctx()
self.ctx["greeting"] = self.config.greeting
self.client = genai.Client(
api_key=self.config.api_key,
http_options={
"api_version": self.config.api_version,
"url": self.config.base_uri,
},
)
self.loop.create_task(self._loop(ten_env))
self.loop.create_task(self._on_video(ten_env))
# self.loop.create_task(self._loop())
except Exception as e:
traceback.print_exc()
self.ten_env.log_error(f"Failed to init client {e}")
async def _loop(self, ten_env: AsyncTenEnv) -> None:
while not self.stopped:
await asyncio.sleep(1)
try:
config: LiveConnectConfig = self._get_session_config()
ten_env.log_info("Start listen")
async with self.client.aio.live.connect(
model=self.config.model, config=config
) as session:
ten_env.log_info("Connected")
session = cast(AsyncSession, session)
self.session = session
self.connected = True
await self._greeting()
while True:
try:
async for response in session.receive():
response = cast(LiveServerMessage, response)
# ten_env.log_info(f"Received response")
try:
if response.server_content:
if response.server_content.interrupted:
ten_env.log_info("Interrupted")
await self._flush()
continue
elif (
not response.server_content.turn_complete
and response.server_content.model_turn
):
for (
part
) in (
response.server_content.model_turn.parts
):
await self.send_audio_out(
ten_env,
part.inline_data.data,
sample_rate=24000,
bytes_per_sample=2,
number_of_channels=1,
)
elif response.server_content.turn_complete:
ten_env.log_info("Turn complete")
elif response.setup_complete:
ten_env.log_info("Setup complete")
elif response.tool_call:
func_calls = response.tool_call.function_calls
self.loop.create_task(
self._handle_tool_call(func_calls)
)
except Exception:
traceback.print_exc()
ten_env.log_error("Failed to handle response")
await self._flush()
ten_env.log_info("Finish listen")
except websockets.exceptions.ConnectionClosedOK:
ten_env.log_info("Connection closed")
break
except Exception as e:
self.ten_env.log_error(f"Failed to handle loop {e}")
async def send_audio_out(
self, ten_env: AsyncTenEnv, audio_data: bytes, **args: TTSPcmOptions
) -> None:
"""End sending audio out."""
sample_rate = args.get("sample_rate", 24000)
bytes_per_sample = args.get("bytes_per_sample", 2)
number_of_channels = args.get("number_of_channels", 1)
try:
# Combine leftover bytes with new audio data
combined_data = self.leftover_bytes + audio_data
# Check if combined_data length is odd
if len(combined_data) % (bytes_per_sample * number_of_channels) != 0:
# Save the last incomplete frame
valid_length = len(combined_data) - (
len(combined_data) % (bytes_per_sample * number_of_channels)
)
self.leftover_bytes = combined_data[valid_length:]
combined_data = combined_data[:valid_length]
else:
self.leftover_bytes = b""
if combined_data:
f = AudioFrame.create("pcm_frame")
f.set_sample_rate(sample_rate)
f.set_bytes_per_sample(bytes_per_sample)
f.set_number_of_channels(number_of_channels)
f.set_data_fmt(AudioFrameDataFmt.INTERLEAVE)
f.set_samples_per_channel(
len(combined_data) // (bytes_per_sample * number_of_channels)
)
f.alloc_buf(len(combined_data))
buff = f.lock_buf()
buff[:] = combined_data
f.unlock_buf(buff)
await ten_env.send_audio_frame(f)
except Exception:
pass
# ten_env.log_error(f"error send audio frame, {traceback.format_exc()}")
async def on_stop(self, ten_env: AsyncTenEnv) -> None:
await super().on_stop(ten_env)
ten_env.log_info("on_stop")
self.stopped = True
if self.session:
await self.session.close()
async def on_audio_frame(
self, ten_env: AsyncTenEnv, audio_frame: AudioFrame
) -> None:
await super().on_audio_frame(ten_env, audio_frame)
try:
stream_id = audio_frame.get_property_int("stream_id")
if self.channel_name == "":
self.channel_name = audio_frame.get_property_string("channel")
if self.remote_stream_id == 0:
self.remote_stream_id = stream_id
frame_buf = audio_frame.get_buf()
self._dump_audio_if_need(frame_buf, Role.User)
await self._on_audio(frame_buf)
if not self.config.server_vad:
self.input_end = time.time()
except Exception as e:
traceback.print_exc()
self.ten_env.log_error(f"on audio frame failed {e}")
async def on_cmd(self, ten_env: AsyncTenEnv, cmd: Cmd) -> None:
cmd_name = cmd.get_name()
ten_env.log_debug(f"on_cmd name {cmd_name}")
status = StatusCode.OK
detail = "success"
if cmd_name == CMD_IN_FLUSH:
# Will only flush if it is client side vad
await self._flush()
await ten_env.send_cmd(Cmd.create(CMD_OUT_FLUSH))
ten_env.log_info("on flush")
elif cmd_name == CMD_IN_ON_USER_JOINED:
self.users_count += 1
# Send greeting when first user joined
if self.users_count == 1:
await self._greeting()
elif cmd_name == CMD_IN_ON_USER_LEFT:
self.users_count -= 1
else:
# Register tool
await super().on_cmd(ten_env, cmd)
return
cmd_result = CmdResult.create(status)
cmd_result.set_property_string("detail", detail)
await ten_env.return_result(cmd_result, cmd)
# Not support for now
async def on_data(self, ten_env: AsyncTenEnv, data: Data) -> None:
pass
async def on_video_frame(self, async_ten_env, video_frame):
await super().on_video_frame(async_ten_env, video_frame)
image_data = video_frame.get_buf()
image_width = video_frame.get_width()
image_height = video_frame.get_height()
await self.image_queue.put([image_data, image_width, image_height])
async def _on_video(self, _: AsyncTenEnv):
while True:
# Process the first frame from the queue
[image_data, image_width, image_height] = await self.image_queue.get()
self.video_buff = rgb2base64jpeg(image_data, image_width, image_height)
media_chunks = [
{
"data": self.video_buff,
"mime_type": "image/jpeg",
}
]
try:
if self.connected:
# ten_env.log_info(f"send image")
await self.session.send(media_chunks)
except Exception as e:
self.ten_env.log_error(f"Failed to send image {e}")
# Skip remaining frames for the second
while not self.image_queue.empty():
await self.image_queue.get()
# Wait for 1 second before processing the next frame
await asyncio.sleep(1)
# Direction: IN
async def _on_audio(self, buff: bytearray):
self.buff += buff
# Buffer audio
if self.connected and len(self.buff) >= self.audio_len_threshold:
# await self.conn.send_audio_data(self.buff)
try:
media_chunks = [
{
"data": base64.b64encode(self.buff).decode(),
"mime_type": "audio/pcm",
}
]
# await self.session.send(LiveClientRealtimeInput(media_chunks=media_chunks))
await self.session.send(media_chunks)
self.buff = b""
except Exception as e:
# pass
self.ten_env.log_error(f"Failed to send audio {e}")
def _get_session_config(self) -> LiveConnectConfigDict:
def tool_dict(tool: LLMToolMetadata):
required = []
properties: dict[str, "Schema"] = {}
for param in tool.parameters:
properties[param.name] = Schema(
type=param.type.upper(), description=param.description
)
if param.required:
required.append(param.name)
t = Tool(
function_declarations=[
FunctionDeclaration(
name=tool.name,
description=tool.description,
parameters=Schema(
type="OBJECT", properties=properties, required=required
),
)
]
)
return t
tools = (
[tool_dict(t) for t in self.available_tools]
if len(self.available_tools) > 0
else []
)
tools.append(Tool(google_search={}))
tools.append(Tool(code_execution={}))
config = LiveConnectConfig(
response_modalities=["AUDIO"],
system_instruction=Content(parts=[Part(text=self.config.prompt)]),
tools=tools,
# voice is currently not working
speech_config=SpeechConfig(
voice_config=VoiceConfig(
prebuilt_voice_config=PrebuiltVoiceConfig(
voice_name=self.config.voice
)
)
),
generation_config=GenerationConfig(
temperature=self.config.temperature,
max_output_tokens=self.config.max_tokens,
),
)
return config
async def on_tools_update(
self, ten_env: AsyncTenEnv, tool: LLMToolMetadata
) -> None:
"""Called when a new tool is registered. Implement this method to process the new tool."""
ten_env.log_info(f"on tools update {tool}")
# await self._update_session()
def _replace(self, prompt: str) -> str:
result = prompt
for token, value in self.ctx.items():
result = result.replace("{" + token + "}", value)
return result
def _send_transcript(self, content: str, role: Role, is_final: bool) -> None:
def is_punctuation(char):
if char in [",", ",", ".", "。", "?", "?", "!", "!"]:
return True
return False
def parse_sentences(sentence_fragment, content):
sentences = []
current_sentence = sentence_fragment
for char in content:
current_sentence += char
if is_punctuation(char):
# Check if the current sentence contains non-punctuation characters
stripped_sentence = current_sentence
if any(c.isalnum() for c in stripped_sentence):
sentences.append(stripped_sentence)
current_sentence = "" # Reset for the next sentence
remain = current_sentence # Any remaining characters form the incomplete sentence
return sentences, remain
def send_data(
ten_env: AsyncTenEnv,
sentence: str,
stream_id: int,
role: str,
is_final: bool,
):
try:
d = Data.create("text_data")
d.set_property_string("text", sentence)
d.set_property_bool("end_of_segment", is_final)
d.set_property_string("role", role)
d.set_property_int("stream_id", stream_id)
ten_env.log_info(
f"send transcript text [{sentence}] stream_id {stream_id} is_final {is_final} end_of_segment {is_final} role {role}"
)
asyncio.create_task(ten_env.send_data(d))
except Exception as e:
ten_env.log_error(
f"Error send text data {role}: {sentence} {is_final} {e}"
)
stream_id = self.remote_stream_id if role == Role.User else 0
try:
if role == Role.Assistant and not is_final:
sentences, self.transcript = parse_sentences(self.transcript, content)
for s in sentences:
asyncio.create_task(
send_data(self.ten_env, s, stream_id, role, is_final)
)
else:
asyncio.create_task(
send_data(self.ten_env, content, stream_id, role, is_final)
)
except Exception as e:
self.ten_env.log_error(
f"Error send text data {role}: {content} {is_final} {e}"
)
def _dump_audio_if_need(self, buf: bytearray, role: Role) -> None:
if not self.config.dump:
return
with open("{}_{}.pcm".format(role, self.channel_name), "ab") as dump_file:
dump_file.write(buf)
async def _handle_tool_call(self, func_calls: list[FunctionCall]) -> None:
function_responses = []
for call in func_calls:
tool_call_id = call.id
name = call.name
arguments = call.args
self.ten_env.log_info(
f"_handle_tool_call {tool_call_id} {name} {arguments}"
)
cmd: Cmd = Cmd.create(CMD_TOOL_CALL)
cmd.set_property_string("name", name)
cmd.set_property_from_json("arguments", json.dumps(arguments))
[result, _] = await self.ten_env.send_cmd(cmd)
func_response = FunctionResponse(
id=tool_call_id, name=name, response={"error": "Failed to call tool"}
)
if result.get_status_code() == StatusCode.OK:
tool_result: LLMToolResult = json.loads(
result.get_property_to_json(CMD_PROPERTY_RESULT)
)
result_content = tool_result["content"]
func_response = FunctionResponse(
id=tool_call_id, name=name, response={"output": result_content}
)
self.ten_env.log_info(f"tool_result: {tool_call_id} {tool_result}")
else:
self.ten_env.log_error("Tool call failed")
function_responses.append(func_response)
# await self.conn.send_request(tool_response)
# await self.conn.send_request(ResponseCreate())
self.ten_env.log_info(f"_remote_tool_call finish {name} {arguments}")
try:
self.ten_env.log_info(f"send tool response {function_responses}")
await self.session.send(
LiveClientToolResponse(function_responses=function_responses)
)
except Exception as e:
self.ten_env.log_error(f"Failed to send tool response {e}")
def _greeting_text(self) -> str:
text = "Hi, there."
if self.config.language == "zh-CN":
text = "你好。"
elif self.config.language == "ja-JP":
text = "こんにちは"
elif self.config.language == "ko-KR":
text = "안녕하세요"
return text
def _convert_tool_params_to_dict(self, tool: LLMToolMetadata):
json_dict = {"type": "object", "properties": {}, "required": []}
for param in tool.parameters:
json_dict["properties"][param.name] = {
"type": param.type,
"description": param.description,
}
if param.required:
json_dict["required"].append(param.name)
return json_dict
def _convert_to_content_parts(
self, content: Iterable[LLMChatCompletionContentPartParam]
):
content_parts = []
if isinstance(content, str):
content_parts.append({"type": "text", "text": content})
else:
for part in content:
# Only text content is supported currently for v2v model
if part["type"] == "text":
content_parts.append(part)
return content_parts
async def _greeting(self) -> None:
if self.connected and self.users_count == 1:
text = self._greeting_text()
if self.config.greeting:
text = "Say '" + self.config.greeting + "' to me."
self.ten_env.log_info(f"send greeting {text}")
await self.session.send(text, end_of_turn=True)
async def _flush(self) -> None:
try:
c = Cmd.create("flush")
await self.ten_env.send_cmd(c)
except Exception:
self.ten_env.log_error("Error flush")
async def _update_usage(self, usage: dict) -> None:
self.total_usage.completion_tokens += usage.get("output_tokens")
self.total_usage.prompt_tokens += usage.get("input_tokens")
self.total_usage.total_tokens += usage.get("total_tokens")
if not self.total_usage.completion_tokens_details:
self.total_usage.completion_tokens_details = LLMCompletionTokensDetails()
if not self.total_usage.prompt_tokens_details:
self.total_usage.prompt_tokens_details = LLMPromptTokensDetails()
if usage.get("output_token_details"):
self.total_usage.completion_tokens_details.accepted_prediction_tokens += (
usage["output_token_details"].get("text_tokens")
)
self.total_usage.completion_tokens_details.audio_tokens += usage[
"output_token_details"
].get("audio_tokens")
if usage.get("input_token_details:"):
self.total_usage.prompt_tokens_details.audio_tokens += usage[
"input_token_details"
].get("audio_tokens")
self.total_usage.prompt_tokens_details.cached_tokens += usage[
"input_token_details"
].get("cached_tokens")
self.total_usage.prompt_tokens_details.text_tokens += usage[
"input_token_details"
].get("text_tokens")
self.ten_env.log_info(f"total usage: {self.total_usage}")
data = Data.create("llm_stat")
data.set_property_from_json("usage", json.dumps(self.total_usage.model_dump()))
if self.connect_times and self.completion_times and self.first_token_times:
data.set_property_from_json(
"latency",
json.dumps(
{
"connection_latency_95": np.percentile(self.connect_times, 95),
"completion_latency_95": np.percentile(
self.completion_times, 95
),
"first_token_latency_95": np.percentile(
self.first_token_times, 95
),
"connection_latency_99": np.percentile(self.connect_times, 99),
"completion_latency_99": np.percentile(
self.completion_times, 99
),
"first_token_latency_99": np.percentile(
self.first_token_times, 99
),
}
),
)
asyncio.create_task(self.ten_env.send_data(data))
async def on_call_chat_completion(self, async_ten_env, **kargs):
raise NotImplementedError
async def on_data_chat_completion(self, async_ten_env, **kargs):
raise NotImplementedError