Spaces:
Running
Running
File size: 2,681 Bytes
8ae45bd 1d02232 9d83366 1d02232 80ffd7b ea38126 1d02232 ea38126 80ffd7b 9d83366 80ffd7b 9d83366 1d02232 80ffd7b 9d83366 1fd24ed 1d02232 8ae45bd 1d02232 80ffd7b 9d83366 148e587 80ffd7b 9d83366 1d02232 8ae45bd 9d83366 1d02232 148e587 1d02232 9d83366 1d02232 148e587 1d02232 148e587 1d02232 148e587 1d02232 148e587 1d02232 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import torch
import gradio as gr
import whisper
from gtts import gTTS
from pydub import AudioSegment
import tempfile
import os
from transformers import MBartForConditionalGeneration, MBart50Tokenizer
# Load Whisper model
whisper_model = whisper.load_model("base")
# Load mBART
model_name = "facebook/mbart-large-50-many-to-many-mmt"
tokenizer = MBart50Tokenizer.from_pretrained(model_name)
model = MBartForConditionalGeneration.from_pretrained(model_name).to("cuda" if torch.cuda.is_available() else "cpu")
# Target language
TARGET_LANG = "hi_IN" # Hindi
def respond(prompt_text, audio_file):
transcription = None
try:
if prompt_text and prompt_text.strip():
final_prompt = prompt_text.strip()
elif audio_file:
sound = AudioSegment.from_file(audio_file)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpwav:
sound.export(tmpwav.name, format="wav")
transcription = whisper_model.transcribe(tmpwav.name)["text"]
final_prompt = transcription
else:
return "No prompt provided", "", None
# Generate response
tokenizer.src_lang = "en_XX"
encoded = tokenizer(final_prompt, return_tensors="pt").to(model.device)
generated_tokens = model.generate(**encoded, forced_bos_token_id=tokenizer.lang_code_to_id[TARGET_LANG], max_new_tokens=100)
translated = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
# TTS
tts = gTTS(translated, lang='hi')
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
tts.save(fp.name)
audio_path = fp.name
return transcription if transcription else "Typed input used", translated, audio_path
except Exception as e:
return f"Error: {str(e)}", "", None
with gr.Blocks(theme=gr.themes.Soft(), title="Chat with Vidhya") as demo:
gr.Markdown("""
# 🧠 Chat with Vidhya
**An AI assistant that listens to your voice or reads your text, and responds in your language.**
""")
with gr.Row():
txt_input = gr.Textbox(lines=2, label="Type your prompt (optional)")
audio_input = gr.Audio(type="filepath", label="Or speak your prompt")
with gr.Row():
transcription_output = gr.Textbox(label="Transcribed Speech")
text_output = gr.Textbox(label="Model's Response")
audio_output = gr.Audio(type="filepath", label="Spoken Response")
submit_btn = gr.Button("Submit")
submit_btn.click(fn=respond, inputs=[txt_input, audio_input], outputs=[transcription_output, text_output, audio_output])
demo.launch()
|