joey1101 commited on
Commit
317475a
ยท
verified ยท
1 Parent(s): d7ef86b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -66
app.py CHANGED
@@ -3,61 +3,62 @@
3
  ##########################################
4
  import streamlit as st # For web interface
5
  from transformers import (
6
- pipeline,
7
- SpeechT5Processor,
8
- SpeechT5ForTextToSpeech,
9
- SpeechT5HifiGan,
10
- AutoModelForCausalLM,
11
- AutoTokenizer
12
  ) # AI model components
13
- from datasets import load_dataset # For voice embeddings
14
- import torch # Tensor computations
15
- import soundfile as sf # Audio file handling
16
- import re # Regular expressions for text processing
 
17
 
18
  ##########################################
19
  # Initial configuration (MUST be first)
20
  ##########################################
21
  st.set_page_config(
22
- page_title="Just Comment",
23
- page_icon="๐Ÿ’ฌ",
24
- layout="centered",
25
- initial_sidebar_state="collapsed"
26
  )
27
 
28
  ##########################################
29
  # Global model loading with caching
30
  ##########################################
31
- @st.cache_resource(show_spinner=False)
32
  def _load_models():
33
  """Load and cache all ML models with optimized settings"""
34
  return {
35
  # Emotion classification pipeline
36
  'emotion': pipeline(
37
- "text-classification",
38
- model="Thea231/jhartmann_emotion_finetuning",
39
  truncation=True # Enable text truncation for long inputs
40
  ),
41
 
42
  # Text generation components
43
  'textgen_tokenizer': AutoTokenizer.from_pretrained(
44
- "Qwen/Qwen1.5-0.5B",
45
  use_fast=True # Enable fast tokenization
46
  ),
47
  'textgen_model': AutoModelForCausalLM.from_pretrained(
48
- "Qwen/Qwen1.5-0.5B",
49
  torch_dtype=torch.float16 # Use half-precision for faster inference
50
  ),
51
 
52
  # Text-to-speech components
53
- 'tts_processor': SpeechT5Processor.from_pretrained("microsoft/speecht5_tts"),
54
- 'tts_model': SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts"),
55
- 'tts_vocoder': SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan"),
56
 
57
  # Preloaded speaker embeddings
58
  'speaker_embeddings': torch.tensor(
59
- load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")[7306]["xvector"]
60
- ).unsqueeze(0)
61
  }
62
 
63
  ##########################################
@@ -65,14 +66,14 @@ def _load_models():
65
  ##########################################
66
  def _display_interface():
67
  """Render user interface elements"""
68
- st.title("Just Comment")
69
- st.markdown("### I'm listening to you, my friend๏ฝž")
70
 
71
  return st.text_area(
72
- "๐Ÿ“ Enter your comment:",
73
- placeholder="Type your message here...",
74
- height=150,
75
- key="user_input"
76
  )
77
 
78
  ##########################################
@@ -80,10 +81,10 @@ def _display_interface():
80
  ##########################################
81
  def _analyze_emotion(text, classifier):
82
  """Identify dominant emotion with confidence threshold"""
83
- results = classifier(text, return_all_scores=True)[0]
84
- valid_emotions = {'sadness', 'joy', 'love', 'anger', 'fear', 'surprise'}
85
- filtered = [e for e in results if e['label'].lower() in valid_emotions]
86
- return max(filtered, key=lambda x: x['score'])
87
 
88
  def _generate_prompt(text, emotion):
89
  """Create structured prompts for all emotion types"""
@@ -125,16 +126,16 @@ def _generate_prompt(text, emotion):
125
  "Response:"
126
  )
127
  }
128
- return prompt_templates.get(emotion.lower(), "").format(input=text)
129
 
130
  def _process_response(raw_text):
131
- """Clean and format generated response"""
132
  # Extract text after last "Response:" marker
133
  processed = raw_text.split("Response:")[-1].strip()
134
 
135
  # Remove incomplete sentences
136
  if '.' in processed:
137
- processed = processed.rsplit('.', 1)[0] + '.'
138
 
139
  # Ensure length between 50-200 characters
140
  return processed[:200].strip() if len(processed) > 50 else "Thank you for your feedback. We value your input and will respond shortly."
@@ -142,44 +143,44 @@ def _process_response(raw_text):
142
  def _generate_text_response(input_text, models):
143
  """Generate optimized text response with timing controls"""
144
  # Emotion analysis
145
- emotion = _analyze_emotion(input_text, models['emotion'])
146
 
147
  # Prompt engineering
148
- prompt = _generate_prompt(input_text, emotion['label'])
149
 
150
  # Text generation with optimized parameters
151
- inputs = models['textgen_tokenizer'](prompt, return_tensors="pt").to('cpu')
152
  outputs = models['textgen_model'].generate(
153
- inputs.input_ids,
154
- max_new_tokens=100, # Strict token limit
155
- temperature=0.7,
156
- top_p=0.9,
157
- do_sample=True,
158
- pad_token_id=models['textgen_tokenizer'].eos_token_id
159
  )
