cs751 / app.py
fajarah's picture
Update app.py
4aa2d44 verified
raw
history blame
4.3 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.nn.functional import sigmoid
import torch
from PIL import Image
from torchvision import transforms
import requests
# Load text emotion model
model_name = "SamLowe/roberta-base-go_emotions"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# Load image emotion model (fine-tuned ResNet-50)
image_model_name = "Celal11/resnet-50-finetuned-FER2013CKPlus-0.003"
image_emotion_model = AutoModelForSequenceClassification.from_pretrained(image_model_name)
image_tokenizer = AutoTokenizer.from_pretrained("microsoft/resnet-50")
# Transform for image preprocessing
image_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# FER labels
image_labels = [
"Angry", "Disgust", "Fear", "Happy", "Sad", "Surprise", "Neutral", "Contempt"
]
# Analyze image emotion
def analyze_image_emotion(image_path):
if image_path is None:
return "No image provided."
image = Image.open(image_path).convert("RGB")
img_tensor = image_transform(image).unsqueeze(0)
with torch.no_grad():
output = image_emotion_model(img_tensor)
probs = sigmoid(output.logits)[0]
top_idx = torch.argmax(probs).item()
return f"{image_labels[top_idx]} ({probs[top_idx]:.2f})"
# Emotion label to icon mapping (subset)
emotion_icons = {
"admiration": "๐Ÿ˜",
"amusement": "๐Ÿ˜…",
"anger": "๐Ÿ˜ก",
"annoyance": "๐Ÿ˜‘",
"approval": "๐Ÿ‘",
"caring": "๐Ÿ’—",
"confusion": "๐Ÿค”",
"curiosity": "๐Ÿ˜ฎ",
"desire": "๐Ÿคค",
"disappointment": "๐Ÿ˜ž",
"disapproval": "๐Ÿ‘Ž",
"disgust": "๐Ÿคฎ",
"embarrassment": "๐Ÿ˜ณ",
"excitement": "๐ŸŽ‰",
"fear": "๐Ÿ˜ฑ",
"gratitude": "๐Ÿ™",
"grief": "๐Ÿ˜ญ",
"joy": "๐Ÿ˜ƒ",
"love": "โค๏ธ",
"nervousness": "๐Ÿคง",
"optimism": "๐Ÿ˜Š",
"pride": "๐Ÿ˜Ž",
"realization": "๐Ÿคฏ",
"relief": "๐Ÿ˜Œ",
"remorse": "๐Ÿ˜”",
"sadness": "๐Ÿ˜ข",
"surprise": "๐Ÿ˜ฒ",
"neutral": "๐Ÿ˜"
}
# Analyze text emotion
def get_emotions(text, threshold):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
logits = model(**inputs).logits
probs = sigmoid(logits)[0]
labels = [model.config.id2label[i] for i, p in enumerate(probs) if p > threshold]
icons = [emotion_icons.get(label, '') + ' ' + label.capitalize() + f" ({probs[i]:.2f})" for i, label in enumerate(labels)]
return ", ".join(icons) if icons else "No strong emotion detected."
# Combined analysis
def analyze_combined(text, threshold, image):
text_result = get_emotions(text, threshold)
image_result = analyze_image_emotion(image)
return text_result, image_result
# Gradio UI
custom_css = """
body {
background: linear-gradient(to right, #f9f9f9, #d4ecff);
font-family: 'Segoe UI', sans-serif;
}
.gr-button {
background-color: #007BFF !important;
color: white !important;
border-radius: 8px !important;
font-weight: bold;
}
.gr-button:hover {
background-color: #0056b3 !important;
}
.gr-textbox {
border-radius: 8px !important;
border: 1px solid #ccc !important;
padding: 10px !important;
}
.output-textbox {
font-size: 1.5rem;
font-weight: bold;
color: #333;
background-color: #f1f9ff;
border-radius: 8px;
padding: 10px;
border: 1px solid #007BFF;
}
"""
demo = gr.Interface(
fn=analyze_combined,
inputs=[
gr.Textbox(lines=5, placeholder="Write a sentence or a full paragraph...", label="Your Text"),
gr.Slider(minimum=0.1, maximum=0.9, value=0.3, step=0.05, label="Threshold"),
gr.Image(type="filepath", label="Upload Face Photo")
],
outputs=[
gr.Textbox(label="Detected Text Emotions", elem_classes=["output-textbox"]),
gr.Textbox(label="Detected Photo Emotion", elem_classes=["output-textbox"])
],
title="๐Ÿฅฐ Multi-Modal Emotion Detector",
description="Analyze emotion from both text and a facial photo. Adjust the threshold for text emotion sensitivity.",
theme="default",
css=custom_css
)
demo.launch()