Spaces:
Running
Running
########################################## | |
# Step 0: Import required libraries | |
########################################## | |
import streamlit as st # For web interface | |
from transformers import ( | |
pipeline, | |
SpeechT5Processor, | |
SpeechT5ForTextToSpeech, | |
SpeechT5HifiGan, | |
AutoModelForCausalLM, | |
AutoTokenizer | |
) # AI model components | |
from datasets import load_dataset # For voice embeddings | |
import torch # Tensor computations | |
import soundfile as sf # Audio file handling | |
import re # Regular expressions for text processing | |
########################################## | |
# Initial configuration (MUST be first) | |
########################################## | |
st.set_page_config( | |
page_title="Just Comment", | |
page_icon="๐ฌ", | |
layout="centered", | |
initial_sidebar_state="collapsed" | |
) | |
########################################## | |
# Global model loading with caching | |
########################################## | |
def _load_models(): | |
"""Load and cache all ML models with optimized settings""" | |
return { | |
# Emotion classification pipeline | |
'emotion': pipeline( | |
"text-classification", | |
model="Thea231/jhartmann_emotion_finetuning", | |
truncation=True # Enable text truncation for long inputs | |
), | |
# Text generation components | |
'textgen_tokenizer': AutoTokenizer.from_pretrained( | |
"Qwen/Qwen1.5-0.5B", | |
use_fast=True # Enable fast tokenization | |
), | |
'textgen_model': AutoModelForCausalLM.from_pretrained( | |
"Qwen/Qwen1.5-0.5B", | |
torch_dtype=torch.float16 # Use half-precision for faster inference | |
), | |
# Text-to-speech components | |
'tts_processor': SpeechT5Processor.from_pretrained("microsoft/speecht5_tts"), | |
'tts_model': SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts"), | |
'tts_vocoder': SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan"), | |
# Preloaded speaker embeddings | |
'speaker_embeddings': torch.tensor( | |
load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")[7306]["xvector"] | |
).unsqueeze(0) | |
} | |
########################################## | |
# UI Components | |
########################################## | |
def _display_interface(): | |
"""Render user interface elements""" | |
st.title("Just Comment") | |
st.markdown("### I'm listening to you, my friend๏ฝ") | |
return st.text_area( | |
"๐ Enter your comment:", | |
placeholder="Type your message here...", | |
height=150, | |
key="user_input" | |
) | |
########################################## | |
# Core Processing Functions | |
########################################## | |
def _analyze_emotion(text, classifier): | |
"""Identify dominant emotion with confidence threshold""" | |
results = classifier(text, return_all_scores=True)[0] | |
valid_emotions = {'sadness', 'joy', 'love', 'anger', 'fear', 'surprise'} | |
filtered = [e for e in results if e['label'].lower() in valid_emotions] | |
return max(filtered, key=lambda x: x['score']) | |
def _generate_prompt(text, emotion): | |
"""Create structured prompts for all emotion types""" | |
prompt_templates = { | |
"sadness": ( | |
"Sadness detected: {input}\n" | |
"Required response structure:\n" | |
"1. Empathetic acknowledgment\n2. Support offer\n3. Solution proposal\n" | |
"Response:" | |
), | |
"joy": ( | |
"Joy detected: {input}\n" | |
"Required response structure:\n" | |
"1. Enthusiastic thanks\n2. Positive reinforcement\n3. Future engagement\n" | |
"Response:" | |
), | |
"love": ( | |
"Affection detected: {input}\n" | |
"Required response structure:\n" | |
"1. Warm appreciation\n2. Community focus\n3. Exclusive benefit\n" | |
"Response:" | |
), | |
"anger": ( | |
"Anger detected: {input}\n" | |
"Required response structure:\n" | |
"1. Sincere apology\n2. Action steps\n3. Compensation\n" | |
"Response:" | |
), | |
"fear": ( | |
"Concern detected: {input}\n" | |
"Required response structure:\n" | |
"1. Reassurance\n2. Safety measures\n3. Support options\n" | |
"Response:" | |
), | |
"surprise": ( | |
"Surprise detected: {input}\n" | |
"Required response structure:\n" | |
"1. Acknowledge uniqueness\n2. Creative solution\n3. Follow-up\n" | |
"Response:" | |
) | |
} | |
return prompt_templates.get(emotion.lower(), "").format(input=text) | |
def _process_response(raw_text): | |
"""Clean and format generated response""" | |
# Extract text after last "Response:" marker | |
processed = raw_text.split("Response:")[-1].strip() | |
# Remove incomplete sentences | |
if '.' in processed: | |
processed = processed.rsplit('.', 1)[0] + '.' | |
# Ensure length between 50-200 characters | |
return processed[:200].strip() if len(processed) > 50 else "Thank you for your feedback. We value your input and will respond shortly." | |
def _generate_text_response(input_text, models): | |
"""Generate optimized text response with timing controls""" | |
# Emotion analysis | |
emotion = _analyze_emotion(input_text, models['emotion']) | |
# Prompt engineering | |
prompt = _generate_prompt(input_text, emotion['label']) | |
# Text generation with optimized parameters | |
inputs = models['textgen_tokenizer'](prompt, return_tensors="pt").to('cpu') | |
outputs = models['textgen_model'].generate( | |
inputs.input_ids, | |
max_new_tokens=100, # Strict token limit | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
pad_token_id=models['textgen_tokenizer'].eos_token_id | |
) | |
return _process_response( | |
models['textgen_tokenizer'].decode(outputs[0], skip_special_tokens=True) | |
) | |
def _generate_audio_response(text, models): | |
"""Convert text to speech with performance optimizations""" | |
# Process text input | |
inputs = models['tts_processor'](text=text, return_tensors="pt") | |
# Generate spectrogram | |
spectrogram = models['tts_model'].generate_speech( | |
inputs["input_ids"], | |
models['speaker_embeddings'] | |
) | |
# Generate waveform with optimizations | |
with torch.no_grad(): # Disable gradient calculation | |
waveform = models['tts_vocoder'](spectrogram) | |
# Save audio file | |
sf.write("response.wav", waveform.numpy(), samplerate=16000) | |
return "response.wav" | |
########################################## | |
# Main Application Flow | |
########################################## | |
def main(): | |
"""Primary execution flow""" | |
# Load models once | |
ml_models = _load_models() | |
# Display interface | |
user_input = _display_interface() | |
if user_input: | |
# Text generation stage | |
with st.spinner("๐ Analyzing emotions and generating response..."): | |
text_response = _generate_text_response(user_input, ml_models) | |
# Display results | |
st.subheader("๐ Generated Response") | |
st.markdown(f"```\n{text_response}\n```") # f-string formatted output | |
# Audio generation stage | |
with st.spinner("๐ Converting to speech..."): | |
audio_file = _generate_audio_response(text_response, ml_models) | |
st.audio(audio_file, format="audio/wav") | |
if __name__ == "__main__": | |
main() |