File size: 11,017 Bytes
7252f98
0af2920
 
7252f98
 
0af2920
 
7252f98
 
0af2920
 
7252f98
 
0af2920
 
 
7252f98
 
0af2920
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7252f98
0af2920
 
 
 
7252f98
0af2920
7252f98
 
0af2920
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7252f98
0af2920
7252f98
0af2920
 
 
 
 
 
 
7252f98
0af2920
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7252f98
0af2920
 
7252f98
0af2920
7252f98
 
0af2920
 
 
 
 
 
 
7252f98
0af2920
 
 
 
 
 
 
7252f98
 
 
 
 
 
 
 
 
 
 
0af2920
7252f98
 
 
 
 
 
 
0af2920
 
 
7252f98
 
 
 
 
0af2920
7252f98
 
 
 
 
 
 
 
 
 
0af2920
7252f98
 
 
 
 
0af2920
7252f98
 
 
 
0af2920
 
 
 
7252f98
0af2920
 
 
 
 
 
 
 
7252f98
 
0af2920
 
 
 
 
7252f98
 
0af2920
7252f98
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import torch.nn as nn
from transformers import AutoModelForCausalLM,  PreTrainedModel, PretrainedConfig
from transformers.models.llama.modeling_llama import LlamaAttention
from torch.amp import autocast
from peft import LoraConfig, get_peft_model
from typing import  Optional, Tuple
import torch
import os



hf_token = os.getenv("HF_TOKEN")

class BidirectionalLlamaAttention(LlamaAttention):
    def __init__(self, original_layer, masking = 'unidirectional'):
        super().__init__(original_layer.config, layer_idx=original_layer.layer_idx)
        self.masking = masking

        # Copy weights from original layer
        self.q_proj.weight = original_layer.q_proj.weight
        self.k_proj.weight = original_layer.k_proj.weight
        self.v_proj.weight = original_layer.v_proj.weight
        self.o_proj.weight = original_layer.o_proj.weight

    def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
        """
        This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
        num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
        """
        batch, num_key_value_heads, slen, head_dim = hidden_states.shape
        if n_rep == 1:
            return hidden_states
        hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)

        return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

    def eager_attention_forward(self,
        module: nn.Module,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        attention_mask: Optional[torch.Tensor],
        scaling: float,
        dropout: float = 0.0,
        **kwargs,
        ):
        key_states = self.repeat_kv(key, module.num_key_value_groups)
        value_states = self.repeat_kv(value, module.num_key_value_groups)

        attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
        if attention_mask is not None:
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
        attn_output = torch.matmul(attn_weights, value_states)
        attn_output = attn_output.transpose(1, 2).contiguous()

        return attn_output, attn_weights

    def rotate_half(self, x):
        """Rotates half the hidden dims of the input."""
        x1 = x[..., : x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2 :]

        return torch.cat((-x2, x1), dim=-1)


    def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
        """Applies Rotary Position Embedding to the query and key tensors.

        Args:
            q (`torch.Tensor`): The query tensor.
            k (`torch.Tensor`): The key tensor.
            cos (`torch.Tensor`): The cosine part of the rotary embedding.
            sin (`torch.Tensor`): The sine part of the rotary embedding.
            position_ids (`torch.Tensor`, *optional*):
                Deprecated and unused.
            unsqueeze_dim (`int`, *optional*, defaults to 1):
                The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
                sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
                that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
                k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
                cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
                the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
        Returns:
            `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
        """
        cos = cos.unsqueeze(unsqueeze_dim)
        sin = sin.unsqueeze(unsqueeze_dim)
        q_embed = (q * cos) + (self.rotate_half(q) * sin)
        k_embed = (k * cos) + (self.rotate_half(k) * sin)

        return q_embed, k_embed

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: Tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_value: Optional[torch.Tensor] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ):
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        # Apply rotary embeddings
        cos, sin = position_embeddings
        query_states, key_states = self.apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        # 🔄 **Modify the Attention Mask**
        seq_len = hidden_states.shape[1]
        batch_size = hidden_states.shape[0]
        if self.masking == 'bidirectional':
            base_mask = torch.ones((seq_len, seq_len), device=hidden_states.device, dtype=torch.bool)
            attn_mask = base_mask.unsqueeze(0).unsqueeze(1).expand(batch_size, 1, seq_len, seq_len).clone()  # ✅ Copy for each batch
        elif self.masking == 'bidirectional_masked':
            base_mask = torch.ones((seq_len, seq_len), device=hidden_states.device, dtype=torch.bool)
            base_mask[:, 1:].fill_diagonal_(False)  # ✅ Apply diagonal masking only in 2D
            attn_mask = base_mask.unsqueeze(0).unsqueeze(1).expand(batch_size, 1, seq_len, seq_len).clone()  # ✅ Copy for each batch
        else: # unidirectional
            # 🚀 Standard autoregressive (causal) mask
            attn_mask = torch.tril(torch.ones(seq_len, seq_len, device=hidden_states.device, dtype=torch.bool))
            attn_mask = base_mask.unsqueeze(0).unsqueeze(1).expand(batch_size, 1, seq_len, seq_len).clone()  # ✅ Copy for each batch


        # Call the default attention function
        attn_output, attn_weights = self.eager_attention_forward(
            self,
            query_states,
            key_states,
            value_states,
            attn_mask,  # ✅ Custom mask is applied here
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            **kwargs,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)

        return attn_output, attn_weights


    def _split_heads(self, tensor, num_heads, attn_head_size):
        """
        Splits hidden_size dim into attn_head_size and num_heads
        """
        new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
        tensor = tensor.view(*new_shape)
        return tensor.permute(0, 2, 1, 3)  # (batch, head, seq_length, head_features)

    def _merge_heads(self, tensor, num_heads, attn_head_size):
        """
        Merges attn_head_size dim and num_attn_heads dim into hidden_size
        """
        tensor = tensor.permute(0, 2, 1, 3).contiguous()
        new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
        return tensor.view(new_shape)

