logasanjeev commited on
Commit
c885400
·
verified ·
1 Parent(s): 300e0f8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from transformers import BertForSequenceClassification, BertTokenizer
5
+ import requests
6
+ import json
7
+
8
+ # Load model and tokenizer from Hugging Face Hub
9
+ repo_id = "logasanjeev/goemotions-bert"
10
+ model = BertForSequenceClassification.from_pretrained(repo_id)
11
+ tokenizer = BertTokenizer.from_pretrained(repo_id)
12
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
+ model.to(device)
14
+ if torch.cuda.device_count() > 1:
15
+ model = nn.DataParallel(model)
16
+ model.eval()
17
+
18
+ # Load optimized thresholds from Hugging Face Hub
19
+ thresholds_url = f"https://huggingface.co/{repo_id}/raw/main/thresholds.json"
20
+ response = requests.get(thresholds_url)
21
+ thresholds_data = json.loads(response.text)
22
+ emotion_labels = thresholds_data["emotion_labels"]
23
+ best_thresholds = thresholds_data["thresholds"]
24
+
25
+ # Prediction function
26
+ def predict_emotions(text):
27
+ encodings = tokenizer(
28
+ text,
29
+ padding='max_length',
30
+ truncation=True,
31
+ max_length=128,
32
+ return_tensors='pt'
33
+ )
34
+ input_ids = encodings['input_ids'].to(device)
35
+ attention_mask = encodings['attention_mask'].to(device)
36
+
37
+ with torch.no_grad():
38
+ outputs = model(input_ids, attention_mask=attention_mask)
39
+ logits = torch.sigmoid(outputs.logits).cpu().numpy()[0]
40
+
41
+ predictions = []
42
+ for i, (logit, thresh) in enumerate(zip(logits, best_thresholds)):
43
+ if logit >= thresh:
44
+ predictions.append((emotion_labels[i], logit))
45
+
46
+ predictions.sort(key=lambda x: x[1], reverse=True)
47
+ if not predictions:
48
+ return "No emotions predicted above thresholds."
49
+
50
+ return "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in predictions])
51
+
52
+ # Gradio interface
53
+ interface = gr.Interface(
54
+ fn=predict_emotions,
55
+ inputs=gr.Textbox(lines=2, placeholder="Enter your text here..."),
56
+ outputs="text",
57
+ title="GoEmotions BERT Classifier",
58
+ description="Predict emotions using a fine-tuned BERT-base model from logasanjeev/goemotions-bert.",
59
+ examples=[
60
+ "I’m just chilling today.",
61
+ "Thank you for saving my life!",
62
+ "I’m nervous about my exam tomorrow."
63
+ ]
64
+ )
65
+
66
+ if __name__ == "__main__":
67
+ interface.launch()