Chaitanya Sagar Gurujula
fixed generate method
0e19c73
# Solving for residual std scaling issue
import os
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
from torchtune.modules import RotaryPositionalEmbeddings
import logging
from transformers import AutoTokenizer
from datasets import load_dataset
from torch.utils.checkpoint import checkpoint
from torch.utils.data import DataLoader
class LlamaMLP(nn.Module):
def __init__(self, config):
super().__init__()
hidden_dim = 1536 # Expand dimension to 1536
self.gate_proj = nn.Linear(config.n_embd, hidden_dim, bias=False)
self.up_proj = nn.Linear(config.n_embd, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, config.n_embd, bias=False)
self.act_fn = nn.SiLU() # Activation function
self.down_proj.NANOGPT_SCALE_INIT = 1
def forward(self, x):
gate = self.gate_proj(x) # Gate projection
up = self.up_proj(x) # Up projection
return self.down_proj(self.act_fn(gate) * up) # Apply activation and down-project
class LlamaDecoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.self_attn = CausalSelfAttention(config) # Self-attention block
self.input_layernorm = nn.RMSNorm(config.n_embd, eps=1e-5) # RMSNorm for inputs
self.post_attention_layernorm = nn.RMSNorm(config.n_embd, eps=1e-5) # RMSNorm post-attention
self.mlp = LlamaMLP(config) # Llama-style MLP
def forward(self, x, attention_mask):
# Use checkpointing for memory-intensive layers
return checkpoint(self._forward_impl, x, attention_mask, use_reentrant=False)
# return checkpoint.checkpoint(self._forward_impl, x, attention_mask, use_reentrant=False)
def _forward_impl(self, x, attention_mask):
# Apply self-attention with normalization
residual = x
x = self.input_layernorm(x)
x = self.self_attn(x, attention_mask) + residual
# Apply MLP with post-attention normalization
residual = x
x = self.post_attention_layernorm(x)
x = self.mlp(x) + residual
return x
@dataclass
class GPTConfig:
block_size: int = 2048 # max sequence length
vocab_size: int = 49152 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token
n_layer: int = 30 # number of layers
n_head: int = 9 # number of heads
n_embd: int = 576 # embedding dimension
num_key_value_heads: int = 3
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
assert config.n_embd % config.num_key_value_heads == 0
# Query projection for all heads
self.cq_attn = nn.Linear(config.n_embd, config.n_embd, bias=False) # For queries
# Key-Value projection for grouped heads
self.ckv_attn = nn.Linear(config.n_embd, 2 * (config.n_embd // config.num_key_value_heads), bias=False) # For keys and values
# Output projection
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.n_head = config.n_head
self.num_key_value_heads = config.num_key_value_heads
self.head_dim = config.n_embd // config.n_head
# Rotary Positional Embedding
self.rope = RotaryPositionalEmbeddings(dim=self.head_dim, max_seq_len=config.block_size)
# Bias for causal mask
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))
def forward(self, x, attention_mask=None):
B, T, C = x.size() # Batch size, sequence length, embedding dimension (n_embd)
# Compute queries
q = self.cq_attn(x) # (B, T, C)
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
# Compute keys and values (shared across grouped heads)
kv = self.ckv_attn(x) # (B, T, 2 * (C / num_key_value_heads))
kv_dim = C // self.num_key_value_heads
k, v = kv.split(kv_dim, dim=2) # Split into keys and values
k = k.view(B, T, self.num_key_value_heads, kv_dim // self.num_key_value_heads).transpose(1, 2) # (B, kvh, T, hs)
v = v.view(B, T, self.num_key_value_heads, kv_dim // self.num_key_value_heads).transpose(1, 2) # (B, kvh, T, hs)
k = torch.repeat_interleave(k, repeats=self.n_head // self.num_key_value_heads, dim=1)
v = torch.repeat_interleave(v, repeats=self.n_head // self.num_key_value_heads, dim=1)
# Apply RoPE to queries and keys
q = self.rope(q)
k = self.rope(k)
# Handle attention mask
if attention_mask is not None:
# Expand attention_mask to (B, 1, 1, T)
attention_mask = attention_mask[:, None, None, :].to(dtype=torch.bool)
# Create causal mask (lower triangular) and convert to bool
causal_mask = torch.tril(torch.ones(T, T, device=x.device, dtype=torch.bool)).view(1, 1, T, T)
# Combine causal mask and padding mask
attention_mask = causal_mask & attention_mask # βœ… Now both are torch.bool
#print(f"q.shape: {q.shape}, k.shape: {k.shape}, v.shape: {v.shape}, attention_mask.shape: {attention_mask.shape}")
# Replace with Flash Attention (memory efficient)
y = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attention_mask, # Combines padding mask
#is_causal=True, # Auto-applies causal mask
dropout_p=0.0
)
# Reshape and combine heads
y = y.transpose(1, 2).contiguous().view(B, T, C) # (B, T, C)
# Output projection
y = self.c_proj(y)
return y
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
# Embeddings
self.token_embedding = nn.Embedding(config.vocab_size, config.n_embd)
# Transformer layers
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.n_layer)])
self.final_norm = nn.RMSNorm(config.n_embd, eps=1e-5)
# Output head
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Share weights between input embedding and output head
self.token_embedding.weight = self.lm_head.weight
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module):
std = 0.041666666666666664
if isinstance(module, nn.Linear):
if hasattr(module, 'NANGPT_SCALE_INIT'):
std *= (2 * self.config.n_layer) ** -0.5
torch.nn.init.normal_(module.weight, mean = 0.0, std = std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std = std)
def forward(self, idx, attention_mask=None):
B, T = idx.size()
assert T <= self.config.block_size, f"Sequence length {T} exceeds block size {self.config.block_size}"
# Token and positional embeddings
token_embeddings = self.token_embedding(idx)
# Combine embeddings
x = token_embeddings
# Pass through transformer layers
for layer in self.layers:
x = layer(x, attention_mask)
# Final layer normalization
x = self.final_norm(x)
# Compute logits
logits = self.lm_head(x)
return logits
def generate(self, input_ids, max_length=50,eos_token_id=None):
generated_tokens = []
current_ids = input_ids
# πŸ”₯ Infer device from input_ids
device = input_ids.device
for _ in range(max_length):
# Forward pass to get logits
logits = self.forward(current_ids) # Shape: (batch_size, seq_len, vocab_size)
# πŸ”₯ Only take the last token's logits
logits = logits[:, -1, :] # Shape: (batch_size, vocab_size)
next_token =logits.argmax(dim=-1).cpu().item()
# Store token (avoid GPU-CPU issues)
generated_tokens.append(next_token)
# Append token to input
current_ids = torch.cat([current_ids, torch.tensor([[next_token]]).to(device)], dim=1)
# Stop if EOS token is generated
if eos_token_id is not None and next_token == eos_token_id:
break
return generated_tokens
# Configuration Class
class OptimizerConfig:
accumulate_grad_in_fp32 = True
clip_grad = 1.0
learning_rate = 0.003
lr_decay_starting_step = 1600000
lr_decay_steps = 400000
lr_decay_style = "linear"
lr_warmup_steps = 2000
lr_warmup_style = "linear"
min_decay_lr = 0.0
adam_beta1 = 0.9
adam_beta2 = 0.95
adam_eps = 1.0e-08
weight_decay = 0.01
zero_stage = 0
name = "adamW"
torch_adam_is_fused = True
if __name__ == "__main__":
logging.basicConfig(filename='/kaggle/working/training_log.txt', level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s', force=True)
# Device setup
device = 'cpu'
if torch.cuda.is_available():
device = 'cuda'
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
print(f"Using device: {device}")
torch.set_float32_matmul_precision('high')
# Seed setup
torch.manual_seed(1337)
if torch.cuda.is_available():
torch.cuda.manual_seed(1337)
# Model initialization
model = GPT(GPTConfig())
model.to(device)
#model = torch.compile(model)
# Load checkpoint if exists
best_model_path = '/kaggle/working/best_model.pth'
checkpoint_model_path = '/kaggle/working/checkpoint_model.pth'
start_epoch = 0
start_step = 0
best_loss = float('inf')
if os.path.exists(checkpoint_model_path):
model_checkpoint = torch.load(checkpoint_model_path, map_location=device, weights_only=True)
model.load_state_dict(model_checkpoint['model_state_dict'])
start_epoch = model_checkpoint['epoch']
start_step = model_checkpoint['step']+1
best_loss = model_checkpoint['loss']
logging.info(f"Resuming from epoch {start_epoch}, step {start_step}, best loss {best_loss:.6f}")
# Model parameter count
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logging.info(f"Total Parameters: {total_params:,}")
logging.info(f"Trainable Parameters: {trainable_params:,}")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
tokenizer.pad_token = tokenizer.eos_token
# Load streaming dataset
dataset = load_dataset(
"HuggingFaceTB/smollm-corpus",
"cosmopedia-v2",
streaming=True
)['train'] # Access only the "train" split
# Define the encode function
def encode(examples):
# Tokenize the text
return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=2048,return_tensors=None)
# Stream mapping
dataset = dataset.map(encode, batched=True,remove_columns=dataset.column_names)
def collate_fn(batch):
input_ids = torch.tensor([example['input_ids'] for example in batch], dtype=torch.long)
attention_mask = torch.tensor([example['attention_mask'] for example in batch], dtype=torch.long)
return {"input_ids": input_ids, "attention_mask": attention_mask}
from torch.utils.data import DataLoader, IterableDataset
train_loader = DataLoader(dataset, batch_size=8, collate_fn=collate_fn)
# Optimizer setup
optimizer_config = OptimizerConfig()
optimizer = torch.optim.AdamW(
model.parameters(),
betas=(optimizer_config.adam_beta1, optimizer_config.adam_beta2),
eps=optimizer_config.adam_eps,
weight_decay=optimizer_config.weight_decay
)
# Training loop
target_loss = 0.099999
max_iterations = 6000
optimizer.zero_grad()
scaler = torch.GradScaler() # βœ… Use AMP GradScaler
autocast_device = "cuda" if "cuda" in device else "cpu" # βœ… Ensure valid autocast device
if os.path.exists(checkpoint_model_path):
optimizer.load_state_dict(model_checkpoint['optimizer_state_dict'])
scaler.load_state_dict(model_checkpoint['scaler_state_dict'])
sample_text = "Once upon a time" # Text for tracking improvements
sample_tokens = tokenizer(sample_text, return_tensors='pt').input_ids.to(device)
#sample_tokens = torch.tensor(sample_tokens).unsqueeze(0) # Add batch dimension
for epoch in range(start_epoch, 100):
for i, batch in enumerate(train_loader, start=start_step):
x = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
# PROPER TARGET SETUP
y = torch.cat([x.clone()[:, 1:], torch.full((x.size(0), 1), tokenizer.eos_token_id, device=device)], dim=1)
with torch.autocast(device_type=device, dtype=torch.bfloat16):
logits = model(x, attention_mask=attention_mask)
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
y.view(-1),
ignore_index=tokenizer.eos_token_id # Exclude padding
)
scaler.scale(loss).backward() # βœ… Apply scaled gradient
# Gradient accumulation (effective batch size = 4)
if (i+1) % 16 == 0: # βœ… Ensure last batch updates
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
# Save best model
if loss.item() < best_loss:
best_loss = loss.item()
torch.save({
'epoch': epoch,
'step': i,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scaler_state_dict': scaler.state_dict(),
'loss': best_loss,
}, best_model_path)
logging.info(f"Epoch {epoch}, Step {i}, Loss: {loss.item():.6f}, Best Loss: {best_loss:.6f}")
# Perform prediction every 500 steps
if (i + 1) % 500 == 0:
model.eval()
with torch.no_grad():
generated_tokens = model.generate(sample_tokens, max_length=50,eos_token_id = tokenizer.eos_token_id)
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
logging.info(f"Step {i + 1} Prompt: {sample_text} \n Generated Token: {generated_tokens} \n Prediction: {generated_text}")
model.train()
if loss.item() <= target_loss:
logging.info(f"Target loss reached at step {i}. Training completed!")
break
if i >= max_iterations:
torch.save({
'epoch': epoch,
'step': i,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scaler_state_dict': scaler.state_dict(),
'loss': best_loss,
}, checkpoint_model_path)
logging.info("Max iterations reached. Training stopped.")
break
else:
continue
break
logging.info("Training completed!")
logging.info(f"Final Loss: {loss.item():.6f}")
logging.info(f"Best Loss Achieved: {best_loss:.6f}")
logging.info(f"Best Model Saved To: {best_model_path}")
logging.info(f"Checpoint Model Saved To: {checkpoint_model_path}")