File size: 2,681 Bytes
8ae45bd
1d02232
9d83366
1d02232
 
 
 
80ffd7b
ea38126
1d02232
 
ea38126
80ffd7b
9d83366
80ffd7b
9d83366
1d02232
80ffd7b
9d83366
1fd24ed
1d02232
 
8ae45bd
1d02232
 
 
 
 
 
 
 
 
 
 
80ffd7b
9d83366
 
 
 
148e587
80ffd7b
9d83366
1d02232
 
 
8ae45bd
9d83366
1d02232
 
 
148e587
1d02232
 
 
9d83366
1d02232
148e587
1d02232
 
 
148e587
1d02232
 
 
 
148e587
1d02232
 
148e587
1d02232
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import gradio as gr
import whisper
from gtts import gTTS
from pydub import AudioSegment
import tempfile
import os
from transformers import MBartForConditionalGeneration, MBart50Tokenizer

# Load Whisper model
whisper_model = whisper.load_model("base")

# Load mBART
model_name = "facebook/mbart-large-50-many-to-many-mmt"
tokenizer = MBart50Tokenizer.from_pretrained(model_name)
model = MBartForConditionalGeneration.from_pretrained(model_name).to("cuda" if torch.cuda.is_available() else "cpu")

# Target language
TARGET_LANG = "hi_IN"  # Hindi

def respond(prompt_text, audio_file):
    transcription = None
    try:
        if prompt_text and prompt_text.strip():
            final_prompt = prompt_text.strip()
        elif audio_file:
            sound = AudioSegment.from_file(audio_file)
            with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpwav:
                sound.export(tmpwav.name, format="wav")
                transcription = whisper_model.transcribe(tmpwav.name)["text"]
                final_prompt = transcription
        else:
            return "No prompt provided", "", None

        # Generate response
        tokenizer.src_lang = "en_XX"
        encoded = tokenizer(final_prompt, return_tensors="pt").to(model.device)
        generated_tokens = model.generate(**encoded, forced_bos_token_id=tokenizer.lang_code_to_id[TARGET_LANG], max_new_tokens=100)
        translated = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)

        # TTS
        tts = gTTS(translated, lang='hi')
        with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
            tts.save(fp.name)
            audio_path = fp.name

        return transcription if transcription else "Typed input used", translated, audio_path

    except Exception as e:
        return f"Error: {str(e)}", "", None

with gr.Blocks(theme=gr.themes.Soft(), title="Chat with Vidhya") as demo:
    gr.Markdown("""
        # 🧠 Chat with Vidhya
        **An AI assistant that listens to your voice or reads your text, and responds in your language.**
    """)

    with gr.Row():
        txt_input = gr.Textbox(lines=2, label="Type your prompt (optional)")
        audio_input = gr.Audio(type="filepath", label="Or speak your prompt")

    with gr.Row():
        transcription_output = gr.Textbox(label="Transcribed Speech")
        text_output = gr.Textbox(label="Model's Response")
        audio_output = gr.Audio(type="filepath", label="Spoken Response")

    submit_btn = gr.Button("Submit")
    submit_btn.click(fn=respond, inputs=[txt_input, audio_input], outputs=[transcription_output, text_output, audio_output])

demo.launch()