Comment_Reply / app.py
joey1101's picture
Update app.py
6597a2f verified
raw
history blame
8 kB
##########################################
# Step 0: Import required libraries
##########################################
import streamlit as st # Web interface framework
from transformers import (
pipeline,
SpeechT5Processor,
SpeechT5ForTextToSpeech,
SpeechT5HifiGan,
AutoModelForCausalLM,
AutoTokenizer
) # AI model components
from datasets import load_dataset # Voice embeddings
import torch # Tensor computation
import soundfile as sf # Audio file handling
import time # Execution timing
##########################################
# Initial configuration (MUST be first)
##########################################
st.set_page_config(
page_title="Just Comment",
page_icon="💬",
layout="centered",
initial_sidebar_state="collapsed"
)
##########################################
# Optimized model loading with caching
##########################################
@st.cache_resource(show_spinner=False)
def _load_models():
"""Load and cache models with maximum optimization"""
# Initialize device-agnostic model loading
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load emotion classifier with optimized settings
emotion_pipe = pipeline(
"text-classification",
model="Thea231/jhartmann_emotion_finetuning",
device=device,
truncation=True,
padding=True
)
# Load text generation model with 4-bit quantization
textgen_tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen1.5-0.5B",
use_fast=True
)
textgen_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen1.5-0.5B",
torch_dtype=torch.float16,
device_map="auto"
)
# Load TTS components with hardware acceleration
tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
tts_model = SpeechT5ForTextToSpeech.from_pretrained(
"microsoft/speecht5_tts",
torch_dtype=torch.float16
).to(device)
tts_vocoder = SpeechT5HifiGan.from_pretrained(
"microsoft/speecht5_hifigan",
torch_dtype=torch.float16
).to(device)
# Preload speaker embeddings
speaker_embeddings = torch.tensor(
load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")[7306]["xvector"]
).unsqueeze(0).to(device)
return {
'emotion': emotion_pipe,
'textgen_tokenizer': textgen_tokenizer,
'textgen_model': textgen_model,
'tts_processor': tts_processor,
'tts_model': tts_model,
'tts_vocoder': tts_vocoder,
'speaker_embeddings': speaker_embeddings,
'device': device
}
##########################################
# UI Components
##########################################
def _display_interface():
"""Render optimized user interface"""
st.title("Just Comment")
st.markdown(f"### I'm listening to you, my friend~") # f-string usage
return st.text_area(
"📝 Enter your comment:",
placeholder="Type your message here...",
height=150,
key="user_input"
)
##########################################
# Core Processing Functions
##########################################
def _analyze_emotion(text, classifier):
"""Fast emotion analysis with early stopping"""
start_time = time.time()
results = classifier(text[:512], return_all_scores=True)[0] # Limit input length
valid_emotions = {'sadness', 'joy', 'love', 'anger', 'fear', 'surprise'}
# Find dominant emotion
dominant = max(
(e for e in results if e['label'].lower() in valid_emotions),
key=lambda x: x['score'],
default={'label': 'neutral', 'score': 1.0}
)
st.write(f"⏱️ Emotion analysis time: {time.time()-start_time:.2f}s")
return dominant
def _generate_prompt(text, emotion):
"""Optimized prompt templates for all emotions"""
prompt_templates = {
"sadness": f"Sadness detected: {{input}}\nRespond with: 1. Empathy 2. Support 3. Solution\nResponse:",
"joy": f"Joy detected: {{input}}\nRespond with: 1. Thanks 2. Appreciation 3. Engagement\nResponse:",
"love": f"Love detected: {{input}}\nRespond with: 1. Warmth 2. Community 3. Exclusive Offer\nResponse:",
"anger": f"Anger detected: {{input}}\nRespond with: 1. Apology 2. Action 3. Compensation\nResponse:",
"fear": f"Fear detected: {{input}}\nRespond with: 1. Reassurance 2. Safety 3. Support\nResponse:",
"surprise": f"Surprise detected: {{input}}\nRespond with: 1. Acknowledgement 2. Solution 3. Follow-up\nResponse:",
"neutral": f"Feedback: {{input}}\nRespond professionally:\n1. Acknowledgement\n2. Assistance\n3. Next Steps\nResponse:"
}
return prompt_templates[emotion.lower()].format(input=text[:300]) # Limit input length
def _process_response(raw_text):
"""Fast response processing with validation"""
# Extract response after last marker
response = raw_text.split("Response:")[-1].strip()
# Ensure complete sentences
if '.' in response:
response = response.rsplit('.', 1)[0] + '.'
# Length control
return response[:200] if len(response) > 50 else "Thank you for your feedback. We'll respond shortly."
def _generate_text(user_input, models):
"""Ultra-fast text generation pipeline"""
start_time = time.time()
# Emotion analysis
emotion = _analyze_emotion(user_input, models['emotion'])
# Generate prompt
prompt = _generate_prompt(user_input, emotion['label'])
# Tokenize and generate
inputs = models['textgen_tokenizer'](
prompt,
return_tensors="pt",
max_length=128,
truncation=True
).to(models['device'])
outputs = models['textgen_model'].generate(
inputs.input_ids,
max_new_tokens=80, # Strict limit for speed
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=models['textgen_tokenizer'].eos_token_id
)
# Decode and process
generated = models['textgen_tokenizer'].decode(
outputs[0],
skip_special_tokens=True
)
st.write(f"⏱️ Text generation time: {time.time()-start_time:.2f}s")
return _process_response(generated)
def _generate_speech(text, models):
"""Hardware-accelerated speech synthesis"""
start_time = time.time()
# Process text
inputs = models['tts_processor'](
text=text[:150], # Limit text length
return_tensors="pt"
).to(models['device'])
# Generate audio
with torch.inference_mode():
spectrogram = models['tts_model'].generate_speech(
inputs["input_ids"],
models['speaker_embeddings']
)
waveform = models['tts_vocoder'](spectrogram)
# Save optimized audio file
sf.write("response.wav", waveform.cpu().numpy(), 16000)
st.write(f"⏱️ Speech synthesis time: {time.time()-start_time:.2f}s")
return "response.wav"
##########################################
# Main Application Flow
##########################################
def main():
"""Optimized execution flow"""
# Load models first
ml_models = _load_models()
# Display interface
user_input = _display_interface()
if user_input:
total_start = time.time()
# Text generation
with st.spinner("🚀 Analyzing & generating response..."):
text_response = _generate_text(user_input, ml_models)
# Display results
st.subheader(f"📄 Generated Response")
st.markdown(f"```\n{text_response}\n```")
# Audio generation
with st.spinner("🔊 Converting to speech..."):
audio_file = _generate_speech(text_response, ml_models)
st.audio(audio_file, format="audio/wav")
st.write(f"⏱️ Total execution time: {time.time()-total_start:.2f}s")
if __name__ == "__main__":
main()