|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoImageProcessor, AutoModelForImageClassification |
|
from torch.nn.functional import sigmoid |
|
import torch |
|
from PIL import Image |
|
|
|
|
|
model_name = "SamLowe/roberta-base-go_emotions" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
|
|
|
|
image_model_name = "Celal11/resnet-50-finetuned-FER2013-0.001" |
|
image_processor = AutoImageProcessor.from_pretrained(image_model_name) |
|
image_model = AutoModelForImageClassification.from_pretrained(image_model_name) |
|
|
|
|
|
def analyze_image_emotion(image): |
|
if image is None: |
|
return "No image provided." |
|
inputs = image_processor(images=image, return_tensors="pt") |
|
with torch.no_grad(): |
|
logits = image_model(**inputs).logits |
|
probs = torch.nn.functional.softmax(logits, dim=1)[0] |
|
pred_idx = torch.argmax(probs).item() |
|
label = image_model.config.id2label[pred_idx] |
|
score = probs[pred_idx].item() |
|
return f"{label} ({score:.2f})" |
|
|
|
|
|
emotion_icons = { |
|
"admiration": "๐", |
|
"amusement": "๐
", |
|
"anger": "๐ก", |
|
"annoyance": "๐", |
|
"approval": "๐", |
|
"caring": "๐", |
|
"confusion": "๐ค", |
|
"curiosity": "๐ฎ", |
|
"desire": "๐คค", |
|
"disappointment": "๐", |
|
"disapproval": "๐", |
|
"disgust": "๐คฎ", |
|
"embarrassment": "๐ณ", |
|
"excitement": "๐", |
|
"fear": "๐ฑ", |
|
"gratitude": "๐", |
|
"grief": "๐ญ", |
|
"joy": "๐", |
|
"love": "โค๏ธ", |
|
"nervousness": "๐คง", |
|
"optimism": "๐", |
|
"pride": "๐", |
|
"realization": "๐คฏ", |
|
"relief": "๐", |
|
"remorse": "๐", |
|
"sadness": "๐ข", |
|
"surprise": "๐ฒ", |
|
"neutral": "๐" |
|
} |
|
|
|
|
|
def get_emotions(text, threshold): |
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) |
|
with torch.no_grad(): |
|
logits = model(**inputs).logits |
|
probs = sigmoid(logits)[0] |
|
|
|
labels = [model.config.id2label[i] for i, p in enumerate(probs) if p > threshold] |
|
icons = [emotion_icons.get(label, '') + ' ' + label.capitalize() + f" ({probs[i]:.2f})" for i, label in enumerate(labels)] |
|
|
|
return ", ".join(icons) if icons else "No strong emotion detected." |
|
|
|
|
|
def analyze_combined(text, threshold, image): |
|
text_result = get_emotions(text, threshold) |
|
image_result = analyze_image_emotion(image) |
|
return text_result, image_result |
|
|
|
|
|
custom_css = """ |
|
body { |
|
background-color: #f5f6f8; |
|
font-family: 'Inter', sans-serif; |
|
color: #2e2e2e; |
|
padding: 20px; |
|
} |
|
|
|
h1, h2, h3, h4 { |
|
font-weight: 600; |
|
color: #1a1a1a; |
|
} |
|
|
|
.gradio-container { |
|
max-width: 900px; |
|
margin: auto; |
|
padding: 30px; |
|
background-color: #ffffff; |
|
border-radius: 12px; |
|
box-shadow: 0 10px 25px rgba(0, 0, 0, 0.05); |
|
} |
|
|
|
.gr-button { |
|
background-color: #4f46e5 !important; |
|
color: white !important; |
|
font-weight: 500; |
|
border-radius: 8px !important; |
|
padding: 10px 20px; |
|
} |
|
|
|
.gr-button:hover { |
|
background-color: #4338ca !important; |
|
} |
|
|
|
.gr-textbox, .gr-slider { |
|
border-radius: 8px !important; |
|
border: 1px solid #e5e7eb !important; |
|
padding: 12px; |
|
font-size: 1rem; |
|
} |
|
|
|
.output-textbox { |
|
font-size: 1.2rem; |
|
color: #111827; |
|
background-color: #f9fafb; |
|
border-radius: 10px; |
|
padding: 16px; |
|
border: 1px solid #d1d5db; |
|
animation: fadeIn 0.8s ease-in-out both; |
|
} |
|
|
|
@keyframes fadeIn { |
|
0% { |
|
opacity: 0; |
|
transform: translateY(10px); |
|
} |
|
100% { |
|
opacity: 1; |
|
transform: translateY(0); |
|
} |
|
} |
|
""" |
|
|
|
|
|
demo = gr.Interface( |
|
fn=analyze_combined, |
|
inputs=[ |
|
gr.Textbox(lines=5, placeholder="Write a sentence or a full paragraph...", label="Your Text"), |
|
gr.Slider(minimum=0.1, maximum=0.9, value=0.3, step=0.05, label="Threshold"), |
|
gr.Image(type="pil", label="Upload Face Photo") |
|
], |
|
outputs=[ |
|
gr.Textbox(label="Detected Text Emotions", elem_classes=["output-textbox"]), |
|
gr.Textbox(label="Detected Photo Emotion", elem_classes=["output-textbox"]) |
|
], |
|
title="๐ฅฐ Multi-Modal Emotion Detector", |
|
description="Analyze emotion from both text and a facial photo. Adjust the threshold for text emotion sensitivity.", |
|
theme="default", |
|
css=custom_css |
|
) |
|
|
|
demo.launch() |
|
|