File size: 1,417 Bytes
2086153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from transformers import BertTokenizer, BertForSequenceClassification
import torch

tokenizer = BertTokenizer.from_pretrained('./models/pretrained')
model = BertForSequenceClassification.from_pretrained('./models/pretrained')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

def model_predict(text: str):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    prediction = torch.argmax(logits, dim=1).item()
    return 'SPAM' if prediction == 1 else 'HAM'

def predict():
    text = "Hello, do you know with this crypto you can be rich? contact us in 88888"
    predicted_label = model_predict(text)
    print(f"1. Predicted class: {predicted_label}") # EXPECT: SPAM

    text = "Help me richard!"
    predicted_label = model_predict(text)
    print(f"2. Predicted class: {predicted_label}") # EXPECT: HAM

    text = "You can buy loopstation for 100$, try buyloopstation.com"
    predicted_label = model_predict(text)
    print(f"3. Predicted class: {predicted_label}") # EXPECT: SPAM

    text = "Mate, I try to contact your phone, where are you?"
    predicted_label = model_predict(text)
    print(f"4. Predicted class: {predicted_label}") # EXPECT: HAM

if __name__ == "__main__":
    predict()