Comment_Reply / app.py
joey1101's picture
Update app.py
d7ef86b verified
raw
history blame
7.49 kB
##########################################
# Step 0: Import required libraries
##########################################
import streamlit as st # For web interface
from transformers import (
pipeline,
SpeechT5Processor,
SpeechT5ForTextToSpeech,
SpeechT5HifiGan,
AutoModelForCausalLM,
AutoTokenizer
) # AI model components
from datasets import load_dataset # For voice embeddings
import torch # Tensor computations
import soundfile as sf # Audio file handling
import re # Regular expressions for text processing
##########################################
# Initial configuration (MUST be first)
##########################################
st.set_page_config(
page_title="Just Comment",
page_icon="๐Ÿ’ฌ",
layout="centered",
initial_sidebar_state="collapsed"
)
##########################################
# Global model loading with caching
##########################################
@st.cache_resource(show_spinner=False)
def _load_models():
"""Load and cache all ML models with optimized settings"""
return {
# Emotion classification pipeline
'emotion': pipeline(
"text-classification",
model="Thea231/jhartmann_emotion_finetuning",
truncation=True # Enable text truncation for long inputs
),
# Text generation components
'textgen_tokenizer': AutoTokenizer.from_pretrained(
"Qwen/Qwen1.5-0.5B",
use_fast=True # Enable fast tokenization
),
'textgen_model': AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen1.5-0.5B",
torch_dtype=torch.float16 # Use half-precision for faster inference
),
# Text-to-speech components
'tts_processor': SpeechT5Processor.from_pretrained("microsoft/speecht5_tts"),
'tts_model': SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts"),
'tts_vocoder': SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan"),
# Preloaded speaker embeddings
'speaker_embeddings': torch.tensor(
load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")[7306]["xvector"]
).unsqueeze(0)
}
##########################################
# UI Components
##########################################
def _display_interface():
"""Render user interface elements"""
st.title("Just Comment")
st.markdown("### I'm listening to you, my friend๏ฝž")
return st.text_area(
"๐Ÿ“ Enter your comment:",
placeholder="Type your message here...",
height=150,
key="user_input"
)
##########################################
# Core Processing Functions
##########################################
def _analyze_emotion(text, classifier):
"""Identify dominant emotion with confidence threshold"""
results = classifier(text, return_all_scores=True)[0]
valid_emotions = {'sadness', 'joy', 'love', 'anger', 'fear', 'surprise'}
filtered = [e for e in results if e['label'].lower() in valid_emotions]
return max(filtered, key=lambda x: x['score'])
def _generate_prompt(text, emotion):
"""Create structured prompts for all emotion types"""
prompt_templates = {
"sadness": (
"Sadness detected: {input}\n"
"Required response structure:\n"
"1. Empathetic acknowledgment\n2. Support offer\n3. Solution proposal\n"
"Response:"
),
"joy": (
"Joy detected: {input}\n"
"Required response structure:\n"
"1. Enthusiastic thanks\n2. Positive reinforcement\n3. Future engagement\n"
"Response:"
),
"love": (
"Affection detected: {input}\n"
"Required response structure:\n"
"1. Warm appreciation\n2. Community focus\n3. Exclusive benefit\n"
"Response:"
),
"anger": (
"Anger detected: {input}\n"
"Required response structure:\n"
"1. Sincere apology\n2. Action steps\n3. Compensation\n"
"Response:"
),
"fear": (
"Concern detected: {input}\n"
"Required response structure:\n"
"1. Reassurance\n2. Safety measures\n3. Support options\n"
"Response:"
),
"surprise": (
"Surprise detected: {input}\n"
"Required response structure:\n"
"1. Acknowledge uniqueness\n2. Creative solution\n3. Follow-up\n"
"Response:"
)
}
return prompt_templates.get(emotion.lower(), "").format(input=text)
def _process_response(raw_text):
"""Clean and format generated response"""
# Extract text after last "Response:" marker
processed = raw_text.split("Response:")[-1].strip()
# Remove incomplete sentences
if '.' in processed:
processed = processed.rsplit('.', 1)[0] + '.'
# Ensure length between 50-200 characters
return processed[:200].strip() if len(processed) > 50 else "Thank you for your feedback. We value your input and will respond shortly."
def _generate_text_response(input_text, models):
"""Generate optimized text response with timing controls"""
# Emotion analysis
emotion = _analyze_emotion(input_text, models['emotion'])
# Prompt engineering
prompt = _generate_prompt(input_text, emotion['label'])
# Text generation with optimized parameters
inputs = models['textgen_tokenizer'](prompt, return_tensors="pt").to('cpu')
outputs = models['textgen_model'].generate(
inputs.input_ids,
max_new_tokens=100, # Strict token limit
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=models['textgen_tokenizer'].eos_token_id
)
return _process_response(
models['textgen_tokenizer'].decode(outputs[0], skip_special_tokens=True)
)
def _generate_audio_response(text, models):
"""Convert text to speech with performance optimizations"""
# Process text input
inputs = models['tts_processor'](text=text, return_tensors="pt")
# Generate spectrogram
spectrogram = models['tts_model'].generate_speech(
inputs["input_ids"],
models['speaker_embeddings']
)
# Generate waveform with optimizations
with torch.no_grad(): # Disable gradient calculation
waveform = models['tts_vocoder'](spectrogram)
# Save audio file
sf.write("response.wav", waveform.numpy(), samplerate=16000)
return "response.wav"
##########################################
# Main Application Flow
##########################################
def main():
"""Primary execution flow"""
# Load models once
ml_models = _load_models()
# Display interface
user_input = _display_interface()
if user_input:
# Text generation stage
with st.spinner("๐Ÿ” Analyzing emotions and generating response..."):
text_response = _generate_text_response(user_input, ml_models)
# Display results
st.subheader("๐Ÿ“„ Generated Response")
st.markdown(f"```\n{text_response}\n```") # f-string formatted output
# Audio generation stage
with st.spinner("๐Ÿ”Š Converting to speech..."):
audio_file = _generate_audio_response(text_response, ml_models)
st.audio(audio_file, format="audio/wav")
if __name__ == "__main__":
main()