File size: 7,996 Bytes
152d61c
 
 
6597a2f
7abe73c
 
 
 
 
 
 
5e2d609
6597a2f
 
5e2d609
6597a2f
c39c802
0e85ac7
6597a2f
0e85ac7
0a4b920
05be7ae
0e85ac7
0a4b920
 
0e85ac7
 
152d61c
6597a2f
152d61c
0a4b920
5e2d609
6597a2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7abe73c
6597a2f
 
 
 
 
 
 
 
7abe73c
3970052
152d61c
0a4b920
152d61c
5e2d609
6597a2f
 
 
0a4b920
 
05be7ae
5e2d609
0a4b920
 
 
e4cf4e2
152d61c
5e2d609
7abe73c
5e2d609
6597a2f
 
 
5e2d609
6597a2f
 
 
 
 
 
 
 
 
 
0a4b920
5e2d609
6597a2f
0a4b920
6597a2f
 
 
 
 
 
 
0a4b920
6597a2f
0a4b920
5e2d609
6597a2f
 
 
5e2d609
6597a2f
 
 
5e2d609
6597a2f
 
0a4b920
6597a2f
 
 
 
5e2d609
6597a2f
 
 
 
5e2d609
6597a2f
 
 
 
 
 
 
0a4b920
 
 
6597a2f
0a4b920
5e2d609
0a4b920
5e2d609
0a4b920
 
6597a2f
 
 
 
05be7ae
6597a2f
 
 
0a4b920
6597a2f
 
 
 
 
 
 
 
 
5e2d609
6597a2f
 
 
 
 
 
0a4b920
5e2d609
6597a2f
 
 
 
0a4b920
152d61c
 
0a4b920
152d61c
 
6597a2f
 
5e2d609
 
 
 
7abe73c
 
6597a2f
 
 
 
 
0a4b920
5e2d609
6597a2f
 
7abe73c
6597a2f
5e2d609
6597a2f
7abe73c
6597a2f
 
c39c802
152d61c
7abe73c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
##########################################
# Step 0: Import required libraries
##########################################
import streamlit as st  # Web interface framework
from transformers import (
    pipeline,
    SpeechT5Processor,
    SpeechT5ForTextToSpeech,
    SpeechT5HifiGan,
    AutoModelForCausalLM,
    AutoTokenizer
)  # AI model components
from datasets import load_dataset  # Voice embeddings
import torch  # Tensor computation
import soundfile as sf  # Audio file handling
import time  # Execution timing

##########################################
# Initial configuration (MUST be first)
##########################################
st.set_page_config(
    page_title="Just Comment",
    page_icon="💬",
    layout="centered",
    initial_sidebar_state="collapsed"
)

##########################################
# Optimized model loading with caching
##########################################
@st.cache_resource(show_spinner=False)
def _load_models():
    """Load and cache models with maximum optimization"""
    # Initialize device-agnostic model loading
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Load emotion classifier with optimized settings
    emotion_pipe = pipeline(
        "text-classification",
        model="Thea231/jhartmann_emotion_finetuning",
        device=device,
        truncation=True,
        padding=True
    )
    
    # Load text generation model with 4-bit quantization
    textgen_tokenizer = AutoTokenizer.from_pretrained(
        "Qwen/Qwen1.5-0.5B",
        use_fast=True
    )
    textgen_model = AutoModelForCausalLM.from_pretrained(
        "Qwen/Qwen1.5-0.5B",
        torch_dtype=torch.float16,
        device_map="auto"
    )
    
    # Load TTS components with hardware acceleration
    tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
    tts_model = SpeechT5ForTextToSpeech.from_pretrained(
        "microsoft/speecht5_tts",
        torch_dtype=torch.float16
    ).to(device)
    tts_vocoder = SpeechT5HifiGan.from_pretrained(
        "microsoft/speecht5_hifigan",
        torch_dtype=torch.float16
    ).to(device)
    
    # Preload speaker embeddings
    speaker_embeddings = torch.tensor(
        load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")[7306]["xvector"]
    ).unsqueeze(0).to(device)
    
    return {
        'emotion': emotion_pipe,
        'textgen_tokenizer': textgen_tokenizer,
        'textgen_model': textgen_model,
        'tts_processor': tts_processor,
        'tts_model': tts_model,
        'tts_vocoder': tts_vocoder,
        'speaker_embeddings': speaker_embeddings,
        'device': device
    }

##########################################
# UI Components
##########################################
def _display_interface():
    """Render optimized user interface"""
    st.title("Just Comment")
    st.markdown(f"### I'm listening to you, my friend~")  # f-string usage
    
    return st.text_area(
        "📝 Enter your comment:",
        placeholder="Type your message here...",
        height=150,
        key="user_input"
    )

