Spaces:
Runtime error
Runtime error
"""Helper for audio loop.""" | |
import asyncio | |
import logging | |
import traceback | |
import wave | |
from typing import Optional | |
import pyaudio | |
from google import genai | |
from models import AudioConfig, ModelConfig | |
from tools import FUNCTION_MAP | |
logger = logging.getLogger(__name__) | |
class TextLoop: | |
def __init__(self, model_config: ModelConfig): | |
self.model_config = model_config | |
self.client = self._setup_client() | |
self.session = None | |
def _setup_client(self) -> genai.Client: | |
"""Initialize the Gemini client.""" | |
return genai.Client( | |
api_key=self.model_config.api_key, | |
http_options={"api_version": "v1alpha"}, | |
) | |
async def send_text(self) -> None: | |
"""Handle text input and send to model.""" | |
while True: | |
try: | |
text = await asyncio.to_thread(input, "message > ") | |
if text.lower() == "q": | |
break | |
await self.session.send(input=text or ".", end_of_turn=True) | |
except Exception as e: | |
logger.error(f"Error sending text: {e}") | |
await asyncio.sleep(0.1) | |
async def receive_text(self) -> None: | |
"""Process and handle model responses.""" | |
while True: | |
try: | |
turn = self.session.receive() | |
async for response in turn: | |
if text := response.text: | |
logger.info(text) | |
if response.tool_call is not None: | |
for tool in response.tool_call.function_calls: | |
tool_response = FUNCTION_MAP[tool.name](**tool.args) | |
logger.info(tool_response) | |
await self.session.send( | |
input=tool_response, end_of_turn=True | |
) | |
await asyncio.sleep(0.1) | |
except Exception as e: | |
logger.error(f"Error receiving text: {e}") | |
await asyncio.sleep(0.1) | |
async def run(self): | |
try: | |
async with ( | |
self.client.aio.live.connect( | |
model=self.model_config.name, | |
config={ | |
"system_instruction": self.model_config.system_instruction, | |
"tools": self.model_config.tools, | |
"generation_config": self.model_config.generation_config, | |
}, | |
) as session, | |
asyncio.TaskGroup() as tg, | |
): | |
self.session = session | |
tasks = [ | |
tg.create_task(self.send_text()), | |
tg.create_task(self.receive_text()), | |
] | |
await tasks[0] # Wait for send_text to complete | |
raise asyncio.CancelledError("User requested exit") | |
except asyncio.CancelledError: | |
logger.info("Shutting down...") | |
except Exception as e: | |
logger.error(f"Error in main loop: {e}") | |
logger.debug(traceback.format_exc()) | |
class AudioLoop: | |
"""Handles real-time audio streaming and processing.""" | |
def __init__( | |
self, | |
audio_config: AudioConfig, | |
model_config: ModelConfig, | |
function_map: Optional[dict[str, callable]] = FUNCTION_MAP, | |
instruction_audio: Optional[str] = None, | |
): | |
"""Initialize the audio loop. | |
Args: | |
audio_config (AudioConfig): Audio configuration settings | |
model_config (ModelConfig): Model configuration settings | |
function_map (Optional[dict[str, callable]]): Function map | |
""" | |
self.audio_config = audio_config | |
self.model_config = model_config | |
self.audio_in_queue: Optional[asyncio.Queue] = None | |
self.out_queue: Optional[asyncio.Queue] = None | |
self.session = None | |
self.audio_stream = None | |
self.client = self._setup_client() | |
self.instruction_audio = instruction_audio | |
self.function_map = function_map | |
def _setup_client(self) -> genai.Client: | |
"""Initialize the Gemini client.""" | |
return genai.Client( | |
api_key=self.model_config.api_key, | |
http_options={"api_version": "v1alpha"}, | |
) | |
async def send_text(self) -> None: | |
"""Handle text input and send to model.""" | |
while True: | |
try: | |
text = await asyncio.to_thread(input, "message > ") | |
if text.lower() == "q": | |
break | |
await self.session.send(input=text or ".", end_of_turn=True) | |
except Exception as e: | |
logger.error(f"Error sending text: {e}") | |
await asyncio.sleep(0.1) | |
async def send_realtime(self) -> None: | |
"""Send real-time audio data to model.""" | |
while True: | |
try: | |
msg = await self.out_queue.get() | |
await self.session.send(input=msg) | |
except Exception as e: | |
logger.error(f"Error in real-time sending: {e}") | |
await asyncio.sleep(0.1) | |
def input_audio_file(self, file_path: str): | |
"""Read audio file and stream to the model.""" | |
try: | |
with wave.open(file_path, "rb") as wave_file: | |
data = wave_file.readframes(wave_file.getnframes()) | |
self.out_queue.put_nowait({"data": data, "mime_type": "audio/pcm"}) | |
except Exception as e: | |
logger.error(f"Error reading audio file: {e}") | |
async def listen_audio(self) -> None: | |
"""Capture and process audio input.""" | |
try: | |
pya = pyaudio.PyAudio() | |
mic_info = pya.get_default_input_device_info() | |
self.audio_stream = await asyncio.to_thread( | |
pya.open, | |
format=self.audio_config.format, | |
channels=self.audio_config.channels, | |
rate=self.audio_config.send_sample_rate, | |
input=True, | |
input_device_index=mic_info["index"], | |
frames_per_buffer=self.audio_config.chunk_size, | |
) | |
kwargs = {"exception_on_overflow": False} if __debug__ else {} | |
while True: | |
data = await asyncio.to_thread( | |
self.audio_stream.read, | |
self.audio_config.chunk_size, | |
**kwargs, | |
) | |
await self.out_queue.put({"data": data, "mime_type": "audio/pcm"}) | |
except Exception as e: | |
logger.error(f"Error in audio listening: {e}") | |
if self.audio_stream: | |
self.audio_stream.close() | |
async def receive_audio(self) -> None: | |
"""Process and handle model responses.""" | |
while True: | |
try: | |
turn = self.session.receive() | |
async for response in turn: | |
if data := response.data: | |
self.audio_in_queue.put_nowait(data) | |
continue | |
if text := response.text: | |
logger.info(text) | |
if response.tool_call is not None: | |
for tool in response.tool_call.function_calls: | |
tool_response = FUNCTION_MAP[tool.name](**tool.args) | |
logger.info(tool_response) | |
await self.session.send( | |
input=tool_response, end_of_turn=True | |
) | |
await asyncio.sleep(0.1) | |
# Clear queue on turn completion | |
while not self.audio_in_queue.empty(): | |
self.audio_in_queue.get_nowait() | |
except Exception as e: | |
logger.error(f"Error receiving audio: {e}") | |
await asyncio.sleep(0.1) | |
async def play_audio(self) -> None: | |
"""Play received audio through output device.""" | |
try: | |
pya = pyaudio.PyAudio() | |
stream = await asyncio.to_thread( | |
pya.open, | |
format=self.audio_config.format, | |
channels=self.audio_config.channels, | |
rate=self.audio_config.receive_sample_rate, | |
output=True, | |
) | |
while True: | |
bytestream = await self.audio_in_queue.get() | |
await asyncio.to_thread(stream.write, bytestream) | |
except Exception as e: | |
logger.error(f"Error playing audio: {e}") | |
if "stream" in locals(): | |
stream.close() | |
async def run(self) -> None: | |
"""Main execution loop.""" | |
try: | |
async with ( | |
self.client.aio.live.connect( | |
model=self.model_config.name, | |
config={ | |
"system_instruction": self.model_config.system_instruction, | |
"tools": self.model_config.tools, | |
"generation_config": self.model_config.generation_config, | |
}, | |
) as session, | |
asyncio.TaskGroup() as tg, | |
): | |
self.session = session | |
self.audio_in_queue = asyncio.Queue() | |
self.out_queue = asyncio.Queue(maxsize=5) | |
if self.instruction_audio: | |
self.input_audio_file(file_path=self.instruction_audio) | |
tasks = [ | |
tg.create_task(self.send_text()), | |
tg.create_task(self.send_realtime()), | |
tg.create_task(self.listen_audio()), | |
tg.create_task(self.receive_audio()), | |
tg.create_task(self.play_audio()), | |
] | |
await tasks[0] # Wait for send_text to complete | |
raise asyncio.CancelledError("User requested exit") | |
except asyncio.CancelledError: | |
logger.info("Shutting down...") | |
except Exception as e: | |
logger.error(f"Error in main loop: {e}") | |
logger.debug(traceback.format_exc()) | |
finally: | |
if self.audio_stream: | |
self.audio_stream.close() | |