File size: 1,002 Bytes
a2682b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a86dbdc
a2682b3
a86dbdc
a2682b3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch

class SummarizationModel:
    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tokenizer = T5Tokenizer.from_pretrained('unicamp-dl/ptt5-base-portuguese-vocab')
        self.model = T5ForConditionalGeneration.from_pretrained('recogna-nlp/ptt5-base-summ').to(self.device)
    
    def summarize(self, text: str, max_length: int = 256, min_length: int = 128) -> str:
        inputs = self.tokenizer.encode(
            text,
            max_length=512,
            truncation=True,
            return_tensors='pt'
        ).to(self.device)
        
        summary_ids = self.model.generate(
            inputs,
            max_length=max_length,
            min_length=min_length,
            num_beams=4,
            no_repeat_ngram_size=3,
            early_stopping=True,
        )
        
        return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)