cs751 / app.py
fajarah's picture
Update app.py
b15ed16 verified
raw
history blame
3.55 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
from torch.nn.functional import sigmoid
import torch
# Load text emotion model
model_name = "SamLowe/roberta-base-go_emotions"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# Load image emotion classification pipeline
image_emotion_pipeline = pipeline("image-classification", model="nateraw/ferplus-emo-resnet34")
# Emotion label to icon mapping (subset)
emotion_icons = {
"admiration": "๐Ÿ˜",
"amusement": "๐Ÿ˜…",
"anger": "๐Ÿ˜ก",
"annoyance": "๐Ÿ˜‘",
"approval": "๐Ÿ‘",
"caring": "๐Ÿ’—",
"confusion": "๐Ÿค”",
"curiosity": "๐Ÿ˜ฎ",
"desire": "๐Ÿคค",
"disappointment": "๐Ÿ˜ž",
"disapproval": "๐Ÿ‘Ž",
"disgust": "๐Ÿคฎ",
"embarrassment": "๐Ÿ˜ณ",
"excitement": "๐ŸŽ‰",
"fear": "๐Ÿ˜ฑ",
"gratitude": "๐Ÿ™",
"grief": "๐Ÿ˜ญ",
"joy": "๐Ÿ˜ƒ",
"love": "โค๏ธ",
"nervousness": "๐Ÿคง",
"optimism": "๐Ÿ˜Š",
"pride": "๐Ÿ˜Ž",
"realization": "๐Ÿคฏ",
"relief": "๐Ÿ˜Œ",
"remorse": "๐Ÿ˜”",
"sadness": "๐Ÿ˜ข",
"surprise": "๐Ÿ˜ฒ",
"neutral": "๐Ÿ˜"
}
# Analyze text emotion
def get_emotions(text, threshold):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
logits = model(**inputs).logits
probs = sigmoid(logits)[0]
labels = [model.config.id2label[i] for i, p in enumerate(probs) if p > threshold]
icons = [emotion_icons.get(label, '') + ' ' + label.capitalize() + f" ({probs[i]:.2f})" for i, label in enumerate(labels)]
return ", ".join(icons) if icons else "No strong emotion detected."
# Analyze image emotion
def analyze_image_emotion(image):
if image is None:
return "No image provided."
results = image_emotion_pipeline(image)
top = results[0]
return f"{top['label']} ({top['score']:.2f})"
# Combined analysis
def analyze_combined(text, threshold, image):
text_result = get_emotions(text, threshold)
image_result = analyze_image_emotion(image)
return text_result, image_result
# Gradio UI
custom_css = """
body {
background: linear-gradient(to right, #f9f9f9, #d4ecff);
font-family: 'Segoe UI', sans-serif;
}
.gr-button {
background-color: #007BFF !important;
color: white !important;
border-radius: 8px !important;
font-weight: bold;
}
.gr-button:hover {
background-color: #0056b3 !important;
}
.gr-textbox {
border-radius: 8px !important;
border: 1px solid #ccc !important;
padding: 10px !important;
}
.output-textbox {
font-size: 1.5rem;
font-weight: bold;
color: #333;
background-color: #f1f9ff;
border-radius: 8px;
padding: 10px;
border: 1px solid #007BFF;
}
"""
demo = gr.Interface(
fn=analyze_combined,
inputs=[
gr.Textbox(lines=5, placeholder="Write a sentence or a full paragraph...", label="Your Text"),
gr.Slider(minimum=0.1, maximum=0.9, value=0.3, step=0.05, label="Threshold"),
gr.Image(type="filepath", label="Upload Face Photo")
],
outputs=[
gr.Textbox(label="Detected Text Emotions", elem_classes=["output-textbox"]),
gr.Textbox(label="Detected Photo Emotion", elem_classes=["output-textbox"])
],
title="๐Ÿฅฐ Multi-Modal Emotion Detector",
description="Analyze emotion from both text and a facial photo. Adjust the threshold for text emotion sensitivity.",
theme="default",
css=custom_css
)
demo.launch()