##########################################
# Core Processing Functions
##########################################
def _analyze_emotion(text, classifier):
    """Fast emotion analysis with early stopping"""
    start_time = time.time()
    results = classifier(text[:512], return_all_scores=True)[0]  # Limit input length
    valid_emotions = {'sadness', 'joy', 'love', 'anger', 'fear', 'surprise'}
    
    # Find dominant emotion
    dominant = max(
        (e for e in results if e['label'].lower() in valid_emotions),
        key=lambda x: x['score'],
        default={'label': 'neutral', 'score': 1.0}
    )
    
    st.write(f"⏱️ Emotion analysis time: {time.time()-start_time:.2f}s")
    return dominant

def _generate_prompt(text, emotion):
    """Optimized prompt templates for all emotions"""
    prompt_templates = {
        "sadness": f"Sadness detected: {{input}}\nRespond with: 1. Empathy 2. Support 3. Solution\nResponse:",
        "joy": f"Joy detected: {{input}}\nRespond with: 1. Thanks 2. Appreciation 3. Engagement\nResponse:",
        "love": f"Love detected: {{input}}\nRespond with: 1. Warmth 2. Community 3. Exclusive Offer\nResponse:",
        "anger": f"Anger detected: {{input}}\nRespond with: 1. Apology 2. Action 3. Compensation\nResponse:",
        "fear": f"Fear detected: {{input}}\nRespond with: 1. Reassurance 2. Safety 3. Support\nResponse:",
        "surprise": f"Surprise detected: {{input}}\nRespond with: 1. Acknowledgement 2. Solution 3. Follow-up\nResponse:",
        "neutral": f"Feedback: {{input}}\nRespond professionally:\n1. Acknowledgement\n2. Assistance\n3. Next Steps\nResponse:"
    }
    return prompt_templates[emotion.lower()].format(input=text[:300])  # Limit input length

def _process_response(raw_text):
    """Fast response processing with validation"""
    # Extract response after last marker
    response = raw_text.split("Response:")[-1].strip()
    
    # Ensure complete sentences
    if '.' in response:
        response = response.rsplit('.', 1)[0] + '.'
    
    # Length control
    return response[:200] if len(response) > 50 else "Thank you for your feedback. We'll respond shortly."

def _generate_text(user_input, models):
    """Ultra-fast text generation pipeline"""
    start_time = time.time()
    
    # Emotion analysis
    emotion = _analyze_emotion(user_input, models['emotion'])
    
    # Generate prompt
    prompt = _generate_prompt(user_input, emotion['label'])
    
    # Tokenize and generate
    inputs = models['textgen_tokenizer'](
        prompt,
        return_tensors="pt",
        max_length=128,
        truncation=True
    ).to(models['device'])
    
    outputs = models['textgen_model'].generate(
        inputs.input_ids,
        max_new_tokens=80,  # Strict limit for speed
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        pad_token_id=models['textgen_tokenizer'].eos_token_id
    )
    
    # Decode and process
    generated = models['textgen_tokenizer'].decode(
        outputs[0],
        skip_special_tokens=True
    )
    
    st.write(f"⏱️ Text generation time: {time.time()-start_time:.2f}s")
    return _process_response(generated)

def _generate_speech(text, models):
    """Hardware-accelerated speech synthesis"""
    start_time = time.time()
    
    # Process text
    inputs = models['tts_processor'](
        text=text[:150],  # Limit text length
        return_tensors="pt"
    ).to(models['device'])
    
    # Generate audio
    with torch.inference_mode():
        spectrogram = models['tts_model'].generate_speech(
            inputs["input_ids"],
            models['speaker_embeddings']
        )
        waveform = models['tts_vocoder'](spectrogram)
    
    # Save optimized audio file
    sf.write("response.wav", waveform.cpu().numpy(), 16000)
    
    st.write(f"⏱️ Speech synthesis time: {time.time()-start_time:.2f}s")
    return "response.wav"

##########################################
# Main Application Flow
##########################################
def main():
    """Optimized execution flow"""
    # Load models first
    ml_models = _load_models()
    
    # Display interface
    user_input = _display_interface()
    
    if user_input:
        total_start = time.time()
        
        # Text generation
        with st.spinner("🚀 Analyzing & generating response..."):
            text_response = _generate_text(user_input, ml_models)
        
        # Display results
        st.subheader(f"📄 Generated Response")
        st.markdown(f"```\n{text_response}\n```")
        
        # Audio generation
        with st.spinner("🔊 Converting to speech..."):
            audio_file = _generate_speech(text_response, ml_models)
            st.audio(audio_file, format="audio/wav")
        
        st.write(f"⏱️ Total execution time: {time.time()-total_start:.2f}s")

if __name__ == "__main__":
    main()