10M-LLM / model.py
abancp's picture
openwebbook2
82f9e44
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
class InputEmbeddings(nn.Module):
def __init__(self, d_model: int, vocab_size: int):
super().__init__()
self.d_model = d_model
self.vocab_size = vocab_size
self.embedding = nn.Embedding(vocab_size, d_model)
def forward(self, x):
return self.embedding(x) * math.sqrt(self.d_model)
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, seq_len: int, dropout):
super().__init__()
self.d_model = d_model
self.seq_len = seq_len
self.dropout = nn.Dropout(dropout)
# create matrix pe of (seq_length , d_model)
pe = torch.zeros(seq_len, d_model)
# create a vector of shape (seq_length ,1 )
# formula = pos / 10000 ** (2i / d_model)
pos = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0)/d_model))
# apply sin for evens
pe[:, 0::2] = torch.sin(pos * div)
# apply cos for odds
pe[:, 1::2] = torch.cos(pos * div)
# changing pe (seq_len, d_model) -> (1, seq_len, d_model)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False)
return self.dropout(x)
class LayerNormalization(nn.Module):
def __init__(self, eps=10**-6):
super().__init__()
self.eps = eps
self.alpha = nn.Parameter(torch.ones(1)) # multiplier
self.bias = nn.Parameter(torch.zeros(1))
def forward(self, x):
mean = x.mean(dim=-1, keepdim=True)
std = x.std(dim=-1, keepdim=True)
return self.alpha * ((x - mean) / (std + self.eps)) + self.bias
class FeedForwardBlock(nn.Module):
def __init__(self, d_model: int, d_ff: int, dropout):
super().__init__()
self.dropout = nn.Dropout(dropout)
# fully connected NN -> input:d_model(512) inner:d_ff(2048) output:d_model(512)
self.linear_1 = nn.Linear(d_model, d_ff) # w1 & b1
self.linear_2 = nn.Linear(d_ff, d_model) # w2 & b2
def forward(self, x):
# (Batch,Seq_len,d_model) --> (Batch,Seq_len,d_ff) --> (Batch,Seq__len,d_model)
return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))
class MultiHeadAttentionBlock(nn.Module):
def __init__(self, d_model: int, h: int, dropout: float):
super().__init__()
self.d_model = d_model
self.h = h
# d_model must divisible by number of attention heads
assert d_model % h == 0, "d_model must divisible by h"
self.d_k = d_model//h
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
@staticmethod
def attention(query,key,value,dropout:nn.Dropout):
# d_k = query.shape[-1]
# #(Batch, h, Seq_len, d_k) --> (Batch, h, Seq_len, Seq_len)
# attention_scores = (query @ key.transpose(-2,-1)) / math.sqrt(d_k)
# if mask is not None:
# attention_scores.masked_fill_(mask==0,-1e9)
# attention_scores = attention_scores.softmax(dim = -1) # (Batch, h, Seq_len, Seq_len)
# if dropout is not None:
# attention_scores = dropout(attention_scores)
# return (attention_scores @ value) , attention_scores
attn_output = F.scaled_dot_product_attention(
query, key, value,
attn_mask=None,
dropout_p=dropout.p if dropout is not None else 0.0,
is_causal=True # If this is for decoder/causal models, set True
)
return attn_output, None
def forward(self, q, k, v):
query = self.w_q(q) # (Batch, Seq_len, d_model) --> (Batch, Seq_len, d_model)
key = self.w_k(k) # (Batch, Seq_len, d_model) --> (Batch, Seq_len, d_model)
value = self.w_v(v) # (Batch, Seq_len, d_model) --> (Batch, Seq_len, d_model)
# (Batch, Seq_len, d_model) --> (Batch, Seq_len, d_model, h, d_k) --> (Batch, h, Seq_len, d_k)
query = query.view(query.shape[0], query.shape[1], self.h , self.d_k).transpose(1,2)
key = key.view(key.shape[0], key.shape[1], self.h , self.d_k).transpose(1,2)
value = value.view(value.shape[0], value.shape[1], self.h , self.d_k).transpose(1,2)
x,attention_scores = MultiHeadAttentionBlock.attention(query,key,value,self.dropout)
# print("Attention scores : ",attention_scores)
# print_attention(attention_scores)
# print("Attention shape : ",attention_scores.shape)
# (Batch, h, Seq_len, d_k) --> (Batch, Seq_len, h, Seq_len) --> (Batch, Seq_len, d_model)
x = x.transpose(1,2).contiguous().view(x.shape[0], -1, self.h*self.d_k)
# (Batch, Seq_len, d_model) --> (Batch, Seq_len, d_model)
return self.w_o(x)
class ResidualConnection(nn.Module):
def __init__(self, dropout:float):
super().__init__()
self.dropout = nn.Dropout(dropout)
self.norm = LayerNormalization()
def forward(self, x, sublayer):
return x + self.dropout(sublayer(self.norm(x)))
class DecoderBlock(nn.Module):
def __init__(self, self_attention_block:MultiHeadAttentionBlock,feed_forward_block:FeedForwardBlock, dropout:float ):
super().__init__()
self.self_attention_block = self_attention_block
self.feed_forward_block = feed_forward_block
self.residual_connections = nn.ModuleList((ResidualConnection(dropout) for _ in range(2)))
def forward(self, x):
x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x))
x = self.residual_connections[1](x, self.feed_forward_block)
return x
class Decoder(nn.Module):
def __init__(self, layers:nn.ModuleList):
super().__init__()
self.layers = layers
self.norm = LayerNormalization()
def forward(self, x):
for layer in self.layers:
x = layer(x)
return self.norm(x)
class ProjectionLayer(nn.Module):
def __init__(self, d_model:int, vocab_size :int ):
super().__init__()
self.proj = nn.Linear(d_model,vocab_size)
def forward(self, x):
#(Batch, Seq_len, d_model) --> (Batch, Seq_len, vocab_size)
return torch.log_softmax(self.proj(x),dim=-1)
class GPT(nn.Module):
def __init__(self,decoder:Decoder, tgt_embed:InputEmbeddings, tgt_pos:PositionalEncoding, projection_layer:ProjectionLayer):
super().__init__()
self.decoder = decoder
self.tgt_embed = tgt_embed
self.tgt_pos = tgt_pos
self.projection_layer = projection_layer
def decode(self, tgt):
tgt = self.tgt_embed(tgt)
tgt = self.tgt_pos(tgt)
return self.decoder(tgt)
def project(self, x):
return self.projection_layer(x)
def build_gpt(vocab_size:int, seq_len:int, d_model:int = 512, N:int = 6, h:int = 8,d_ff:int = 2048, dropout:float = 0.3) -> GPT:
#create embedding layer
tgt_embed = InputEmbeddings(d_model,vocab_size)
#create positional encoding layer
tgt_pos = PositionalEncoding(d_model, seq_len, dropout)
decoder_blocks = []
for _ in range(N):
decoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
decoder_feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
decoder_block = DecoderBlock(decoder_self_attention_block, decoder_feed_forward_block, dropout)
decoder_blocks.append(decoder_block)
decoder = Decoder(nn.ModuleList(decoder_blocks))
#create the projection layer
projection_layer = ProjectionLayer(d_model,vocab_size)
#create the transformer
gpt = GPT(decoder,tgt_embed,tgt_pos,projection_layer)
for p in gpt.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
return gpt
def colorize(value):
"""Return colored string based on value between 0 and 1."""
# Map to 0-255 for brightness
brightness = int(232 + value * 23) # 232 to 255 grayscale in ANSI
return f"\033[48;5;{brightness}m {value:.2f} \033[0m"
def print_attention(attention_scores):
batch_size, num_heads, seq_len, _ = attention_scores.shape
for head in range(num_heads):
print(f"\nAttention Head {head + 1}:\n" + "-" * 30)
attn = attention_scores[0, head] # Take batch 0
attn = (attn - attn.min()) / (attn.max() - attn.min() + 1e-8) # Normalize
for i in range(seq_len):
for j in range(seq_len):
print(colorize(attn[i, j].item()), end=" ")
print() # Newline after each row
if __name__ == "__main__":
gpt = build_gpt(10000, 350)
print(gpt)