import torch import pytorch_lightning as pl import transformers as hf import numpy as np class LitModel(pl.LightningModule): ''' pytorch-lightning model ''' def __init__(self, model, tokenizer, learning_rate = 5e-5): super().__init__() self.model = model self.tokenizer = tokenizer self.learning_rate = learning_rate def freeze_embeds(self): ''' freeze the positional embedding parameters of the model ''' freeze_params(self.model.model.shared) for _ in [self.model.model.encoder, self.model.model.decoder]: freeze_params(_.embed_positions) freeze_params(_.embed_tokens) def forward(self, input_ids, **kwargs): return self.model(input_ids, **kwargs) def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr = self.learning_rate) return optimizer def training_step(self, batch, batch_idx): # load the data into variables src_ids, src_mask = batch[0], batch[1] target_ids = batch[2] # shift the decoder tokens right decoder_input_ids = shift_tokens_right(target_ids, tokenizer.pad_token_id) # run the model and get the logits outputs = self( src_ids, attention_mask = src_mask, decoder_input_ids = decoder_input_ids, use_cache = False ) logits = outputs[0] # create the loss function f_loss = torch.nn.CrossEntropyLoss(ignore_index = self.tokenizer.pad_token_id) # calculate the loss on the unshifted tokens loss = f_loss(logits.view(-1, logits.shape[-1]), target_ids.view(-1)) return {'loss': loss} def validation_step(self, batch, batch_idx): src_ids, src_mask = batch[0], batch[1] target_ids = batch[2] decoder_input_ids = shift_tokens_right(target_ids, tokenizer.pad_token_id) outputs = self( src_ids, attention_mask = src_mask, decoder_input_ids = decoder_input_ids, use_cache = False ) logits = outputs[0] f_loss = torch.nn.CrossEntropyLoss(ignore_index = self.tokenizer.pad_token_id) loss = f_loss(logits.view(-1, logits.shape[-1]), target_ids.view(-1)) self.log('loss', torch.tensor([loss])) return {'loss': loss} def generate(self, text, min_length = 40, max_length = 256, eval_beams = 4, early_stopping = True): ''' generate text ''' # generated = self.model.generate( # text, # min_length = min_length, # max_length = max_length, # num_beams = eval_beams, # early_stopping = early_stopping # ) generated = self.model.generate( text['input_ids'], attention_mask = text['attention_mask'], use_cache = True, decoder_start_token_id = self.tokenizer.pad_token_id, min_length = min_length, max_length = max_length, num_beams = eval_beams, early_stopping = early_stopping ) return [self.tokenizer.decode( w, skip_special_tokens = True, clean_up_tokenization_spaces = True ) for w in generated] def freeze_params(model): ''' freeze layers of model for faster training ''' for layer in model.parameters(): layer.requires_grade = False class SummaryDataModule(pl.LightningDataModule): ''' pytorch-lightning dataloading module ''' def __init__(self, tokenizer, dataframe, batch_size, num_examples = 20000): super().__init__() self.tokenizer = tokenizer self.dataframe = dataframe self.batch_size = batch_size self.num_examples = num_examples def prepare_data(self, split = [0.6, 0.2, 0.2]): ''' loads and splits data ''' self.data = self.dataframe[:self.num_examples] self.train, self.validate, self.test = np.split( self.data.sample(frac = 1), [ int(split[0] * len(self.data)), int(sum([split[i] for i in range(2)]) * len(self.data)) ] ) def setup(self, stage): self.train = encode_sentences(self.tokenizer, self.train['source'], self.train['target']) self.validate = encode_sentences(self.tokenizer, self.validate['source'], self.validate['target']) self.test = encode_sentences(self.tokenizer, self.test['source'], self.test['target']) def train_dataloader(self): dataset = torch.utils.data.TensorDataset( self.train['input_ids'], self.train['attention_mask'], self.train['labels'] ) train_data = torch.utils.data.DataLoader( dataset, sampler = torch.utils.data.RandomSampler(dataset), batch_size = self.batch_size ) return train_data def val_dataloader(self): dataset = torch.utils.data.TensorDataset( self.validate['input_ids'], self.validate['attention_mask'], self.validate['labels'] ) val_data = torch.utils.data.DataLoader( dataset, batch_size = self.batch_size ) return val_data def test_dataloader(self): dataset = torch.utils.data.TensorDataset( self.test['input_ids'], self.test['attention_mask'], self.test['labels'] ) test_data = torch.utils.data.DataLoader( dataset, batch_size = self.batch_size ) return test_data def shift_tokens_right(input_ids, pad_token_id): prev_output_tokens = input_ids.clone() index_of_eos = (input_ids.ne(pad_token_id).sum(dim = 1) - 1).unsqueeze(-1) prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze() prev_output_tokens[:, 1:] = input_ids[:, :-1] return prev_output_tokens def encode_sentences(tokenizer, source_sentences, target_sentences, max_length = 128, pad_to_max_length = True, return_tensors = 'pt'): input_ids = [] attention_masks = [] target_ids = [] tokenized_sentences = {} for s in source_sentences: encoded_dict = tokenizer( s, max_length = max_length, padding = 'max_length' if pad_to_max_length else None, truncation = True, return_tensors = return_tensors, add_prefix_space = True ) input_ids.append(encoded_dict['input_ids']) attention_masks.append(encoded_dict['attention_mask']) input_ids = torch.cat(input_ids, dim = 0) attention_masks = torch.cat(attention_masks, dim = 0) for s in target_sentences: encoded_dict = tokenizer( s, max_length = max_length, padding = 'max_length' if pad_to_max_length else None, truncation = True, return_tensors = return_tensors, add_prefix_space = True ) target_ids.append(encoded_dict['input_ids']) target_ids = torch.cat(target_ids, dim = 0) batch = { 'input_ids': input_ids, 'attention_mask': attention_masks, 'labels': target_ids } return batch tokenizer = hf.BartTokenizer.from_pretrained('sshleifer/distilbart-cnn-12-6', add_prefix_space = True) base_model = hf.BartForConditionalGeneration.from_pretrained('sshleifer/distilbart-cnn-12-6')