Spaces:
Running
Running
import torch | |
import gradio as gr | |
import whisper | |
from gtts import gTTS | |
from pydub import AudioSegment | |
import tempfile | |
import os | |
from transformers import MBartForConditionalGeneration, MBart50Tokenizer | |
# Load Whisper model | |
whisper_model = whisper.load_model("base") | |
# Load mBART | |
model_name = "facebook/mbart-large-50-many-to-many-mmt" | |
tokenizer = MBart50Tokenizer.from_pretrained(model_name) | |
model = MBartForConditionalGeneration.from_pretrained(model_name).to("cuda" if torch.cuda.is_available() else "cpu") | |
# Target language | |
TARGET_LANG = "hi_IN" # Hindi | |
def respond(prompt_text, audio_file): | |
transcription = None | |
try: | |
if prompt_text and prompt_text.strip(): | |
final_prompt = prompt_text.strip() | |
elif audio_file: | |
sound = AudioSegment.from_file(audio_file) | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpwav: | |
sound.export(tmpwav.name, format="wav") | |
transcription = whisper_model.transcribe(tmpwav.name)["text"] | |
final_prompt = transcription | |
else: | |
return "No prompt provided", "", None | |
# Generate response | |
tokenizer.src_lang = "en_XX" | |
encoded = tokenizer(final_prompt, return_tensors="pt").to(model.device) | |
generated_tokens = model.generate(**encoded, forced_bos_token_id=tokenizer.lang_code_to_id[TARGET_LANG], max_new_tokens=100) | |
translated = tokenizer.decode(generated_tokens[0], skip_special_tokens=True) | |
# TTS | |
tts = gTTS(translated, lang='hi') | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp: | |
tts.save(fp.name) | |
audio_path = fp.name | |
return transcription if transcription else "Typed input used", translated, audio_path | |
except Exception as e: | |
return f"Error: {str(e)}", "", None | |
with gr.Blocks(theme=gr.themes.Soft(), title="Chat with Vidhya") as demo: | |
gr.Markdown(""" | |
# 🧠 Chat with Vidhya | |
**An AI assistant that listens to your voice or reads your text, and responds in your language.** | |
""") | |
with gr.Row(): | |
txt_input = gr.Textbox(lines=2, label="Type your prompt (optional)") | |
audio_input = gr.Audio(type="filepath", label="Or speak your prompt") | |
with gr.Row(): | |
transcription_output = gr.Textbox(label="Transcribed Speech") | |
text_output = gr.Textbox(label="Model's Response") | |
audio_output = gr.Audio(type="filepath", label="Spoken Response") | |
submit_btn = gr.Button("Submit") | |
submit_btn.click(fn=respond, inputs=[txt_input, audio_input], outputs=[transcription_output, text_output, audio_output]) | |
demo.launch() | |