File size: 1,417 Bytes
2086153 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
from transformers import BertTokenizer, BertForSequenceClassification
import torch
tokenizer = BertTokenizer.from_pretrained('./models/pretrained')
model = BertForSequenceClassification.from_pretrained('./models/pretrained')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
def model_predict(text: str):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
prediction = torch.argmax(logits, dim=1).item()
return 'SPAM' if prediction == 1 else 'HAM'
def predict():
text = "Hello, do you know with this crypto you can be rich? contact us in 88888"
predicted_label = model_predict(text)
print(f"1. Predicted class: {predicted_label}") # EXPECT: SPAM
text = "Help me richard!"
predicted_label = model_predict(text)
print(f"2. Predicted class: {predicted_label}") # EXPECT: HAM
text = "You can buy loopstation for 100$, try buyloopstation.com"
predicted_label = model_predict(text)
print(f"3. Predicted class: {predicted_label}") # EXPECT: SPAM
text = "Mate, I try to contact your phone, where are you?"
predicted_label = model_predict(text)
print(f"4. Predicted class: {predicted_label}") # EXPECT: HAM
if __name__ == "__main__":
predict() |