File size: 8,486 Bytes
152d61c
5e4841e
152d61c
5e4841e
f551227
7abe73c
 
 
 
 
 
5e4841e
f551227
 
 
 
c39c802
0e85ac7
5e4841e
0e85ac7
f551227
05be7ae
0e85ac7
5e4841e
0e85ac7
 
152d61c
5e4841e
152d61c
0a4b920
5e4841e
f551227
 
6597a2f
f551227
6597a2f
 
 
 
5e4841e
6597a2f
 
f551227
5e4841e
 
6597a2f
 
 
 
 
5e4841e
6597a2f
 
 
 
 
 
 
 
 
 
f551227
5e4841e
6597a2f
 
 
7abe73c
5e4841e
 
 
 
 
 
 
 
7abe73c
3970052
152d61c
5e4841e
152d61c
5e4841e
f551227
 
 
 
05be7ae
5e4841e
0a4b920
5e4841e
0a4b920
e4cf4e2
152d61c
5e4841e
7abe73c
5e4841e
f551227
 
 
 
5e4841e
f551227
6597a2f
5e4841e
6597a2f
0a4b920
5e4841e
f551227
5e4841e
f551227
 
 
 
 
 
 
0a4b920
f551227
 
0a4b920
5e4841e
f551227
 
 
 
 
 
 
 
5e4841e
6597a2f
 
5e4841e
6597a2f
5e4841e
0a4b920
f551227
5e4841e
0a4b920
f551227
 
0a4b920
5e2d609
0a4b920
5e4841e
0a4b920
 
f551227
5e4841e
f551227
5e4841e
f551227
 
 
0a4b920
5e4841e
f551227
5e4841e
f551227
6597a2f
5e4841e
5e2d609
f551227
5e4841e
6597a2f
5e4841e
6597a2f
5e4841e
6597a2f
f551227
5e4841e
152d61c
 
5e4841e
152d61c
 
f551227
 
 
7abe73c
f551227
 
 
 
 
 
 
 
 
 
 
 
c39c802
f551227
152d61c
f551227
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
##########################################
# Step 0: Essential imports
##########################################
import streamlit as st  # Web interface
from transformers import (  # AI components: emotion analysis, TTS, and text generation
    pipeline,
    SpeechT5Processor,
    SpeechT5ForTextToSpeech,
    SpeechT5HifiGan,
    AutoModelForCausalLM,
    AutoTokenizer
)
from datasets import load_dataset  # To load speaker embeddings dataset
import torch  # For tensor operations
import soundfile as sf  # For writing audio files
import sentencepiece  # Required for SpeechT5Processor tokenization

##########################################
# Initial configuration (MUST BE FIRST)
##########################################
st.set_page_config(  # Configure the web page
    page_title="Just Comment",
    page_icon="πŸ’¬",
    layout="centered"
)

##########################################
# Optimized model loader with caching
##########################################
@st.cache_resource(show_spinner=False)
def _load_components():
    """Load and cache all models with hardware optimization."""
    device = "cuda" if torch.cuda.is_available() else "cpu"  # Detect available device
    
    # Emotion classifier (fast and truncated)
    emotion_pipe = pipeline(
        "text-classification",
        model="Thea231/jhartmann_emotion_finetuning",
        device=device,
        truncation=True
    )
    
    # Text generator (optimized with FP16 and auto device mapping)
    text_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B")
    text_model = AutoModelForCausalLM.from_pretrained(
        "Qwen/Qwen1.5-0.5B",
        torch_dtype=torch.float16,
        device_map="auto"
    )
    
    # TTS system (accelerated)
    tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
    tts_model = SpeechT5ForTextToSpeech.from_pretrained(
        "microsoft/speecht5_tts",
        torch_dtype=torch.float16
    ).to(device)
    tts_vocoder = SpeechT5HifiGan.from_pretrained(
        "microsoft/speecht5_hifigan",
        torch_dtype=torch.float16
    ).to(device)
    
    # Preloaded voice profile for TTS
    speaker_emb = torch.tensor(
        load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")[7306]["xvector"]
    ).unsqueeze(0).to(device)
    
    return {
        "emotion": emotion_pipe,
        "text_model": text_model,
        "text_tokenizer": text_tokenizer,
        "tts_processor": tts_processor,
        "tts_model": tts_model,
        "tts_vocoder": tts_vocoder,
        "speaker_emb": speaker_emb,
        "device": device
    }

##########################################
# User interface components
##########################################
def _show_interface():
    """Render the input interface"""
    st.title("πŸš€ Just Comment")  # Display the title with a rocket icon
    st.markdown("### I'm listening to you, my friend~")  # Display the friendly subtitle
    return st.text_area(  # Return user's comment input
        "πŸ“ Enter your comment:",
        placeholder="Share your thoughts...",
        height=150,
        key="input"
    )

