Spaces:
Runtime error
Runtime error
import asyncio | |
import base64 | |
import json | |
import os | |
from typing import Literal | |
import gradio as gr | |
import numpy as np | |
from fastrtc import AsyncStreamHandler, WebRTC, wait_for_item | |
from google import genai | |
from google.cloud import texttospeech | |
from google.genai.types import FunctionDeclaration, LiveConnectConfig, Tool | |
import helpers.datastore as datastore | |
from helpers.prompts import load_prompt | |
from tools import FUNCTION_MAP, TOOLS | |
with open("questions.json", "r") as f: | |
questions_dict = json.load(f) | |
datastore.DATA_STORE["questions"] = questions_dict | |
SYSTEM_PROMPT = load_prompt( | |
"src/prompts/default_prompt.jinja2", questions=questions_dict | |
) | |
class TTSConfig: | |
def __init__(self): | |
self.client = texttospeech.TextToSpeechClient() | |
self.voice = texttospeech.VoiceSelectionParams( | |
name="en-US-Chirp3-HD-Charon", language_code="en-US" | |
) | |
self.audio_config = texttospeech.AudioConfig( | |
audio_encoding=texttospeech.AudioEncoding.LINEAR16 | |
) | |
class AsyncGeminiHandler(AsyncStreamHandler): | |
"""Simple Async Gemini Handler""" | |
def __init__( | |
self, | |
expected_layout: Literal["mono"] = "mono", | |
output_sample_rate: int = 24000, | |
output_frame_size: int = 480, | |
) -> None: | |
super().__init__( | |
expected_layout, | |
output_sample_rate, | |
output_frame_size, | |
input_sample_rate=16000, | |
) | |
self.input_queue: asyncio.Queue = asyncio.Queue() | |
self.output_queue: asyncio.Queue = asyncio.Queue() | |
self.text_queue: asyncio.Queue = asyncio.Queue() | |
self.quit: asyncio.Event = asyncio.Event() | |
self.chunk_size = 1024 | |
self.tts_config: TTSConfig | None = TTSConfig() | |
self.text_buffer = "" | |
def copy(self) -> "AsyncGeminiHandler": | |
return AsyncGeminiHandler( | |
expected_layout="mono", | |
output_sample_rate=self.output_sample_rate, | |
output_frame_size=self.output_frame_size, | |
) | |
def _encode_audio(self, data: np.ndarray) -> str: | |
"""Encode Audio data to send to the server""" | |
return base64.b64encode(data.tobytes()).decode("UTF-8") | |
async def receive(self, frame: tuple[int, np.ndarray]) -> None: | |
"""Receives and processes audio frames asynchronously.""" | |
_, array = frame | |
array = array.squeeze() | |
audio_message = self._encode_audio(array) | |
self.input_queue.put_nowait(audio_message) | |
async def emit(self) -> tuple[int, np.ndarray] | None: | |
"""Asynchronously emits items from the output queue.""" | |
return await wait_for_item(self.output_queue) | |
async def start_up(self) -> None: | |
"""Initialize and start the voice agent application. | |
This asynchronous method sets up the Gemini API client, configures the live connection, | |
and starts three concurrent tasks for receiving, processing and sending information. | |
Returns: | |
None | |
Raises: | |
ValueError: If GEMINI_API_KEY is not provided when required. | |
""" | |
if not os.getenv("GOOGLE_GENAI_USE_VERTEXAI") == "True": | |
api_key = os.getenv("GEMINI_API_KEY") | |
if not api_key: | |
raise ValueError("API Key is required") | |
client = genai.Client( | |
api_key=api_key, | |
http_options={"api_version": "v1alpha"}, | |
) | |
else: | |
client = genai.Client(http_options={"api_version": "v1beta1"}) | |
config = LiveConnectConfig( | |
system_instruction={ | |
"parts": [{"text": SYSTEM_PROMPT}], | |
"role": "user", | |
}, | |
tools=[ | |
Tool( | |
function_declarations=[ | |
FunctionDeclaration(**tool) for tool in TOOLS | |
] | |
) | |
], | |
response_modalities=["AUDIO"], | |
) | |
async with ( | |
client.aio.live.connect( | |
model="gemini-2.0-flash-exp", config=config | |
) as session, # setup the live connection session (websocket) | |
asyncio.TaskGroup() as tg, # create a task group to run multiple tasks concurrently | |
): | |
self.session = session | |
# these tasks will run concurrently and continuously | |
[ | |
tg.create_task(self.process()), | |
tg.create_task(self.send_realtime()), | |
tg.create_task(self.tts()), | |
] | |
async def process(self) -> None: | |
"""Process responses from the session in a continuous loop. | |
This asynchronous method handles different types of responses from the session: | |
- Audio data: Processes and queues audio data with the specified sample rate | |
- Text data: Accumulates received text in a buffer | |
- Tool calls: Executes registered functions and sends their responses back | |
- Server content: Handles turn completion and stores conversation history | |
The method runs indefinitely until interrupted, handling any exceptions that occur | |
during processing by logging them and continuing after a brief delay. | |
Returns: | |
None | |
Raises: | |
Exception: Any exceptions during processing are caught and logged | |
""" | |
while True: | |
try: | |
turn = self.session.receive() | |
async for response in turn: | |
if data := response.data: | |
# audio data | |
array = np.frombuffer(data, dtype=np.int16) | |
self.output_queue.put_nowait((self.output_sample_rate, array)) | |
continue | |
if text := response.text: | |
# text data | |
print(f"Received text: {text}") | |
self.text_buffer += text | |
if response.tool_call is not None: | |
# function calling | |
for tool in response.tool_call.function_calls: | |
try: | |
tool_response = FUNCTION_MAP[tool.name](**tool.args) | |
print(f"Calling tool: {tool.name}") | |
print(f"Tool response: {tool_response}") | |
await self.session.send( | |
input=tool_response, end_of_turn=True | |
) | |
await asyncio.sleep(0.1) | |
except Exception as e: | |
print(f"Error in tool call: {e}") | |
await asyncio.sleep(0.1) | |
if sc := response.server_content: | |
# check if bot's turn is complete | |
if sc.turn_complete and self.text_buffer: | |
self.text_queue.put_nowait(self.text_buffer) | |
FUNCTION_MAP["store_input"]( | |
role="bot", input=self.text_buffer | |
) | |
self.text_buffer = "" | |
except Exception as e: | |
print(f"Error in processing: {e}") | |
await asyncio.sleep(0.1) | |
async def send_realtime(self) -> None: | |
"""Send real-time audio data to model. | |
This method continuously reads audio data from an input queue and sends it to a model | |
session in real-time. It runs in an infinite loop until interrupted. | |
The audio data is sent with mime type 'audio/pcm'. If an error occurs during sending, | |
it will be printed and the method will sleep briefly before retrying. | |
Returns: | |
None | |
Raises: | |
Exception: Any exceptions during queue access or session sending will be caught and logged. | |
""" | |
while True: | |
try: | |
data = await self.input_queue.get() | |
msg = {"data": data, "mime_type": "audio/pcm"} | |
await self.session.send(input=msg) | |
except Exception as e: | |
print(f"Error in real-time sending: {e}") | |
await asyncio.sleep(0.1) | |
async def tts(self) -> None: | |
while True: | |
try: | |
text = await self.text_queue.get() | |
# Get response in a single request | |
if text: | |
response = self.tts_config.client.synthesize_speech( | |
input=texttospeech.SynthesisInput(text=text), | |
voice=self.tts_config.voice, | |
audio_config=self.tts_config.audio_config, | |
) | |
array = np.frombuffer(response.audio_content, dtype=np.int16) | |
self.output_queue.put_nowait((self.output_sample_rate, array)) | |
except Exception as e: | |
print(f"Error in TTS: {e}") | |
await asyncio.sleep(0.1) | |
def shutdown(self) -> None: | |
self.quit.set() | |
# Main Gradio Interface | |
def registry(*args, **kwargs): | |
"""Sets up and returns the Gradio interface.""" | |
interface = gr.Blocks() | |
with interface: | |
with gr.Tabs(): | |
with gr.TabItem("Voice Chat"): | |
gr.HTML( | |
""" | |
<div style='text-align: left'> | |
<h1>ML6 Voice Demo</h1> | |
</div> | |
""" | |
) | |
gemini_handler = AsyncGeminiHandler() | |
with gr.Row(): | |
audio = WebRTC( | |
label="Voice Chat", | |
modality="audio", | |
mode="send-receive", | |
) | |
# Add display components for questions and answers | |
with gr.Row(): | |
with gr.Column(): | |
gr.JSON( | |
label="Questions", | |
value=datastore.DATA_STORE["questions"], | |
) | |
with gr.Column(): | |
gr.JSON( | |
label="Answers", | |
value=lambda: datastore.DATA_STORE["answers"], | |
every=1, | |
) | |
audio.stream( | |
gemini_handler, | |
inputs=[audio], | |
outputs=[audio], | |
time_limit=600, | |
concurrency_limit=10, | |
) | |
return interface | |
# Launch the Gradio interface | |
gr.load( | |
name="demo", | |
src=registry, | |
).launch() | |