Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -3,61 +3,62 @@
|
|
3 |
##########################################
|
4 |
import streamlit as st # For web interface
|
5 |
from transformers import (
|
6 |
-
pipeline,
|
7 |
-
SpeechT5Processor,
|
8 |
-
SpeechT5ForTextToSpeech,
|
9 |
-
SpeechT5HifiGan,
|
10 |
-
AutoModelForCausalLM,
|
11 |
-
AutoTokenizer
|
12 |
) # AI model components
|
13 |
-
|
14 |
-
import
|
15 |
-
import
|
16 |
-
import
|
|
|
17 |
|
18 |
##########################################
|
19 |
# Initial configuration (MUST be first)
|
20 |
##########################################
|
21 |
st.set_page_config(
|
22 |
-
page_title="Just Comment",
|
23 |
-
page_icon="๐ฌ",
|
24 |
-
layout="centered",
|
25 |
-
initial_sidebar_state="collapsed"
|
26 |
)
|
27 |
|
28 |
##########################################
|
29 |
# Global model loading with caching
|
30 |
##########################################
|
31 |
-
@st.cache_resource(show_spinner=False)
|
32 |
def _load_models():
|
33 |
"""Load and cache all ML models with optimized settings"""
|
34 |
return {
|
35 |
# Emotion classification pipeline
|
36 |
'emotion': pipeline(
|
37 |
-
"text-classification",
|
38 |
-
model="Thea231/jhartmann_emotion_finetuning",
|
39 |
truncation=True # Enable text truncation for long inputs
|
40 |
),
|
41 |
|
42 |
# Text generation components
|
43 |
'textgen_tokenizer': AutoTokenizer.from_pretrained(
|
44 |
-
"Qwen/Qwen1.5-0.5B",
|
45 |
use_fast=True # Enable fast tokenization
|
46 |
),
|
47 |
'textgen_model': AutoModelForCausalLM.from_pretrained(
|
48 |
-
"Qwen/Qwen1.5-0.5B",
|
49 |
torch_dtype=torch.float16 # Use half-precision for faster inference
|
50 |
),
|
51 |
|
52 |
# Text-to-speech components
|
53 |
-
'tts_processor': SpeechT5Processor.from_pretrained("microsoft/speecht5_tts"),
|
54 |
-
'tts_model': SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts"),
|
55 |
-
'tts_vocoder': SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan"),
|
56 |
|
57 |
# Preloaded speaker embeddings
|
58 |
'speaker_embeddings': torch.tensor(
|
59 |
-
load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")[7306]["xvector"]
|
60 |
-
).unsqueeze(0)
|
61 |
}
|
62 |
|
63 |
##########################################
|
@@ -65,14 +66,14 @@ def _load_models():
|
|
65 |
##########################################
|
66 |
def _display_interface():
|
67 |
"""Render user interface elements"""
|
68 |
-
st.title("Just Comment")
|
69 |
-
st.markdown("### I'm listening to you, my friend๏ฝ")
|
70 |
|
71 |
return st.text_area(
|
72 |
-
"๐ Enter your comment:",
|
73 |
-
placeholder="Type your message here...",
|
74 |
-
height=150,
|
75 |
-
key="user_input"
|
76 |
)
|
77 |
|
78 |
##########################################
|
@@ -80,10 +81,10 @@ def _display_interface():
|
|
80 |
##########################################
|
81 |
def _analyze_emotion(text, classifier):
|
82 |
"""Identify dominant emotion with confidence threshold"""
|
83 |
-
results = classifier(text, return_all_scores=True)[0]
|
84 |
-
valid_emotions = {'sadness', 'joy', 'love', 'anger', 'fear', 'surprise'}
|
85 |
-
filtered = [e for e in results if e['label'].lower() in valid_emotions]
|
86 |
-
return max(filtered, key=lambda x: x['score'])
|
87 |
|
88 |
def _generate_prompt(text, emotion):
|
89 |
"""Create structured prompts for all emotion types"""
|
@@ -125,16 +126,16 @@ def _generate_prompt(text, emotion):
|
|
125 |
"Response:"
|
126 |
)
|
127 |
}
|
128 |
-
return prompt_templates.get(emotion.lower(), "").format(input=text)
|
129 |
|
130 |
def _process_response(raw_text):
|
131 |
-
"""Clean and format generated response"""
|
132 |
# Extract text after last "Response:" marker
|
133 |
processed = raw_text.split("Response:")[-1].strip()
|
134 |
|
135 |
# Remove incomplete sentences
|
136 |
if '.' in processed:
|
137 |
-
processed = processed.rsplit('.', 1)[0] + '.'
|
138 |
|
139 |
# Ensure length between 50-200 characters
|
140 |
return processed[:200].strip() if len(processed) > 50 else "Thank you for your feedback. We value your input and will respond shortly."
|
@@ -142,44 +143,44 @@ def _process_response(raw_text):
|
|
142 |
def _generate_text_response(input_text, models):
|
143 |
"""Generate optimized text response with timing controls"""
|
144 |
# Emotion analysis
|
145 |
-
emotion = _analyze_emotion(input_text, models['emotion'])
|
146 |
|
147 |
# Prompt engineering
|
148 |
-
prompt = _generate_prompt(input_text, emotion['label'])
|
149 |
|
150 |
# Text generation with optimized parameters
|
151 |
-
inputs = models['textgen_tokenizer'](prompt, return_tensors="pt").to('cpu')
|
152 |
outputs = models['textgen_model'].generate(
|
153 |
-
inputs.input_ids,
|
154 |
-
max_new_tokens=100, # Strict token limit
|
155 |
-
temperature=0.7,
|
156 |
-
top_p=0.9,
|
157 |
-
do_sample=True,
|
158 |
-
pad_token_id=models['textgen_tokenizer'].eos_token_id
|
159 |
)
|
160 |
|
161 |
return _process_response(
|
162 |
-
models['textgen_tokenizer'].decode(outputs[0], skip_special_tokens=True)
|
163 |
)
|
164 |
|
165 |
def _generate_audio_response(text, models):
|
166 |
"""Convert text to speech with performance optimizations"""
|
167 |
-
# Process text input
|
168 |
-
inputs = models['tts_processor'](text=text, return_tensors="pt")
|
169 |
|
170 |
# Generate spectrogram
|
171 |
spectrogram = models['tts_model'].generate_speech(
|
172 |
-
inputs["input_ids"],
|
173 |
-
models['speaker_embeddings']
|
174 |
)
|
175 |
|
176 |
# Generate waveform with optimizations
|
177 |
-
with torch.no_grad(): # Disable gradient calculation
|
178 |
-
waveform = models['tts_vocoder'](spectrogram)
|
179 |
|
180 |
# Save audio file
|
181 |
-
sf.write("response.wav", waveform.numpy(), samplerate=16000)
|
182 |
-
return "response.wav"
|
183 |
|
184 |
##########################################
|
185 |
# Main Application Flow
|
@@ -187,24 +188,24 @@ def _generate_audio_response(text, models):
|
|
187 |
def main():
|
188 |
"""Primary execution flow"""
|
189 |
# Load models once
|
190 |
-
ml_models = _load_models()
|
191 |
|
192 |
# Display interface
|
193 |
-
user_input = _display_interface()
|
194 |
|
195 |
-
if user_input:
|
196 |
# Text generation stage
|
197 |
-
with st.spinner("๐ Analyzing emotions and generating response..."):
|
198 |
-
text_response = _generate_text_response(user_input, ml_models)
|
199 |
|
200 |
# Display results
|
201 |
-
st.subheader("๐ Generated Response")
|
202 |
-
st.markdown(f"```\n{text_response}\n```") #
|
203 |
|
204 |
# Audio generation stage
|
205 |
-
with st.spinner("๐ Converting to speech..."):
|
206 |
-
audio_file = _generate_audio_response(text_response, ml_models)
|
207 |
-
st.audio(audio_file, format="audio/wav")
|
208 |
|
209 |
if __name__ == "__main__":
|
210 |
-
main()
|
|
|
3 |
##########################################
|
4 |
import streamlit as st # For web interface
|
5 |
from transformers import (
|
6 |
+
pipeline, # For loading pre-trained models
|
7 |
+
SpeechT5Processor, # For text-to-speech processing
|
8 |
+
SpeechT5ForTextToSpeech, # TTS model
|
9 |
+
SpeechT5HifiGan, # Vocoder for generating audio waveforms
|
10 |
+
AutoModelForCausalLM, # For text generation
|
11 |
+
AutoTokenizer # For tokenizing input text
|
12 |
) # AI model components
|
13 |
+
|
14 |
+
from datasets import load_dataset # To load voice embeddings
|
15 |
+
import torch # For tensor computations
|
16 |
+
import soundfile as sf # For handling audio files
|
17 |
+
import re # For regular expressions in text processing
|
18 |
|
19 |
##########################################
|
20 |
# Initial configuration (MUST be first)
|
21 |
##########################################
|
22 |
st.set_page_config(
|
23 |
+
page_title="Just Comment", # Title of the web app
|
24 |
+
page_icon="๐ฌ", # Icon displayed in the browser tab
|
25 |
+
layout="centered", # Center the layout of the app
|
26 |
+
initial_sidebar_state="collapsed" # Start with sidebar collapsed
|
27 |
)
|
28 |
|
29 |
##########################################
|
30 |
# Global model loading with caching
|
31 |
##########################################
|
32 |
+
@st.cache_resource(show_spinner=False) # Cache the models for performance
|
33 |
def _load_models():
|
34 |
"""Load and cache all ML models with optimized settings"""
|
35 |
return {
|
36 |
# Emotion classification pipeline
|
37 |
'emotion': pipeline(
|
38 |
+
"text-classification", # Specify task type
|
39 |
+
model="Thea231/jhartmann_emotion_finetuning", # Load the model
|
40 |
truncation=True # Enable text truncation for long inputs
|
41 |
),
|
42 |
|
43 |
# Text generation components
|
44 |
'textgen_tokenizer': AutoTokenizer.from_pretrained(
|
45 |
+
"Qwen/Qwen1.5-0.5B", # Load tokenizer
|
46 |
use_fast=True # Enable fast tokenization
|
47 |
),
|
48 |
'textgen_model': AutoModelForCausalLM.from_pretrained(
|
49 |
+
"Qwen/Qwen1.5-0.5B", # Load text generation model
|
50 |
torch_dtype=torch.float16 # Use half-precision for faster inference
|
51 |
),
|
52 |
|
53 |
# Text-to-speech components
|
54 |
+
'tts_processor': SpeechT5Processor.from_pretrained("microsoft/speecht5_tts"), # Load TTS processor
|
55 |
+
'tts_model': SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts"), # Load TTS model
|
56 |
+
'tts_vocoder': SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan"), # Load vocoder
|
57 |
|
58 |
# Preloaded speaker embeddings
|
59 |
'speaker_embeddings': torch.tensor(
|
60 |
+
load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")[7306]["xvector"] # Load speaker embeddings
|
61 |
+
).unsqueeze(0) # Add an additional dimension for batch processing
|
62 |
}
|
63 |
|
64 |
##########################################
|
|
|
66 |
##########################################
|
67 |
def _display_interface():
|
68 |
"""Render user interface elements"""
|
69 |
+
st.title("Just Comment") # Set the main title of the app
|
70 |
+
st.markdown("### I'm listening to you, my friend๏ฝ") # Subheading for user interaction
|
71 |
|
72 |
return st.text_area(
|
73 |
+
"๐ Enter your comment:", # Label for the text area
|
74 |
+
placeholder="Type your message here...", # Placeholder text
|
75 |
+
height=150, # Height of the text area
|
76 |
+
key="user_input" # Unique key for the text area
|
77 |
)
|
78 |
|
79 |
##########################################
|
|
|
81 |
##########################################
|
82 |
def _analyze_emotion(text, classifier):
|
83 |
"""Identify dominant emotion with confidence threshold"""
|
84 |
+
results = classifier(text, return_all_scores=True)[0] # Get emotion scores
|
85 |
+
valid_emotions = {'sadness', 'joy', 'love', 'anger', 'fear', 'surprise'} # Define valid emotions
|
86 |
+
filtered = [e for e in results if e['label'].lower() in valid_emotions] # Filter results by valid emotions
|
87 |
+
return max(filtered, key=lambda x: x['score']) # Return the emotion with the highest score
|
88 |
|
89 |
def _generate_prompt(text, emotion):
|
90 |
"""Create structured prompts for all emotion types"""
|
|
|
126 |
"Response:"
|
127 |
)
|
128 |
}
|
129 |
+
return prompt_templates.get(emotion.lower(), "").format(input=text) # Format and return the appropriate prompt
|
130 |
|
131 |
def _process_response(raw_text):
|
132 |
+
"""Clean and format the generated response"""
|
133 |
# Extract text after last "Response:" marker
|
134 |
processed = raw_text.split("Response:")[-1].strip()
|
135 |
|
136 |
# Remove incomplete sentences
|
137 |
if '.' in processed:
|
138 |
+
processed = processed.rsplit('.', 1)[0] + '.' # Ensure the response ends with a period
|
139 |
|
140 |
# Ensure length between 50-200 characters
|
141 |
return processed[:200].strip() if len(processed) > 50 else "Thank you for your feedback. We value your input and will respond shortly."
|
|
|
143 |
def _generate_text_response(input_text, models):
|
144 |
"""Generate optimized text response with timing controls"""
|
145 |
# Emotion analysis
|
146 |
+
emotion = _analyze_emotion(input_text, models['emotion']) # Analyze the emotion of user input
|
147 |
|
148 |
# Prompt engineering
|
149 |
+
prompt = _generate_prompt(input_text, emotion['label']) # Generate prompt based on detected emotion
|
150 |
|
151 |
# Text generation with optimized parameters
|
152 |
+
inputs = models['textgen_tokenizer'](prompt, return_tensors="pt").to('cpu') # Tokenize the prompt
|
153 |
outputs = models['textgen_model'].generate(
|
154 |
+
inputs.input_ids, # Input token IDs
|
155 |
+
max_new_tokens=100, # Strict token limit for response length
|
156 |
+
temperature=0.7, # Control randomness in text generation
|
157 |
+
top_p=0.9, # Control diversity in sampling
|
158 |
+
do_sample=True, # Enable sampling to generate varied responses
|
159 |
+
pad_token_id=models['textgen_tokenizer'].eos_token_id # Use end-of-sequence token for padding
|
160 |
)
|
161 |
|
162 |
return _process_response(
|
163 |
+
models['textgen_tokenizer'].decode(outputs[0], skip_special_tokens=True) # Decode and process the response
|
164 |
)
|
165 |
|
166 |
def _generate_audio_response(text, models):
|
167 |
"""Convert text to speech with performance optimizations"""
|
168 |
+
# Process text input for TTS
|
169 |
+
inputs = models['tts_processor'](text=text, return_tensors="pt") # Tokenize input text for TTS
|
170 |
|
171 |
# Generate spectrogram
|
172 |
spectrogram = models['tts_model'].generate_speech(
|
173 |
+
inputs["input_ids"], # Input token IDs for TTS
|
174 |
+
models['speaker_embeddings'] # Use preloaded speaker embeddings
|
175 |
)
|
176 |
|
177 |
# Generate waveform with optimizations
|
178 |
+
with torch.no_grad(): # Disable gradient calculation for inference
|
179 |
+
waveform = models['tts_vocoder'](spectrogram) # Generate audio waveform from spectrogram
|
180 |
|
181 |
# Save audio file
|
182 |
+
sf.write("response.wav", waveform.numpy(), samplerate=16000) # Save waveform as a WAV file
|
183 |
+
return "response.wav" # Return the path to the saved audio file
|
184 |
|
185 |
##########################################
|
186 |
# Main Application Flow
|
|
|
188 |
def main():
|
189 |
"""Primary execution flow"""
|
190 |
# Load models once
|
191 |
+
ml_models = _load_models() # Load all models and cache them
|
192 |
|
193 |
# Display interface
|
194 |
+
user_input = _display_interface() # Show the user input interface
|
195 |
|
196 |
+
if user_input: # Check if user has entered input
|
197 |
# Text generation stage
|
198 |
+
with st.spinner("๐ Analyzing emotions and generating response..."): # Show loading spinner
|
199 |
+
text_response = _generate_text_response(user_input, ml_models) # Generate text response
|
200 |
|
201 |
# Display results
|
202 |
+
st.subheader("๐ Generated Response") # Subheader for response section
|
203 |
+
st.markdown(f"```\n{text_response}\n```") # Display generated response in markdown format
|
204 |
|
205 |
# Audio generation stage
|
206 |
+
with st.spinner("๐ Converting to speech..."): # Show loading spinner
|
207 |
+
audio_file = _generate_audio_response(text_response, ml_models) # Generate audio response
|
208 |
+
st.audio(audio_file, format="audio/wav") # Play audio file in the app
|
209 |
|
210 |
if __name__ == "__main__":
|
211 |
+
main() # Execute the main function when the script is run
|