joey1101 commited on
Commit
5e2d609
·
verified ·
1 Parent(s): 05be7ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -66
app.py CHANGED
@@ -1,7 +1,7 @@
1
  ##########################################
2
  # Step 0: Import required libraries
3
  ##########################################
4
- import streamlit as st
5
  from transformers import (
6
  pipeline,
7
  SpeechT5Processor,
@@ -9,14 +9,14 @@ from transformers import (
9
  SpeechT5HifiGan,
10
  AutoModelForCausalLM,
11
  AutoTokenizer
12
- )
13
- from datasets import load_dataset
14
- import torch
15
- import soundfile as sf
16
- import sentencepiece
17
 
18
  ##########################################
19
- # Initial configuration (MUST be first)
20
  ##########################################
21
  st.set_page_config(
22
  page_title="Just Comment",
@@ -29,18 +29,32 @@ st.set_page_config(
29
  # Global model loading with caching
30
  ##########################################
31
  @st.cache_resource(show_spinner=False)
32
- def load_models():
33
- """Load and cache all ML models"""
34
  return {
 
35
  'emotion': pipeline(
36
- "text-classification",
37
- model="Thea231/jhartmann_emotion_finetuning"
 
 
 
 
 
 
 
 
 
 
 
38
  ),
39
- 'textgen_tokenizer': AutoTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B"),
40
- 'textgen_model': AutoModelForCausalLM.from_pretrained("Qwen/Qwen1.5-0.5B"),
41
  'tts_processor': SpeechT5Processor.from_pretrained("microsoft/speecht5_tts"),
42
  'tts_model': SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts"),
43
  'tts_vocoder': SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan"),
 
 
44
  'speaker_embeddings': torch.tensor(
45
  load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")[7306]["xvector"]
46
  ).unsqueeze(0)
@@ -49,96 +63,121 @@ def load_models():
49
  ##########################################
50
  # UI Components
51
  ##########################################
52
- def render_interface():
53
- """Create user interface"""
54
- st.title("Just Comment")
55
  st.markdown("### I'm listening to you, my friend~")
56
 
57
  return st.text_area(
58
  "📝 Enter your comment:",
59
- placeholder="Share your thoughts...",
60
  height=150,
61
  key="user_input"
62
  )
63
 
64
  ##########################################
65
- # Core Logic Components
66
  ##########################################
67
- def analyze_emotion(text, classifier):
68
- """Determine emotion with quick analysis"""
69
  results = classifier(text, return_all_scores=True)[0]
70
- valid_emotions = ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']
71
  filtered = [e for e in results if e['label'].lower() in valid_emotions]
72
  return max(filtered, key=lambda x: x['score'])
73
 
74
- def generate_prompt(text, emotion):
75
- """Complete prompt templates for all 6 emotions"""
76
  prompt_templates = {
77
  "sadness": (
78
- "Customer expressed sadness: {input}\n"
79
- "Respond with:\n1. Empathetic acknowledgment\n"
80
- "2. Supportive statement\n3. Concrete help offer\nResponse:"
 
81
  ),
82
  "joy": (
83
- "Positive feedback: {input}\n"
84
- "Respond with:\n1. Enthusiastic thanks\n"
85
- "2. Specific compliment\n3. Future engagement\nResponse:"
 
86
  ),
87
  "love": (
88
- "Customer showed affection: {input}\n"
89
- "Respond with:\n1. Warm appreciation\n"
90
- "2. Community building\n3. Exclusive offer\nResponse:"
 
91
  ),
92
  "anger": (
93
- "Angry complaint: {input}\n"
94
- "Respond with:\n1. Sincere apology\n"
95
- "2. Solution steps\n3. Compensation\nResponse:"
 
96
  ),
97
  "fear": (
98
- "Customer expressed concerns: {input}\n"
99
- "Respond with:\n1. Reassurance\n"
100
- "2. Safety measures\n3. Support channels\nResponse:"
 
101
  ),
102
  "surprise": (
103
- "Unexpected feedback: {input}\n"
104
- "Respond with:\n1. Acknowledge uniqueness\n"
105
- "2. Creative solution\n3. Follow-up plan\nResponse:"
 
106
  )
107
  }
108
  return prompt_templates.get(emotion.lower(), "").format(input=text)
109
 
110
- def process_response(output_text):
111
- """Optimized response processing"""
112
- output_text = output_text.split("Response:")[-1].strip()
113
- return output_text[:200] # Strict length control
 
 
 
 
 
 
 
114
 
115
- def generate_text_response(user_input, models):
116
- """Efficient text generation"""
117
- emotion = analyze_emotion(user_input, models['emotion'])
118
- prompt = generate_prompt(user_input, emotion['label'])
 
 
 
119
 
120
- inputs = models['textgen_tokenizer'](prompt, return_tensors="pt")
 
121
  outputs = models['textgen_model'].generate(
122
  inputs.input_ids,
123
- max_new_tokens=150, # Reduced for speed
124
  temperature=0.7,
 
125
  do_sample=True,
126
- top_p=0.9
127
  )
128
 
129
- return process_response(
130
  models['textgen_tokenizer'].decode(outputs[0], skip_special_tokens=True)
131
  )
132
 
133
- def generate_audio_response(text, models):
134
- """Optimized TTS conversion"""
 
135
  inputs = models['tts_processor'](text=text, return_tensors="pt")
 
 
136
  spectrogram = models['tts_model'].generate_speech(
137
  inputs["input_ids"],
138
  models['speaker_embeddings']
139
  )
140
- with torch.no_grad():
 
 
141
  waveform = models['tts_vocoder'](spectrogram)
 
 
142
  sf.write("response.wav", waveform.numpy(), samplerate=16000)
143
  return "response.wav"
144
 
@@ -146,21 +185,25 @@ def generate_audio_response(text, models):
146
  # Main Application Flow
147
  ##########################################
148
  def main():
149
- ml_models = load_models()
150
- user_input = render_interface()
 
 
 
 
151
 
152
  if user_input:
153
- # Text Generation
154
- with st.spinner("🔍 Analyzing emotions..."):
155
- text_response = generate_text_response(user_input, ml_models)
156
 
157
- # Display Results
158
  st.subheader("📄 Generated Response")
159
- st.success(text_response)
160
 
161
- # Audio Generation
162
- with st.spinner("🔊 Generating voice..."):
163
- audio_file = generate_audio_response(text_response, ml_models)
164
  st.audio(audio_file, format="audio/wav")
165
 
166
  if __name__ == "__main__":
 
1
  ##########################################
2
  # Step 0: Import required libraries
3
  ##########################################
4
+ import streamlit as st # For web interface
5
  from transformers import (
6
  pipeline,
7
  SpeechT5Processor,
 
9
  SpeechT5HifiGan,
10
  AutoModelForCausalLM,
11
  AutoTokenizer
12
+ ) # AI model components
13
+ from datasets import load_dataset # For voice embeddings
14
+ import torch # Tensor computations
15
+ import soundfile as sf # Audio file handling
16
+ import re # Regular expressions for text processing
17
 
18
  ##########################################
19
+ # Initial configuration
20
  ##########################################
21
  st.set_page_config(
22
  page_title="Just Comment",
 
29
  # Global model loading with caching
30
  ##########################################
31
  @st.cache_resource(show_spinner=False)
32
+ def _load_models():
33
+ """Load and cache all ML models with optimized settings"""
34
  return {
35
+ # Emotion classification pipeline
36
  'emotion': pipeline(
37
+ "text-classification",
38
+ model="Thea231/jhartmann_emotion_finetuning",
39
+ truncation=True # Enable text truncation for long inputs
40
+ ),
41
+
42
+ # Text generation components
43
+ 'textgen_tokenizer': AutoTokenizer.from_pretrained(
44
+ "Qwen/Qwen1.5-0.5B",
45
+ use_fast=True # Enable fast tokenization
46
+ ),
47
+ 'textgen_model': AutoModelForCausalLM.from_pretrained(
48
+ "Qwen/Qwen1.5-0.5B",
49
+ torch_dtype=torch.float16 # Use half-precision for faster inference
50
  ),
51
+
52
+ # Text-to-speech components
53
  'tts_processor': SpeechT5Processor.from_pretrained("microsoft/speecht5_tts"),
54
  'tts_model': SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts"),
55
  'tts_vocoder': SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan"),
56
+
57
+ # Preloaded speaker embeddings
58
  'speaker_embeddings': torch.tensor(
59
  load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")[7306]["xvector"]
60
  ).unsqueeze(0)
 
63
  ##########################################
64
  # UI Components
65
  ##########################################
66
+ def _display_interface():
67
+ """Render user interface elements"""
68
+ st.title("🚀 Just Comment")
69
  st.markdown("### I'm listening to you, my friend~")
70
 
71
  return st.text_area(
72
  "📝 Enter your comment:",
73
+ placeholder="Type your message here...",
74
  height=150,
75
  key="user_input"
76
  )
77
 
78
  ##########################################
79
+ # Core Processing Functions
80
  ##########################################
81
+ def _analyze_emotion(text, classifier):
82
+ """Identify dominant emotion with confidence threshold"""
83
  results = classifier(text, return_all_scores=True)[0]
84
+ valid_emotions = {'sadness', 'joy', 'love', 'anger', 'fear', 'surprise'}
85
  filtered = [e for e in results if e['label'].lower() in valid_emotions]
86
  return max(filtered, key=lambda x: x['score'])
87
 
88
+ def _generate_prompt(text, emotion):
89
+ """Create structured prompts for all emotion types"""
90
  prompt_templates = {
91
  "sadness": (
92
+ "Sadness detected: {input}\n"
93
+ "Required response structure:\n"
94
+ "1. Empathetic acknowledgment\n2. Support offer\n3. Solution proposal\n"
95
+ "Response:"
96
  ),
97
  "joy": (
98
+ "Joy detected: {input}\n"
99
+ "Required response structure:\n"
100
+ "1. Enthusiastic thanks\n2. Positive reinforcement\n3. Future engagement\n"
101
+ "Response:"
102
  ),
103
  "love": (
104
+ "Affection detected: {input}\n"
105
+ "Required response structure:\n"
106
+ "1. Warm appreciation\n2. Community focus\n3. Exclusive benefit\n"
107
+ "Response:"
108
  ),
109
  "anger": (
110
+ "Anger detected: {input}\n"
111
+ "Required response structure:\n"
112
+ "1. Sincere apology\n2. Action steps\n3. Compensation\n"
113
+ "Response:"
114
  ),
115
  "fear": (
116
+ "Concern detected: {input}\n"
117
+ "Required response structure:\n"
118
+ "1. Reassurance\n2. Safety measures\n3. Support options\n"
119
+ "Response:"
120
  ),
121
  "surprise": (
122
+ "Surprise detected: {input}\n"
123
+ "Required response structure:\n"
124
+ "1. Acknowledge uniqueness\n2. Creative solution\n3. Follow-up\n"
125
+ "Response:"
126
  )
127
  }
128
  return prompt_templates.get(emotion.lower(), "").format(input=text)
129
 
130
+ def _process_response(raw_text):
131
+ """Clean and format generated response"""
132
+ # Extract text after last "Response:" marker
133
+ processed = raw_text.split("Response:")[-1].strip()
134
+
135
+ # Remove incomplete sentences
136
+ if '.' in processed:
137
+ processed = processed.rsplit('.', 1)[0] + '.'
138
+
139
+ # Ensure length between 50-200 characters
140
+ return processed[:200].strip() if len(processed) > 50 else "Thank you for your feedback. We value your input and will respond shortly."
141
 
142
+ def _generate_text_response(input_text, models):
143
+ """Generate optimized text response with timing controls"""
144
+ # Emotion analysis
145
+ emotion = _analyze_emotion(input_text, models['emotion'])
146
+
147
+ # Prompt engineering
148
+ prompt = _generate_prompt(input_text, emotion['label'])
149
 
150
+ # Text generation with optimized parameters
151
+ inputs = models['textgen_tokenizer'](prompt, return_tensors="pt").to('cpu')
152
  outputs = models['textgen_model'].generate(
153
  inputs.input_ids,
154
+ max_new_tokens=100, # Strict token limit
155
  temperature=0.7,
156
+ top_p=0.9,
157
  do_sample=True,
158
+ pad_token_id=models['textgen_tokenizer'].eos_token_id
159
  )
160
 
161
+ return _process_response(
162
  models['textgen_tokenizer'].decode(outputs[0], skip_special_tokens=True)
163
  )
164
 
165
+ def _generate_audio_response(text, models):
166
+ """Convert text to speech with performance optimizations"""
167
+ # Process text input
168
  inputs = models['tts_processor'](text=text, return_tensors="pt")
169
+
170
+ # Generate spectrogram
171
  spectrogram = models['tts_model'].generate_speech(
172
  inputs["input_ids"],
173
  models['speaker_embeddings']
174
  )
175
+
176
+ # Generate waveform with optimizations
177
+ with torch.no_grad(): # Disable gradient calculation
178
  waveform = models['tts_vocoder'](spectrogram)
179
+
180
+ # Save audio file
181
  sf.write("response.wav", waveform.numpy(), samplerate=16000)
182
  return "response.wav"
183
 
 
185
  # Main Application Flow
186
  ##########################################
187
  def main():
188
+ """Primary execution flow"""
189
+ # Load models once
190
+ ml_models = _load_models()
191
+
192
+ # Display interface
193
+ user_input = _display_interface()
194
 
195
  if user_input:
196
+ # Text generation stage
197
+ with st.spinner("🔍 Analyzing emotions and generating response..."):
198
+ text_response = _generate_text_response(user_input, ml_models)
199
 
200
+ # Display results
201
  st.subheader("📄 Generated Response")
202
+ st.markdown(f"```\n{text_response}\n```") # f-string formatted output
203
 
204
+ # Audio generation stage
205
+ with st.spinner("🔊 Converting to speech..."):
206
+ audio_file = _generate_audio_response(text_response, ml_models)
207
  st.audio(audio_file, format="audio/wav")
208
 
209
  if __name__ == "__main__":