##########################################
# Core processing functions
##########################################
def _fast_emotion(text, analyzer):
    """Rapid emotion detection with input length limit."""
    result = analyzer(text[:256], return_all_scores=True)[0]  # Analyze only the first 256 characters for speed
    valid_emotions = ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']
    # Select the emotion from valid ones or default to neutral
    return max(
        (e for e in result if e['label'].lower() in valid_emotions),
        key=lambda x: x['score'],
        default={'label': 'neutral', 'score': 0}
    )

def _build_prompt(text, emotion):
    """Template-based prompt engineering in continuous prose (no bullet points)."""
    templates = {
        "sadness": "I sensed sadness in your comment: {text}. We are truly sorry and are here to support you.",
        "joy": "Your comment radiates joy: {text}. Thank you for your bright feedback; we look forward to serving you even better.",
        "love": "Your message exudes love: {text}. We appreciate your heartfelt words and cherish our connection with you.",
        "anger": "I understand your comment reflects anger: {text}. Please accept our sincere apologies as we work to resolve your concerns.",
        "fear": "It seems you feel fear in your comment: {text}. We want to reassure you that your safety and satisfaction are our priority.",
        "surprise": "Your comment conveys surprise: {text}. We are delighted by your experience and will strive to exceed your expectations.",
        "neutral": "Thank you for your comment: {text}. We remain committed to providing you with outstanding service."
    }
    # Build and return a continuous prompt with the user comment truncated to 200 characters
    return templates.get(emotion.lower(), templates["neutral"]).format(text=text[:200])

def _generate_response(text, models):
    """Optimized text generation pipeline using the detected emotion."""
    # Detect the dominant emotion quickly
    detected = _fast_emotion(text, models["emotion"])
    # Build prompt based on detected emotion (continuous sentences)
    prompt = _build_prompt(text, detected["label"])
    print(f"Generated prompt: {prompt}")  # Print prompt using f-string for debugging

    # Generate text using the Qwen model
    inputs = models["text_tokenizer"](
        prompt,
        return_tensors="pt",
        max_length=100,
        truncation=True
    ).to(models["device"])
    
    # Generate the response ensuring balanced length (approximately 50-200 tokens)
    output = models["text_model"].generate(
        inputs.input_ids,
        max_new_tokens=120,  # Upper bound tokens for answer
        min_length=50,       # Lower bound to ensure completeness
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        pad_token_id=models["text_tokenizer"].eos_token_id
    )
    
    input_len = inputs.input_ids.shape[1]  # Determine the length of the prompt tokens
    full_text = models["text_tokenizer"].decode(output[0], skip_special_tokens=True)
    # Extract only the generated portion after "Response:" if present
    response = full_text.split("Response:")[-1].strip()
    print(f"Generated response: {response}")  # Debug print using f-string
    # Return response ensuring it is within 50-200 words (approximation by character length here)
    return response[:200]  # Truncate to 200 characters as an approximation

def _text_to_speech(text, models):
    """Efficiently synthesize speech for the given text."""
    inputs = models["tts_processor"](
        text=text[:150],  # Limit text length for TTS to 150 characters
        return_tensors="pt"
    ).to(models["device"])
    
    with torch.inference_mode():  # Fast, no-grad inference
        spectrogram = models["tts_model"].generate_speech(
            inputs["input_ids"],
            models["speaker_emb"]
        )
        audio = models["tts_vocoder"](spectrogram)
    
    sf.write("output.wav", audio.cpu().numpy(), 16000)  # Save generated audio as .wav at 16kHz
    return "output.wav"

##########################################
# Main application flow
##########################################
def main():
    """Primary execution controller."""
    components = _load_components()  # Load all models and components
    user_input = _show_interface()  # Render input interface and capture user comment
    
    if user_input:  # If a comment is provided
        with st.spinner("πŸ” Generating response..."):
            generated_response = _generate_response(user_input, components)  # Generate response based on emotion and text
        st.subheader("πŸ“„ Response")
        st.markdown(
            f"<p style='color:#3498DB; font-size:20px;'>{generated_response}</p>",
            unsafe_allow_html=True
        )  # Display the generated response in styled format
        with st.spinner("πŸ”Š Synthesizing audio..."):
            audio_file = _text_to_speech(generated_response, components)  # Convert response to speech
            st.audio(audio_file, format="audio/wav", start_time=0)  # Embed auto-playing audio player
        print(f"Final generated response: {generated_response}")  # Debug output using f-string

# Run the main function when the script is executed
if __name__ == "__main__":
    main()  # Call the main function