Spaces:
Running
Running
########################################## | |
# Step 0: Essential imports | |
########################################## | |
import streamlit as st # Web interface | |
from transformers import ( # AI components: emotion analysis, TTS, and text generation | |
pipeline, | |
SpeechT5Processor, | |
SpeechT5ForTextToSpeech, | |
SpeechT5HifiGan, | |
AutoModelForCausalLM, | |
AutoTokenizer | |
) | |
from datasets import load_dataset # To load speaker embeddings dataset | |
import torch # For tensor operations | |
import soundfile as sf # For writing audio files | |
import sentencepiece # Required for SpeechT5Processor tokenization | |
########################################## | |
# Initial configuration (MUST BE FIRST) | |
########################################## | |
st.set_page_config( # Configure the web page | |
page_title="Just Comment", | |
page_icon="๐ฌ", | |
layout="centered" | |
) | |
########################################## | |
# Optimized model loader with caching | |
########################################## | |
def _load_components(): | |
"""Load and cache all models with hardware optimization.""" | |
device = "cuda" if torch.cuda.is_available() else "cpu" # Detect available device | |
# Emotion classifier (fast and truncated) | |
emotion_pipe = pipeline( | |
"text-classification", | |
model="Thea231/jhartmann_emotion_finetuning", | |
device=device, | |
truncation=True | |
) | |
# Text generator (optimized with FP16 and auto device mapping) | |
text_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B") | |
text_model = AutoModelForCausalLM.from_pretrained( | |
"Qwen/Qwen1.5-0.5B", | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
# TTS system (accelerated) | |
tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") | |
tts_model = SpeechT5ForTextToSpeech.from_pretrained( | |
"microsoft/speecht5_tts", | |
torch_dtype=torch.float16 | |
).to(device) | |
tts_vocoder = SpeechT5HifiGan.from_pretrained( | |
"microsoft/speecht5_hifigan", | |
torch_dtype=torch.float16 | |
).to(device) | |
# Preloaded voice profile for TTS | |
speaker_emb = torch.tensor( | |
load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")[7306]["xvector"] | |
).unsqueeze(0).to(device) | |
return { | |
"emotion": emotion_pipe, | |
"text_model": text_model, | |
"text_tokenizer": text_tokenizer, | |
"tts_processor": tts_processor, | |
"tts_model": tts_model, | |
"tts_vocoder": tts_vocoder, | |
"speaker_emb": speaker_emb, | |
"device": device | |
} | |
########################################## | |
# User interface components | |
########################################## | |
def _show_interface(): | |
"""Render the input interface""" | |
st.title("๐ Just Comment") # Display the title with a rocket icon | |
st.markdown("### I'm listening to you, my friend๏ฝ") # Display the friendly subtitle | |
return st.text_area( # Return user's comment input | |
"๐ Enter your comment:", | |
placeholder="Share your thoughts...", | |
height=150, | |
key="input" | |
) | |
########################################## | |
# Core processing functions | |
########################################## | |
def _fast_emotion(text, analyzer): | |
"""Rapid emotion detection with input length limit.""" | |
result = analyzer(text[:256], return_all_scores=True)[0] # Analyze only the first 256 characters for speed | |
valid_emotions = ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise'] | |
# Select the emotion from valid ones or default to neutral | |
return max( | |
(e for e in result if e['label'].lower() in valid_emotions), | |
key=lambda x: x['score'], | |
default={'label': 'neutral', 'score': 0} | |
) | |
def _build_prompt(text, emotion): | |
"""Template-based prompt engineering in continuous prose (no bullet points).""" | |
templates = { | |
"sadness": "I sensed sadness in your comment: {text}. We are truly sorry and are here to support you.", | |
"joy": "Your comment radiates joy: {text}. Thank you for your bright feedback; we look forward to serving you even better.", | |
"love": "Your message exudes love: {text}. We appreciate your heartfelt words and cherish our connection with you.", | |
"anger": "I understand your comment reflects anger: {text}. Please accept our sincere apologies as we work to resolve your concerns.", | |
"fear": "It seems you feel fear in your comment: {text}. We want to reassure you that your safety and satisfaction are our priority.", | |
"surprise": "Your comment conveys surprise: {text}. We are delighted by your experience and will strive to exceed your expectations.", | |
"neutral": "Thank you for your comment: {text}. We remain committed to providing you with outstanding service." | |
} | |
# Build and return a continuous prompt with the user comment truncated to 200 characters | |
return templates.get(emotion.lower(), templates["neutral"]).format(text=text[:200]) | |
def _generate_response(text, models): | |
"""Optimized text generation pipeline using the detected emotion.""" | |
# Detect the dominant emotion quickly | |
detected = _fast_emotion(text, models["emotion"]) | |
# Build prompt based on detected emotion (continuous sentences) | |
prompt = _build_prompt(text, detected["label"]) | |
print(f"Generated prompt: {prompt}") # Print prompt using f-string for debugging | |
# Generate text using the Qwen model | |
inputs = models["text_tokenizer"]( | |
prompt, | |
return_tensors="pt", | |
max_length=100, | |
truncation=True | |
).to(models["device"]) | |
# Generate the response ensuring balanced length (approximately 50-200 tokens) | |
output = models["text_model"].generate( | |
inputs.input_ids, | |
max_new_tokens=120, # Upper bound tokens for answer | |
min_length=50, # Lower bound to ensure completeness | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
pad_token_id=models["text_tokenizer"].eos_token_id | |
) | |
input_len = inputs.input_ids.shape[1] # Determine the length of the prompt tokens | |
full_text = models["text_tokenizer"].decode(output[0], skip_special_tokens=True) | |
# Extract only the generated portion after "Response:" if present | |
response = full_text.split("Response:")[-1].strip() | |
print(f"Generated response: {response}") # Debug print using f-string | |
# Return response ensuring it is within 50-200 words (approximation by character length here) | |
return response[:200] # Truncate to 200 characters as an approximation | |
def _text_to_speech(text, models): | |
"""Efficiently synthesize speech for the given text.""" | |
inputs = models["tts_processor"]( | |
text=text[:150], # Limit text length for TTS to 150 characters | |
return_tensors="pt" | |
).to(models["device"]) | |
with torch.inference_mode(): # Fast, no-grad inference | |
spectrogram = models["tts_model"].generate_speech( | |
inputs["input_ids"], | |
models["speaker_emb"] | |
) | |
audio = models["tts_vocoder"](spectrogram) | |
sf.write("output.wav", audio.cpu().numpy(), 16000) # Save generated audio as .wav at 16kHz | |
return "output.wav" | |
########################################## | |
# Main application flow | |
########################################## | |
def main(): | |
"""Primary execution controller.""" | |
components = _load_components() # Load all models and components | |
user_input = _show_interface() # Render input interface and capture user comment | |
if user_input: # If a comment is provided | |
with st.spinner("๐ Generating response..."): | |
generated_response = _generate_response(user_input, components) # Generate response based on emotion and text | |
st.subheader("๐ Response") | |
st.markdown( | |
f"<p style='color:#3498DB; font-size:20px;'>{generated_response}</p>", | |
unsafe_allow_html=True | |
) # Display the generated response in styled format | |
with st.spinner("๐ Synthesizing audio..."): | |
audio_file = _text_to_speech(generated_response, components) # Convert response to speech | |
st.audio(audio_file, format="audio/wav", start_time=0) # Embed auto-playing audio player | |
print(f"Final generated response: {generated_response}") # Debug output using f-string | |
# Run the main function when the script is executed | |
if __name__ == "__main__": | |
main() # Call the main function | |