Spaces:
Running on Zero

File size: 4,004 Bytes
b43e862
7252f98
 
b43e862
 
7252f98
 
b43e862
0af2920
7252f98
 
 
b43e862
7d7b6d7
7252f98
 
 
 
 
 
 
 
0af2920
b43e862
7252f98
 
 
 
 
 
b43e862
7252f98
 
 
 
 
 
 
 
b43e862
 
 
7252f98
 
 
b43e862
f7efac8
7252f98
b43e862
 
 
 
 
7141e39
7d7b6d7
238c8f8
b43e862
238c8f8
b43e862
 
238c8f8
b43e862
 
 
 
 
8851563
 
b43e862
b6cb410
 
 
 
 
 
 
 
b43e862
 
 
 
7252f98
b43e862
 
7252f98
 
 
 
 
 
 
 
b43e862
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
import torch.nn as nn
from torch.amp import autocast
from transformers import AutoModelForCausalLM, PreTrainedModel, PretrainedConfig
from transformers.models.llama.modeling_llama import LlamaAttention
from peft import LoraConfig, get_peft_model
import os
from typing import Optional, Tuple

hf_token = os.getenv("HF_TOKEN")

class CustomTransformerConfig(PretrainedConfig):
    def __init__(self, vocab_size=128256, hidden_size=4096, num_layers=32, num_heads=32, prediction_chunk=256, dropout=0,
                 max_position_embeddings=4096, masking_type="bidirectional", **kwargs):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.dropout = dropout
        self.prediction_chunk = prediction_chunk
        self.max_position_embeddings = max_position_embeddings
        self.input_size = prediction_chunk
        self.masking_type = masking_type

class CustomTransformerModel(PreTrainedModel):
    config_class = CustomTransformerConfig

    def __init__(self, config):
        super().__init__(config)
        self.llama = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B", torch_dtype=torch.float16, device_map="auto", token=hf_token)
        self.llama.resize_token_embeddings(config.vocab_size)

        for param in self.llama.parameters():
            param.requires_grad = False
        for param in self.llama.lm_head.parameters():
            param.requires_grad = True

        lora_config = LoraConfig(
            r=512, lora_alpha=512, lora_dropout=0.0,
            target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
            bias="none", task_type=None
        )

        self.llama = get_peft_model(self.llama, lora_config)
        self.llama.print_trainable_parameters()

    def forward(self, input_ids, labels=None, **kwargs):
        batch_size, seq_len = input_ids.shape
        assert seq_len == self.config.prediction_chunk, f"Expected input length {self.config.prediction_chunk}, got {seq_len}"

        # Build attention mask
        device = input_ids.device

        masking_type = getattr(self.config, "masking_type", "bidirectional")
        if masking_type == 'bidirectional':
            base_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
        elif masking_type == 'bidirectional_masked':
            base_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
            base_mask.fill_diagonal_(False)
        elif masking_type == 'unidirectional':
            base_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
        else:
            raise ValueError(f"Unknown masking type: {self.config.masking_type}")

        attention_mask = base_mask.unsqueeze(0).unsqueeze(1).expand(batch_size, 1, seq_len, seq_len).clone()
        attention_mask = attention_mask.to(dtype=torch.float32)  # required for SDPA and Flash attention


        with autocast("cuda", dtype=torch.float16):
            outputs = self.llama(
                input_ids,
                attention_mask=attention_mask,
                output_hidden_states=True,
                use_cache=False,
                **kwargs
            )

        logits = outputs.logits[:, :, :self.config.vocab_size].view(batch_size, seq_len, self.config.vocab_size)

        loss = None
        if labels is not None:
            assert labels.shape == (batch_size, seq_len), f"Labels shape mismatch: expected ({batch_size}, {seq_len}), got {labels.shape}"
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))

        return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}

def disable_dropout(model):
    for name, module in model.named_modules():
        if isinstance(module, nn.Dropout):
            setattr(model, name, nn.Identity())
    return model