Spaces:
Running
Running
########################################## | |
# Step 0: Import required libraries | |
########################################## | |
import streamlit as st # Web interface framework | |
from transformers import ( | |
pipeline, | |
SpeechT5Processor, | |
SpeechT5ForTextToSpeech, | |
SpeechT5HifiGan, | |
AutoModelForCausalLM, | |
AutoTokenizer | |
) # AI model components | |
from datasets import load_dataset # Voice embeddings | |
import torch # Tensor computation | |
import soundfile as sf # Audio file handling | |
import time # Execution timing | |
########################################## | |
# Initial configuration (MUST be first) | |
########################################## | |
st.set_page_config( | |
page_title="Just Comment", | |
page_icon="💬", | |
layout="centered", | |
initial_sidebar_state="collapsed" | |
) | |
########################################## | |
# Optimized model loading with caching | |
########################################## | |
def _load_models(): | |
"""Load and cache models with maximum optimization""" | |
# Initialize device-agnostic model loading | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load emotion classifier with optimized settings | |
emotion_pipe = pipeline( | |
"text-classification", | |
model="Thea231/jhartmann_emotion_finetuning", | |
device=device, | |
truncation=True, | |
padding=True | |
) | |
# Load text generation model with 4-bit quantization | |
textgen_tokenizer = AutoTokenizer.from_pretrained( | |
"Qwen/Qwen1.5-0.5B", | |
use_fast=True | |
) | |
textgen_model = AutoModelForCausalLM.from_pretrained( | |
"Qwen/Qwen1.5-0.5B", | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
# Load TTS components with hardware acceleration | |
tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") | |
tts_model = SpeechT5ForTextToSpeech.from_pretrained( | |
"microsoft/speecht5_tts", | |
torch_dtype=torch.float16 | |
).to(device) | |
tts_vocoder = SpeechT5HifiGan.from_pretrained( | |
"microsoft/speecht5_hifigan", | |
torch_dtype=torch.float16 | |
).to(device) | |
# Preload speaker embeddings | |
speaker_embeddings = torch.tensor( | |
load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")[7306]["xvector"] | |
).unsqueeze(0).to(device) | |
return { | |
'emotion': emotion_pipe, | |
'textgen_tokenizer': textgen_tokenizer, | |
'textgen_model': textgen_model, | |
'tts_processor': tts_processor, | |
'tts_model': tts_model, | |
'tts_vocoder': tts_vocoder, | |
'speaker_embeddings': speaker_embeddings, | |
'device': device | |
} | |
########################################## | |
# UI Components | |
########################################## | |
def _display_interface(): | |
"""Render optimized user interface""" | |
st.title("Just Comment") | |
st.markdown(f"### I'm listening to you, my friend~") # f-string usage | |
return st.text_area( | |
"📝 Enter your comment:", | |
placeholder="Type your message here...", | |
height=150, | |
key="user_input" | |
) | |
########################################## | |
# Core Processing Functions | |
########################################## | |
def _analyze_emotion(text, classifier): | |
"""Fast emotion analysis with early stopping""" | |
start_time = time.time() | |
results = classifier(text[:512], return_all_scores=True)[0] # Limit input length | |
valid_emotions = {'sadness', 'joy', 'love', 'anger', 'fear', 'surprise'} | |
# Find dominant emotion | |
dominant = max( | |
(e for e in results if e['label'].lower() in valid_emotions), | |
key=lambda x: x['score'], | |
default={'label': 'neutral', 'score': 1.0} | |
) | |
st.write(f"⏱️ Emotion analysis time: {time.time()-start_time:.2f}s") | |
return dominant | |
def _generate_prompt(text, emotion): | |
"""Optimized prompt templates for all emotions""" | |
prompt_templates = { | |
"sadness": f"Sadness detected: {{input}}\nRespond with: 1. Empathy 2. Support 3. Solution\nResponse:", | |
"joy": f"Joy detected: {{input}}\nRespond with: 1. Thanks 2. Appreciation 3. Engagement\nResponse:", | |
"love": f"Love detected: {{input}}\nRespond with: 1. Warmth 2. Community 3. Exclusive Offer\nResponse:", | |
"anger": f"Anger detected: {{input}}\nRespond with: 1. Apology 2. Action 3. Compensation\nResponse:", | |
"fear": f"Fear detected: {{input}}\nRespond with: 1. Reassurance 2. Safety 3. Support\nResponse:", | |
"surprise": f"Surprise detected: {{input}}\nRespond with: 1. Acknowledgement 2. Solution 3. Follow-up\nResponse:", | |
"neutral": f"Feedback: {{input}}\nRespond professionally:\n1. Acknowledgement\n2. Assistance\n3. Next Steps\nResponse:" | |
} | |
return prompt_templates[emotion.lower()].format(input=text[:300]) # Limit input length | |
def _process_response(raw_text): | |
"""Fast response processing with validation""" | |
# Extract response after last marker | |
response = raw_text.split("Response:")[-1].strip() | |
# Ensure complete sentences | |
if '.' in response: | |
response = response.rsplit('.', 1)[0] + '.' | |
# Length control | |
return response[:200] if len(response) > 50 else "Thank you for your feedback. We'll respond shortly." | |
def _generate_text(user_input, models): | |
"""Ultra-fast text generation pipeline""" | |
start_time = time.time() | |
# Emotion analysis | |
emotion = _analyze_emotion(user_input, models['emotion']) | |
# Generate prompt | |
prompt = _generate_prompt(user_input, emotion['label']) | |
# Tokenize and generate | |
inputs = models['textgen_tokenizer']( | |
prompt, | |
return_tensors="pt", | |
max_length=128, | |
truncation=True | |
).to(models['device']) | |
outputs = models['textgen_model'].generate( | |
inputs.input_ids, | |
max_new_tokens=80, # Strict limit for speed | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
pad_token_id=models['textgen_tokenizer'].eos_token_id | |
) | |
# Decode and process | |
generated = models['textgen_tokenizer'].decode( | |
outputs[0], | |
skip_special_tokens=True | |
) | |
st.write(f"⏱️ Text generation time: {time.time()-start_time:.2f}s") | |
return _process_response(generated) | |
def _generate_speech(text, models): | |
"""Hardware-accelerated speech synthesis""" | |
start_time = time.time() | |
# Process text | |
inputs = models['tts_processor']( | |
text=text[:150], # Limit text length | |
return_tensors="pt" | |
).to(models['device']) | |
# Generate audio | |
with torch.inference_mode(): | |
spectrogram = models['tts_model'].generate_speech( | |
inputs["input_ids"], | |
models['speaker_embeddings'] | |
) | |
waveform = models['tts_vocoder'](spectrogram) | |
# Save optimized audio file | |
sf.write("response.wav", waveform.cpu().numpy(), 16000) | |
st.write(f"⏱️ Speech synthesis time: {time.time()-start_time:.2f}s") | |
return "response.wav" | |
########################################## | |
# Main Application Flow | |
########################################## | |
def main(): | |
"""Optimized execution flow""" | |
# Load models first | |
ml_models = _load_models() | |
# Display interface | |
user_input = _display_interface() | |
if user_input: | |
total_start = time.time() | |
# Text generation | |
with st.spinner("🚀 Analyzing & generating response..."): | |
text_response = _generate_text(user_input, ml_models) | |
# Display results | |
st.subheader(f"📄 Generated Response") | |
st.markdown(f"```\n{text_response}\n```") | |
# Audio generation | |
with st.spinner("🔊 Converting to speech..."): | |
audio_file = _generate_speech(text_response, ml_models) | |
st.audio(audio_file, format="audio/wav") | |
st.write(f"⏱️ Total execution time: {time.time()-total_start:.2f}s") | |
if __name__ == "__main__": | |
main() |