Spaces:
Build error
Build error
import os | |
import gradio as gr | |
from fastrtc import Stream, AdditionalOutputs | |
from fastrtc_walkie_talkie import WalkieTalkie | |
# Import your custom models | |
from tts import tortoise_tts, TortoiseOptions | |
from stt import whisper_stt | |
import cohereAPI | |
# Environment variables | |
COHERE_API_KEY = os.getenv("COHERE_API_KEY") | |
system_message = "You respond concisely, in about 15 words or less" | |
# Initialize conversation history | |
conversation_history = [] | |
# Create a handler function that uses both your custom models | |
def response(audio): | |
global conversation_history | |
# Convert speech to text using your Whisper model | |
user_message = whisper_stt.stt(audio) | |
# Yield the transcription as additional output | |
yield AdditionalOutputs(user_message) | |
# Send text to Cohere API | |
response_text, updated_history = cohereAPI.send_message( | |
system_message, | |
user_message, | |
conversation_history, | |
COHERE_API_KEY | |
) | |
# Update conversation history | |
conversation_history = updated_history | |
# Print the response for logging | |
print(f"Assistant: {response_text}") | |
# Use your TTS model to generate audio | |
tts_options = TortoiseOptions(voice_preset="random") | |
# Stream the audio response in chunks | |
for chunk in tortoise_tts.stream_tts_sync(response_text, tts_options): | |
yield chunk | |
# Create the FastRTC stream with WalkieTalkie for turn detection | |
stream = Stream( | |
handler=WalkieTalkie(response), # Use WalkieTalkie instead of ReplyOnPause | |
modality="audio", | |
mode="send-receive", | |
additional_outputs=[gr.Textbox(label="Transcription")], | |
additional_outputs_handler=lambda old, new: new if old is None else f"{old}\nUser: {new}", | |
ui_args={ | |
"title": "Voice Assistant (Walkie-Talkie Style)", | |
"subtitle": "Say 'over' to finish your turn. For example, 'What's the weather like today? over.'" | |
} | |
) | |
# Launch the Gradio UI | |
if __name__ == "__main__": | |
stream.ui.launch( | |
server_name="0.0.0.0", | |
share=False, | |
show_error=True | |
) |