mmchowdhury commited on
Commit
a7a9ad2
·
verified ·
1 Parent(s): a6befd9

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +78 -0
model.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import pandas as pd
3
+ import numpy as np
4
+ import torch
5
+ from pathlib import Path
6
+ import lightning as pl
7
+ from lightning.pytorch.callbacks import ModelCheckpoint
8
+ from lightning.pytorch.loggers import TensorBoardLogger
9
+ from torch.utils.data import Dataset, DataLoader
10
+ import textwrap
11
+ from transformers import (
12
+ AdamW,
13
+ T5ForConditionalGeneration,
14
+ T5TokenizerFast as T5Tokenizer
15
+ )
16
+ from tqdm.auto import tqdm
17
+
18
+ class NewsSummaryModel(pl.LightningModule):
19
+ def __init__(self):
20
+ super().__init__()
21
+ self.model= T5ForConditionalGeneration.from_pretrained("t5-base", return_dict=True)
22
+ def forward(self,input_ids, attention_mask, decoder_attention_mask, labels=None):
23
+ output = self.model(
24
+ input_ids,
25
+ attention_mask=attention_mask,
26
+ labels=labels,
27
+ decoder_attention_mask=decoder_attention_mask
28
+ )
29
+ return output.loss, output.logits
30
+ def training_step(self, batch, batch_idx):
31
+ input_ids=batch["text_input_ids"]
32
+ attention_mask=batch["text_attention_mask"]
33
+ labels=batch["labels"]
34
+ labels_attention_mask=batch["labels_attention_mask"]
35
+
36
+
37
+ loss, outputs = self(
38
+ input_ids=input_ids,
39
+ attention_mask=attention_mask,
40
+ decoder_attention_mask=labels_attention_mask,
41
+ labels=labels
42
+ )
43
+ self.log("train_loss", loss, prog_bar=True, logger=True)
44
+ return loss
45
+ def validation_step(self, batch, batch_idx):
46
+ input_ids=batch["text_input_ids"]
47
+ attention_mask=batch["text_attention_mask"]
48
+ labels=batch["labels"]
49
+ labels_attention_mask=batch["labels_attention_mask"]
50
+
51
+
52
+ loss, outputs = self(
53
+ input_ids=input_ids,
54
+ attention_mask=attention_mask,
55
+ decoder_attention_mask=labels_attention_mask,
56
+ labels=labels
57
+ )
58
+ self.log("val_loss", loss, prog_bar=True, logger=True)
59
+ return loss
60
+ def test_step(self, batch, batch_idx):
61
+ input_ids=batch["text_input_ids"]
62
+ attention_mask=batch["text_attention_mask"]
63
+ labels=batch["labels"]
64
+ labels_attention_mask=batch["labels_attention_mask"]
65
+
66
+
67
+ loss, outputs = self(
68
+ input_ids=input_ids,
69
+ attention_mask=attention_mask,
70
+ decoder_attention_mask=labels_attention_mask,
71
+ labels=labels
72
+ )
73
+ self.log("test_loss", loss, prog_bar=True, logger=True)
74
+ return loss
75
+ def configure_optimizers(self):
76
+ return AdamW(self.parameters(), lr=0.0001)
77
+
78
+