|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import math |
|
import numpy as np |
|
|
|
class InputEmbeddings(nn.Module): |
|
def __init__(self, d_model: int, vocab_size: int): |
|
super().__init__() |
|
self.d_model = d_model |
|
self.vocab_size = vocab_size |
|
self.embedding = nn.Embedding(vocab_size, d_model) |
|
|
|
def forward(self, x): |
|
return self.embedding(x) * math.sqrt(self.d_model) |
|
|
|
|
|
class PositionalEncoding(nn.Module): |
|
def __init__(self, d_model: int, seq_len: int, dropout): |
|
super().__init__() |
|
self.d_model = d_model |
|
self.seq_len = seq_len |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
pe = torch.zeros(seq_len, d_model) |
|
|
|
|
|
|
|
pos = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) |
|
div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0)/d_model)) |
|
|
|
|
|
pe[:, 0::2] = torch.sin(pos * div) |
|
|
|
pe[:, 1::2] = torch.cos(pos * div) |
|
|
|
|
|
pe = pe.unsqueeze(0) |
|
|
|
self.register_buffer('pe', pe) |
|
|
|
def forward(self, x): |
|
x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False) |
|
return self.dropout(x) |
|
|
|
|
|
class LayerNormalization(nn.Module): |
|
def __init__(self, eps=10**-6): |
|
super().__init__() |
|
self.eps = eps |
|
self.alpha = nn.Parameter(torch.ones(1)) |
|
self.bias = nn.Parameter(torch.zeros(1)) |
|
|
|
def forward(self, x): |
|
mean = x.mean(dim=-1, keepdim=True) |
|
std = x.std(dim=-1, keepdim=True) |
|
return self.alpha * ((x - mean) / (std + self.eps)) + self.bias |
|
|
|
|
|
class FeedForwardBlock(nn.Module): |
|
def __init__(self, d_model: int, d_ff: int, dropout): |
|
super().__init__() |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
self.linear_1 = nn.Linear(d_model, d_ff) |
|
self.linear_2 = nn.Linear(d_ff, d_model) |
|
|
|
def forward(self, x): |
|
|
|
return self.linear_2(self.dropout(torch.relu(self.linear_1(x)))) |
|
|
|
|
|
class MultiHeadAttentionBlock(nn.Module): |
|
def __init__(self, d_model: int, h: int, dropout: float): |
|
super().__init__() |
|
self.d_model = d_model |
|
self.h = h |
|
|
|
assert d_model % h == 0, "d_model must divisible by h" |
|
self.d_k = d_model//h |
|
|
|
self.w_q = nn.Linear(d_model, d_model) |
|
self.w_k = nn.Linear(d_model, d_model) |
|
self.w_v = nn.Linear(d_model, d_model) |
|
self.w_o = nn.Linear(d_model, d_model) |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
@staticmethod |
|
def attention(query,key,value,dropout:nn.Dropout): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attn_output = F.scaled_dot_product_attention( |
|
query, key, value, |
|
attn_mask=None, |
|
dropout_p=dropout.p if dropout is not None else 0.0, |
|
is_causal=True |
|
) |
|
|
|
return attn_output, None |
|
|
|
|
|
def forward(self, q, k, v): |
|
query = self.w_q(q) |
|
key = self.w_k(k) |
|
value = self.w_v(v) |
|
|
|
|
|
query = query.view(query.shape[0], query.shape[1], self.h , self.d_k).transpose(1,2) |
|
key = key.view(key.shape[0], key.shape[1], self.h , self.d_k).transpose(1,2) |
|
value = value.view(value.shape[0], value.shape[1], self.h , self.d_k).transpose(1,2) |
|
|
|
x,attention_scores = MultiHeadAttentionBlock.attention(query,key,value,self.dropout) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = x.transpose(1,2).contiguous().view(x.shape[0], -1, self.h*self.d_k) |
|
|
|
|
|
return self.w_o(x) |
|
|
|
class ResidualConnection(nn.Module): |
|
def __init__(self, dropout:float): |
|
super().__init__() |
|
self.dropout = nn.Dropout(dropout) |
|
self.norm = LayerNormalization() |
|
|
|
def forward(self, x, sublayer): |
|
return x + self.dropout(sublayer(self.norm(x))) |
|
|
|
class DecoderBlock(nn.Module): |
|
def __init__(self, self_attention_block:MultiHeadAttentionBlock,feed_forward_block:FeedForwardBlock, dropout:float ): |
|
super().__init__() |
|
self.self_attention_block = self_attention_block |
|
self.feed_forward_block = feed_forward_block |
|
self.residual_connections = nn.ModuleList((ResidualConnection(dropout) for _ in range(2))) |
|
|
|
def forward(self, x): |
|
x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x)) |
|
x = self.residual_connections[1](x, self.feed_forward_block) |
|
return x |
|
|
|
class Decoder(nn.Module): |
|
def __init__(self, layers:nn.ModuleList): |
|
super().__init__() |
|
self.layers = layers |
|
self.norm = LayerNormalization() |
|
|
|
def forward(self, x): |
|
for layer in self.layers: |
|
x = layer(x) |
|
return self.norm(x) |
|
|
|
class ProjectionLayer(nn.Module): |
|
def __init__(self, d_model:int, vocab_size :int ): |
|
super().__init__() |
|
self.proj = nn.Linear(d_model,vocab_size) |
|
|
|
def forward(self, x): |
|
|
|
return torch.log_softmax(self.proj(x),dim=-1) |
|
|
|
class GPT(nn.Module): |
|
def __init__(self,decoder:Decoder, tgt_embed:InputEmbeddings, tgt_pos:PositionalEncoding, projection_layer:ProjectionLayer): |
|
super().__init__() |
|
self.decoder = decoder |
|
self.tgt_embed = tgt_embed |
|
self.tgt_pos = tgt_pos |
|
self.projection_layer = projection_layer |
|
|
|
def decode(self, tgt): |
|
tgt = self.tgt_embed(tgt) |
|
tgt = self.tgt_pos(tgt) |
|
return self.decoder(tgt) |
|
|
|
def project(self, x): |
|
return self.projection_layer(x) |
|
|
|
def build_gpt(vocab_size:int, seq_len:int, d_model:int = 512, N:int = 6, h:int = 8,d_ff:int = 2048, dropout:float = 0.3) -> GPT: |
|
|
|
|
|
tgt_embed = InputEmbeddings(d_model,vocab_size) |
|
|
|
|
|
tgt_pos = PositionalEncoding(d_model, seq_len, dropout) |
|
|
|
decoder_blocks = [] |
|
for _ in range(N): |
|
decoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout) |
|
decoder_feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout) |
|
decoder_block = DecoderBlock(decoder_self_attention_block, decoder_feed_forward_block, dropout) |
|
decoder_blocks.append(decoder_block) |
|
|
|
decoder = Decoder(nn.ModuleList(decoder_blocks)) |
|
|
|
|
|
projection_layer = ProjectionLayer(d_model,vocab_size) |
|
|
|
|
|
gpt = GPT(decoder,tgt_embed,tgt_pos,projection_layer) |
|
|
|
for p in gpt.parameters(): |
|
if p.dim() > 1: |
|
nn.init.xavier_uniform_(p) |
|
|
|
return gpt |
|
def colorize(value): |
|
"""Return colored string based on value between 0 and 1.""" |
|
|
|
brightness = int(232 + value * 23) |
|
return f"\033[48;5;{brightness}m {value:.2f} \033[0m" |
|
|
|
def print_attention(attention_scores): |
|
batch_size, num_heads, seq_len, _ = attention_scores.shape |
|
|
|
for head in range(num_heads): |
|
print(f"\nAttention Head {head + 1}:\n" + "-" * 30) |
|
|
|
attn = attention_scores[0, head] |
|
attn = (attn - attn.min()) / (attn.max() - attn.min() + 1e-8) |
|
|
|
for i in range(seq_len): |
|
for j in range(seq_len): |
|
print(colorize(attn[i, j].item()), end=" ") |
|
print() |
|
|
|
if __name__ == "__main__": |
|
gpt = build_gpt(10000, 350) |
|
print(gpt) |