File size: 1,824 Bytes
22e1b62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from typing import Any
from transformers import (
    AutoModelForTokenClassification,
    AutoTokenizer,
)
from nltk.tokenize import sent_tokenize
import torch
import numpy as np


class Detector:
    def __init__(self, model_name, device):
        if "classification" in model_name:
            num_labels = 2
        elif "multi-dimension" in model_name:
            num_labels = 3
        else:
            num_labels = 1
        self.model = AutoModelForTokenClassification.from_pretrained(
            model_name, num_labels=num_labels
        )
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.device = device

        self.model.to(device)
        self.model.eval()

    @torch.no_grad()
    def __call__(self, text, preprocess=True, threshold=None):
        """
        return_type: sentence or text
        """
        if preprocess:
            sents = sent_tokenize(text)
            text = " </s> ".join(sents)
        else:

            sents = text.split(" </s> ")
        input_ids = self.tokenizer(text, max_length=2048, truncation=True)["input_ids"]

        sent_label_idx = [i for i, ids in enumerate(input_ids) if ids == 2]

        tensor_input = torch.tensor([input_ids]).to(self.device)
        outputs = self.model(tensor_input).logits.detach().cpu().numpy()
        outputs_logits = outputs[0][sent_label_idx]
        outputs_logits: np.ndarray

        if outputs_logits.shape[1] == 2:
            outputs_logits = outputs_logits[:, 1]
        elif outputs_logits.shape[1] == 3:
            outputs_logits = outputs_logits.mean(axis=-1)
        outputs_logits = outputs_logits.flatten()
        if threshold is None:
            return list(zip(sents, outputs_logits.tolist()))
        else:
            return list(zip(sents, (outputs_logits > threshold).tolist()))