File size: 2,471 Bytes
23e46f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.nn.functional import softmax
import torch
# Load model and tokenizer
model_name = "bhadresh-savani/distilbert-base-uncased-emotion"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# Emotion label to icon mapping
emotion_icons = {
"admiration": "๐",
"amusement": "๐
",
"anger": "๐ก",
"annoyance": "๐",
"approval": "๐",
"caring": "๐",
"confusion": "๐ค",
"curiosity": "๐ฎ",
"desire": "๐คค",
"disappointment": "๐",
"disapproval": "๐",
"disgust": "๐คฎ",
"embarrassment": "๐ณ",
"excitement": "๐",
"fear": "๐ฑ",
"gratitude": "๐",
"grief": "๐ญ",
"joy": "๐",
"love": "โค๏ธ",
"nervousness": "๐คง",
"optimism": "๐",
"pride": "๐",
"realization": "๐คฏ",
"relief": "๐",
"remorse": "๐",
"sadness": "๐ข",
"surprise": "๐ฒ",
"neutral": "๐"
}
# Prediction function
def get_emotion(text):
inputs = tokenizer(text, return_tensors="pt")
outputs = model(**inputs)
probs = softmax(outputs.logits, dim=1)
predicted_class = torch.argmax(probs).item()
label = model.config.id2label[predicted_class]
icon = emotion_icons.get(label, "")
return f"{icon} {label.capitalize()}"
# Gradio UI
custom_css = """
body {
background: linear-gradient(to right, #f9f9f9, #d4ecff);
font-family: 'Segoe UI', sans-serif;
}
.gr-button {
background-color: #007BFF !important;
color: white !important;
border-radius: 8px !important;
font-weight: bold;
}
.gr-button:hover {
background-color: #0056b3 !important;
}
.gr-textbox {
border-radius: 8px !important;
border: 1px solid #ccc !important;
padding: 10px !important;
}
.output-textbox {
font-size: 1.5rem;
font-weight: bold;
color: #333;
background-color: #f1f9ff;
border-radius: 8px;
padding: 10px;
border: 1px solid #007BFF;
}
"""
demo = gr.Interface(
fn=get_emotion,
inputs=gr.Textbox(lines=3, placeholder="What's on your mind today?", label="Your Text"),
outputs=gr.Textbox(label="Detected Emotion", elem_classes=["output-textbox"]),
title="๐ฅฐ Emotion Detector",
description="Type a sentence below and hit Submit to reveal the emotion behind your words.",
theme="default",
css=custom_css
)
demo.launch()
|