fajarah commited on
Commit
89891d1
Β·
verified Β·
1 Parent(s): 3d45c60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -18
app.py CHANGED
@@ -1,14 +1,14 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
- from torch.nn.functional import softmax
4
  import torch
5
 
6
- # Load model and tokenizer
7
- model_name = "bhadresh-savani/distilbert-base-uncased-emotion"
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
10
 
11
- # Emotion label to icon mapping
12
  emotion_icons = {
13
  "admiration": "😍",
14
  "amusement": "πŸ˜…",
@@ -40,15 +40,18 @@ emotion_icons = {
40
  "neutral": "😐"
41
  }
42
 
43
- # Prediction function
44
- def get_emotion(text):
45
- inputs = tokenizer(text, return_tensors="pt")
46
- outputs = model(**inputs)
47
- probs = softmax(outputs.logits, dim=1)
48
- predicted_class = torch.argmax(probs).item()
49
- label = model.config.id2label[predicted_class]
50
- icon = emotion_icons.get(label, "")
51
- return f"{icon} {label.capitalize()}"
 
 
 
52
 
53
  # Gradio UI
54
  custom_css = """
@@ -82,11 +85,11 @@ body {
82
  """
83
 
84
  demo = gr.Interface(
85
- fn=get_emotion,
86
- inputs=gr.Textbox(lines=3, placeholder="What's on your mind today?", label="Your Text"),
87
- outputs=gr.Textbox(label="Detected Emotion", elem_classes=["output-textbox"]),
88
- title="πŸ₯° Emotion Detector",
89
- description="Type a sentence below and hit Submit to reveal the emotion behind your words.",
90
  theme="default",
91
  css=custom_css
92
  )
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+ from torch.nn.functional import sigmoid
4
  import torch
5
 
6
+ # Load multi-label GoEmotions model
7
+ model_name = "joeddav/distilbert-base-uncased-go-emotions-group"
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
10
 
11
+ # Emotion label to icon mapping (subset)
12
  emotion_icons = {
13
  "admiration": "😍",
14
  "amusement": "πŸ˜…",
 
40
  "neutral": "😐"
41
  }
42
 
43
+ # Prediction function for multi-label output
44
+ def get_emotions(text):
45
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
46
+ with torch.no_grad():
47
+ logits = model(**inputs).logits
48
+ probs = sigmoid(logits)[0]
49
+
50
+ threshold = 0.3 # You can adjust this threshold
51
+ labels = [model.config.id2label[i] for i, p in enumerate(probs) if p > threshold]
52
+ icons = [emotion_icons.get(label, '') + ' ' + label.capitalize() for label in labels]
53
+
54
+ return ", ".join(icons) if icons else "No strong emotion detected."
55
 
56
  # Gradio UI
57
  custom_css = """
 
85
  """
86
 
87
  demo = gr.Interface(
88
+ fn=get_emotions,
89
+ inputs=gr.Textbox(lines=5, placeholder="Write a sentence or a full paragraph...", label="Your Text"),
90
+ outputs=gr.Textbox(label="Detected Emotions", elem_classes=["output-textbox"]),
91
+ title="πŸ₯° Multi-Label Emotion Detector",
92
+ description="Enter a sentence or paragraph to detect multiple emotions present in the text.",
93
  theme="default",
94
  css=custom_css
95
  )