File size: 3,927 Bytes
781bf2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import pytorch_lightning as pl
import torch
from peft import LoraConfig, get_peft_model
from torch import nn as nn
from torchmetrics import Accuracy
from transformers import AutoTokenizer, AutoModelForCausalLM


base_checkpoint = "HuggingFaceTB/SmolLM2-360M"
device = "mps" if torch.backends.mps.is_available() else "cpu"
criterion = nn.BCEWithLogitsLoss()


class SmolLM(pl.LightningModule):
    def __init__(self, learning_rate=3e-4):
        super().__init__()
        self.learning_rate = learning_rate
        self.criterion = criterion
        self.tokenizer = AutoTokenizer.from_pretrained(base_checkpoint)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.base_model = AutoModelForCausalLM.from_pretrained(base_checkpoint).to(device)
        self.base_model.lm_head = nn.Identity()
        self.classifier = nn.Sequential(
            # nn.Linear(self.base_model.lm_head.out_features, 1024),
            nn.Linear(960, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
        )
        # Freeze smollm2 parameters
        for param in self.base_model.parameters():
            param.requires_grad = False
        # LoRA fine-tuning
        lora_config = LoraConfig(
            r=8,
            lora_alpha=32,
            target_modules=["q_proj", "v_proj", 'k_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],
            # Target modules for LoRA
            lora_dropout=0.0,
            bias="none",
            use_dora=True
        )
        self.base_model = get_peft_model(self.base_model, lora_config)
        self.base_model.print_trainable_parameters()
        self.save_hyperparameters()
        self.val_accuracy = Accuracy(task="binary")

    def forward(self, x):
        input_ids = x["input_ids"]
        attention_mask = x["attention_mask"]

        # Forward pass through the base model using the attention mask
        out = self.base_model(input_ids, attention_mask=attention_mask)
        logits = out.logits  # shape: (batch_size, seq_len, hidden_dim)

        # Calculate the index of the last non-padding token for each sequence
        last_token_indices = attention_mask.sum(dim=1) - 1  # shape: (batch_size)
        real_batch_size = logits.size(0)
        batch_indices = torch.arange(real_batch_size, device=device)

        # Select logits corresponding to the last non-padding token
        last_logits = logits[batch_indices, last_token_indices, :]  # shape: (batch_size, hidden_dim)

        # Pass the selected logits through the classifier
        output_logits = self.classifier(last_logits)
        return output_logits.squeeze(-1)

    def training_step(self, batch, batch_idx):
        sentences = batch["sentence"]
        labels = batch["eos_label"].to(device)
        inputs = self.tokenizer(sentences, return_tensors="pt", padding=True, truncation=True).to(device)
        logits = self(inputs)
        loss = self.criterion(logits, labels)
        self.log('Train Step Loss', loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        sentences = batch["sentence"]
        labels = batch["eos_label"].to(device)
        inputs = self.tokenizer(sentences, return_tensors="pt", padding=True, truncation=True).to(device)
        logits = self(inputs)
        loss = self.criterion(logits, labels)
        preds = (torch.sigmoid(logits) > 0.5).long()
        self.val_accuracy.update(preds, labels.long())
        self.log('Validation Step Loss', loss, prog_bar=True)
        return loss

    def on_validation_epoch_end(self):
        # Compute and log the overall validation accuracy
        acc = self.val_accuracy.compute()
        self.log('Validation Accuracy', acc, prog_bar=True)
        self.val_accuracy.reset()

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, self.parameters()), lr=self.learning_rate)
        return optimizer