Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import numpy as np | |
from transformers import BertForSequenceClassification, BertTokenizer | |
import requests | |
import json | |
import plotly.express as px | |
import pandas as pd | |
# Load model and tokenizer from Hugging Face Hub | |
repo_id = "logasanjeev/goemotions-bert" | |
model = BertForSequenceClassification.from_pretrained(repo_id) | |
tokenizer = BertTokenizer.from_pretrained(repo_id) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model.to(device) | |
if torch.cuda.device_count() > 1: | |
model = nn.DataParallel(model) | |
model.eval() | |
# Load optimized thresholds from Hugging Face Hub | |
thresholds_url = f"https://huggingface.co/{repo_id}/raw/main/thresholds.json" | |
response = requests.get(thresholds_url) | |
thresholds_data = json.loads(response.text) | |
emotion_labels = thresholds_data["emotion_labels"] | |
default_thresholds = thresholds_data["thresholds"] | |
# Prediction function | |
def predict_emotions(text, confidence_threshold=0.0): | |
encodings = tokenizer( | |
text, | |
padding='max_length', | |
truncation=True, | |
max_length=128, | |
return_tensors='pt' | |
) | |
input_ids = encodings['input_ids'].to(device) | |
attention_mask = encodings['attention_mask'].to(device) | |
with torch.no_grad(): | |
outputs = model(input_ids, attention_mask=attention_mask) | |
logits = torch.sigmoid(outputs.logits).cpu().numpy()[0] | |
# Apply thresholds with user-defined confidence boost | |
predictions = [] | |
for i, (logit, thresh) in enumerate(zip(logits, default_thresholds)): | |
adjusted_thresh = max(thresh, confidence_threshold) | |
if logit >= adjusted_thresh: | |
predictions.append((emotion_labels[i], logit)) | |
predictions.sort(key=lambda x: x[1], reverse=True) | |
if not predictions: | |
return "No emotions predicted above thresholds.", None | |
# Format output | |
text_output = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in predictions]) | |
# Create bar chart | |
df = pd.DataFrame(predictions, columns=["Emotion", "Confidence"]) | |
fig = px.bar( | |
df, | |
x="Emotion", | |
y="Confidence", | |
color="Emotion", | |
text="Confidence", | |
title="Emotion Confidence Levels", | |
height=400 | |
) | |
fig.update_traces(texttemplate='%{text:.2f}', textposition='auto') | |
fig.update_layout(showlegend=False, margin=dict(t=40, b=40)) | |
return text_output, fig | |
# Custom CSS for modern UI | |
custom_css = """ | |
body { | |
font-family: 'Segoe UI', Arial, sans-serif; | |
} | |
.gr-panel { | |
border-radius: 12px; | |
box-shadow: 0 4px 12px rgba(0,0,0,0.1); | |
background: linear-gradient(145deg, #ffffff, #f0f4f8); | |
} | |
.gr-button { | |
border-radius: 8px; | |
background: #007bff; | |
color: white; | |
padding: 10px 20px; | |
transition: background 0.3s; | |
} | |
.gr-button:hover { | |
background: #0056b3; | |
} | |
#title { | |
font-size: 2.5em; | |
color: #1a3c6e; | |
text-align: center; | |
margin-bottom: 20px; | |
} | |
#description { | |
font-size: 1.1em; | |
color: #333; | |
text-align: center; | |
max-width: 700px; | |
margin: 0 auto; | |
} | |
#theme-toggle { | |
position: absolute; | |
top: 20px; | |
right: 20px; | |
} | |
.dark-mode { | |
background: #1a1a1a; | |
color: #e0e0e0; | |
} | |
.dark-mode .gr-panel { | |
background: linear-gradient(145deg, #2a2a2a, #3a3a3a); | |
} | |
.dark-mode #title { | |
color: #66b3ff; | |
} | |
.dark-mode #description { | |
color: #b0b0b0; | |
} | |
""" | |
# JavaScript for theme toggle | |
theme_js = """ | |
function toggleTheme() { | |
document.body.classList.toggle('dark-mode'); | |
} | |
""" | |
# Gradio Blocks UI | |
with gr.Blocks(css=custom_css) as demo: | |
# Header | |
gr.Markdown("<div id='title'>GoEmotions BERT Classifier</div>", elem_id="title") | |
gr.Markdown( | |
""" | |
<div id='description'> | |
Predict emotions from text using a fine-tuned BERT-base model. | |
Explore 28 emotions with optimized thresholds (Micro F1: 0.6025). | |
Try examples or enter your own text! | |
</div> | |
""", | |
elem_id="description" | |
) | |
# Theme toggle button | |
with gr.Row(): | |
gr.HTML( | |
""" | |
<button id='theme-toggle' onclick='toggleTheme()'>Toggle Dark Mode</button> | |
<script>{}</script> | |
""".format(theme_js) | |
) | |
# Main input and output | |
with gr.Row(): | |
with gr.Column(scale=1): | |
text_input = gr.Textbox( | |
label="Enter Your Text", | |
placeholder="Type something like 'I’m just chilling today'...", | |
lines=3 | |
) | |
confidence_slider = gr.Slider( | |
minimum=0.0, | |
maximum=0.9, | |
value=0.0, | |
step=0.05, | |
label="Minimum Confidence Threshold", | |
info="Adjust to filter low-confidence predictions" | |
) | |
submit_btn = gr.Button("Predict Emotions", variant="primary") | |
with gr.Column(scale=1): | |
output_text = gr.Textbox(label="Predicted Emotions", lines=5) | |
output_plot = gr.Plot(label="Emotion Confidence Chart") | |
# Example carousel | |
examples = gr.Examples( | |
examples=[ | |
"I’m just chilling today.", | |
"Thank you for saving my life!", | |
"I’m nervous about my exam tomorrow.", | |
"I love my new puppy so much!", | |
"I’m so relieved the storm passed." | |
], | |
inputs=text_input, | |
label="Try These Examples" | |
) | |
# Bind prediction | |
submit_btn.click( | |
fn=predict_emotions, | |
inputs=[text_input, confidence_slider], | |
outputs=[output_text, output_plot] | |
) | |
# Launch | |
if __name__ == "__main__": | |
demo.launch() |