Spaces:
Runtime error
Runtime error
import torch | |
import pytorch_lightning as pl | |
import transformers as hf | |
import numpy as np | |
class LitModel(pl.LightningModule): | |
''' pytorch-lightning model ''' | |
def __init__(self, model, tokenizer, learning_rate = 5e-5): | |
super().__init__() | |
self.model = model | |
self.tokenizer = tokenizer | |
self.learning_rate = learning_rate | |
def freeze_embeds(self): | |
''' freeze the positional embedding parameters of the model ''' | |
freeze_params(self.model.model.shared) | |
for _ in [self.model.model.encoder, self.model.model.decoder]: | |
freeze_params(_.embed_positions) | |
freeze_params(_.embed_tokens) | |
def forward(self, input_ids, **kwargs): | |
return self.model(input_ids, **kwargs) | |
def configure_optimizers(self): | |
optimizer = torch.optim.Adam(self.parameters(), lr = self.learning_rate) | |
return optimizer | |
def training_step(self, batch, batch_idx): | |
# load the data into variables | |
src_ids, src_mask = batch[0], batch[1] | |
target_ids = batch[2] | |
# shift the decoder tokens right | |
decoder_input_ids = shift_tokens_right(target_ids, tokenizer.pad_token_id) | |
# run the model and get the logits | |
outputs = self( | |
src_ids, | |
attention_mask = src_mask, | |
decoder_input_ids = decoder_input_ids, | |
use_cache = False | |
) | |
logits = outputs[0] | |
# create the loss function | |
f_loss = torch.nn.CrossEntropyLoss(ignore_index = self.tokenizer.pad_token_id) | |
# calculate the loss on the unshifted tokens | |
loss = f_loss(logits.view(-1, logits.shape[-1]), target_ids.view(-1)) | |
return {'loss': loss} | |
def validation_step(self, batch, batch_idx): | |
src_ids, src_mask = batch[0], batch[1] | |
target_ids = batch[2] | |
decoder_input_ids = shift_tokens_right(target_ids, tokenizer.pad_token_id) | |
outputs = self( | |
src_ids, | |
attention_mask = src_mask, | |
decoder_input_ids = decoder_input_ids, | |
use_cache = False | |
) | |
logits = outputs[0] | |
f_loss = torch.nn.CrossEntropyLoss(ignore_index = self.tokenizer.pad_token_id) | |
loss = f_loss(logits.view(-1, logits.shape[-1]), target_ids.view(-1)) | |
self.log('loss', torch.tensor([loss])) | |
return {'loss': loss} | |
def generate(self, text, min_length = 40, max_length = 256, eval_beams = 4, early_stopping = True): | |
''' generate text ''' | |
# generated = self.model.generate( | |
# text, | |
# min_length = min_length, | |
# max_length = max_length, | |
# num_beams = eval_beams, | |
# early_stopping = early_stopping | |
# ) | |
generated = self.model.generate( | |
text['input_ids'], | |
attention_mask = text['attention_mask'], | |
use_cache = True, | |
decoder_start_token_id = self.tokenizer.pad_token_id, | |
min_length = min_length, | |
max_length = max_length, | |
num_beams = eval_beams, | |
early_stopping = early_stopping | |
) | |
return [self.tokenizer.decode( | |
w, | |
skip_special_tokens = True, | |
clean_up_tokenization_spaces = True | |
) for w in generated] | |
def freeze_params(model): | |
''' freeze layers of model for faster training ''' | |
for layer in model.parameters(): | |
layer.requires_grade = False | |
class SummaryDataModule(pl.LightningDataModule): | |
''' pytorch-lightning dataloading module ''' | |
def __init__(self, tokenizer, dataframe, batch_size, num_examples = 20000): | |
super().__init__() | |
self.tokenizer = tokenizer | |
self.dataframe = dataframe | |
self.batch_size = batch_size | |
self.num_examples = num_examples | |
def prepare_data(self, split = [0.6, 0.2, 0.2]): | |
''' loads and splits data ''' | |
self.data = self.dataframe[:self.num_examples] | |
self.train, self.validate, self.test = np.split( | |
self.data.sample(frac = 1), | |
[ | |
int(split[0] * len(self.data)), | |
int(sum([split[i] for i in range(2)]) * len(self.data)) | |
] | |
) | |
def setup(self, stage): | |
self.train = encode_sentences(self.tokenizer, self.train['source'], self.train['target']) | |
self.validate = encode_sentences(self.tokenizer, self.validate['source'], self.validate['target']) | |
self.test = encode_sentences(self.tokenizer, self.test['source'], self.test['target']) | |
def train_dataloader(self): | |
dataset = torch.utils.data.TensorDataset( | |
self.train['input_ids'], | |
self.train['attention_mask'], | |
self.train['labels'] | |
) | |
train_data = torch.utils.data.DataLoader( | |
dataset, | |
sampler = torch.utils.data.RandomSampler(dataset), | |
batch_size = self.batch_size | |
) | |
return train_data | |
def val_dataloader(self): | |
dataset = torch.utils.data.TensorDataset( | |
self.validate['input_ids'], | |
self.validate['attention_mask'], | |
self.validate['labels'] | |
) | |
val_data = torch.utils.data.DataLoader( | |
dataset, | |
batch_size = self.batch_size | |
) | |
return val_data | |
def test_dataloader(self): | |
dataset = torch.utils.data.TensorDataset( | |
self.test['input_ids'], | |
self.test['attention_mask'], | |
self.test['labels'] | |
) | |
test_data = torch.utils.data.DataLoader( | |
dataset, | |
batch_size = self.batch_size | |
) | |
return test_data | |
def shift_tokens_right(input_ids, pad_token_id): | |
prev_output_tokens = input_ids.clone() | |
index_of_eos = (input_ids.ne(pad_token_id).sum(dim = 1) - 1).unsqueeze(-1) | |
prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze() | |
prev_output_tokens[:, 1:] = input_ids[:, :-1] | |
return prev_output_tokens | |
def encode_sentences(tokenizer, source_sentences, target_sentences, max_length = 128, pad_to_max_length = True, return_tensors = 'pt'): | |
input_ids = [] | |
attention_masks = [] | |
target_ids = [] | |
tokenized_sentences = {} | |
for s in source_sentences: | |
encoded_dict = tokenizer( | |
s, | |
max_length = max_length, | |
padding = 'max_length' if pad_to_max_length else None, | |
truncation = True, | |
return_tensors = return_tensors, | |
add_prefix_space = True | |
) | |
input_ids.append(encoded_dict['input_ids']) | |
attention_masks.append(encoded_dict['attention_mask']) | |
input_ids = torch.cat(input_ids, dim = 0) | |
attention_masks = torch.cat(attention_masks, dim = 0) | |
for s in target_sentences: | |
encoded_dict = tokenizer( | |
s, | |
max_length = max_length, | |
padding = 'max_length' if pad_to_max_length else None, | |
truncation = True, | |
return_tensors = return_tensors, | |
add_prefix_space = True | |
) | |
target_ids.append(encoded_dict['input_ids']) | |
target_ids = torch.cat(target_ids, dim = 0) | |
batch = { | |
'input_ids': input_ids, | |
'attention_mask': attention_masks, | |
'labels': target_ids | |
} | |
return batch | |
tokenizer = hf.BartTokenizer.from_pretrained('sshleifer/distilbart-cnn-12-6', add_prefix_space = True) | |
base_model = hf.BartForConditionalGeneration.from_pretrained('sshleifer/distilbart-cnn-12-6') |