Create model.py
Browse files
model.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import lightning
|
5 |
+
from safetensors.torch import save_file
|
6 |
+
|
7 |
+
class Config:
|
8 |
+
vocab_size = 50304
|
9 |
+
n_epochs = 50
|
10 |
+
batch_size = 36
|
11 |
+
lr = 3e-4
|
12 |
+
wd = 1e-6
|
13 |
+
n_embed = 256
|
14 |
+
num_blocks = 12
|
15 |
+
num_heads = 12
|
16 |
+
head_size = n_embed//num_heads
|
17 |
+
context_len = 224
|
18 |
+
attn_dropout_val = 0.2
|
19 |
+
mha_dropout_val = 0.2
|
20 |
+
ffn_dropout_val = 0.2
|
21 |
+
|
22 |
+
class CausalAttentionHead(nn.Module):
|
23 |
+
def __init__(self, config):
|
24 |
+
super(CausalAttentionHead, self).__init__()
|
25 |
+
self.config = config
|
26 |
+
|
27 |
+
self.query = nn.Linear(config.n_embed, config.head_size, bias=False)
|
28 |
+
self.key = nn.Linear(config.n_embed, config.head_size, bias=False)
|
29 |
+
self.value = nn.Linear(config.n_embed, config.head_size, bias=False)
|
30 |
+
self.attn_drop = nn.Dropout(config.attn_dropout_val)
|
31 |
+
# mask for causal attention during training
|
32 |
+
self.register_buffer("mask", torch.tril(torch.ones(config.context_len, config.context_len)))
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
bs, context_len, embed_dim = x.shape
|
36 |
+
q, k, v = self.query(x), self.key(x), self.value(x)
|
37 |
+
attn_filter = torch.divide(torch.bmm(q, k.transpose(1, 2)), self.config.head_size)
|
38 |
+
attn_filter = attn_filter.masked_fill(self.mask[:context_len, :context_len]==0, float("-inf"))
|
39 |
+
attn_weights = F.softmax(attn_filter, dim=-1)
|
40 |
+
attn_weights = self.attn_drop(attn_weights)
|
41 |
+
output = torch.bmm(attn_weights, v)
|
42 |
+
return output
|
43 |
+
|
44 |
+
class MultiHeadedAttention(nn.Module):
|
45 |
+
def __init__(self, config):
|
46 |
+
super(MultiHeadedAttention, self).__init__()
|
47 |
+
self.config = config
|
48 |
+
self.heads = nn.ModuleList(
|
49 |
+
[CausalAttentionHead(config) for _ in range(config.num_heads)]
|
50 |
+
)
|
51 |
+
self.proj = nn.Linear(config.num_heads*config.head_size, config.n_embed)
|
52 |
+
self.mha_drop = nn.Dropout(config.mha_dropout_val)
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
mha_output = torch.cat([head(x) for head in self.heads], dim=-1)
|
56 |
+
return self.mha_drop(self.proj(mha_output))
|
57 |
+
|
58 |
+
class FeedForwardNetwork(nn.Module):
|
59 |
+
def __init__(self, config):
|
60 |
+
super(FeedForwardNetwork, self).__init__()
|
61 |
+
|
62 |
+
self.ffn = nn.Sequential(
|
63 |
+
nn.Linear(config.n_embed, config.n_embed*4),
|
64 |
+
nn.GELU(),
|
65 |
+
nn.Linear(config.n_embed*4, config.n_embed),
|
66 |
+
nn.Dropout()
|
67 |
+
)
|
68 |
+
def forward(self, x):
|
69 |
+
return self.ffn(x)
|
70 |
+
|
71 |
+
class Block(nn.Module):
|
72 |
+
def __init__(self, config):
|
73 |
+
super(Block, self).__init__()
|
74 |
+
self.mha = MultiHeadedAttention(config)
|
75 |
+
self.ln1 = nn.LayerNorm(config.n_embed)
|
76 |
+
self.ffn = FeedForwardNetwork(config)
|
77 |
+
self.ln2 = nn.LayerNorm(config.n_embed)
|
78 |
+
|
79 |
+
def forward(self, x):
|
80 |
+
x = self.ln1(x+self.mha(x))
|
81 |
+
x = self.ln2(x+self.ffn(x))
|
82 |
+
return x
|
83 |
+
|
84 |
+
class GPT(lightning.LightningModule):
|
85 |
+
def __init__(self, config):
|
86 |
+
super(GPT, self).__init__()
|
87 |
+
self.config = config
|
88 |
+
self.save_hyperparameters()
|
89 |
+
self.token_embedding = nn.Embedding(config.vocab_size, config.n_embed)
|
90 |
+
self.positional_embedding = nn.Embedding(config.context_len, config.n_embed)
|
91 |
+
self.backbone = nn.Sequential(*[Block(config) for _ in range(config.num_blocks)])
|
92 |
+
self.lm_head = nn.Linear(config.n_embed, config.vocab_size)
|
93 |
+
|
94 |
+
def forward(self, x):
|
95 |
+
tok_emb = self.token_embedding(x)
|
96 |
+
pos_emb = self.positional_embedding(torch.arange(x.shape[1], device=self.device))
|
97 |
+
x = tok_emb+pos_emb
|
98 |
+
x = self.backbone(x)
|
99 |
+
logits = self.lm_head(x)
|
100 |
+
return logits
|
101 |
+
|
102 |
+
def get_loss(self, predictions, target):
|
103 |
+
B, C, V = predictions.shape
|
104 |
+
predictions = predictions.view(B*C, V)
|
105 |
+
target = target.view(B*C)
|
106 |
+
loss = F.cross_entropy(predictions, target)
|
107 |
+
return loss
|
108 |
+
|
109 |
+
def training_step(self, batch, batch_idx):
|
110 |
+
text, target = batch
|
111 |
+
text = text.long()
|
112 |
+
target = target.long()
|
113 |
+
logits = self(text)
|
114 |
+
loss = self.get_loss(logits, target)
|
115 |
+
|
116 |
+
self.log('loss', loss.item(), prog_bar=True)
|
117 |
+
logs = {'loss': loss}
|
118 |
+
|
119 |
+
return {"log": logs, "loss": loss}
|
120 |
+
|
121 |
+
def training_end(self, outputs):
|
122 |
+
avg_loss = torch.stack([x['log']['loss'] for x in outputs]).mean()
|
123 |
+
logs = {"log": avg_loss}
|
124 |
+
print(f"val_loss: {avg_loss}")
|
125 |
+
return {"log": logs}
|
126 |
+
|
127 |
+
def configure_optimizers(self):
|
128 |
+
opt = torch.optim.AdamW(self.parameters(), lr=self.config.lr, weight_decay=self.config.wd)
|
129 |
+
return [opt], []
|
130 |
+
|
131 |
+
if __name__ == "__main__":
|
132 |
+
config = Config()
|
133 |
+
gpt = GPT(config)
|
134 |
+
save_file(gpt, "storyGPT.safetensors")
|