Update app.py
Browse files
app.py
CHANGED
@@ -1,14 +1,14 @@
|
|
1 |
import gradio as gr
|
2 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
3 |
-
from torch.nn.functional import
|
4 |
import torch
|
5 |
|
6 |
-
# Load
|
7 |
-
model_name = "
|
8 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
9 |
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
10 |
|
11 |
-
# Emotion label to icon mapping
|
12 |
emotion_icons = {
|
13 |
"admiration": "π",
|
14 |
"amusement": "π
",
|
@@ -40,15 +40,18 @@ emotion_icons = {
|
|
40 |
"neutral": "π"
|
41 |
}
|
42 |
|
43 |
-
# Prediction function
|
44 |
-
def
|
45 |
-
inputs = tokenizer(text, return_tensors="pt")
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
52 |
|
53 |
# Gradio UI
|
54 |
custom_css = """
|
@@ -82,11 +85,11 @@ body {
|
|
82 |
"""
|
83 |
|
84 |
demo = gr.Interface(
|
85 |
-
fn=
|
86 |
-
inputs=gr.Textbox(lines=
|
87 |
-
outputs=gr.Textbox(label="Detected
|
88 |
-
title="π₯° Emotion Detector",
|
89 |
-
description="
|
90 |
theme="default",
|
91 |
css=custom_css
|
92 |
)
|
|
|
1 |
import gradio as gr
|
2 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
3 |
+
from torch.nn.functional import sigmoid
|
4 |
import torch
|
5 |
|
6 |
+
# Load multi-label GoEmotions model
|
7 |
+
model_name = "joeddav/distilbert-base-uncased-go-emotions-group"
|
8 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
9 |
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
10 |
|
11 |
+
# Emotion label to icon mapping (subset)
|
12 |
emotion_icons = {
|
13 |
"admiration": "π",
|
14 |
"amusement": "π
",
|
|
|
40 |
"neutral": "π"
|
41 |
}
|
42 |
|
43 |
+
# Prediction function for multi-label output
|
44 |
+
def get_emotions(text):
|
45 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
46 |
+
with torch.no_grad():
|
47 |
+
logits = model(**inputs).logits
|
48 |
+
probs = sigmoid(logits)[0]
|
49 |
+
|
50 |
+
threshold = 0.3 # You can adjust this threshold
|
51 |
+
labels = [model.config.id2label[i] for i, p in enumerate(probs) if p > threshold]
|
52 |
+
icons = [emotion_icons.get(label, '') + ' ' + label.capitalize() for label in labels]
|
53 |
+
|
54 |
+
return ", ".join(icons) if icons else "No strong emotion detected."
|
55 |
|
56 |
# Gradio UI
|
57 |
custom_css = """
|
|
|
85 |
"""
|
86 |
|
87 |
demo = gr.Interface(
|
88 |
+
fn=get_emotions,
|
89 |
+
inputs=gr.Textbox(lines=5, placeholder="Write a sentence or a full paragraph...", label="Your Text"),
|
90 |
+
outputs=gr.Textbox(label="Detected Emotions", elem_classes=["output-textbox"]),
|
91 |
+
title="π₯° Multi-Label Emotion Detector",
|
92 |
+
description="Enter a sentence or paragraph to detect multiple emotions present in the text.",
|
93 |
theme="default",
|
94 |
css=custom_css
|
95 |
)
|