Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,004 Bytes
b43e862 7252f98 b43e862 7252f98 b43e862 0af2920 7252f98 b43e862 7d7b6d7 7252f98 0af2920 b43e862 7252f98 b43e862 7252f98 b43e862 7252f98 b43e862 f7efac8 7252f98 b43e862 7141e39 7d7b6d7 238c8f8 b43e862 238c8f8 b43e862 238c8f8 b43e862 8851563 b43e862 b6cb410 b43e862 7252f98 b43e862 7252f98 b43e862 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import torch
import torch.nn as nn
from torch.amp import autocast
from transformers import AutoModelForCausalLM, PreTrainedModel, PretrainedConfig
from transformers.models.llama.modeling_llama import LlamaAttention
from peft import LoraConfig, get_peft_model
import os
from typing import Optional, Tuple
hf_token = os.getenv("HF_TOKEN")
class CustomTransformerConfig(PretrainedConfig):
def __init__(self, vocab_size=128256, hidden_size=4096, num_layers=32, num_heads=32, prediction_chunk=256, dropout=0,
max_position_embeddings=4096, masking_type="bidirectional", **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_heads = num_heads
self.dropout = dropout
self.prediction_chunk = prediction_chunk
self.max_position_embeddings = max_position_embeddings
self.input_size = prediction_chunk
self.masking_type = masking_type
class CustomTransformerModel(PreTrainedModel):
config_class = CustomTransformerConfig
def __init__(self, config):
super().__init__(config)
self.llama = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B", torch_dtype=torch.float16, device_map="auto", token=hf_token)
self.llama.resize_token_embeddings(config.vocab_size)
for param in self.llama.parameters():
param.requires_grad = False
for param in self.llama.lm_head.parameters():
param.requires_grad = True
lora_config = LoraConfig(
r=512, lora_alpha=512, lora_dropout=0.0,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
bias="none", task_type=None
)
self.llama = get_peft_model(self.llama, lora_config)
self.llama.print_trainable_parameters()
def forward(self, input_ids, labels=None, **kwargs):
batch_size, seq_len = input_ids.shape
assert seq_len == self.config.prediction_chunk, f"Expected input length {self.config.prediction_chunk}, got {seq_len}"
# Build attention mask
device = input_ids.device
masking_type = getattr(self.config, "masking_type", "bidirectional")
if masking_type == 'bidirectional':
base_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
elif masking_type == 'bidirectional_masked':
base_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
base_mask.fill_diagonal_(False)
elif masking_type == 'unidirectional':
base_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
else:
raise ValueError(f"Unknown masking type: {self.config.masking_type}")
attention_mask = base_mask.unsqueeze(0).unsqueeze(1).expand(batch_size, 1, seq_len, seq_len).clone()
attention_mask = attention_mask.to(dtype=torch.float32) # required for SDPA and Flash attention
with autocast("cuda", dtype=torch.float16):
outputs = self.llama(
input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
**kwargs
)
logits = outputs.logits[:, :, :self.config.vocab_size].view(batch_size, seq_len, self.config.vocab_size)
loss = None
if labels is not None:
assert labels.shape == (batch_size, seq_len), f"Labels shape mismatch: expected ({batch_size}, {seq_len}), got {labels.shape}"
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}
def disable_dropout(model):
for name, module in model.named_modules():
if isinstance(module, nn.Dropout):
setattr(model, name, nn.Identity())
return model |