160
 
161
  return _process_response(
162
- models['textgen_tokenizer'].decode(outputs[0], skip_special_tokens=True)
163
  )
164
 
165
  def _generate_audio_response(text, models):
166
  """Convert text to speech with performance optimizations"""
167
- # Process text input
168
- inputs = models['tts_processor'](text=text, return_tensors="pt")
169
 
170
  # Generate spectrogram
171
  spectrogram = models['tts_model'].generate_speech(
172
- inputs["input_ids"],
173
- models['speaker_embeddings']
174
  )
175
 
176
  # Generate waveform with optimizations
177
- with torch.no_grad(): # Disable gradient calculation
178
- waveform = models['tts_vocoder'](spectrogram)
179
 
180
  # Save audio file
181
- sf.write("response.wav", waveform.numpy(), samplerate=16000)
182
- return "response.wav"
183
 
184
  ##########################################
185
  # Main Application Flow
@@ -187,24 +188,24 @@ def _generate_audio_response(text, models):
187
  def main():
188
  """Primary execution flow"""
189
  # Load models once
190
- ml_models = _load_models()
191
 
192
  # Display interface
193
- user_input = _display_interface()
194
 
195
- if user_input:
196
  # Text generation stage
197
- with st.spinner("๐Ÿ” Analyzing emotions and generating response..."):
198
- text_response = _generate_text_response(user_input, ml_models)
199
 
200
  # Display results
201
- st.subheader("๐Ÿ“„ Generated Response")
202
- st.markdown(f"```\n{text_response}\n```") # f-string formatted output
203
 
204
  # Audio generation stage
205
- with st.spinner("๐Ÿ”Š Converting to speech..."):
206
- audio_file = _generate_audio_response(text_response, ml_models)
207
- st.audio(audio_file, format="audio/wav")
208
 
209
  if __name__ == "__main__":
210
- main()
 
3
  ##########################################
4
  import streamlit as st # For web interface
5
  from transformers import (
6
+ pipeline, # For loading pre-trained models
7
+ SpeechT5Processor, # For text-to-speech processing
8
+ SpeechT5ForTextToSpeech, # TTS model
9
+ SpeechT5HifiGan, # Vocoder for generating audio waveforms
10
+ AutoModelForCausalLM, # For text generation
11
+ AutoTokenizer # For tokenizing input text
12
  ) # AI model components
13
+
14
+ from datasets import load_dataset # To load voice embeddings
15
+ import torch # For tensor computations
16
+ import soundfile as sf # For handling audio files
17
+ import re # For regular expressions in text processing
18
 
19
  ##########################################
20
  # Initial configuration (MUST be first)
21
  ##########################################
22
  st.set_page_config(
23
+ page_title="Just Comment", # Title of the web app
24
+ page_icon="๐Ÿ’ฌ", # Icon displayed in the browser tab
25
+ layout="centered", # Center the layout of the app
26
+ initial_sidebar_state="collapsed" # Start with sidebar collapsed
27
  )
28
 
29
  ##########################################
30
  # Global model loading with caching
31
  ##########################################
32
+ @st.cache_resource(show_spinner=False) # Cache the models for performance
33
  def _load_models():
34
  """Load and cache all ML models with optimized settings"""
35
  return {
36
  # Emotion classification pipeline
37
  'emotion': pipeline(
38
+ "text-classification", # Specify task type
39
+ model="Thea231/jhartmann_emotion_finetuning", # Load the model
40
  truncation=True # Enable text truncation for long inputs
41
  ),
42
 
43
  # Text generation components
44
  'textgen_tokenizer': AutoTokenizer.from_pretrained(
45
+ "Qwen/Qwen1.5-0.5B", # Load tokenizer
46
  use_fast=True # Enable fast tokenization
47
  ),
48
  'textgen_model': AutoModelForCausalLM.from_pretrained(
49
+ "Qwen/Qwen1.5-0.5B", # Load text generation model
50
  torch_dtype=torch.float16 # Use half-precision for faster inference
51
  ),
52
 
53
  # Text-to-speech components
54
+ 'tts_processor': SpeechT5Processor.from_pretrained("microsoft/speecht5_tts"), # Load TTS processor
55
+ 'tts_model': SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts"), # Load TTS model
56
+ 'tts_vocoder': SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan"), # Load vocoder
57
 
58
  # Preloaded speaker embeddings
59
  'speaker_embeddings': torch.tensor(
60
+ load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")[7306]["xvector"] # Load speaker embeddings
61
+ ).unsqueeze(0) # Add an additional dimension for batch processing
62
  }
63
 
64
  ##########################################
 
66
  ##########################################
67
  def _display_interface():
68
  """Render user interface elements"""
69
+ st.title("Just Comment") # Set the main title of the app
70
+ st.markdown("### I'm listening to you, my friend๏ฝž") # Subheading for user interaction
71
 
