Voice_Assistant / app.py
pradeepsengarr's picture
Update app.py
80ffd7b verified
import torch
import gradio as gr
import whisper
from gtts import gTTS
from pydub import AudioSegment
import tempfile
import os
from transformers import MBartForConditionalGeneration, MBart50Tokenizer
# Load Whisper model
whisper_model = whisper.load_model("base")
# Load mBART
model_name = "facebook/mbart-large-50-many-to-many-mmt"
tokenizer = MBart50Tokenizer.from_pretrained(model_name)
model = MBartForConditionalGeneration.from_pretrained(model_name).to("cuda" if torch.cuda.is_available() else "cpu")
# Target language
TARGET_LANG = "hi_IN" # Hindi
def respond(prompt_text, audio_file):
transcription = None
try:
if prompt_text and prompt_text.strip():
final_prompt = prompt_text.strip()
elif audio_file:
sound = AudioSegment.from_file(audio_file)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpwav:
sound.export(tmpwav.name, format="wav")
transcription = whisper_model.transcribe(tmpwav.name)["text"]
final_prompt = transcription
else:
return "No prompt provided", "", None
# Generate response
tokenizer.src_lang = "en_XX"
encoded = tokenizer(final_prompt, return_tensors="pt").to(model.device)
generated_tokens = model.generate(**encoded, forced_bos_token_id=tokenizer.lang_code_to_id[TARGET_LANG], max_new_tokens=100)
translated = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
# TTS
tts = gTTS(translated, lang='hi')
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
tts.save(fp.name)
audio_path = fp.name
return transcription if transcription else "Typed input used", translated, audio_path
except Exception as e:
return f"Error: {str(e)}", "", None
with gr.Blocks(theme=gr.themes.Soft(), title="Chat with Vidhya") as demo:
gr.Markdown("""
# 🧠 Chat with Vidhya
**An AI assistant that listens to your voice or reads your text, and responds in your language.**
""")
with gr.Row():
txt_input = gr.Textbox(lines=2, label="Type your prompt (optional)")
audio_input = gr.Audio(type="filepath", label="Or speak your prompt")
with gr.Row():
transcription_output = gr.Textbox(label="Transcribed Speech")
text_output = gr.Textbox(label="Model's Response")
audio_output = gr.Audio(type="filepath", label="Spoken Response")
submit_btn = gr.Button("Submit")
submit_btn.click(fn=respond, inputs=[txt_input, audio_input], outputs=[transcription_output, text_output, audio_output])
demo.launch()