cs751 / app.py
fajarah's picture
Update app.py
d831a69 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoImageProcessor, AutoModelForImageClassification
from torch.nn.functional import sigmoid
import torch
from PIL import Image
# Load text emotion model
model_name = "SamLowe/roberta-base-go_emotions"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# Load image emotion model
image_model_name = "Celal11/resnet-50-finetuned-FER2013-0.001"
image_processor = AutoImageProcessor.from_pretrained(image_model_name)
image_model = AutoModelForImageClassification.from_pretrained(image_model_name)
# Analyze image emotion using processor and model
def analyze_image_emotion(image):
if image is None:
return "No image provided."
inputs = image_processor(images=image, return_tensors="pt")
with torch.no_grad():
logits = image_model(**inputs).logits
probs = torch.nn.functional.softmax(logits, dim=1)[0]
pred_idx = torch.argmax(probs).item()
label = image_model.config.id2label[pred_idx]
score = probs[pred_idx].item()
return f"{label} ({score:.2f})"
# Emotion label to icon mapping (subset)
emotion_icons = {
"admiration": "๐Ÿ˜",
"amusement": "๐Ÿ˜…",
"anger": "๐Ÿ˜ก",
"annoyance": "๐Ÿ˜‘",
"approval": "๐Ÿ‘",
"caring": "๐Ÿ’—",
"confusion": "๐Ÿค”",
"curiosity": "๐Ÿ˜ฎ",
"desire": "๐Ÿคค",
"disappointment": "๐Ÿ˜ž",
"disapproval": "๐Ÿ‘Ž",
"disgust": "๐Ÿคฎ",
"embarrassment": "๐Ÿ˜ณ",
"excitement": "๐ŸŽ‰",
"fear": "๐Ÿ˜ฑ",
"gratitude": "๐Ÿ™",
"grief": "๐Ÿ˜ญ",
"joy": "๐Ÿ˜ƒ",
"love": "โค๏ธ",
"nervousness": "๐Ÿคง",
"optimism": "๐Ÿ˜Š",
"pride": "๐Ÿ˜Ž",
"realization": "๐Ÿคฏ",
"relief": "๐Ÿ˜Œ",
"remorse": "๐Ÿ˜”",
"sadness": "๐Ÿ˜ข",
"surprise": "๐Ÿ˜ฒ",
"neutral": "๐Ÿ˜"
}
# Analyze text emotion
def get_emotions(text, threshold):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
logits = model(**inputs).logits
probs = sigmoid(logits)[0]
labels = [model.config.id2label[i] for i, p in enumerate(probs) if p > threshold]
icons = [emotion_icons.get(label, '') + ' ' + label.capitalize() + f" ({probs[i]:.2f})" for i, label in enumerate(labels)]
return ", ".join(icons) if icons else "No strong emotion detected."
# Combined analysis
def analyze_combined(text, threshold, image):
text_result = get_emotions(text, threshold)
image_result = analyze_image_emotion(image)
return text_result, image_result
# Gradio UI
custom_css = """
body {
background-color: #f5f6f8;
font-family: 'Inter', sans-serif;
color: #2e2e2e;
padding: 20px;
}
h1, h2, h3, h4 {
font-weight: 600;
color: #1a1a1a;
}
.gradio-container {
max-width: 900px;
margin: auto;
padding: 30px;
background-color: #ffffff;
border-radius: 12px;
box-shadow: 0 10px 25px rgba(0, 0, 0, 0.05);
}
.gr-button {
background-color: #4f46e5 !important;
color: white !important;
font-weight: 500;
border-radius: 8px !important;
padding: 10px 20px;
}
.gr-button:hover {
background-color: #4338ca !important;
}
.gr-textbox, .gr-slider {
border-radius: 8px !important;
border: 1px solid #e5e7eb !important;
padding: 12px;
font-size: 1rem;
}
.output-textbox {
font-size: 1.2rem;
color: #111827;
background-color: #f9fafb;
border-radius: 10px;
padding: 16px;
border: 1px solid #d1d5db;
animation: fadeIn 0.8s ease-in-out both;
}
@keyframes fadeIn {
0% {
opacity: 0;
transform: translateY(10px);
}
100% {
opacity: 1;
transform: translateY(0);
}
}
"""
demo = gr.Interface(
fn=analyze_combined,
inputs=[
gr.Textbox(lines=5, placeholder="Write a sentence or a full paragraph...", label="Your Text"),
gr.Slider(minimum=0.1, maximum=0.9, value=0.3, step=0.05, label="Threshold"),
gr.Image(type="pil", label="Upload Face Photo")
],
outputs=[
gr.Textbox(label="Detected Text Emotions", elem_classes=["output-textbox"]),
gr.Textbox(label="Detected Photo Emotion", elem_classes=["output-textbox"])
],
title="๐Ÿฅฐ Multi-Modal Emotion Detector",
description="Analyze emotion from both text and a facial photo. Adjust the threshold for text emotion sensitivity.",
theme="default",
css=custom_css
)
demo.launch()