class CustomTransformerConfig(PretrainedConfig):
    def __init__(self, vocab_size=128256, hidden_size=4096, num_layers=32, num_heads=32, prediction_chunk=256, dropout=0, max_position_embeddings=4096, **kwargs):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.dropout = dropout
        self.prediction_chunk = prediction_chunk
        self.max_position_embeddings = max_position_embeddings
        self.input_size = prediction_chunk

class CustomTransformerModel(PreTrainedModel):
    config_class = CustomTransformerConfig

    def __init__(self, config):
        super().__init__(config)

        # Load pre-trained Llama model (excluding its original lm_head)
        self.llama = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B", torch_dtype=torch.float16, device_map="auto", token = hf_token)

        self.llama.resize_token_embeddings(config.vocab_size)

        for i, layer in enumerate(self.llama.model.layers):
            layer.self_attn = BidirectionalLlamaAttention(layer.self_attn, masking='bidirectional')

        # Freeze Llama to retain pre-trained knowledge
        for param in self.llama.parameters():
            param.requires_grad = False

        for param in self.llama.lm_head.parameters():
            param.requires_grad = True

        lora_config = LoraConfig(
            r=256,
            lora_alpha=256,
            lora_dropout=0.0,
            target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],  # Llama-3 uses these attention modules
            bias="none",
            task_type=None
        )

        self.llama = get_peft_model(self.llama, lora_config)
        self.llama.print_trainable_parameters()  # Print number of trainable parameters
        self.llama = self.llama.to(torch.float16)

    def forward(self, input_ids, labels=None, **kwargs):
        batch_size, seq_length = input_ids.shape
        assert seq_length == self.input_size, f"Expected input length input_size, got {seq_length}"

        with autocast("cuda", dtype=torch.float16):  # ✅ Correct future-proof usage


            outputs = self.llama(input_ids, output_hidden_states=True, **kwargs)

            logits = outputs.logits[:,:,:self.config.vocab_size]

            # Reshape logits to (batch, input_size, vocab_size)
            logits = logits.view(batch_size, self.config.prediction_chunk, self.config.vocab_size)

            loss = None

        if labels is not None:
            assert labels.shape == (batch_size, self.input_size), f"Labels shape mismatch: expected (batch, input_size), got {labels.shape}"

            # Compute loss
            loss_fct = torch.nn.CrossEntropyLoss()

            loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))


        return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}


def disable_dropout(model):
    for name, module in model.named_modules():
        if isinstance(module, nn.Dropout):
            setattr(model, name, nn.Identity())
    return model