Sartc commited on
Commit
64d28a7
·
verified ·
1 Parent(s): 75ae50f

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +134 -0
model.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import lightning
5
+ from safetensors.torch import save_file
6
+
7
+ class Config:
8
+ vocab_size = 50304
9
+ n_epochs = 50
10
+ batch_size = 36
11
+ lr = 3e-4
12
+ wd = 1e-6
13
+ n_embed = 256
14
+ num_blocks = 12
15
+ num_heads = 12
16
+ head_size = n_embed//num_heads
17
+ context_len = 224
18
+ attn_dropout_val = 0.2
19
+ mha_dropout_val = 0.2
20
+ ffn_dropout_val = 0.2
21
+
22
+ class CausalAttentionHead(nn.Module):
23
+ def __init__(self, config):
24
+ super(CausalAttentionHead, self).__init__()
25
+ self.config = config
26
+
27
+ self.query = nn.Linear(config.n_embed, config.head_size, bias=False)
28
+ self.key = nn.Linear(config.n_embed, config.head_size, bias=False)
29
+ self.value = nn.Linear(config.n_embed, config.head_size, bias=False)
30
+ self.attn_drop = nn.Dropout(config.attn_dropout_val)
31
+ # mask for causal attention during training
32
+ self.register_buffer("mask", torch.tril(torch.ones(config.context_len, config.context_len)))
33
+
34
+ def forward(self, x):
35
+ bs, context_len, embed_dim = x.shape
36
+ q, k, v = self.query(x), self.key(x), self.value(x)
37
+ attn_filter = torch.divide(torch.bmm(q, k.transpose(1, 2)), self.config.head_size)
38
+ attn_filter = attn_filter.masked_fill(self.mask[:context_len, :context_len]==0, float("-inf"))
39
+ attn_weights = F.softmax(attn_filter, dim=-1)
40
+ attn_weights = self.attn_drop(attn_weights)
41
+ output = torch.bmm(attn_weights, v)
42
+ return output
43
+
44
+ class MultiHeadedAttention(nn.Module):
45
+ def __init__(self, config):
46
+ super(MultiHeadedAttention, self).__init__()
47
+ self.config = config
48
+ self.heads = nn.ModuleList(
49
+ [CausalAttentionHead(config) for _ in range(config.num_heads)]
50
+ )
51
+ self.proj = nn.Linear(config.num_heads*config.head_size, config.n_embed)
52
+ self.mha_drop = nn.Dropout(config.mha_dropout_val)
53
+
54
+ def forward(self, x):
55
+ mha_output = torch.cat([head(x) for head in self.heads], dim=-1)
56
+ return self.mha_drop(self.proj(mha_output))
57
+
58
+ class FeedForwardNetwork(nn.Module):
59
+ def __init__(self, config):
60
+ super(FeedForwardNetwork, self).__init__()
61
+
62
+ self.ffn = nn.Sequential(
63
+ nn.Linear(config.n_embed, config.n_embed*4),
64
+ nn.GELU(),
65
+ nn.Linear(config.n_embed*4, config.n_embed),
66
+ nn.Dropout()
67
+ )
68
+ def forward(self, x):
69
+ return self.ffn(x)
70
+
71
+ class Block(nn.Module):
72
+ def __init__(self, config):
73
+ super(Block, self).__init__()
74
+ self.mha = MultiHeadedAttention(config)
75
+ self.ln1 = nn.LayerNorm(config.n_embed)
76
+ self.ffn = FeedForwardNetwork(config)
77
+ self.ln2 = nn.LayerNorm(config.n_embed)
78
+
79
+ def forward(self, x):
80
+ x = self.ln1(x+self.mha(x))
81
+ x = self.ln2(x+self.ffn(x))
82
+ return x
83
+
84
+ class GPT(lightning.LightningModule):
85
+ def __init__(self, config):
86
+ super(GPT, self).__init__()
87
+ self.config = config
88
+ self.save_hyperparameters()
89
+ self.token_embedding = nn.Embedding(config.vocab_size, config.n_embed)
90
+ self.positional_embedding = nn.Embedding(config.context_len, config.n_embed)
91
+ self.backbone = nn.Sequential(*[Block(config) for _ in range(config.num_blocks)])
92
+ self.lm_head = nn.Linear(config.n_embed, config.vocab_size)
93
+
94
+ def forward(self, x):
95
+ tok_emb = self.token_embedding(x)
96
+ pos_emb = self.positional_embedding(torch.arange(x.shape[1], device=self.device))
97
+ x = tok_emb+pos_emb
98
+ x = self.backbone(x)
99
+ logits = self.lm_head(x)
100
+ return logits
101
+
102
+ def get_loss(self, predictions, target):
103
+ B, C, V = predictions.shape
104
+ predictions = predictions.view(B*C, V)
105
+ target = target.view(B*C)
106
+ loss = F.cross_entropy(predictions, target)
107
+ return loss
108
+
109
+ def training_step(self, batch, batch_idx):
110
+ text, target = batch
111
+ text = text.long()
112
+ target = target.long()
113
+ logits = self(text)
114
+ loss = self.get_loss(logits, target)
115
+
116
+ self.log('loss', loss.item(), prog_bar=True)
117
+ logs = {'loss': loss}
118
+
119
+ return {"log": logs, "loss": loss}
120
+
121
+ def training_end(self, outputs):
122
+ avg_loss = torch.stack([x['log']['loss'] for x in outputs]).mean()
123
+ logs = {"log": avg_loss}
124
+ print(f"val_loss: {avg_loss}")
125
+ return {"log": logs}
126
+
127
+ def configure_optimizers(self):
128
+ opt = torch.optim.AdamW(self.parameters(), lr=self.config.lr, weight_decay=self.config.wd)
129
+ return [opt], []
130
+
131
+ if __name__ == "__main__":
132
+ config = Config()
133
+ gpt = GPT(config)
134
+ save_file(gpt, "storyGPT.safetensors")