Spaces:
Sleeping
Sleeping
File size: 1,824 Bytes
22e1b62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
from typing import Any
from transformers import (
AutoModelForTokenClassification,
AutoTokenizer,
)
from nltk.tokenize import sent_tokenize
import torch
import numpy as np
class Detector:
def __init__(self, model_name, device):
if "classification" in model_name:
num_labels = 2
elif "multi-dimension" in model_name:
num_labels = 3
else:
num_labels = 1
self.model = AutoModelForTokenClassification.from_pretrained(
model_name, num_labels=num_labels
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.device = device
self.model.to(device)
self.model.eval()
@torch.no_grad()
def __call__(self, text, preprocess=True, threshold=None):
"""
return_type: sentence or text
"""
if preprocess:
sents = sent_tokenize(text)
text = " </s> ".join(sents)
else:
sents = text.split(" </s> ")
input_ids = self.tokenizer(text, max_length=2048, truncation=True)["input_ids"]
sent_label_idx = [i for i, ids in enumerate(input_ids) if ids == 2]
tensor_input = torch.tensor([input_ids]).to(self.device)
outputs = self.model(tensor_input).logits.detach().cpu().numpy()
outputs_logits = outputs[0][sent_label_idx]
outputs_logits: np.ndarray
if outputs_logits.shape[1] == 2:
outputs_logits = outputs_logits[:, 1]
elif outputs_logits.shape[1] == 3:
outputs_logits = outputs_logits.mean(axis=-1)
outputs_logits = outputs_logits.flatten()
if threshold is None:
return list(zip(sents, outputs_logits.tolist()))
else:
return list(zip(sents, (outputs_logits > threshold).tolist())) |