72
  return st.text_area(
73
+ "๐Ÿ“ Enter your comment:", # Label for the text area
74
+ placeholder="Type your message here...", # Placeholder text
75
+ height=150, # Height of the text area
76
+ key="user_input" # Unique key for the text area
77
  )
78
 
79
  ##########################################
 
81
  ##########################################
82
  def _analyze_emotion(text, classifier):
83
  """Identify dominant emotion with confidence threshold"""
84
+ results = classifier(text, return_all_scores=True)[0] # Get emotion scores
85
+ valid_emotions = {'sadness', 'joy', 'love', 'anger', 'fear', 'surprise'} # Define valid emotions
86
+ filtered = [e for e in results if e['label'].lower() in valid_emotions] # Filter results by valid emotions
87
+ return max(filtered, key=lambda x: x['score']) # Return the emotion with the highest score
88
 
89
  def _generate_prompt(text, emotion):
90
  """Create structured prompts for all emotion types"""
 
126
  "Response:"
127
  )
128
  }
129
+ return prompt_templates.get(emotion.lower(), "").format(input=text) # Format and return the appropriate prompt
130
 
131
  def _process_response(raw_text):
132
+ """Clean and format the generated response"""
133
  # Extract text after last "Response:" marker
134
  processed = raw_text.split("Response:")[-1].strip()
135
 
136
  # Remove incomplete sentences
137
  if '.' in processed:
138
+ processed = processed.rsplit('.', 1)[0] + '.' # Ensure the response ends with a period
139
 
140
  # Ensure length between 50-200 characters
141
  return processed[:200].strip() if len(processed) > 50 else "Thank you for your feedback. We value your input and will respond shortly."
 
143
  def _generate_text_response(input_text, models):
144
  """Generate optimized text response with timing controls"""
145
  # Emotion analysis
146
+ emotion = _analyze_emotion(input_text, models['emotion']) # Analyze the emotion of user input
147
 
148
  # Prompt engineering
149
+ prompt = _generate_prompt(input_text, emotion['label']) # Generate prompt based on detected emotion
150
 
151
  # Text generation with optimized parameters
152
+ inputs = models['textgen_tokenizer'](prompt, return_tensors="pt").to('cpu') # Tokenize the prompt
153
  outputs = models['textgen_model'].generate(
154
+ inputs.input_ids, # Input token IDs
155
+ max_new_tokens=100, # Strict token limit for response length
156
+ temperature=0.7, # Control randomness in text generation
157
+ top_p=0.9, # Control diversity in sampling
158
+ do_sample=True, # Enable sampling to generate varied responses
159
+ pad_token_id=models['textgen_tokenizer'].eos_token_id # Use end-of-sequence token for padding
160
  )
161
 
162
  return _process_response(
163
+ models['textgen_tokenizer'].decode(outputs[0], skip_special_tokens=True) # Decode and process the response
164
  )
165
 
166
  def _generate_audio_response(text, models):
167
  """Convert text to speech with performance optimizations"""
168
+ # Process text input for TTS
169
+ inputs = models['tts_processor'](text=text, return_tensors="pt") # Tokenize input text for TTS
170
 
171
  # Generate spectrogram
172
  spectrogram = models['tts_model'].generate_speech(
173
+ inputs["input_ids"], # Input token IDs for TTS
174
+ models['speaker_embeddings'] # Use preloaded speaker embeddings
175
  )
176
 
177
  # Generate waveform with optimizations
178
+ with torch.no_grad(): # Disable gradient calculation for inference
179
+ waveform = models['tts_vocoder'](spectrogram) # Generate audio waveform from spectrogram
180
 
181
  # Save audio file
182
+ sf.write("response.wav", waveform.numpy(), samplerate=16000) # Save waveform as a WAV file
183
+ return "response.wav" # Return the path to the saved audio file
184
 
185
  ##########################################
186
  # Main Application Flow
 
188
  def main():
189
  """Primary execution flow"""
190
  # Load models once
191
+ ml_models = _load_models() # Load all models and cache them
192
 
193
  # Display interface
194
+ user_input = _display_interface() # Show the user input interface
195
 
196
+ if user_input: # Check if user has entered input
197
  # Text generation stage
198
+ with st.spinner("๐Ÿ” Analyzing emotions and generating response..."): # Show loading spinner
199
+ text_response = _generate_text_response(user_input, ml_models) # Generate text response
200
 
201
  # Display results
202
+ st.subheader("๐Ÿ“„ Generated Response") # Subheader for response section
203
+ st.markdown(f"```\n{text_response}\n```") # Display generated response in markdown format
204
 
205
  # Audio generation stage
206
+ with st.spinner("๐Ÿ”Š Converting to speech..."): # Show loading spinner
207
+ audio_file = _generate_audio_response(text_response, ml_models) # Generate audio response
208
+ st.audio(audio_file, format="audio/wav") # Play audio file in the app
209
 
210
  if __name__ == "__main__":
211
+ main() # Execute the main function when the script is run