import gradio as gr import torch import numpy as np from transformers import BertForSequenceClassification, BertTokenizer import requests import json import plotly.express as px import pandas as pd # Load model and tokenizer from Hugging Face Hub repo_id = "logasanjeev/goemotions-bert" model = BertForSequenceClassification.from_pretrained(repo_id) tokenizer = BertTokenizer.from_pretrained(repo_id) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) if torch.cuda.device_count() > 1: model = nn.DataParallel(model) model.eval() # Load optimized thresholds from Hugging Face Hub thresholds_url = f"https://huggingface.co/{repo_id}/raw/main/thresholds.json" response = requests.get(thresholds_url) thresholds_data = json.loads(response.text) emotion_labels = thresholds_data["emotion_labels"] default_thresholds = thresholds_data["thresholds"] # Prediction function def predict_emotions(text, confidence_threshold=0.0): encodings = tokenizer( text, padding='max_length', truncation=True, max_length=128, return_tensors='pt' ) input_ids = encodings['input_ids'].to(device) attention_mask = encodings['attention_mask'].to(device) with torch.no_grad(): outputs = model(input_ids, attention_mask=attention_mask) logits = torch.sigmoid(outputs.logits).cpu().numpy()[0] # Apply thresholds with user-defined confidence boost predictions = [] for i, (logit, thresh) in enumerate(zip(logits, default_thresholds)): adjusted_thresh = max(thresh, confidence_threshold) if logit >= adjusted_thresh: predictions.append((emotion_labels[i], logit)) predictions.sort(key=lambda x: x[1], reverse=True) if not predictions: return "No emotions predicted above thresholds.", None # Format output text_output = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in predictions]) # Create bar chart df = pd.DataFrame(predictions, columns=["Emotion", "Confidence"]) fig = px.bar( df, x="Emotion", y="Confidence", color="Emotion", text="Confidence", title="Emotion Confidence Levels", height=400 ) fig.update_traces(texttemplate='%{text:.2f}', textposition='auto') fig.update_layout(showlegend=False, margin=dict(t=40, b=40)) return text_output, fig # Custom CSS for modern UI custom_css = """ body { font-family: 'Segoe UI', Arial, sans-serif; } .gr-panel { border-radius: 12px; box-shadow: 0 4px 12px rgba(0,0,0,0.1); background: linear-gradient(145deg, #ffffff, #f0f4f8); } .gr-button { border-radius: 8px; background: #007bff; color: white; padding: 10px 20px; transition: background 0.3s; } .gr-button:hover { background: #0056b3; } #title { font-size: 2.5em; color: #1a3c6e; text-align: center; margin-bottom: 20px; } #description { font-size: 1.1em; color: #333; text-align: center; max-width: 700px; margin: 0 auto; } #theme-toggle { position: absolute; top: 20px; right: 20px; } .dark-mode { background: #1a1a1a; color: #e0e0e0; } .dark-mode .gr-panel { background: linear-gradient(145deg, #2a2a2a, #3a3a3a); } .dark-mode #title { color: #66b3ff; } .dark-mode #description { color: #b0b0b0; } """ # JavaScript for theme toggle theme_js = """ function toggleTheme() { document.body.classList.toggle('dark-mode'); } """ # Gradio Blocks UI with gr.Blocks(css=custom_css) as demo: # Header gr.Markdown("