Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
# Copyright 2025 Google LLC | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
## Setup | |
The gradio-webrtc install fails unless you have ffmpeg@6, on mac: | |
``` | |
brew uninstall ffmpeg | |
brew install ffmpeg@6 | |
brew link ffmpeg@6 | |
``` | |
Create a virtual python environment, then install the dependencies for this script: | |
``` | |
pip install websockets numpy gradio-webrtc "gradio>=5.9.1" | |
``` | |
If installation fails it may be | |
Before running this script, ensure the `GOOGLE_API_KEY` environment | |
``` | |
$ export GOOGLE_API_KEY ='add your key here' | |
``` | |
You can get an api-key from Google AI Studio (https://aistudio.google.com/apikey) | |
## Run | |
To run the script: | |
``` | |
python gemini_gradio_audio.py | |
``` | |
On the gradio page (http://127.0.0.1:7860/) click record, and talk, gemini will reply. But note that interruptions | |
don't work. | |
""" | |
import asyncio | |
import json | |
import os | |
from typing import Literal | |
import base64 | |
import gradio as gr | |
import numpy as np | |
from fastrtc import ( | |
AsyncStreamHandler, | |
WebRTC, | |
wait_for_item, | |
) | |
from jinja2 import Template | |
from google import genai | |
from google.genai.types import LiveConnectConfig, Tool, FunctionDeclaration | |
from google.cloud import texttospeech | |
from tools import FUNCTION_MAP, TOOLS | |
with open("questions.json", "r") as f: | |
questions_dict = json.load(f) | |
with open("src/prompts/default_prompt.jinja2") as f: | |
template_str = f.read() | |
template = Template(template_str) | |
system_prompt = template.render(questions=json.dumps(questions_dict, indent=4)) | |
class TTSConfig: | |
def __init__(self): | |
self.client = texttospeech.TextToSpeechClient() | |
self.voice = texttospeech.VoiceSelectionParams( | |
name="en-US-Chirp3-HD-Charon", | |
language_code="en-US" | |
) | |
self.audio_config = texttospeech.AudioConfig( | |
audio_encoding=texttospeech.AudioEncoding.LINEAR16 | |
) | |
class AsyncGeminiHandler(AsyncStreamHandler): | |
"""Simple Async Gemini Handler""" | |
def __init__( | |
self, | |
expected_layout: Literal["mono"] = "mono", | |
output_sample_rate: int = 24000, | |
output_frame_size: int = 480, | |
) -> None: | |
super().__init__( | |
expected_layout, | |
output_sample_rate, | |
output_frame_size, | |
input_sample_rate=16000, | |
) | |
self.input_queue: asyncio.Queue = asyncio.Queue() | |
self.output_queue: asyncio.Queue = asyncio.Queue() | |
self.text_queue: asyncio.Queue = asyncio.Queue() | |
self.quit: asyncio.Event = asyncio.Event() | |
self.chunk_size = 1024 | |
self.tts_config: TTSConfig | None = TTSConfig() | |
self.text_buffer = "" | |
def copy(self) -> "AsyncGeminiHandler": | |
return AsyncGeminiHandler( | |
expected_layout="mono", | |
output_sample_rate=self.output_sample_rate, | |
output_frame_size=self.output_frame_size, | |
) | |
def _encode_audio(self, data: np.ndarray) -> str: | |
"""Encode Audio data to send to the server""" | |
return base64.b64encode(data.tobytes()).decode("UTF-8") | |
async def receive(self, frame: tuple[int, np.ndarray]) -> None: | |
_, array = frame | |
array = array.squeeze() | |
audio_message = self._encode_audio(array) | |
self.input_queue.put_nowait(audio_message) | |
async def emit(self) -> tuple[int, np.ndarray] | None: | |
return await wait_for_item(self.output_queue) | |
async def start_up(self) -> None: | |
client = genai.Client( | |
api_key=os.getenv("GOOGLE_API_KEY"), | |
http_options={"api_version": "v1alpha"}, | |
) | |
config = LiveConnectConfig( | |
system_instruction={ | |
"parts": [{"text": system_prompt}], | |
"role": "user", | |
}, | |
tools=[Tool(function_declarations=[FunctionDeclaration(**tool) for tool in TOOLS])], | |
response_modalities=["AUDIO"], | |
) | |
async with ( | |
client.aio.live.connect(model="gemini-2.0-flash-exp", config=config) as session, | |
asyncio.TaskGroup() as tg | |
): | |
self.session = session | |
tasks = [ | |
tg.create_task(self.process()), | |
tg.create_task(self.send_realtime()), | |
tg.create_task(self.tts()), | |
] | |
async def process(self) -> None: | |
while True: | |
try: | |
turn = self.session.receive() | |
async for response in turn: | |
if data := response.data: | |
array = np.frombuffer(data, dtype=np.int16) | |
self.output_queue.put_nowait((self.output_sample_rate, array)) | |
continue | |
if text := response.text: | |
print(f"Received text: {text}") | |
self.text_buffer += text | |
if response.tool_call is not None: | |
for tool in response.tool_call.function_calls: | |
tool_response = FUNCTION_MAP[tool.name](**tool.args) | |
print(f"Calling tool: {tool.name}") | |
print(f"Tool response: {tool_response}") | |
await self.session.send( | |
input=tool_response, end_of_turn=True | |
) | |
await asyncio.sleep(0.1) | |
if sc := response.server_content: | |
if sc.turn_complete and self.text_buffer: | |
self.text_queue.put_nowait(self.text_buffer) | |
FUNCTION_MAP["store_input"]( | |
role="bot", | |
input=self.text_buffer | |
) | |
self.text_buffer = "" | |
except Exception as e: | |
print(f"Error in processing: {e}") | |
await asyncio.sleep(0.1) | |
async def send_realtime(self) -> None: | |
"""Send real-time audio data to model.""" | |
while True: | |
try: | |
data = await self.input_queue.get() | |
msg = {"data": data, "mime_type": "audio/pcm"} | |
await self.session.send(input=msg) | |
except Exception as e: | |
print(f"Error in real-time sending: {e}") | |
await asyncio.sleep(0.1) | |
async def tts(self) -> None: | |
while True: | |
try: | |
text = await self.text_queue.get() | |
# Get response in a single request | |
if text: | |
response = self.tts_config.client.synthesize_speech( | |
input=texttospeech.SynthesisInput(text=text), | |
voice=self.tts_config.voice, | |
audio_config=self.tts_config.audio_config | |
) | |
array = np.frombuffer(response.audio_content, dtype=np.int16) | |
self.output_queue.put_nowait((self.output_sample_rate, array)) | |
except Exception as e: | |
print(f"Error in TTS: {e}") | |
await asyncio.sleep(0.1) | |
def shutdown(self) -> None: | |
self.quit.set() | |
def reload_json(path): | |
with open(path, "r") as f: | |
return json.load(f) | |
# Main Gradio Interface | |
def registry(name: str, token: str | None = None, **kwargs): | |
"""Sets up and returns the Gradio interface.""" | |
interface = gr.Blocks() | |
with interface: | |
with gr.Tabs(): | |
with gr.TabItem("Voice Chat"): | |
gr.HTML( | |
""" | |
<div style='text-align: left'> | |
<h1>ML6 Voice Demo - Function Calling and Custom Output Voice</h1> | |
</div> | |
""" | |
) | |
gemini_handler = AsyncGeminiHandler() | |
with gr.Row(): | |
audio = WebRTC( | |
label="Voice Chat", modality="audio", mode="send-receive" | |
) | |
# Add display components for questions and answers | |
with gr.Row(): | |
with gr.Column(): | |
gr.JSON( | |
label="Questions", | |
value=questions_dict, | |
) | |
# with gr.Column(): | |
# gr.JSON(reload_json, inputs=gr.Text(value="/Users/georgeslorre/ML6/internal/gemini-voice-agents/conversation.json", visible=False), label="Conversation", every=1) | |
with gr.Column(): | |
gr.JSON(reload_json, inputs=gr.Text(value="/Users/georgeslorre/ML6/internal/gemini-voice-agents/answers.json", visible=False),label="Collected Answers", every=1) | |
audio.stream( | |
gemini_handler, | |
inputs=[audio], # Add audio_file to inputs | |
outputs=[audio], | |
time_limit=600, | |
concurrency_limit=10, | |
) | |
return interface | |
# Function to clear JSON files | |
def clear_json_files(): | |
with open("/Users/georgeslorre/ML6/internal/gemini-voice-agents/conversation.json", "w") as f: | |
json.dump([], f) | |
with open("/Users/georgeslorre/ML6/internal/gemini-voice-agents/answers.json", "w") as f: | |
json.dump({}, f) | |
# Clear files before launching | |
clear_json_files() | |
# Launch the Gradio interface | |
gr.load( | |
name="gemini-2.0-flash-exp", | |
src=registry, | |
).launch() | |