Govi2020 commited on
Commit
c2e9676
·
verified ·
1 Parent(s): 0dd2eb3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -4
app.py CHANGED
@@ -1,6 +1,21 @@
1
- # Use a pipeline as a high-level helper
2
- from transformers import pipeline
3
 
4
- pipe = pipeline("text-classification", model="priyabrat/bert_categorisation")
 
5
 
6
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, BertForSequenceClassification
3
 
4
+ tokenizer = AutoTokenizer.from_pretrained("textattack/bert-base-uncased-yelp-polarity")
5
+ model = BertForSequenceClassification.from_pretrained("textattack/bert-base-uncased-yelp-polarity")
6
 
7
+ inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
8
+
9
+ with torch.no_grad():
10
+ logits = model(**inputs).logits
11
+
12
+ predicted_class_id = logits.argmax().item()
13
+ model.config.id2label[predicted_class_id]
14
+
15
+ # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
16
+ num_labels = len(model.config.id2label)
17
+ model = BertForSequenceClassification.from_pretrained("textattack/bert-base-uncased-yelp-polarity", num_labels=num_labels)
18
+
19
+ labels = torch.tensor([1])
20
+ loss = model(**inputs, labels=labels).loss
21
+ round(loss.item(), 2)