File size: 50,163 Bytes
2799123
1
{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.10.12","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"gpu","dataSources":[],"dockerImageVersionId":30840,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"# This Python 3 environment comes with many helpful analytics libraries installed\n# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python\n# For example, here's several helpful packages to load\n\nimport numpy as np # linear algebra\nimport pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n\n# Input data files are available in the read-only \"../input/\" directory\n# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory\n\nimport os\nfor dirname, _, filenames in os.walk('/kaggle/input'):\n    for filename in filenames:\n        print(os.path.join(dirname, filename))\n\n# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using \"Save & Run All\" \n# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session","metadata":{"_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"!export CUDA_LAUNCH_BLOCKING=1","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:56:36.728095Z","iopub.execute_input":"2025-02-01T14:56:36.728438Z","iopub.status.idle":"2025-02-01T14:56:36.843265Z","shell.execute_reply.started":"2025-02-01T14:56:36.728407Z","shell.execute_reply":"2025-02-01T14:56:36.842447Z"}},"outputs":[],"execution_count":1},{"cell_type":"code","source":"# !rm /kaggle/working/best_model.pth\n# !rm /kaggle/working/training_log.txt\n# !rm /kaggle/working/checkpoint_model.pth","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:56:37.188794Z","iopub.execute_input":"2025-02-01T14:56:37.189068Z","iopub.status.idle":"2025-02-01T14:56:37.192229Z","shell.execute_reply.started":"2025-02-01T14:56:37.189041Z","shell.execute_reply":"2025-02-01T14:56:37.191535Z"}},"outputs":[],"execution_count":2},{"cell_type":"code","source":"!pip install torchao\n!pip install triton","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:56:38.269595Z","iopub.execute_input":"2025-02-01T14:56:38.269864Z","iopub.status.idle":"2025-02-01T14:56:54.484791Z","shell.execute_reply.started":"2025-02-01T14:56:38.269842Z","shell.execute_reply":"2025-02-01T14:56:54.483971Z"}},"outputs":[{"name":"stdout","text":"Collecting torchao\n  Downloading torchao-0.8.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl.metadata (14 kB)\nDownloading torchao-0.8.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl (4.7 MB)\n\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.7/4.7 MB\u001b[0m \u001b[31m42.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n\u001b[?25hInstalling collected packages: torchao\nSuccessfully installed torchao-0.8.0\nCollecting triton\n  Downloading triton-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)\nDownloading triton-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (253.1 MB)\n\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m253.1/253.1 MB\u001b[0m \u001b[31m6.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m0:00:01\u001b[0m00:01\u001b[0m\n\u001b[?25hInstalling collected packages: triton\nSuccessfully installed triton-3.2.0\n","output_type":"stream"}],"execution_count":3},{"cell_type":"code","source":"import os\nimport math\nimport time\nimport inspect\nfrom dataclasses import dataclass\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\nfrom torchtune.modules import RotaryPositionalEmbeddings\nimport logging\nfrom transformers import AutoTokenizer\nfrom datasets import load_dataset\n","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:56:54.485978Z","iopub.execute_input":"2025-02-01T14:56:54.486228Z","iopub.status.idle":"2025-02-01T14:57:04.296970Z","shell.execute_reply.started":"2025-02-01T14:56:54.486192Z","shell.execute_reply":"2025-02-01T14:57:04.296075Z"}},"outputs":[],"execution_count":4},{"cell_type":"code","source":"\nclass LlamaMLP(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        hidden_dim = 1536  # Expand dimension to 1536\n        self.gate_proj = nn.Linear(config.n_embd, hidden_dim, bias=False)\n        self.up_proj = nn.Linear(config.n_embd, hidden_dim, bias=False)\n        self.down_proj = nn.Linear(hidden_dim, config.n_embd, bias=False)\n        self.act_fn = nn.SiLU()  # Activation function\n        self.down_proj.NANOGPT_SCALE_INIT = 1\n        \n    def forward(self, x):\n        gate = self.gate_proj(x)  # Gate projection\n        up = self.up_proj(x)     # Up projection\n        return self.down_proj(self.act_fn(gate) * up)  # Apply activation and down-project\n","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:57:27.167394Z","iopub.execute_input":"2025-02-01T14:57:27.167929Z","iopub.status.idle":"2025-02-01T14:57:27.173065Z","shell.execute_reply.started":"2025-02-01T14:57:27.167902Z","shell.execute_reply":"2025-02-01T14:57:27.172323Z"}},"outputs":[],"execution_count":5},{"cell_type":"code","source":"from torch.utils.checkpoint import checkpoint\n\nclass LlamaDecoderLayer(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.self_attn = CausalSelfAttention(config)  # Self-attention block\n        self.input_layernorm = nn.RMSNorm(config.n_embd, eps=1e-5)  # RMSNorm for inputs\n        self.post_attention_layernorm = nn.RMSNorm(config.n_embd, eps=1e-5)  # RMSNorm post-attention\n        self.mlp = LlamaMLP(config)  # Llama-style MLP\n\n    def forward(self, x, attention_mask):\n        # Use checkpointing for memory-intensive layers\n        return checkpoint(self._forward_impl, x, attention_mask, use_reentrant=False)\n        # return checkpoint.checkpoint(self._forward_impl, x, attention_mask, use_reentrant=False)\n    \n    def _forward_impl(self, x, attention_mask):\n        # Apply self-attention with normalization\n        residual = x\n        x = self.input_layernorm(x)\n        x = self.self_attn(x, attention_mask) + residual\n\n        # Apply MLP with post-attention normalization\n        residual = x\n        x = self.post_attention_layernorm(x)\n        x = self.mlp(x) + residual\n        return x","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:57:30.369518Z","iopub.execute_input":"2025-02-01T14:57:30.369808Z","iopub.status.idle":"2025-02-01T14:57:30.375285Z","shell.execute_reply.started":"2025-02-01T14:57:30.369785Z","shell.execute_reply":"2025-02-01T14:57:30.374378Z"}},"outputs":[],"execution_count":6},{"cell_type":"code","source":"@dataclass\nclass GPTConfig:\n    block_size: int = 2048 # max sequence length\n    vocab_size: int = 49152 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token\n    n_layer: int = 30 # number of layers\n    n_head: int = 9 # number of heads\n    n_embd: int = 576 # embedding dimension\n    num_key_value_heads: int = 3","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:57:32.533603Z","iopub.execute_input":"2025-02-01T14:57:32.533898Z","iopub.status.idle":"2025-02-01T14:57:32.538680Z","shell.execute_reply.started":"2025-02-01T14:57:32.533877Z","shell.execute_reply":"2025-02-01T14:57:32.537832Z"}},"outputs":[],"execution_count":7},{"cell_type":"code","source":"\nclass CausalSelfAttention(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        assert config.n_embd % config.n_head == 0\n        assert config.n_embd % config.num_key_value_heads == 0\n\n        # Query projection for all heads\n        self.cq_attn = nn.Linear(config.n_embd, config.n_embd, bias=False)  # For queries\n        # Key-Value projection for grouped heads\n        self.ckv_attn = nn.Linear(config.n_embd, 2 * (config.n_embd // config.num_key_value_heads), bias=False)  # For keys and values\n        \n        # Output projection\n        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)\n        self.n_head = config.n_head\n        self.num_key_value_heads = config.num_key_value_heads\n        self.head_dim = config.n_embd // config.n_head\n\n        # Rotary Positional Embedding\n        self.rope = RotaryPositionalEmbeddings(dim=self.head_dim, max_seq_len=config.block_size)\n\n\n        # Bias for causal mask\n        self.register_buffer(\"bias\", torch.tril(torch.ones(config.block_size, config.block_size))\n                             .view(1, 1, config.block_size, config.block_size))\n\n    def forward(self, x, attention_mask=None):\n        B, T, C = x.size()  # Batch size, sequence length, embedding dimension (n_embd)\n        \n        # Compute queries\n        q = self.cq_attn(x)  # (B, T, C)\n        q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)  # (B, nh, T, hs)\n        \n        # Compute keys and values (shared across grouped heads)\n        kv = self.ckv_attn(x)  # (B, T, 2 * (C / num_key_value_heads))\n        kv_dim = C // self.num_key_value_heads\n        k, v = kv.split(kv_dim, dim=2)  # Split into keys and values\n        k = k.view(B, T, self.num_key_value_heads, kv_dim // self.num_key_value_heads).transpose(1, 2)  # (B, kvh, T, hs)\n        v = v.view(B, T, self.num_key_value_heads, kv_dim // self.num_key_value_heads).transpose(1, 2)  # (B, kvh, T, hs)\n    \n        # k = k.repeat(1, self.n_head // self.num_key_value_heads, 1, 1)  # Repeat along the second dimension (B, 3, T, 64) -> (B, 9, T, 64)\n        # v = v.repeat(1, self.n_head // self.num_key_value_heads, 1, 1)  # Repeat along the second dimension (B, 3, T, 64) -> (B, 9, T, 64)\n\n        k = torch.repeat_interleave(k, repeats=self.n_head // self.num_key_value_heads, dim=1)\n        v = torch.repeat_interleave(v, repeats=self.n_head // self.num_key_value_heads, dim=1)\n        \n        # Apply RoPE to queries and keys\n        q = self.rope(q)\n        k = self.rope(k)\n    \n        # Scale dot-product attention\n        #att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))  # (B, nh, T, T)\n        \n        # Apply causal mask\n        # att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))\n        \n        # # If attention_mask is provided, apply it\n        # if attention_mask is not None:\n        #     # Expand attention_mask from (B, T) -> (B, 1, 1, T) to match attention scores (B, nh, T, T)\n        #     attention_mask = attention_mask[:, None, None, :]  # Add dimensions for heads and query positions\n        #     att = att.masked_fill(attention_mask == 0, float('-inf'))\n\n        # att = F.softmax(att, dim=-1)    \n        # Weighted sum of values\n        #y = att @ v  # (B, nh, T, T) x (B, kvh, T, hs) -> (B, nh, T, hs)\n\n        # Handle attention mask\n        if attention_mask is not None:\n            # Expand attention_mask to (B, 1, 1, T)\n            attention_mask = attention_mask[:, None, None, :].to(dtype=torch.bool)\n            \n            # Create causal mask (lower triangular) and convert to bool\n            causal_mask = torch.tril(torch.ones(T, T, device=x.device, dtype=torch.bool)).view(1, 1, T, T)\n            \n            # Combine causal mask and padding mask\n            attention_mask = causal_mask & attention_mask  # ✅ Now both are torch.bool\n\n\n        #print(f\"q.shape: {q.shape}, k.shape: {k.shape}, v.shape: {v.shape}, attention_mask.shape: {attention_mask.shape}\")\n        # Replace with Flash Attention (memory efficient)\n        y = F.scaled_dot_product_attention(\n            q, k, v, \n            attn_mask=attention_mask,  # Combines padding mask\n            #is_causal=True,  # Auto-applies causal mask\n            dropout_p=0.0\n        )\n\n\n    \n        # Reshape and combine heads\n        y = y.transpose(1, 2).contiguous().view(B, T, C)  # (B, T, C)\n    \n        # Output projection\n        y = self.c_proj(y)\n        return y\n","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:57:32.876787Z","iopub.execute_input":"2025-02-01T14:57:32.877010Z","iopub.status.idle":"2025-02-01T14:57:32.888047Z","shell.execute_reply.started":"2025-02-01T14:57:32.876993Z","shell.execute_reply":"2025-02-01T14:57:32.887239Z"}},"outputs":[],"execution_count":8},{"cell_type":"code","source":"class GPT(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n\n        # Embeddings\n        self.token_embedding = nn.Embedding(config.vocab_size, config.n_embd)\n\n        # Transformer layers\n        self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.n_layer)])\n        self.final_norm = nn.RMSNorm(config.n_embd, eps=1e-5)\n\n        # Output head\n        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)\n\n        # Share weights between input embedding and output head\n        self.token_embedding.weight = self.lm_head.weight\n\n        # Initialize weights\n        self.apply(self._init_weights)\n\n    def _init_weights(self, module):\n        std = 0.041666666666666664\n        if isinstance(module, nn.Linear):\n            if hasattr(module, 'NANGPT_SCALE_INIT'):\n                std *= (2 * self.config.n_layer) ** -0.5\n            torch.nn.init.normal_(module.weight, mean = 0.0, std = std)\n            if module.bias is not None:\n                torch.nn.init.zeros_(module.bias)\n        elif isinstance(module, nn.Embedding):\n            torch.nn.init.normal_(module.weight, mean=0.0, std = std)\n\n    def forward(self, idx, attention_mask=None):\n        B, T = idx.size()\n        assert T <= self.config.block_size, f\"Sequence length {T} exceeds block size {self.config.block_size}\"\n\n        # Token and positional embeddings\n        token_embeddings = self.token_embedding(idx)\n        #position_ids = torch.arange(0, T, dtype=torch.long, device=idx.device).unsqueeze(0)\n        #position_embeddings = self.position_embedding(position_ids)\n\n        # Combine embeddings\n        x = token_embeddings \n\n        # Pass through transformer layers\n        for layer in self.layers:\n            x = layer(x, attention_mask)\n\n        # Final layer normalization\n        x = self.final_norm(x)\n\n        # Compute logits\n        logits = self.lm_head(x)\n        \n        # if targets is None:\n        #     loss = None\n        # else:\n        #     # Mask padding tokens in loss calculation\n        #     loss_mask = attention_mask.reshape(-1) == 1\n        #     logits = logits.view(-1, logits.size(-1))\n        #     targets = targets.view(-1)\n            \n        #     # Only compute loss for non-padded tokens\n        #     loss = F.cross_entropy(\n        #         logits[loss_mask],\n        #         targets[loss_mask]\n        #     )\n        \n        return logits\n    \n    # def generate(self, input_ids, max_length=50):\n    #     generated_tokens = []\n    #     current_ids = input_ids\n    \n    #     for _ in range(max_length):\n    #         # Forward pass to get logits\n    #         logits = self.forward(current_ids)  # Shape: (batch_size, seq_len, vocab_size)\n    \n    #         # 🔥 Only take the last token's logits\n    #         logits = logits[:, -1, :]  # Shape: (batch_size, vocab_size)\n    \n    #         # Ensure logits are within a reasonable range\n    #         #logits = torch.clamp(logits, min=-100, max=100)\n\n    #         next_token =logits.argmax(dim=-1).cpu().item()\n            \n    #         # Store token (avoid GPU-CPU issues)\n    #         generated_tokens.append(next_token)\n    #         # print(\"next token: \", next_token)\n    \n    #         # Append token to input\n    #         current_ids = torch.cat([current_ids, torch.tensor([[next_token]]).to(device)], dim=1)\n    \n    #     return generated_tokens\n\n    def generate(self, input_ids, max_length=50,eos_token_id=None):\n        generated_tokens = []\n        current_ids = input_ids\n    \n        for _ in range(max_length):\n            # Forward pass to get logits\n            logits = self.forward(current_ids)  # Shape: (batch_size, seq_len, vocab_size)\n    \n            # 🔥 Only take the last token's logits\n            logits = logits[:, -1, :]  # Shape: (batch_size, vocab_size)\n    \n            # Ensure logits are within a reasonable range\n            #logits = torch.clamp(logits, min=-100, max=100)\n\n            next_token =logits.argmax(dim=-1).cpu().item()\n            \n            # Store token (avoid GPU-CPU issues)\n            generated_tokens.append(next_token)\n            # print(\"next token: \", next_token)\n    \n            # Append token to input\n            current_ids = torch.cat([current_ids, torch.tensor([[next_token]]).to(device)], dim=1)\n\n            # Stop if EOS token is generated\n            if eos_token_id is not None and next_token == eos_token_id:\n                break\n    \n        return generated_tokens","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:57:36.257089Z","iopub.execute_input":"2025-02-01T14:57:36.257403Z","iopub.status.idle":"2025-02-01T14:57:36.267105Z","shell.execute_reply.started":"2025-02-01T14:57:36.257377Z","shell.execute_reply":"2025-02-01T14:57:36.266261Z"}},"outputs":[],"execution_count":9},{"cell_type":"code","source":"\n# Configuration Class\nclass OptimizerConfig:\n    accumulate_grad_in_fp32 = True\n    clip_grad = 1.0\n    learning_rate = 0.003\n    lr_decay_starting_step = 1600000\n    lr_decay_steps = 400000\n    lr_decay_style = \"linear\"\n    lr_warmup_steps = 2000\n    lr_warmup_style = \"linear\"\n    min_decay_lr = 0.0\n    adam_beta1 = 0.9\n    adam_beta2 = 0.95\n    adam_eps = 1.0e-08\n    weight_decay = 0.01\n    zero_stage = 0\n    name = \"adamW\"\n    torch_adam_is_fused = True","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:57:40.542357Z","iopub.execute_input":"2025-02-01T14:57:40.542665Z","iopub.status.idle":"2025-02-01T14:57:40.547087Z","shell.execute_reply.started":"2025-02-01T14:57:40.542636Z","shell.execute_reply":"2025-02-01T14:57:40.546330Z"}},"outputs":[],"execution_count":10},{"cell_type":"code","source":"import logging\nfrom transformers import AutoTokenizer\nimport torch\nfrom datasets import load_dataset\nfrom torch.utils.data import DataLoader\nfrom torch.utils.data import Dataset\n\nif __name__ == \"__main__\":\n    logging.basicConfig(filename='/kaggle/working/training_log.txt', level=logging.INFO, \n                        format='%(asctime)s - %(levelname)s - %(message)s', force=True)\n    # Device setup\n    device = 'cpu'\n    if torch.cuda.is_available():\n        device = 'cuda'\n    elif hasattr(torch.backends, \"mps\") and torch.backends.mps.is_available():\n        device = \"mps\"\n    print(f\"Using device: {device}\")\n\n    torch.set_float32_matmul_precision('high')\n    \n    # Seed setup\n    torch.manual_seed(1337)\n    if torch.cuda.is_available():\n        torch.cuda.manual_seed(1337)\n    \n    # Model initialization\n    model = GPT(GPTConfig())\n    model.to(device)\n    #model = torch.compile(model)\n\n    # Load checkpoint if exists\n    best_model_path = '/kaggle/working/best_model.pth'\n    checkpoint_model_path = '/kaggle/working/checkpoint_model.pth'\n    start_epoch = 0\n    start_step = 0\n    best_loss = float('inf')\n    \n    if os.path.exists(checkpoint_model_path):\n        model_checkpoint = torch.load(checkpoint_model_path, map_location=device, weights_only=True)\n        model.load_state_dict(model_checkpoint['model_state_dict'])\n        start_epoch = model_checkpoint['epoch']\n        start_step = model_checkpoint['step']+1\n        best_loss = model_checkpoint['loss']\n        logging.info(f\"Resuming from epoch {start_epoch}, step {start_step}, best loss {best_loss:.6f}\")\n        \n    # Model parameter count\n    total_params = sum(p.numel() for p in model.parameters())\n    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n    logging.info(f\"Total Parameters: {total_params:,}\")\n    logging.info(f\"Trainable Parameters: {trainable_params:,}\")\n\n    # Load tokenizer\n    tokenizer = AutoTokenizer.from_pretrained(\"HuggingFaceTB/cosmo2-tokenizer\")\n    tokenizer.pad_token = tokenizer.eos_token\n    \n    # Load streaming dataset\n    dataset = load_dataset(\n        \"HuggingFaceTB/smollm-corpus\",\n        \"cosmopedia-v2\",\n        streaming=True\n    )['train']  # Access only the \"train\" split\n    \n    # Define the encode function\n    def encode(examples):\n        # Tokenize the text\n        return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=2048,return_tensors=None)\n\n    # Stream mapping\n    dataset = dataset.map(encode, batched=True,remove_columns=dataset.column_names)\n\n    def collate_fn(batch):\n        input_ids = torch.tensor([example['input_ids'] for example in batch], dtype=torch.long)\n        attention_mask = torch.tensor([example['attention_mask'] for example in batch], dtype=torch.long)\n    \n        return {\"input_ids\": input_ids, \"attention_mask\": attention_mask}\n\n    from torch.utils.data import DataLoader, IterableDataset\n    train_loader = DataLoader(dataset, batch_size=8, collate_fn=collate_fn)\n\n    # Optimizer setup\n    optimizer_config = OptimizerConfig()\n    optimizer = torch.optim.AdamW(\n        model.parameters(),\n        betas=(optimizer_config.adam_beta1, optimizer_config.adam_beta2),\n        eps=optimizer_config.adam_eps,\n        weight_decay=optimizer_config.weight_decay\n    )\n\n    # Training loop\n    target_loss = 0.099999\n    max_iterations = 6000\n    optimizer.zero_grad()\n\n    scaler = torch.GradScaler()  # ✅ Use AMP GradScaler\n    autocast_device = \"cuda\" if \"cuda\" in device else \"cpu\"  # ✅ Ensure valid autocast device\n\n    \n    if os.path.exists(checkpoint_model_path):\n        optimizer.load_state_dict(model_checkpoint['optimizer_state_dict'])\n        scaler.load_state_dict(model_checkpoint['scaler_state_dict'])\n    \n    sample_text = \"Once upon a time\"  # Text for tracking improvements\n\n    sample_tokens = tokenizer(sample_text, return_tensors='pt').input_ids.to(device)\n    #sample_tokens = torch.tensor(sample_tokens).unsqueeze(0)  # Add batch dimension\n    \n    \n    for epoch in range(start_epoch, 100):\n        for i, batch in enumerate(train_loader, start=start_step):\n            x = batch[\"input_ids\"].to(device)\n            attention_mask = batch[\"attention_mask\"].to(device)\n            # PROPER TARGET SETUP\n            y = torch.cat([x.clone()[:, 1:], torch.full((x.size(0), 1), tokenizer.eos_token_id, device=device)], dim=1)\n\n\n            with torch.autocast(device_type=device, dtype=torch.bfloat16):\n                  logits = model(x, attention_mask=attention_mask)\n                  loss = F.cross_entropy(\n                      logits.view(-1, logits.size(-1)),\n                      y.view(-1),\n                      ignore_index=tokenizer.eos_token_id  # Exclude padding\n                  )\n\n            scaler.scale(loss).backward()  # ✅ Apply scaled gradient\n    \n            # Gradient accumulation (effective batch size = 4)\n            if (i+1) % 16 == 0:  # ✅ Ensure last batch updates\n                scaler.step(optimizer)\n                scaler.update()\n                optimizer.zero_grad()\n                \n            # Save best model\n            if loss.item() < best_loss:\n                best_loss = loss.item()\n                torch.save({\n                    'epoch': epoch,\n                    'step': i,\n                    'model_state_dict': model.state_dict(),\n                    'optimizer_state_dict': optimizer.state_dict(),\n                    'scaler_state_dict': scaler.state_dict(),\n                    'loss': best_loss,\n                }, best_model_path)\n                \n\n            logging.info(f\"Epoch {epoch}, Step {i}, Loss: {loss.item():.6f}, Best Loss: {best_loss:.6f}\")\n\n            # Perform prediction every 500 steps\n            if (i + 1) % 500 == 0:\n                model.eval()\n                with torch.no_grad():\n            \n                    generated_tokens = model.generate(sample_tokens, max_length=50,eos_token_id = tokenizer.eos_token_id)\n                    generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)\n            \n                    logging.info(f\"Step {i + 1} Prompt: {sample_text} \\n Generated Token: {generated_tokens} \\n Prediction: {generated_text}\")\n            \n                model.train()\n                \n            if loss.item() <= target_loss:\n                logging.info(f\"Target loss reached at step {i}. Training completed!\")\n                break\n\n            if i >= max_iterations:\n                torch.save({\n                    'epoch': epoch,\n                    'step': i,\n                    'model_state_dict': model.state_dict(),\n                    'optimizer_state_dict': optimizer.state_dict(),\n                    'scaler_state_dict': scaler.state_dict(),\n                    'loss': best_loss,\n                }, checkpoint_model_path)\n                logging.info(\"Max iterations reached. Training stopped.\")\n                break\n\n        else:\n            continue\n        break\n\n    logging.info(\"Training completed!\")\n    logging.info(f\"Final Loss: {loss.item():.6f}\")\n    logging.info(f\"Best Loss Achieved: {best_loss:.6f}\")\n    logging.info(f\"Best Model Saved To: {best_model_path}\")\n    logging.info(f\"Checpoint Model Saved To: {checkpoint_model_path}\")\n","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:58:20.220862Z","iopub.execute_input":"2025-02-01T14:58:20.221186Z","iopub.status.idle":"2025-02-01T17:06:18.942728Z","shell.execute_reply.started":"2025-02-01T14:58:20.221164Z","shell.execute_reply":"2025-02-01T17:06:18.941276Z"}},"outputs":[{"name":"stdout","text":"Using device: cuda\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":"tokenizer_config.json:   0%|          | 0.00/3.91k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"e20cf1e86bb9459e8140176c7d2ac7c5"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"vocab.json:   0%|          | 0.00/801k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"c0f53ea2c8734454a3ac8e95d3cbc2bc"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"merges.txt:   0%|          | 0.00/466k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"745e47221b6e44e1b397b97c7baa90f9"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"tokenizer.json:   0%|          | 0.00/2.10M [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"3e3bc1425f6e423988769200b9676bff"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"special_tokens_map.json:   0%|          | 0.00/489 [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"67ed9fcf5ec24d109415ee701d326093"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"README.md:   0%|          | 0.00/7.05k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"9403297bfa9b42cc8d42883044d1ed6f"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"Resolving data files:   0%|          | 0/104 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"c6888f15fb7148b3bb57024e822289d9"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"Resolving data files:   0%|          | 0/104 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"97c5eeb430e740dc9a128043f79e65be"}},"metadata":{}}],"execution_count":11},{"cell_type":"code","source":"del model  # If you no longer need the model\ntorch.cuda.empty_cache()","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:52:31.509162Z","iopub.execute_input":"2025-02-01T14:52:31.509508Z","iopub.status.idle":"2025-02-01T14:52:31.513649Z","shell.execute_reply.started":"2025-02-01T14:52:31.509478Z","shell.execute_reply":"2025-02-01T14:52:31.512773Z"}},"outputs":[],"execution_count":13},{"cell_type":"code","source":"torch.cuda.reset_max_memory_allocated()\ntorch.cuda.reset_max_memory_cached()","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:53:57.801944Z","iopub.execute_input":"2025-02-01T14:53:57.802238Z","iopub.status.idle":"2025-02-01T14:53:57.807199Z","shell.execute_reply.started":"2025-02-01T14:53:57.802218Z","shell.execute_reply":"2025-02-01T14:53:57.806479Z"}},"outputs":[{"name":"stderr","text":"/usr/local/lib/python3.10/dist-packages/torch/cuda/memory.py:365: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.\n  warnings.warn(\n/usr/local/lib/python3.10/dist-packages/torch/cuda/memory.py:391: FutureWarning: torch.cuda.reset_max_memory_cached now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.\n  warnings.warn(\n","output_type":"stream"}],"execution_count":15},{"cell_type":"code","source":"torch.cuda.memory_stats(device=None)  # Get current memory stats\ntorch.cuda.reset_peak_memory_stats()  # Reset memory tracking stats","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:54:40.435595Z","iopub.execute_input":"2025-02-01T14:54:40.435913Z","iopub.status.idle":"2025-02-01T14:54:40.440121Z","shell.execute_reply.started":"2025-02-01T14:54:40.435885Z","shell.execute_reply":"2025-02-01T14:54:40.439251Z"}},"outputs":[],"execution_count":17},{"cell_type":"code","source":"import torch\n\n# Check allocated memory\nprint(f\"Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB\")\n\n# Check reserved (cached) memory\nprint(f\"Reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB\")\n","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:54:44.037499Z","iopub.execute_input":"2025-02-01T14:54:44.037869Z","iopub.status.idle":"2025-02-01T14:54:44.042882Z","shell.execute_reply.started":"2025-02-01T14:54:44.037840Z","shell.execute_reply":"2025-02-01T14:54:44.041975Z"}},"outputs":[{"name":"stdout","text":"Allocated: 11621.08 MB\nReserved: 14858.00 MB\n","output_type":"stream"}],"execution_count":18},{"cell_type":"code","source":"model","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T17:27:44.142744Z","iopub.execute_input":"2025-02-01T17:27:44.143057Z","iopub.status.idle":"2025-02-01T17:27:44.149775Z","shell.execute_reply.started":"2025-02-01T17:27:44.143033Z","shell.execute_reply":"2025-02-01T17:27:44.149094Z"}},"outputs":[{"execution_count":12,"output_type":"execute_result","data":{"text/plain":"GPT(\n  (token_embedding): Embedding(49152, 576)\n  (layers): ModuleList(\n    (0-29): 30 x LlamaDecoderLayer(\n      (self_attn): CausalSelfAttention(\n        (cq_attn): Linear(in_features=576, out_features=576, bias=False)\n        (ckv_attn): Linear(in_features=576, out_features=384, bias=False)\n        (c_proj): Linear(in_features=576, out_features=576, bias=False)\n        (rope): RotaryPositionalEmbeddings()\n      )\n      (input_layernorm): RMSNorm((576,), eps=1e-05, elementwise_affine=True)\n      (post_attention_layernorm): RMSNorm((576,), eps=1e-05, elementwise_affine=True)\n      (mlp): LlamaMLP(\n        (gate_proj): Linear(in_features=576, out_features=1536, bias=False)\n        (up_proj): Linear(in_features=576, out_features=1536, bias=False)\n        (down_proj): Linear(in_features=1536, out_features=576, bias=False)\n        (act_fn): SiLU()\n      )\n    )\n  )\n  (final_norm): RMSNorm((576,), eps=1e-05, elementwise_affine=True)\n  (lm_head): Linear(in_features=576, out_features=49152, bias=False)\n)"},"metadata":{}}],"execution_count":12},{"cell_type":"code","source":"model.config","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T17:28:00.622591Z","iopub.execute_input":"2025-02-01T17:28:00.622987Z","iopub.status.idle":"2025-02-01T17:28:00.628489Z","shell.execute_reply.started":"2025-02-01T17:28:00.622953Z","shell.execute_reply":"2025-02-01T17:28:00.627499Z"}},"outputs":[{"execution_count":13,"output_type":"execute_result","data":{"text/plain":"GPTConfig(block_size=2048, vocab_size=49152, n_layer=30, n_head=9, n_embd=576, num_key_value_heads=3)"},"metadata":{}}],"execution_count":13},{"cell_type":"code","source":"from torchinfo import summary\n\nsummary(model, input_size=(1, 2048),dtypes=[torch.long],)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T17:33:18.647477Z","iopub.execute_input":"2025-02-01T17:33:18.647826Z","iopub.status.idle":"2025-02-01T17:33:18.764100Z","shell.execute_reply.started":"2025-02-01T17:33:18.647799Z","shell.execute_reply":"2025-02-01T17:33:18.763376Z"}},"outputs":[{"execution_count":18,"output_type":"execute_result","data":{"text/plain":"=========================================================================================================\nLayer (type:depth-idx)                                  Output Shape              Param #\n=========================================================================================================\nGPT                                                     [1, 2048, 49152]          --\n├─Embedding: 1-1                                        [1, 2048, 576]            28,311,552\n├─ModuleList: 1-2                                       --                        --\n│    └─LlamaDecoderLayer: 2-1                           [1, 2048, 576]            --\n│    │    └─RMSNorm: 3-1                                [1, 2048, 576]            576\n│    │    └─CausalSelfAttention: 3-2                    [1, 2048, 576]            884,736\n│    │    └─RMSNorm: 3-3                                [1, 2048, 576]            576\n│    │    └─LlamaMLP: 3-4                               [1, 2048, 576]            2,654,208\n│    └─LlamaDecoderLayer: 2-2                           [1, 2048, 576]            --\n│    │    └─RMSNorm: 3-5                                [1, 2048, 576]            576\n│    │    └─CausalSelfAttention: 3-6                    [1, 2048, 576]            884,736\n│    │    └─RMSNorm: 3-7                                [1, 2048, 576]            576\n│    │    └─LlamaMLP: 3-8                               [1, 2048, 576]            2,654,208\n│    └─LlamaDecoderLayer: 2-3                           [1, 2048, 576]            --\n│    │    └─RMSNorm: 3-9                                [1, 2048, 576]            576\n│    │    └─CausalSelfAttention: 3-10                   [1, 2048, 576]            884,736\n│    │    └─RMSNorm: 3-11                               [1, 2048, 576]            576\n│    │    └─LlamaMLP: 3-12                              [1, 2048, 576]            2,654,208\n│    └─LlamaDecoderLayer: 2-4                           [1, 2048, 576]            --\n│    │    └─RMSNorm: 3-13                               [1, 2048, 576]            576\n│    │    └─CausalSelfAttention: 3-14                   [1, 2048, 576]            884,736\n│    │    └─RMSNorm: 3-15                               [1, 2048, 576]            576\n│    │    └─LlamaMLP: 3-16                              [1, 2048, 576]            2,654,208\n│    └─LlamaDecoderLayer: 2-5                           [1, 2048, 576]            --\n│    │    └─RMSNorm: 3-17                               [1, 2048, 576]            576\n│    │    └─CausalSelfAttention: 3-18                   [1, 2048, 576]            884,736\n│    │    └─RMSNorm: 3-19                               [1, 2048, 576]            576\n│    │    └─LlamaMLP: 3-20                              [1, 2048, 576]            2,654,208\n│    └─LlamaDecoderLayer: 2-6                           [1, 2048, 576]            --\n│    │    └─RMSNorm: 3-21                               [1, 2048, 576]            576\n│    │    └─CausalSelfAttention: 3-22                   [1, 2048, 576]            884,736\n│    │    └─RMSNorm: 3-23                               [1, 2048, 576]            576\n│    │    └─LlamaMLP: 3-24                              [1, 2048, 576]            2,654,208\n│    └─LlamaDecoderLayer: 2-7                           [1, 2048, 576]            --\n│    │    └─RMSNorm: 3-25                               [1, 2048, 576]            576\n│    │    └─CausalSelfAttention: 3-26                   [1, 2048, 576]            884,736\n│    │    └─RMSNorm: 3-27                               [1, 2048, 576]            576\n│    │    └─LlamaMLP: 3-28                              [1, 2048, 576]            2,654,208\n│    └─LlamaDecoderLayer: 2-8                           [1, 2048, 576]            --\n│    │    └─RMSNorm: 3-29                               [1, 2048, 576]            576\n│    │    └─CausalSelfAttention: 3-30                   [1, 2048, 576]            884,736\n│    │    └─RMSNorm: 3-31                               [1, 2048, 576]            576\n│    │    └─LlamaMLP: 3-32                              [1, 2048, 576]            2,654,208\n│    └─LlamaDecoderLayer: 2-9                           [1, 2048, 576]            --\n│    │    └─RMSNorm: 3-33                               [1, 2048, 576]            576\n│    │    └─CausalSelfAttention: 3-34                   [1, 2048, 576]            884,736\n│    │    └─RMSNorm: 3-35                               [1, 2048, 576]            576\n│    │    └─LlamaMLP: 3-36                              [1, 2048, 576]            2,654,208\n│    └─LlamaDecoderLayer: 2-10                          [1, 2048, 576]            --\n│    │    └─RMSNorm: 3-37                               [1, 2048, 576]            576\n│    │    └─CausalSelfAttention: 3-38                   [1, 2048, 576]            884,736\n│    │    └─RMSNorm: 3-39                               [1, 2048, 576]            576\n│    │    └─LlamaMLP: 3-40                              [1, 2048, 576]            2,654,208\n│    └─LlamaDecoderLayer: 2-11                          [1, 2048, 576]            --\n│    │    └─RMSNorm: 3-41                               [1, 2048, 576]            576\n│    │    └─CausalSelfAttention: 3-42                   [1, 2048, 576]            884,736\n│    │    └─RMSNorm: 3-43                               [1, 2048, 576]            576\n│    │    └─LlamaMLP: 3-44                              [1, 2048, 576]            2,654,208\n│    └─LlamaDecoderLayer: 2-12                          [1, 2048, 576]            --\n│    │    └─RMSNorm: 3-45                               [1, 2048, 576]            576\n│    │    └─CausalSelfAttention: 3-46                   [1, 2048, 576]            884,736\n│    │    └─RMSNorm: 3-47                               [1, 2048, 576]            576\n│    │    └─LlamaMLP: 3-48                              [1, 2048, 576]            2,654,208\n│    └─LlamaDecoderLayer: 2-13                          [1, 2048, 576]            --\n│    │    └─RMSNorm: 3-49                               [1, 2048, 576]            576\n│    │    └─CausalSelfAttention: 3-50                   [1, 2048, 576]            884,736\n│    │    └─RMSNorm: 3-51                               [1, 2048, 576]            576\n│    │    └─LlamaMLP: 3-52                              [1, 2048, 576]            2,654,208\n│    └─LlamaDecoderLayer: 2-14                          [1, 2048, 576]            --\n│    │    └─RMSNorm: 3-53                               [1, 2048, 576]            576\n│    │    └─CausalSelfAttention: 3-54                   [1, 2048, 576]            884,736\n│    │    └─RMSNorm: 3-55                               [1, 2048, 576]            576\n│    │    └─LlamaMLP: 3-56                              [1, 2048, 576]            2,654,208\n│    └─LlamaDecoderLayer: 2-15                          [1, 2048, 576]            --\n│    │    └─RMSNorm: 3-57                               [1, 2048, 576]            576\n│    │    └─CausalSelfAttention: 3-58                   [1, 2048, 576]            884,736\n│    │    └─RMSNorm: 3-59                               [1, 2048, 576]            576\n│    │    └─LlamaMLP: 3-60                              [1, 2048, 576]            2,654,208\n│    └─LlamaDecoderLayer: 2-16                          [1, 2048, 576]            --\n│    │    └─RMSNorm: 3-61                               [1, 2048, 576]            576\n│    │    └─CausalSelfAttention: 3-62                   [1, 2048, 576]            884,736\n│    │    └─RMSNorm: 3-63                               [1, 2048, 576]            576\n│    │    └─LlamaMLP: 3-64                              [1, 2048, 576]            2,654,208\n│    └─LlamaDecoderLayer: 2-17                          [1, 2048, 576]            --\n│    │    └─RMSNorm: 3-65                               [1, 2048, 576]            576\n│    │    └─CausalSelfAttention: 3-66                   [1, 2048, 576]            884,736\n│    │    └─RMSNorm: 3-67                               [1, 2048, 576]            576\n│    │    └─LlamaMLP: 3-68                              [1, 2048, 576]            2,654,208\n│    └─LlamaDecoderLayer: 2-18                          [1, 2048, 576]            --\n│    │    └─RMSNorm: 3-69                               [1, 2048, 576]            576\n│    │    └─CausalSelfAttention: 3-70                   [1, 2048, 576]            884,736\n│    │    └─RMSNorm: 3-71                               [1, 2048, 576]            576\n│    │    └─LlamaMLP: 3-72                              [1, 2048, 576]            2,654,208\n│    └─LlamaDecoderLayer: 2-19                          [1, 2048, 576]            --\n│    │    └─RMSNorm: 3-73                               [1, 2048, 576]            576\n│    │    └─CausalSelfAttention: 3-74                   [1, 2048, 576]            884,736\n│    │    └─RMSNorm: 3-75                               [1, 2048, 576]            576\n│    │    └─LlamaMLP: 3-76                              [1, 2048, 576]            2,654,208\n│    └─LlamaDecoderLayer: 2-20                          [1, 2048, 576]            --\n│    │    └─RMSNorm: 3-77                               [1, 2048, 576]            576\n│    │    └─CausalSelfAttention: 3-78                   [1, 2048, 576]            884,736\n│    │    └─RMSNorm: 3-79                               [1, 2048, 576]            576\n│    │    └─LlamaMLP: 3-80                              [1, 2048, 576]            2,654,208\n│    └─LlamaDecoderLayer: 2-21                          [1, 2048, 576]            --\n│    │    └─RMSNorm: 3-81                               [1, 2048, 576]            576\n│    │    └─CausalSelfAttention: 3-82                   [1, 2048, 576]            884,736\n│    │    └─RMSNorm: 3-83                               [1, 2048, 576]            576\n│    │    └─LlamaMLP: 3-84                              [1, 2048, 576]            2,654,208\n│    └─LlamaDecoderLayer: 2-22                          [1, 2048, 576]            --\n│    │    └─RMSNorm: 3-85                               [1, 2048, 576]            576\n│    │    └─CausalSelfAttention: 3-86                   [1, 2048, 576]            884,736\n│    │    └─RMSNorm: 3-87                               [1, 2048, 576]            576\n│    │    └─LlamaMLP: 3-88                              [1, 2048, 576]            2,654,208\n│    └─LlamaDecoderLayer: 2-23                          [1, 2048, 576]            --\n│    │    └─RMSNorm: 3-89                               [1, 2048, 576]            576\n│    │    └─CausalSelfAttention: 3-90                   [1, 2048, 576]            884,736\n│    │    └─RMSNorm: 3-91                               [1, 2048, 576]            576\n│    │    └─LlamaMLP: 3-92                              [1, 2048, 576]            2,654,208\n│    └─LlamaDecoderLayer: 2-24                          [1, 2048, 576]            --\n│    │    └─RMSNorm: 3-93                               [1, 2048, 576]            576\n│    │    └─CausalSelfAttention: 3-94                   [1, 2048, 576]            884,736\n│    │    └─RMSNorm: 3-95                               [1, 2048, 576]            576\n│    │    └─LlamaMLP: 3-96                              [1, 2048, 576]            2,654,208\n│    └─LlamaDecoderLayer: 2-25                          [1, 2048, 576]            --\n│    │    └─RMSNorm: 3-97                               [1, 2048, 576]            576\n│    │    └─CausalSelfAttention: 3-98                   [1, 2048, 576]            884,736\n│    │    └─RMSNorm: 3-99                               [1, 2048, 576]            576\n│    │    └─LlamaMLP: 3-100                             [1, 2048, 576]            2,654,208\n│    └─LlamaDecoderLayer: 2-26                          [1, 2048, 576]            --\n│    │    └─RMSNorm: 3-101                              [1, 2048, 576]            576\n│    │    └─CausalSelfAttention: 3-102                  [1, 2048, 576]            884,736\n│    │    └─RMSNorm: 3-103                              [1, 2048, 576]            576\n│    │    └─LlamaMLP: 3-104                             [1, 2048, 576]            2,654,208\n│    └─LlamaDecoderLayer: 2-27                          [1, 2048, 576]            --\n│    │    └─RMSNorm: 3-105                              [1, 2048, 576]            576\n│    │    └─CausalSelfAttention: 3-106                  [1, 2048, 576]            884,736\n│    │    └─RMSNorm: 3-107                              [1, 2048, 576]            576\n│    │    └─LlamaMLP: 3-108                             [1, 2048, 576]            2,654,208\n│    └─LlamaDecoderLayer: 2-28                          [1, 2048, 576]            --\n│    │    └─RMSNorm: 3-109                              [1, 2048, 576]            576\n│    │    └─CausalSelfAttention: 3-110                  [1, 2048, 576]            884,736\n│    │    └─RMSNorm: 3-111                              [1, 2048, 576]            576\n│    │    └─LlamaMLP: 3-112                             [1, 2048, 576]            2,654,208\n│    └─LlamaDecoderLayer: 2-29                          [1, 2048, 576]            --\n│    │    └─RMSNorm: 3-113                              [1, 2048, 576]            576\n│    │    └─CausalSelfAttention: 3-114                  [1, 2048, 576]            884,736\n│    │    └─RMSNorm: 3-115                              [1, 2048, 576]            576\n│    │    └─LlamaMLP: 3-116                             [1, 2048, 576]            2,654,208\n│    └─LlamaDecoderLayer: 2-30                          [1, 2048, 576]            --\n│    │    └─RMSNorm: 3-117                              [1, 2048, 576]            576\n│    │    └─CausalSelfAttention: 3-118                  [1, 2048, 576]            884,736\n│    │    └─RMSNorm: 3-119                              [1, 2048, 576]            576\n│    │    └─LlamaMLP: 3-120                             [1, 2048, 576]            2,654,208\n├─RMSNorm: 1-3                                          [1, 2048, 576]            576\n├─Linear: 1-4                                           [1, 2048, 49152]          28,311,552\n=========================================================================================================\nTotal params: 162,826,560\nTrainable params: 162,826,560\nNon-trainable params: 0\nTotal mult-adds (M): 162.83\n=========================================================================================================\nInput size (MB): 0.02\nForward/backward pass size (MB): 3938.45\nParams size (MB): 651.31\nEstimated Total Size (MB): 4589.77\n========================================================================================================="},"metadata":{}}],"execution_count":18},{"cell_type":"code","source":"","metadata":{"trusted":true},"outputs":[],"execution_count":null}]}