shoukaku's picture
Upload summarizer.py
e319ff3
import torch
import pytorch_lightning as pl
import transformers as hf
import numpy as np
class LitModel(pl.LightningModule):
''' pytorch-lightning model '''
def __init__(self, model, tokenizer, learning_rate = 5e-5):
super().__init__()
self.model = model
self.tokenizer = tokenizer
self.learning_rate = learning_rate
def freeze_embeds(self):
''' freeze the positional embedding parameters of the model '''
freeze_params(self.model.model.shared)
for _ in [self.model.model.encoder, self.model.model.decoder]:
freeze_params(_.embed_positions)
freeze_params(_.embed_tokens)
def forward(self, input_ids, **kwargs):
return self.model(input_ids, **kwargs)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr = self.learning_rate)
return optimizer
def training_step(self, batch, batch_idx):
# load the data into variables
src_ids, src_mask = batch[0], batch[1]
target_ids = batch[2]
# shift the decoder tokens right
decoder_input_ids = shift_tokens_right(target_ids, tokenizer.pad_token_id)
# run the model and get the logits
outputs = self(
src_ids,
attention_mask = src_mask,
decoder_input_ids = decoder_input_ids,
use_cache = False
)
logits = outputs[0]
# create the loss function
f_loss = torch.nn.CrossEntropyLoss(ignore_index = self.tokenizer.pad_token_id)
# calculate the loss on the unshifted tokens
loss = f_loss(logits.view(-1, logits.shape[-1]), target_ids.view(-1))
return {'loss': loss}
def validation_step(self, batch, batch_idx):
src_ids, src_mask = batch[0], batch[1]
target_ids = batch[2]
decoder_input_ids = shift_tokens_right(target_ids, tokenizer.pad_token_id)
outputs = self(
src_ids,
attention_mask = src_mask,
decoder_input_ids = decoder_input_ids,
use_cache = False
)
logits = outputs[0]
f_loss = torch.nn.CrossEntropyLoss(ignore_index = self.tokenizer.pad_token_id)
loss = f_loss(logits.view(-1, logits.shape[-1]), target_ids.view(-1))
self.log('loss', torch.tensor([loss]))
return {'loss': loss}
def generate(self, text, min_length = 40, max_length = 256, eval_beams = 4, early_stopping = True):
''' generate text '''
# generated = self.model.generate(
# text,
# min_length = min_length,
# max_length = max_length,
# num_beams = eval_beams,
# early_stopping = early_stopping
# )
generated = self.model.generate(
text['input_ids'],
attention_mask = text['attention_mask'],
use_cache = True,
decoder_start_token_id = self.tokenizer.pad_token_id,
min_length = min_length,
max_length = max_length,
num_beams = eval_beams,
early_stopping = early_stopping
)
return [self.tokenizer.decode(
w,
skip_special_tokens = True,
clean_up_tokenization_spaces = True
) for w in generated]
def freeze_params(model):
''' freeze layers of model for faster training '''
for layer in model.parameters():
layer.requires_grade = False
class SummaryDataModule(pl.LightningDataModule):
''' pytorch-lightning dataloading module '''
def __init__(self, tokenizer, dataframe, batch_size, num_examples = 20000):
super().__init__()
self.tokenizer = tokenizer
self.dataframe = dataframe
self.batch_size = batch_size
self.num_examples = num_examples
def prepare_data(self, split = [0.6, 0.2, 0.2]):
''' loads and splits data '''
self.data = self.dataframe[:self.num_examples]
self.train, self.validate, self.test = np.split(
self.data.sample(frac = 1),
[
int(split[0] * len(self.data)),
int(sum([split[i] for i in range(2)]) * len(self.data))
]
)
def setup(self, stage):
self.train = encode_sentences(self.tokenizer, self.train['source'], self.train['target'])
self.validate = encode_sentences(self.tokenizer, self.validate['source'], self.validate['target'])
self.test = encode_sentences(self.tokenizer, self.test['source'], self.test['target'])
def train_dataloader(self):
dataset = torch.utils.data.TensorDataset(
self.train['input_ids'],
self.train['attention_mask'],
self.train['labels']
)
train_data = torch.utils.data.DataLoader(
dataset,
sampler = torch.utils.data.RandomSampler(dataset),
batch_size = self.batch_size
)
return train_data
def val_dataloader(self):
dataset = torch.utils.data.TensorDataset(
self.validate['input_ids'],
self.validate['attention_mask'],
self.validate['labels']
)
val_data = torch.utils.data.DataLoader(
dataset,
batch_size = self.batch_size
)
return val_data
def test_dataloader(self):
dataset = torch.utils.data.TensorDataset(
self.test['input_ids'],
self.test['attention_mask'],
self.test['labels']
)
test_data = torch.utils.data.DataLoader(
dataset,
batch_size = self.batch_size
)
return test_data
def shift_tokens_right(input_ids, pad_token_id):
prev_output_tokens = input_ids.clone()
index_of_eos = (input_ids.ne(pad_token_id).sum(dim = 1) - 1).unsqueeze(-1)
prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
prev_output_tokens[:, 1:] = input_ids[:, :-1]
return prev_output_tokens
def encode_sentences(tokenizer, source_sentences, target_sentences, max_length = 128, pad_to_max_length = True, return_tensors = 'pt'):
input_ids = []
attention_masks = []
target_ids = []
tokenized_sentences = {}
for s in source_sentences:
encoded_dict = tokenizer(
s,
max_length = max_length,
padding = 'max_length' if pad_to_max_length else None,
truncation = True,
return_tensors = return_tensors,
add_prefix_space = True
)
input_ids.append(encoded_dict['input_ids'])
attention_masks.append(encoded_dict['attention_mask'])
input_ids = torch.cat(input_ids, dim = 0)
attention_masks = torch.cat(attention_masks, dim = 0)
for s in target_sentences:
encoded_dict = tokenizer(
s,
max_length = max_length,
padding = 'max_length' if pad_to_max_length else None,
truncation = True,
return_tensors = return_tensors,
add_prefix_space = True
)
target_ids.append(encoded_dict['input_ids'])
target_ids = torch.cat(target_ids, dim = 0)
batch = {
'input_ids': input_ids,
'attention_mask': attention_masks,
'labels': target_ids
}
return batch
tokenizer = hf.BartTokenizer.from_pretrained('sshleifer/distilbart-cnn-12-6', add_prefix_space = True)
base_model = hf.BartForConditionalGeneration.from_pretrained('sshleifer/distilbart-cnn-12-6')