tini / llama_diffusion_model.py
Ruurd's picture
Overhaul code for appropriate masking for full model instead of just attention layers
b43e862 verified
raw
history blame
7.49 kB
import torch
import torch.nn as nn
from torch.amp import autocast
from transformers import AutoModelForCausalLM, PreTrainedModel, PretrainedConfig
from transformers.models.llama.modeling_llama import LlamaAttention
from peft import LoraConfig, get_peft_model
import os
from typing import Optional, Tuple
hf_token = os.getenv("HF_TOKEN")
class BidirectionalLlamaAttention(LlamaAttention):
def __init__(self, original_layer, masking='unidirectional'):
super().__init__(original_layer.config, layer_idx=original_layer.layer_idx)
self.masking = masking
self.q_proj.weight = original_layer.q_proj.weight
self.k_proj.weight = original_layer.k_proj.weight
self.v_proj.weight = original_layer.v_proj.weight
self.o_proj.weight = original_layer.o_proj.weight
def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(self, module: nn.Module, query, key, value, attention_mask, scaling, dropout=0.0, **kwargs):
key_states = self.repeat_kv(key, module.num_key_value_groups)
value_states = self.repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_mask = attention_mask.masked_fill(~attention_mask, float('-inf')).to(query.dtype)
attn_weights = attn_weights + attn_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states).transpose(1, 2).contiguous()
return attn_output, attn_weights
def rotate_half(self, x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(self, q, k, cos, sin, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (self.rotate_half(q) * sin)
k_embed = (k * cos) + (self.rotate_half(k) * sin)
return q_embed, k_embed
def forward(self, hidden_states, position_embeddings, attention_mask=None, past_key_value=None, cache_position=None, **kwargs):
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = self.apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attn_output, attn_weights = self.eager_attention_forward(
self, query_states, key_states, value_states, attention_mask,
dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
return self.o_proj(attn_output), attn_weights
class CustomTransformerConfig(PretrainedConfig):
def __init__(self, vocab_size=128256, hidden_size=4096, num_layers=32, num_heads=32, prediction_chunk=256, dropout=0,
max_position_embeddings=4096, masking_type="bidirectional_masked", **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_heads = num_heads
self.dropout = dropout
self.prediction_chunk = prediction_chunk
self.max_position_embeddings = max_position_embeddings
self.input_size = prediction_chunk
self.masking_type = masking_type
class CustomTransformerModel(PreTrainedModel):
config_class = CustomTransformerConfig
def __init__(self, config):
super().__init__(config)
self.llama = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B", torch_dtype=torch.float16, device_map="auto", token=hf_token)
self.llama.resize_token_embeddings(config.vocab_size)
for i, layer in enumerate(self.llama.model.layers):
layer.self_attn = BidirectionalLlamaAttention(layer.self_attn, masking=config.masking_type)
for param in self.llama.parameters():
param.requires_grad = False
for param in self.llama.lm_head.parameters():
param.requires_grad = True
lora_config = LoraConfig(
r=512, lora_alpha=512, lora_dropout=0.0,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
bias="none", task_type=None
)
self.llama = get_peft_model(self.llama, lora_config)
self.llama.print_trainable_parameters()
self.llama = self.llama.to(torch.float16)
def forward(self, input_ids, labels=None, **kwargs):
batch_size, seq_len = input_ids.shape
assert seq_len == self.config.prediction_chunk, f"Expected input length {self.config.prediction_chunk}, got {seq_len}"
# Build attention mask
device = input_ids.device
if self.config.masking_type == 'bidirectional':
base_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
elif self.config.masking_type == 'bidirectional_masked':
base_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
base_mask.fill_diagonal_(False)
elif self.config.masking_type == 'unidirectional':
base_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
else:
raise ValueError(f"Unknown masking type: {self.config.masking_type}")
attention_mask = base_mask.unsqueeze(0).unsqueeze(1).expand(batch_size, 1, seq_len, seq_len).clone()
with autocast("cuda", dtype=torch.float16):
outputs = self.llama(
input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
**kwargs
)
logits = outputs.logits[:, :, :self.config.vocab_size].view(batch_size, seq_len, self.config.vocab_size)
loss = None
if labels is not None:
assert labels.shape == (batch_size, seq_len), f"Labels shape mismatch: expected ({batch_size}, {seq_len}), got {labels.shape}"
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}
def disable_dropout(model):
for name, module in model.named_modules():
if isinstance(module, nn.Dropout):
setattr(model, name, nn.Identity())
return model