inner_lexicon / calibration_utils.py
Guy24's picture
adding application
d844e87
import os
import torch
from torch import nn
from tqdm import tqdm
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator
from transformers import get_scheduler
from accelerate import Accelerator
from accelerate.utils import set_seed
from collections import defaultdict
from torch.utils.data import DataLoader
import torch.optim as optim
from ..utils.data_utils import load_lm_dataset, extract_new_words_from_dataset, get_group_texts_func, get_tokenize_func
class EmbeddingCalibrator(nn.Module):
def __init__(self, hidden_size, lora_r=None, lora_alpha=None, dtype=torch.bfloat16):
super().__init__()
self.use_lora = lora_r is not None
if not self.use_lora:
self.weight = nn.Parameter(torch.zeros(hidden_size, hidden_size, dtype=dtype))
else:
self.lora_scaling = lora_alpha / lora_r if lora_alpha is not None else 1.0
self.lora_A = nn.Parameter(torch.randn(lora_rank, hidden_size, dtype=dtype) * (1/lora_r))
self.lora_B = nn.Parameter(torch.zeros(hidden_size, lora_rank, dtype=dtype))
def forward(self, x):
if not self.use_lora:
return x + torch.matmul(x, self.weight.t())
else:
# Low-rank adaptation
lora_out = torch.matmul(x, self.lora_A.t())
lora_out = torch.matmul(lora_out, self.lora_B.t())
return x + self.lora_scaling * lora_out
class CalibrationModel(nn.Module):
def __init__(
self,
base_model, lm_head, original_vocab_size, num_new_tokens,
calibrate_embedding=True, calibrate_lm_head=True, empty_init=False,
lora_alpha=None, lora_r=None,
target_loss_weight=0.15, subsequent_loss_weight=0.15,
):
super().__init__()
self.base_model = base_model
self.lm_head = lm_head
self.new_tokens_start = original_vocab_size
self.new_tokens_end = original_vocab_size + num_new_tokens
self.calibrate_lm_head = calibrate_lm_head
self.calibrate_embedding = calibrate_embedding
if not empty_init:
self.lm_head_calibrator = EmbeddingCalibrator(base_model.config.hidden_size, lora_r, lora_alpha)
self.embedding_calibrator = EmbeddingCalibrator(base_model.config.hidden_size, lora_r, lora_alpha)
self.loss_fct = nn.CrossEntropyLoss(reduction="none")
self.subsequent_tokens_loss_alpha = subsequent_loss_weight
self.new_tokens_loss_alpha = target_loss_weight
self.original_tokens_loss_alpha = 1 - self.new_tokens_loss_alpha - self.subsequent_tokens_loss_alpha
def forward(self, input_ids, labels, attention_mask=None):
# shift labels by 1 for CLM
labels = labels[:, 1:].contiguous()
input_ids = input_ids[:, :-1].contiguous()
if self.calibrate_embedding:
E_weights = self.base_model.get_input_embeddings().weight.data
E_weights = torch.cat((E_weights[:self.new_tokens_start], self.embedding_calibrator(E_weights[self.new_tokens_start:])))
input_embeddings = E_weights[input_ids]
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
outputs = self.base_model(inputs_embeds=input_embeddings, attention_mask=attention_mask)
else:
with torch.no_grad():
# Forward pass through the base model
outputs = self.base_model(input_ids, attention_mask=attention_mask)
if self.calibrate_lm_head:
with torch.no_grad():
lm_head_weights = self.lm_head.weight
normed_weights = lm_head_weights.clone()
normed_weights[self.new_tokens_start:self.new_tokens_end] = self.lm_head_calibrator(lm_head_weights[self.new_tokens_start:self.new_tokens_end])
logits = torch.matmul(outputs['last_hidden_state'], normed_weights.T)
else:
if self.calibrate_embedding:
logits = self.lm_head(outputs['last_hidden_state'])
else:
with torch.no_grad():
logits = self.lm_head(outputs['last_hidden_state'])
per_example_loss = self.loss_fct(logits.transpose(1,2), labels)
original_tokens_mask = labels < self.new_tokens_start
new_tokens_mask = ~original_tokens_mask
loss = 0.0
if self.original_tokens_loss_alpha > 0.0:
loss += self.original_tokens_loss_alpha * per_example_loss[original_tokens_mask].mean()
if self.new_tokens_loss_alpha > 0.0:
loss += self.new_tokens_loss_alpha * per_example_loss[new_tokens_mask].mean()
if self.subsequent_tokens_loss_alpha > 0.0:
subsequent_tokens_mask = torch.zeros_like(original_tokens_mask, dtype=torch.bool)
subsequent_tokens_mask[:, 1:][new_tokens_mask[:, :-1]] = True
loss += self.subsequent_tokens_loss_alpha * per_example_loss[subsequent_tokens_mask].mean()
return {'loss': loss, 'logits': logits}
def get_calibrators(self):
embedding_calibrator = self.embedding_calibrator if self.calibrate_embedding else None
lm_head_calibrator = self.lm_head_calibrator if self.calibrate_lm_head else None
return {
"embedding_calibrator": embedding_calibrator,
"lm_head_calibrator": lm_head_calibrator,
"new_tokens_start": self.new_tokens_start,
"new_tokens_end": self.new_tokens_end,
}
def set_calibrators(self, embedding_calibrator=None, lm_head_calibrator=None):
self.embedding_calibrator = embedding_calibrator
self.lm_head_calibrator = lm_head_calibrator
def save_calibrators(self, save_dir):
os.makedirs(save_dir, exist_ok=True)
if self.calibrate_embedding:
torch.save(self.embedding_calibrator, os.path.join(save_dir, "embedding_calibrator.pt"))
if self.calibrate_lm_head:
torch.save(self.lm_head_calibrator, os.path.join(save_dir, "lm_head_calibrator.pt"))
def load_calibrators(self, load_dir, fail_ok=False):
"""Loads the model's state dictionary from a file."""
try:
if self.calibrate_embedding:
self.embedding_calibrator = torch.load(os.path.join(load_dir, "embedding_calibrator.pt"))
if self.calibrate_lm_head:
self.lm_head_calibrator = torch.load(os.path.join(load_dir, "lm_head_calibrator.pt"))
return True
except:
if fail_ok:
return False
raise FileNotFoundError(f"Loading calibrators from '{load_dir}' failed")
def get_calibration_model(model, original_vocab_size, num_new_tokens, target_loss_weight=0.15, subsequent_loss_weight=0.15):
calibrated_model = CalibrationModel(model.model, model.lm_head, original_vocab_size, num_new_tokens, target_loss_weight=target_loss_weight, subsequent_loss_weight=subsequent_loss_weight)
calibrated_model.base_model.eval()
calibrated_model.lm_head.eval()
for param in calibrated_model.base_model.parameters():
param.requires_grad = False
for param in calibrated_model.lm_head.parameters():
param.requires_grad = False
for param in calibrated_model.lm_head_calibrator.parameters():
param.requires_grad = True
for param in calibrated_model.embedding_calibrator.parameters():
param.requires_grad = True
return calibrated_model
def train_calibration_model(calibrated_model: CalibrationModel, tokenizer, dataset, save_dir=None, max_samples=None, filter_examples_without_new_tokens=True, lr=1e-4, lr_schedule="linear", num_epochs=1, batch_size=8, max_length=256, n_warmup_steps=0, text_col_name="text", clip_grad_norm=1.0, mixed_precision=None):
accelerator = Accelerator(mixed_precision=mixed_precision)
# Optimizer
optimizer = optim.AdamW(calibrated_model.parameters(), lr=lr)
# Tokenize data
if tokenizer.bos_token is not None and max_length:
add_start_token = True
# leave room for <BOS> token to be added:
max_tokenized_len = max_length - 1
else:
add_start_token = False
max_tokenized_len = max_length
def _add_start_token(batch):
bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * batch["input_ids"].size(dim=0)).to(batch["input_ids"].device)
batch["input_ids"] = torch.cat([bos_tokens_tensor, batch["input_ids"]], dim=1)
batch["attention_mask"] = torch.cat(
[torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(batch["attention_mask"].device), batch["attention_mask"]], dim=1)
return batch
tokenize_function = get_tokenize_func(tokenizer, text_col_name)
column_names = dataset.column_names
with accelerator.main_process_first():
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
remove_columns=column_names,
load_from_cache_file=False,
desc="Running tokenizer on dataset",
)
group_texts = get_group_texts_func(block_size=max_tokenized_len)
lm_dataset = tokenized_dataset.map(
group_texts,
batched=True,
)
if filter_examples_without_new_tokens:
examples_w_new_token = np.arange(len(lm_dataset))[np.any(np.array(lm_dataset['input_ids']) >= calibrated_model.new_tokens_start, axis=1)]
lm_dataset = lm_dataset.select(examples_w_new_token)
if max_samples is not None:
lm_dataset = lm_dataset.select(np.arange(max_samples))
data_collator = default_data_collator
# Create data loaders
dataloader = DataLoader(
lm_dataset, collate_fn=data_collator, batch_size=batch_size, drop_last=True, shuffle=True,
)
# Learning rate scheduler
if isinstance(n_warmup_steps, float):
n_warmup_steps = n_warmup_steps * len(dataloader)
scheduler = get_scheduler(lr_schedule, optimizer=optimizer, num_warmup_steps=n_warmup_steps, num_training_steps=len(dataloader) * num_epochs)
calibrated_model, dataloader = accelerator.prepare(calibrated_model, dataloader)
# Freeze the original lm_head weights
for param in calibrated_model.lm_head.parameters():
param.requires_grad = False
calibrated_model.train()
for epoch in tqdm(range(num_epochs), unit="epochs", desc="Fitting calibration"):
total_loss = 0.0
for step, batch in tqdm(enumerate(dataloader), total=len(dataloader), miniters=10, unit="batches"):
if add_start_token:
batch = _add_start_token(batch)
batch["labels"] = batch["input_ids"]
optimizer.zero_grad()
outputs = calibrated_model(**batch)
loss = outputs['loss']
loss.backward()
torch.nn.utils.clip_grad_norm_(calibrated_model.parameters(), max_norm=clip_grad_norm)
optimizer.step()
scheduler.step()
total_loss += loss.item()
# # Log loss
# if step % 10 == 0:
# print(f"Epoch {epoch + 1}, Step {step}, Loss: {loss.item()}")
avg_loss = total_loss / len(dataloader)
print(f"Epoch {epoch + 1} completed. Average Loss: {avg_loss}")
if save_dir is not None:
calibrated_model.save_calibrators(save_dir)
return calibrated_model
def merge_calibrators_to_hf_model(hf_model, new_tokens_start, new_tokens_end=None, embedding_calibrator=None, lm_head_calibrator=None):
embedding_calibrator.to(hf_model.device)
lm_head_calibrator.to(hf_model.device)
if embedding_calibrator is not None:
embedding_weights = hf_model.get_input_embeddings().weight
with torch.no_grad():
calibrated_weights = embedding_calibrator(embedding_weights[new_tokens_start:new_tokens_end])
hf_model.model.embed_tokens.weight.data[
new_tokens_start:new_tokens_end] = calibrated_weights
if lm_head_calibrator is not None:
lm_head_weights = hf_model.get_output_embeddings().weight
with torch.no_grad():
calibrated_weights = lm_head_calibrator(lm_head_weights[new_tokens_start:new_tokens_end])
hf_model.lm_head.weight.data[new_tokens_start:new_tokens_end] = calibrated_weights
return hf_model
def merge_calibration_model_to_hf_model(hf_model, calibrated_model):
calibrated_model.to(hf_model.device)
if calibrated_model.calibrate_lm_head:
lm_head_weights = calibrated_model.lm_head.weight
normed_weights = calibrated_model.lm_head_calibrator(lm_head_weights[calibrated_model.new_tokens_start:calibrated_model.new_tokens_end])
with torch.no_grad():
hf_model.lm_head.weight.data[calibrated_model.new_tokens_start:calibrated_model.new_tokens_end] = normed_weights
if calibrated_model.calibrate_embedding:
embedding_weights = calibrated_model.base_model.get_input_embeddings().weight
normed_weights = calibrated_model.embedding_calibrator(embedding_weights[calibrated_model.new_tokens_start:calibrated_model.new_tokens_end])
with torch.no_grad():
hf_model.model.embed_tokens.weight.data[calibrated_model.new_tokens_start:calibrated_model.new_tokens_end] = normed_weights
return hf_model