|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline |
|
from torch.nn.functional import sigmoid |
|
import torch |
|
|
|
|
|
model_name = "SamLowe/roberta-base-go_emotions" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
|
|
|
|
image_emotion_pipeline = pipeline("image-classification", model="nateraw/ferplus-emo-resnet34") |
|
|
|
|
|
emotion_icons = { |
|
"admiration": "๐", |
|
"amusement": "๐
", |
|
"anger": "๐ก", |
|
"annoyance": "๐", |
|
"approval": "๐", |
|
"caring": "๐", |
|
"confusion": "๐ค", |
|
"curiosity": "๐ฎ", |
|
"desire": "๐คค", |
|
"disappointment": "๐", |
|
"disapproval": "๐", |
|
"disgust": "๐คฎ", |
|
"embarrassment": "๐ณ", |
|
"excitement": "๐", |
|
"fear": "๐ฑ", |
|
"gratitude": "๐", |
|
"grief": "๐ญ", |
|
"joy": "๐", |
|
"love": "โค๏ธ", |
|
"nervousness": "๐คง", |
|
"optimism": "๐", |
|
"pride": "๐", |
|
"realization": "๐คฏ", |
|
"relief": "๐", |
|
"remorse": "๐", |
|
"sadness": "๐ข", |
|
"surprise": "๐ฒ", |
|
"neutral": "๐" |
|
} |
|
|
|
|
|
def get_emotions(text, threshold): |
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) |
|
with torch.no_grad(): |
|
logits = model(**inputs).logits |
|
probs = sigmoid(logits)[0] |
|
|
|
labels = [model.config.id2label[i] for i, p in enumerate(probs) if p > threshold] |
|
icons = [emotion_icons.get(label, '') + ' ' + label.capitalize() + f" ({probs[i]:.2f})" for i, label in enumerate(labels)] |
|
|
|
return ", ".join(icons) if icons else "No strong emotion detected." |
|
|
|
|
|
def analyze_image_emotion(image): |
|
if image is None: |
|
return "No image provided." |
|
results = image_emotion_pipeline(image) |
|
top = results[0] |
|
return f"{top['label']} ({top['score']:.2f})" |
|
|
|
|
|
def analyze_combined(text, threshold, image): |
|
text_result = get_emotions(text, threshold) |
|
image_result = analyze_image_emotion(image) |
|
return text_result, image_result |
|
|
|
|
|
custom_css = """ |
|
body { |
|
background: linear-gradient(to right, #f9f9f9, #d4ecff); |
|
font-family: 'Segoe UI', sans-serif; |
|
} |
|
.gr-button { |
|
background-color: #007BFF !important; |
|
color: white !important; |
|
border-radius: 8px !important; |
|
font-weight: bold; |
|
} |
|
.gr-button:hover { |
|
background-color: #0056b3 !important; |
|
} |
|
.gr-textbox { |
|
border-radius: 8px !important; |
|
border: 1px solid #ccc !important; |
|
padding: 10px !important; |
|
} |
|
.output-textbox { |
|
font-size: 1.5rem; |
|
font-weight: bold; |
|
color: #333; |
|
background-color: #f1f9ff; |
|
border-radius: 8px; |
|
padding: 10px; |
|
border: 1px solid #007BFF; |
|
} |
|
""" |
|
|
|
demo = gr.Interface( |
|
fn=analyze_combined, |
|
inputs=[ |
|
gr.Textbox(lines=5, placeholder="Write a sentence or a full paragraph...", label="Your Text"), |
|
gr.Slider(minimum=0.1, maximum=0.9, value=0.3, step=0.05, label="Threshold"), |
|
gr.Image(type="filepath", label="Upload Face Photo") |
|
], |
|
outputs=[ |
|
gr.Textbox(label="Detected Text Emotions", elem_classes=["output-textbox"]), |
|
gr.Textbox(label="Detected Photo Emotion", elem_classes=["output-textbox"]) |
|
], |
|
title="๐ฅฐ Multi-Modal Emotion Detector", |
|
description="Analyze emotion from both text and a facial photo. Adjust the threshold for text emotion sensitivity.", |
|
theme="default", |
|
css=custom_css |
|
) |
|
|
|
demo.launch() |
|
|