3v324v23's picture
Зафиксирована рабочая версия TEN-Agent для HuggingFace Space
87337b1
import asyncio
import base64
import json
import os
import aiohttp
from ten import AsyncTenEnv
from typing import Any, AsyncGenerator
from .struct import InputAudioBufferAppend, ClientToServerMessage, ServerToClientMessage, parse_server_message, to_json
def smart_str(s: str, max_field_len: int = 128) -> str:
"""parse string as json, truncate data field to 128 characters, reserialize"""
try:
data = json.loads(s)
if "delta" in data:
key = "delta"
elif "audio" in data:
key = "audio"
else:
return s
if len(data[key]) > max_field_len:
data[key] = data[key][:max_field_len] + "..."
return json.dumps(data)
except json.JSONDecodeError:
return s
class RealtimeApiConnection:
def __init__(
self,
ten_env: AsyncTenEnv,
base_uri: str,
api_key: str | None = None,
path: str = "/v1/realtime",
verbose: bool = False
):
self.ten_env = ten_env
self.url = f"{base_uri}{path}"
# if not self.vendor and "model=" not in self.url:
# self.url += f"?model={model}"
self.api_key = api_key or os.environ.get("GLM_API_KEY")
self.websocket: aiohttp.ClientWebSocketResponse | None = None
self.verbose = verbose
self.session = aiohttp.ClientSession()
async def __aenter__(self) -> "RealtimeApiConnection":
await self.connect()
return self
async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> bool:
await self.close()
return False
async def connect(self):
headers = {}
headers = {"Authorization": "Bearer " + self.api_key}
self.websocket = await self.session.ws_connect(
url=self.url,
# auth=auth,
headers=headers,
)
async def send_audio_data(self, audio_data: bytes):
"""audio_data is assumed to be pcm16 24kHz mono little-endian"""
base64_audio_data = base64.b64encode(audio_data).decode("utf-8")
message = InputAudioBufferAppend(audio=base64_audio_data)
await self.send_request(message)
async def send_request(self, message: ClientToServerMessage):
assert self.websocket is not None
message_str = to_json(message)
if self.verbose:
self.ten_env.log_info(f"-> {smart_str(message_str)}")
await self.websocket.send_str(message_str)
async def listen(self) -> AsyncGenerator[ServerToClientMessage, None]:
assert self.websocket is not None
if self.verbose:
self.ten_env.log_info("Listening for realtimeapi messages")
try:
async for msg in self.websocket:
if msg.type == aiohttp.WSMsgType.TEXT:
if self.verbose:
self.ten_env.log_info(f"<- {smart_str(msg.data)}")
yield self.handle_server_message(msg.data)
elif msg.type == aiohttp.WSMsgType.ERROR:
self.ten_env.log_error("Error during receive: %s", self.websocket.exception())
break
except asyncio.CancelledError:
self.ten_env.log_info("Receive messages task cancelled")
def handle_server_message(self, message: str) -> ServerToClientMessage:
try:
return parse_server_message(message)
except Exception as e:
self.ten_env.log_info(f"Error handling message {message} {e}")
async def close(self):
# Close the websocket connection if it exists
if self.websocket:
await self.websocket.close()
self.websocket = None