pradeepsengarr commited on
Commit
1d02232
·
verified ·
1 Parent(s): 8ae45bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -53
app.py CHANGED
@@ -1,65 +1,73 @@
1
- import streamlit as st
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
- import whisper
4
- from streamlit_webrtc import webrtc_streamer, AudioProcessorBase
5
  import torch
 
 
 
 
 
 
 
6
 
7
- # ----------------------------- SETUP -----------------------------
8
- st.set_page_config(page_title="🧠 Talkative AI Bot", layout="centered")
9
 
10
- # ----------------------------- LOAD MODELS -----------------------------
11
- # Load Whisper model for speech-to-text
12
- @st.cache_resource
13
- def load_whisper():
14
- try:
15
- model = whisper.load_model("base")
16
- return model
17
- except Exception as e:
18
- st.error(f"An error occurred while loading Whisper model: {e}")
19
- return None
20
 
21
- # Load DistilGPT-2 model for generating responses
22
- @st.cache_resource
23
- def load_language_model():
24
  try:
25
- tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
26
- model = AutoModelForCausalLM.from_pretrained("distilgpt2")
27
- return model, tokenizer
28
- except Exception as e:
29
- st.error(f"An error occurred while loading Language model: {e}")
30
- return None, None
 
 
 
 
 
 
 
 
31
 
32
- # ----------------------------- FUNCTION TO HANDLE SPEECH -----------------------------
33
- class AudioProcessor(AudioProcessorBase):
34
- def __init__(self):
35
- self.whisper_model = load_whisper()
36
 
37
- def transform(self, audio_frame):
38
- # Convert audio frame to audio file and get text transcription
39
- result = self.whisper_model.transcribe(audio_frame)
40
- return result['text']
41
 
42
- # ----------------------------- FUNCTION TO GENERATE RESPONSE -----------------------------
43
- def generate_response(user_input):
44
- model, tokenizer = load_language_model()
45
- if model and tokenizer:
46
- inputs = tokenizer(user_input, return_tensors="pt")
47
- outputs = model.generate(inputs['input_ids'], max_length=100)
48
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
49
- return response
50
- return "Sorry, I couldn't process that."
 
 
51
 
52
- # ----------------------------- STREAMLIT UI -----------------------------
53
- st.title("🧠 Talkative AI Bot")
54
- st.write("Talk to the bot using your microphone, and it will respond!")
55
 
56
- # Streamlit WebRTC for speech-to-text
57
- webrtc_streamer(key="example", audio_processor_factory=AudioProcessor)
 
 
58
 
59
- # Input text for chatbot
60
- user_input = st.text_input("Type something for the bot:")
61
 
62
- # Handle text input and generate response
63
- if user_input:
64
- response = generate_response(user_input)
65
- st.write(f"Bot: {response}")
 
1
+ # app.py
 
 
 
2
  import torch
3
+ import whisper
4
+ import gradio as gr
5
+ from gtts import gTTS
6
+ from pydub import AudioSegment
7
+ import tempfile
8
+ import os
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
 
11
+ # Load Whisper model
12
+ whisper_model = whisper.load_model("base")
13
 
14
+ # Load Qwen model
15
+ model_name = "Qwen/Qwen2.5-1.5B"
16
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True)
17
+ model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=True).to("cuda" if torch.cuda.is_available() else "cpu")
18
+
19
+ print(f"Model loaded on: {'GPU' if next(model.parameters()).is_cuda else 'CPU'}")
 
 
 
 
20
 
21
+ def respond(prompt_text, audio_file):
22
+ transcription = None
 
23
  try:
24
+ if prompt_text and prompt_text.strip():
25
+ final_prompt = prompt_text.strip()
26
+ elif audio_file:
27
+ sound = AudioSegment.from_file(audio_file)
28
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpwav:
29
+ sound.export(tmpwav.name, format="wav")
30
+ transcription = whisper_model.transcribe(tmpwav.name)["text"]
31
+ final_prompt = transcription
32
+ else:
33
+ return "No prompt provided", "", None
34
+
35
+ inputs = tokenizer(final_prompt, return_tensors="pt").to(model.device)
36
+ outputs = model.generate(**inputs, max_new_tokens=100, do_sample=True, temperature=0.7)
37
+ text_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
38
 
39
+ tts = gTTS(text_response)
40
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
41
+ tts.save(fp.name)
42
+ audio_path = fp.name
43
 
44
+ return transcription if transcription else "Typed input used", text_response, audio_path
45
+
46
+ except Exception as e:
47
+ return f"Error: {str(e)}", "", None
48
 
49
+ with gr.Blocks(theme=gr.themes.Soft(), title="Chat with Vidhya") as demo:
50
+ gr.Markdown("""
51
+ # 🧠 Chat with Vidhya
52
+ **An AI assistant that understands your voice or typed input, and responds in speech + text.**
53
+
54
+ 💡 Try asking about:
55
+ - Technology trends
56
+ - Motorbikes & automobiles
57
+ - Finance and money tips
58
+ - Gaming news or strategies
59
+ """)
60
 
61
+ with gr.Row():
62
+ txt_input = gr.Textbox(lines=2, label="Type your prompt (optional)")
63
+ audio_input = gr.Audio(type="filepath", label="Or speak your prompt")
64
 
65
+ with gr.Row():
66
+ transcription_output = gr.Textbox(label="Transcribed Speech")
67
+ text_output = gr.Textbox(label="Model's Response")
68
+ audio_output = gr.Audio(type="filepath", label="Spoken Response")
69
 
70
+ submit_btn = gr.Button("Submit")
71
+ submit_btn.click(fn=respond, inputs=[txt_input, audio_input], outputs=[transcription_output, text_output, audio_output])
72
 
73
+ demo.launch()