logasanjeev commited on
Commit
f4dcd5d
·
verified ·
1 Parent(s): 8a395cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -20
app.py CHANGED
@@ -4,6 +4,8 @@ import numpy as np
4
  from transformers import BertForSequenceClassification, BertTokenizer
5
  import requests
6
  import json
 
 
7
 
8
  # Load model and tokenizer from Hugging Face Hub
9
  repo_id = "logasanjeev/goemotions-bert"
@@ -20,10 +22,10 @@ thresholds_url = f"https://huggingface.co/{repo_id}/raw/main/thresholds.json"
20
  response = requests.get(thresholds_url)
21
  thresholds_data = json.loads(response.text)
22
  emotion_labels = thresholds_data["emotion_labels"]
23
- best_thresholds = thresholds_data["thresholds"]
24
 
25
  # Prediction function
26
- def predict_emotions(text):
27
  encodings = tokenizer(
28
  text,
29
  padding='max_length',
@@ -38,30 +40,162 @@ def predict_emotions(text):
38
  outputs = model(input_ids, attention_mask=attention_mask)
39
  logits = torch.sigmoid(outputs.logits).cpu().numpy()[0]
40
 
 
41
  predictions = []
42
- for i, (logit, thresh) in enumerate(zip(logits, best_thresholds)):
43
- if logit >= thresh:
 
44
  predictions.append((emotion_labels[i], logit))
45
 
46
  predictions.sort(key=lambda x: x[1], reverse=True)
47
  if not predictions:
48
- return "No emotions predicted above thresholds."
49
 
50
- return "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in predictions])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- # Gradio interface
53
- interface = gr.Interface(
54
- fn=predict_emotions,
55
- inputs=gr.Textbox(lines=2, placeholder="Enter your text here..."),
56
- outputs="text",
57
- title="GoEmotions BERT Classifier",
58
- description="Predict emotions using a fine-tuned BERT-base model from logasanjeev/goemotions-bert.",
59
- examples=[
60
- "I’m just chilling today.",
61
- "Thank you for saving my life!",
62
- "I’m nervous about my exam tomorrow."
63
- ]
64
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
 
66
  if __name__ == "__main__":
67
- interface.launch()
 
4
  from transformers import BertForSequenceClassification, BertTokenizer
5
  import requests
6
  import json
7
+ import plotly.express as px
8
+ import pandas as pd
9
 
10
  # Load model and tokenizer from Hugging Face Hub
11
  repo_id = "logasanjeev/goemotions-bert"
 
22
  response = requests.get(thresholds_url)
23
  thresholds_data = json.loads(response.text)
24
  emotion_labels = thresholds_data["emotion_labels"]
25
+ default_thresholds = thresholds_data["thresholds"]
26
 
27
  # Prediction function
28
+ def predict_emotions(text, confidence_threshold=0.0):
29
  encodings = tokenizer(
30
  text,
31
  padding='max_length',
 
40
  outputs = model(input_ids, attention_mask=attention_mask)
41
  logits = torch.sigmoid(outputs.logits).cpu().numpy()[0]
42
 
43
+ # Apply thresholds with user-defined confidence boost
44
  predictions = []
45
+ for i, (logit, thresh) in enumerate(zip(logits, default_thresholds)):
46
+ adjusted_thresh = max(thresh, confidence_threshold)
47
+ if logit >= adjusted_thresh:
48
  predictions.append((emotion_labels[i], logit))
49
 
50
  predictions.sort(key=lambda x: x[1], reverse=True)
51
  if not predictions:
52
+ return "No emotions predicted above thresholds.", None
53
 
54
+ # Format output
55
+ text_output = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in predictions])
56
+
57
+ # Create bar chart
58
+ df = pd.DataFrame(predictions, columns=["Emotion", "Confidence"])
59
+ fig = px.bar(
60
+ df,
61
+ x="Emotion",
62
+ y="Confidence",
63
+ color="Emotion",
64
+ text="Confidence",
65
+ title="Emotion Confidence Levels",
66
+ height=400
67
+ )
68
+ fig.update_traces(texttemplate='%{text:.2f}', textposition='auto')
69
+ fig.update_layout(showlegend=False, margin=dict(t=40, b=40))
70
+
71
+ return text_output, fig
72
 
73
+ # Custom CSS for modern UI
74
+ custom_css = """
75
+ body {
76
+ font-family: 'Segoe UI', Arial, sans-serif;
77
+ }
78
+ .gr-panel {
79
+ border-radius: 12px;
80
+ box-shadow: 0 4px 12px rgba(0,0,0,0.1);
81
+ background: linear-gradient(145deg, #ffffff, #f0f4f8);
82
+ }
83
+ .gr-button {
84
+ border-radius: 8px;
85
+ background: #007bff;
86
+ color: white;
87
+ padding: 10px 20px;
88
+ transition: background 0.3s;
89
+ }
90
+ .gr-button:hover {
91
+ background: #0056b3;
92
+ }
93
+ #title {
94
+ font-size: 2.5em;
95
+ color: #1a3c6e;
96
+ text-align: center;
97
+ margin-bottom: 20px;
98
+ }
99
+ #description {
100
+ font-size: 1.1em;
101
+ color: #333;
102
+ text-align: center;
103
+ max-width: 700px;
104
+ margin: 0 auto;
105
+ }
106
+ #theme-toggle {
107
+ position: absolute;
108
+ top: 20px;
109
+ right: 20px;
110
+ }
111
+ .dark-mode {
112
+ background: #1a1a1a;
113
+ color: #e0e0e0;
114
+ }
115
+ .dark-mode .gr-panel {
116
+ background: linear-gradient(145deg, #2a2a2a, #3a3a3a);
117
+ }
118
+ .dark-mode #title {
119
+ color: #66b3ff;
120
+ }
121
+ .dark-mode #description {
122
+ color: #b0b0b0;
123
+ }
124
+ """
125
+
126
+ # JavaScript for theme toggle
127
+ theme_js = """
128
+ function toggleTheme() {
129
+ document.body.classList.toggle('dark-mode');
130
+ }
131
+ """
132
+
133
+ # Gradio Blocks UI
134
+ with gr.Blocks(css=custom_css) as demo:
135
+ # Header
136
+ gr.Markdown("<div id='title'>GoEmotions BERT Classifier</div>", elem_id="title")
137
+ gr.Markdown(
138
+ """
139
+ <div id='description'>
140
+ Predict emotions from text using a fine-tuned BERT-base model.
141
+ Explore 28 emotions with optimized thresholds (Micro F1: 0.6025).
142
+ Try examples or enter your own text!
143
+ </div>
144
+ """,
145
+ elem_id="description"
146
+ )
147
+
148
+ # Theme toggle button
149
+ with gr.Row():
150
+ gr.HTML(
151
+ """
152
+ <button id='theme-toggle' onclick='toggleTheme()'>Toggle Dark Mode</button>
153
+ <script>{}</script>
154
+ """.format(theme_js)
155
+ )
156
+
157
+ # Main input and output
158
+ with gr.Row():
159
+ with gr.Column(scale=1):
160
+ text_input = gr.Textbox(
161
+ label="Enter Your Text",
162
+ placeholder="Type something like 'I’m just chilling today'...",
163
+ lines=3
164
+ )
165
+ confidence_slider = gr.Slider(
166
+ minimum=0.0,
167
+ maximum=0.9,
168
+ value=0.0,
169
+ step=0.05,
170
+ label="Minimum Confidence Threshold",
171
+ info="Adjust to filter low-confidence predictions"
172
+ )
173
+ submit_btn = gr.Button("Predict Emotions", variant="primary")
174
+
175
+ with gr.Column(scale=1):
176
+ output_text = gr.Textbox(label="Predicted Emotions", lines=5)
177
+ output_plot = gr.Plot(label="Emotion Confidence Chart")
178
+
179
+ # Example carousel
180
+ examples = gr.Examples(
181
+ examples=[
182
+ "I’m just chilling today.",
183
+ "Thank you for saving my life!",
184
+ "I’m nervous about my exam tomorrow.",
185
+ "I love my new puppy so much!",
186
+ "I’m so relieved the storm passed."
187
+ ],
188
+ inputs=text_input,
189
+ label="Try These Examples"
190
+ )
191
+
192
+ # Bind prediction
193
+ submit_btn.click(
194
+ fn=predict_emotions,
195
+ inputs=[text_input, confidence_slider],
196
+ outputs=[output_text, output_plot]
197
+ )
198
 
199
+ # Launch
200
  if __name__ == "__main__":
201
+ demo.launch()