|
import json |
|
import pandas as pd |
|
import numpy as np |
|
import torch |
|
from pathlib import Path |
|
import lightning as pl |
|
from lightning.pytorch.callbacks import ModelCheckpoint |
|
from lightning.pytorch.loggers import TensorBoardLogger |
|
from torch.utils.data import Dataset, DataLoader |
|
import textwrap |
|
from transformers import ( |
|
AdamW, |
|
T5ForConditionalGeneration, |
|
T5TokenizerFast as T5Tokenizer |
|
) |
|
from tqdm.auto import tqdm |
|
|
|
class NewsSummaryModel(pl.LightningModule): |
|
def __init__(self): |
|
super().__init__() |
|
self.model= T5ForConditionalGeneration.from_pretrained("t5-base", return_dict=True) |
|
def forward(self,input_ids, attention_mask, decoder_attention_mask, labels=None): |
|
output = self.model( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
labels=labels, |
|
decoder_attention_mask=decoder_attention_mask |
|
) |
|
return output.loss, output.logits |
|
def training_step(self, batch, batch_idx): |
|
input_ids=batch["text_input_ids"] |
|
attention_mask=batch["text_attention_mask"] |
|
labels=batch["labels"] |
|
labels_attention_mask=batch["labels_attention_mask"] |
|
|
|
|
|
loss, outputs = self( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
decoder_attention_mask=labels_attention_mask, |
|
labels=labels |
|
) |
|
self.log("train_loss", loss, prog_bar=True, logger=True) |
|
return loss |
|
def validation_step(self, batch, batch_idx): |
|
input_ids=batch["text_input_ids"] |
|
attention_mask=batch["text_attention_mask"] |
|
labels=batch["labels"] |
|
labels_attention_mask=batch["labels_attention_mask"] |
|
|
|
|
|
loss, outputs = self( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
decoder_attention_mask=labels_attention_mask, |
|
labels=labels |
|
) |
|
self.log("val_loss", loss, prog_bar=True, logger=True) |
|
return loss |
|
def test_step(self, batch, batch_idx): |
|
input_ids=batch["text_input_ids"] |
|
attention_mask=batch["text_attention_mask"] |
|
labels=batch["labels"] |
|
labels_attention_mask=batch["labels_attention_mask"] |
|
|
|
|
|
loss, outputs = self( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
decoder_attention_mask=labels_attention_mask, |
|
labels=labels |
|
) |
|
self.log("test_loss", loss, prog_bar=True, logger=True) |
|
return loss |
|
def configure_optimizers(self): |
|
return AdamW(self.parameters(), lr=0.0001) |
|
|
|
|
|
|