File size: 1,002 Bytes
a2682b3 a86dbdc a2682b3 a86dbdc a2682b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch
class SummarizationModel:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = T5Tokenizer.from_pretrained('unicamp-dl/ptt5-base-portuguese-vocab')
self.model = T5ForConditionalGeneration.from_pretrained('recogna-nlp/ptt5-base-summ').to(self.device)
def summarize(self, text: str, max_length: int = 256, min_length: int = 128) -> str:
inputs = self.tokenizer.encode(
text,
max_length=512,
truncation=True,
return_tensors='pt'
).to(self.device)
summary_ids = self.model.generate(
inputs,
max_length=max_length,
min_length=min_length,
num_beams=4,
no_repeat_ngram_size=3,
early_stopping=True,
)
return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True) |