|
|
|
import os |
|
from dataclasses import dataclass |
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
from torchtune.modules import RotaryPositionalEmbeddings |
|
import logging |
|
from transformers import AutoTokenizer |
|
from datasets import load_dataset |
|
from torch.utils.checkpoint import checkpoint |
|
from torch.utils.data import DataLoader |
|
|
|
|
|
class LlamaMLP(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
hidden_dim = 1536 |
|
self.gate_proj = nn.Linear(config.n_embd, hidden_dim, bias=False) |
|
self.up_proj = nn.Linear(config.n_embd, hidden_dim, bias=False) |
|
self.down_proj = nn.Linear(hidden_dim, config.n_embd, bias=False) |
|
self.act_fn = nn.SiLU() |
|
self.down_proj.NANOGPT_SCALE_INIT = 1 |
|
|
|
def forward(self, x): |
|
gate = self.gate_proj(x) |
|
up = self.up_proj(x) |
|
return self.down_proj(self.act_fn(gate) * up) |
|
|
|
|
|
class LlamaDecoderLayer(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.self_attn = CausalSelfAttention(config) |
|
self.input_layernorm = nn.RMSNorm(config.n_embd, eps=1e-5) |
|
self.post_attention_layernorm = nn.RMSNorm(config.n_embd, eps=1e-5) |
|
self.mlp = LlamaMLP(config) |
|
|
|
def forward(self, x, attention_mask): |
|
|
|
return checkpoint(self._forward_impl, x, attention_mask, use_reentrant=False) |
|
|
|
|
|
def _forward_impl(self, x, attention_mask): |
|
|
|
residual = x |
|
x = self.input_layernorm(x) |
|
x = self.self_attn(x, attention_mask) + residual |
|
|
|
|
|
residual = x |
|
x = self.post_attention_layernorm(x) |
|
x = self.mlp(x) + residual |
|
return x |
|
|
|
|
|
@dataclass |
|
class GPTConfig: |
|
block_size: int = 2048 |
|
vocab_size: int = 49152 |
|
n_layer: int = 30 |
|
n_head: int = 9 |
|
n_embd: int = 576 |
|
num_key_value_heads: int = 3 |
|
|
|
|
|
class CausalSelfAttention(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
assert config.n_embd % config.n_head == 0 |
|
assert config.n_embd % config.num_key_value_heads == 0 |
|
|
|
|
|
self.cq_attn = nn.Linear(config.n_embd, config.n_embd, bias=False) |
|
|
|
self.ckv_attn = nn.Linear(config.n_embd, 2 * (config.n_embd // config.num_key_value_heads), bias=False) |
|
|
|
|
|
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) |
|
self.n_head = config.n_head |
|
self.num_key_value_heads = config.num_key_value_heads |
|
self.head_dim = config.n_embd // config.n_head |
|
|
|
|
|
self.rope = RotaryPositionalEmbeddings(dim=self.head_dim, max_seq_len=config.block_size) |
|
|
|
|
|
|
|
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) |
|
.view(1, 1, config.block_size, config.block_size)) |
|
|
|
def forward(self, x, attention_mask=None): |
|
B, T, C = x.size() |
|
|
|
|
|
q = self.cq_attn(x) |
|
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) |
|
|
|
|
|
kv = self.ckv_attn(x) |
|
kv_dim = C // self.num_key_value_heads |
|
k, v = kv.split(kv_dim, dim=2) |
|
k = k.view(B, T, self.num_key_value_heads, kv_dim // self.num_key_value_heads).transpose(1, 2) |
|
v = v.view(B, T, self.num_key_value_heads, kv_dim // self.num_key_value_heads).transpose(1, 2) |
|
|
|
k = torch.repeat_interleave(k, repeats=self.n_head // self.num_key_value_heads, dim=1) |
|
v = torch.repeat_interleave(v, repeats=self.n_head // self.num_key_value_heads, dim=1) |
|
|
|
|
|
q = self.rope(q) |
|
k = self.rope(k) |
|
|
|
|
|
if attention_mask is not None: |
|
|
|
attention_mask = attention_mask[:, None, None, :].to(dtype=torch.bool) |
|
|
|
|
|
causal_mask = torch.tril(torch.ones(T, T, device=x.device, dtype=torch.bool)).view(1, 1, T, T) |
|
|
|
|
|
attention_mask = causal_mask & attention_mask |
|
|
|
|
|
|
|
|
|
y = F.scaled_dot_product_attention( |
|
q, k, v, |
|
attn_mask=attention_mask, |
|
|
|
dropout_p=0.0 |
|
) |
|
|
|
|
|
y = y.transpose(1, 2).contiguous().view(B, T, C) |
|
|
|
|
|
y = self.c_proj(y) |
|
return y |
|
|
|
|
|
class GPT(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.config = config |
|
|
|
|
|
self.token_embedding = nn.Embedding(config.vocab_size, config.n_embd) |
|
|
|
|
|
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.n_layer)]) |
|
self.final_norm = nn.RMSNorm(config.n_embd, eps=1e-5) |
|
|
|
|
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
|
|
|
|
|
self.token_embedding.weight = self.lm_head.weight |
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
def _init_weights(self, module): |
|
std = 0.041666666666666664 |
|
if isinstance(module, nn.Linear): |
|
if hasattr(module, 'NANGPT_SCALE_INIT'): |
|
std *= (2 * self.config.n_layer) ** -0.5 |
|
torch.nn.init.normal_(module.weight, mean = 0.0, std = std) |
|
if module.bias is not None: |
|
torch.nn.init.zeros_(module.bias) |
|
elif isinstance(module, nn.Embedding): |
|
torch.nn.init.normal_(module.weight, mean=0.0, std = std) |
|
|
|
def forward(self, idx, attention_mask=None): |
|
B, T = idx.size() |
|
assert T <= self.config.block_size, f"Sequence length {T} exceeds block size {self.config.block_size}" |
|
|
|
|
|
token_embeddings = self.token_embedding(idx) |
|
|
|
|
|
x = token_embeddings |
|
|
|
|
|
for layer in self.layers: |
|
x = layer(x, attention_mask) |
|
|
|
|
|
x = self.final_norm(x) |
|
|
|
|
|
logits = self.lm_head(x) |
|
|
|
return logits |
|
|
|
|
|
def generate(self, input_ids, max_length=50,eos_token_id=None): |
|
generated_tokens = [] |
|
current_ids = input_ids |
|
|
|
|
|
device = input_ids.device |
|
|
|
for _ in range(max_length): |
|
|
|
logits = self.forward(current_ids) |
|
|
|
|
|
logits = logits[:, -1, :] |
|
|
|
next_token =logits.argmax(dim=-1).cpu().item() |
|
|
|
|
|
generated_tokens.append(next_token) |
|
|
|
|
|
current_ids = torch.cat([current_ids, torch.tensor([[next_token]]).to(device)], dim=1) |
|
|
|
|
|
if eos_token_id is not None and next_token == eos_token_id: |
|
break |
|
|
|
return generated_tokens |
|
|
|
|
|
|
|
class OptimizerConfig: |
|
accumulate_grad_in_fp32 = True |
|
clip_grad = 1.0 |
|
learning_rate = 0.003 |
|
lr_decay_starting_step = 1600000 |
|
lr_decay_steps = 400000 |
|
lr_decay_style = "linear" |
|
lr_warmup_steps = 2000 |
|
lr_warmup_style = "linear" |
|
min_decay_lr = 0.0 |
|
adam_beta1 = 0.9 |
|
adam_beta2 = 0.95 |
|
adam_eps = 1.0e-08 |
|
weight_decay = 0.01 |
|
zero_stage = 0 |
|
name = "adamW" |
|
torch_adam_is_fused = True |
|
|
|
|
|
if __name__ == "__main__": |
|
logging.basicConfig(filename='/kaggle/working/training_log.txt', level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s', force=True) |
|
|
|
device = 'cpu' |
|
if torch.cuda.is_available(): |
|
device = 'cuda' |
|
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
|
device = "mps" |
|
print(f"Using device: {device}") |
|
|
|
torch.set_float32_matmul_precision('high') |
|
|
|
|
|
torch.manual_seed(1337) |
|
if torch.cuda.is_available(): |
|
torch.cuda.manual_seed(1337) |
|
|
|
|
|
model = GPT(GPTConfig()) |
|
model.to(device) |
|
|
|
|
|
|
|
best_model_path = '/kaggle/working/best_model.pth' |
|
checkpoint_model_path = '/kaggle/working/checkpoint_model.pth' |
|
start_epoch = 0 |
|
start_step = 0 |
|
best_loss = float('inf') |
|
|
|
if os.path.exists(checkpoint_model_path): |
|
model_checkpoint = torch.load(checkpoint_model_path, map_location=device, weights_only=True) |
|
model.load_state_dict(model_checkpoint['model_state_dict']) |
|
start_epoch = model_checkpoint['epoch'] |
|
start_step = model_checkpoint['step']+1 |
|
best_loss = model_checkpoint['loss'] |
|
logging.info(f"Resuming from epoch {start_epoch}, step {start_step}, best loss {best_loss:.6f}") |
|
|
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
logging.info(f"Total Parameters: {total_params:,}") |
|
logging.info(f"Trainable Parameters: {trainable_params:,}") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer") |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
dataset = load_dataset( |
|
"HuggingFaceTB/smollm-corpus", |
|
"cosmopedia-v2", |
|
streaming=True |
|
)['train'] |
|
|
|
|
|
def encode(examples): |
|
|
|
return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=2048,return_tensors=None) |
|
|
|
|
|
dataset = dataset.map(encode, batched=True,remove_columns=dataset.column_names) |
|
|
|
def collate_fn(batch): |
|
input_ids = torch.tensor([example['input_ids'] for example in batch], dtype=torch.long) |
|
attention_mask = torch.tensor([example['attention_mask'] for example in batch], dtype=torch.long) |
|
|
|
return {"input_ids": input_ids, "attention_mask": attention_mask} |
|
|
|
from torch.utils.data import DataLoader, IterableDataset |
|
train_loader = DataLoader(dataset, batch_size=8, collate_fn=collate_fn) |
|
|
|
|
|
optimizer_config = OptimizerConfig() |
|
optimizer = torch.optim.AdamW( |
|
model.parameters(), |
|
betas=(optimizer_config.adam_beta1, optimizer_config.adam_beta2), |
|
eps=optimizer_config.adam_eps, |
|
weight_decay=optimizer_config.weight_decay |
|
) |
|
|
|
|
|
target_loss = 0.099999 |
|
max_iterations = 6000 |
|
optimizer.zero_grad() |
|
|
|
scaler = torch.GradScaler() |
|
autocast_device = "cuda" if "cuda" in device else "cpu" |
|
|
|
|
|
if os.path.exists(checkpoint_model_path): |
|
optimizer.load_state_dict(model_checkpoint['optimizer_state_dict']) |
|
scaler.load_state_dict(model_checkpoint['scaler_state_dict']) |
|
|
|
sample_text = "Once upon a time" |
|
|
|
sample_tokens = tokenizer(sample_text, return_tensors='pt').input_ids.to(device) |
|
|
|
|
|
|
|
for epoch in range(start_epoch, 100): |
|
for i, batch in enumerate(train_loader, start=start_step): |
|
x = batch["input_ids"].to(device) |
|
attention_mask = batch["attention_mask"].to(device) |
|
|
|
y = torch.cat([x.clone()[:, 1:], torch.full((x.size(0), 1), tokenizer.eos_token_id, device=device)], dim=1) |
|
|
|
|
|
with torch.autocast(device_type=device, dtype=torch.bfloat16): |
|
logits = model(x, attention_mask=attention_mask) |
|
loss = F.cross_entropy( |
|
logits.view(-1, logits.size(-1)), |
|
y.view(-1), |
|
ignore_index=tokenizer.eos_token_id |
|
) |
|
|
|
scaler.scale(loss).backward() |
|
|
|
|
|
if (i+1) % 16 == 0: |
|
scaler.step(optimizer) |
|
scaler.update() |
|
optimizer.zero_grad() |
|
|
|
|
|
if loss.item() < best_loss: |
|
best_loss = loss.item() |
|
torch.save({ |
|
'epoch': epoch, |
|
'step': i, |
|
'model_state_dict': model.state_dict(), |
|
'optimizer_state_dict': optimizer.state_dict(), |
|
'scaler_state_dict': scaler.state_dict(), |
|
'loss': best_loss, |
|
}, best_model_path) |
|
|
|
|
|
logging.info(f"Epoch {epoch}, Step {i}, Loss: {loss.item():.6f}, Best Loss: {best_loss:.6f}") |
|
|
|
|
|
if (i + 1) % 500 == 0: |
|
model.eval() |
|
with torch.no_grad(): |
|
|
|
generated_tokens = model.generate(sample_tokens, max_length=50,eos_token_id = tokenizer.eos_token_id) |
|
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True) |
|
|
|
logging.info(f"Step {i + 1} Prompt: {sample_text} \n Generated Token: {generated_tokens} \n Prediction: {generated_text}") |
|
|
|
model.train() |
|
|
|
if loss.item() <= target_loss: |
|
logging.info(f"Target loss reached at step {i}. Training completed!") |
|
break |
|
|
|
if i >= max_iterations: |
|
torch.save({ |
|
'epoch': epoch, |
|
'step': i, |
|
'model_state_dict': model.state_dict(), |
|
'optimizer_state_dict': optimizer.state_dict(), |
|
'scaler_state_dict': scaler.state_dict(), |
|
'loss': best_loss, |
|
}, checkpoint_model_path) |
|
logging.info("Max iterations reached. Training stopped.") |
|
break |
|
|
|
else: |
|
continue |
|
break |
|
|
|
logging.info("Training completed!") |
|
logging.info(f"Final Loss: {loss.item():.6f}") |
|
logging.info(f"Best Loss Achieved: {best_loss:.6f}") |
|
logging.info(f"Best Model Saved To: {best_model_path}") |
|
logging.info(f"Checpoint Model Saved To: {checkpoint_model_path}") |