Whshhs / app.py
Athspi's picture
Update app.py
c02bb52 verified
raw
history blame
4.94 kB
import gradio as gr
import asyncio
import numpy as np
from google import genai
from google.genai import types
import soundfile as sf
import io
# Configuration
SAMPLE_RATE = 24000
MODEL = "gemini-2.0-flash-exp" # Correct experimental model name
class GeminiTTS:
def __init__(self, api_key):
if not api_key:
raise ValueError("API key cannot be empty")
self.client = genai.Client(http_options={"api_version": "v1alpha"}, api_key=api_key)
self.config = types.LiveConnectConfig(
response_modalities=["AUDIO"],
speech_config=types.SpeechConfig(
voice_config=types.VoiceConfig(
prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name="Puck")
)
),
system_instruction=types.Content(
parts=[types.Part.from_text(text="Speak exactly what the user says")],
role="user"
),
)
async def text_to_speech(self, text):
try:
async with self.client.aio.live.connect(model=MODEL, config=self.config) as session:
await session.send(input=text or " ", end_of_turn=True)
async for response in session.receive():
if audio_data := response.data:
# Convert to numpy array
audio_array = np.frombuffer(audio_data, dtype=np.float32)
# Handle empty/quiet audio
if audio_array.size == 0:
audio_array = np.zeros(int(SAMPLE_RATE * 0.5)) # 0.5s of silence
# Normalize audio to prevent processing warnings
max_val = np.max(np.abs(audio_array))
if max_val > 0:
audio_array = audio_array / max_val
# Convert to proper format for Gradio
return self._create_audio_response(audio_array)
if text_response := response.text:
return text_response
return None
except Exception as e:
return f"Error: {str(e)}"
def _create_audio_response(self, audio_array):
"""Create properly formatted audio response for Gradio"""
# Convert to 16-bit PCM format
audio_array = (audio_array * 32767).astype(np.int16)
# Create WAV file in memory
with io.BytesIO() as wav_buffer:
with sf.SoundFile(
wav_buffer,
mode='w',
samplerate=SAMPLE_RATE,
channels=1,
format='WAV',
subtype='PCM_16'
) as sf_file:
sf_file.write(audio_array)
wav_bytes = wav_buffer.getvalue()
return (SAMPLE_RATE, wav_bytes)
def create_interface():
tts_engine = None
def init_engine(api_key):
nonlocal tts_engine
try:
tts_engine = GeminiTTS(api_key)
return "βœ… TTS Initialized Successfully"
except Exception as e:
return f"❌ Initialization Failed: {str(e)}"
async def generate_speech(text):
if not tts_engine:
raise gr.Error("Please initialize the TTS first")
result = await tts_engine.text_to_speech(text)
if isinstance(result, str):
return None, result # Return error message
elif result:
return result, "" # Return audio and empty message
return None, "No response received"
with gr.Blocks(title="Gemini TTS") as app:
gr.Markdown("# 🎀 Gemini Text-to-Speech")
with gr.Row():
api_key = gr.Textbox(
label="API Key",
type="password",
placeholder="Enter your Gemini API key"
)
init_btn = gr.Button("Initialize")
init_status = gr.Textbox(label="Status", interactive=False)
init_btn.click(init_engine, inputs=api_key, outputs=init_status)
with gr.Group():
text_input = gr.Textbox(
label="Input Text",
lines=3,
placeholder="Type something to speak..."
)
generate_btn = gr.Button("Generate Speech")
audio_output = gr.Audio(label="Output Audio", type="filepath")
text_output = gr.Textbox(label="Messages", interactive=False)
generate_btn.click(
generate_speech,
inputs=text_input,
outputs=[audio_output, text_output]
)
return app
if __name__ == "__main__":
app = create_interface()
app.launch(server_name="0.0.0.0", server_port=7860)