import os import torch from torch import nn from tqdm import tqdm import numpy as np from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator from transformers import get_scheduler from accelerate import Accelerator from accelerate.utils import set_seed from collections import defaultdict from torch.utils.data import DataLoader import torch.optim as optim from ..utils.data_utils import load_lm_dataset, extract_new_words_from_dataset, get_group_texts_func, get_tokenize_func class EmbeddingCalibrator(nn.Module): def __init__(self, hidden_size, lora_r=None, lora_alpha=None, dtype=torch.bfloat16): super().__init__() self.use_lora = lora_r is not None if not self.use_lora: self.weight = nn.Parameter(torch.zeros(hidden_size, hidden_size, dtype=dtype)) else: self.lora_scaling = lora_alpha / lora_r if lora_alpha is not None else 1.0 self.lora_A = nn.Parameter(torch.randn(lora_rank, hidden_size, dtype=dtype) * (1/lora_r)) self.lora_B = nn.Parameter(torch.zeros(hidden_size, lora_rank, dtype=dtype)) def forward(self, x): if not self.use_lora: return x + torch.matmul(x, self.weight.t()) else: # Low-rank adaptation lora_out = torch.matmul(x, self.lora_A.t()) lora_out = torch.matmul(lora_out, self.lora_B.t()) return x + self.lora_scaling * lora_out class CalibrationModel(nn.Module): def __init__( self, base_model, lm_head, original_vocab_size, num_new_tokens, calibrate_embedding=True, calibrate_lm_head=True, empty_init=False, lora_alpha=None, lora_r=None, target_loss_weight=0.15, subsequent_loss_weight=0.15, ): super().__init__() self.base_model = base_model self.lm_head = lm_head self.new_tokens_start = original_vocab_size self.new_tokens_end = original_vocab_size + num_new_tokens self.calibrate_lm_head = calibrate_lm_head self.calibrate_embedding = calibrate_embedding if not empty_init: self.lm_head_calibrator = EmbeddingCalibrator(base_model.config.hidden_size, lora_r, lora_alpha) self.embedding_calibrator = EmbeddingCalibrator(base_model.config.hidden_size, lora_r, lora_alpha) self.loss_fct = nn.CrossEntropyLoss(reduction="none") self.subsequent_tokens_loss_alpha = subsequent_loss_weight self.new_tokens_loss_alpha = target_loss_weight self.original_tokens_loss_alpha = 1 - self.new_tokens_loss_alpha - self.subsequent_tokens_loss_alpha def forward(self, input_ids, labels, attention_mask=None): # shift labels by 1 for CLM labels = labels[:, 1:].contiguous() input_ids = input_ids[:, :-1].contiguous() if self.calibrate_embedding: E_weights = self.base_model.get_input_embeddings().weight.data E_weights = torch.cat((E_weights[:self.new_tokens_start], self.embedding_calibrator(E_weights[self.new_tokens_start:]))) input_embeddings = E_weights[input_ids] if attention_mask is None: attention_mask = torch.ones_like(input_ids, dtype=torch.long) outputs = self.base_model(inputs_embeds=input_embeddings, attention_mask=attention_mask) else: with torch.no_grad(): # Forward pass through the base model outputs = self.base_model(input_ids, attention_mask=attention_mask) if self.calibrate_lm_head: with torch.no_grad(): lm_head_weights = self.lm_head.weight normed_weights = lm_head_weights.clone() normed_weights[self.new_tokens_start:self.new_tokens_end] = self.lm_head_calibrator(lm_head_weights[self.new_tokens_start:self.new_tokens_end]) logits = torch.matmul(outputs['last_hidden_state'], normed_weights.T) else: if self.calibrate_embedding: logits = self.lm_head(outputs['last_hidden_state']) else: with torch.no_grad(): logits = self.lm_head(outputs['last_hidden_state']) per_example_loss = self.loss_fct(logits.transpose(1,2), labels) original_tokens_mask = labels < self.new_tokens_start new_tokens_mask = ~original_tokens_mask loss = 0.0 if self.original_tokens_loss_alpha > 0.0: loss += self.original_tokens_loss_alpha * per_example_loss[original_tokens_mask].mean() if self.new_tokens_loss_alpha > 0.0: loss += self.new_tokens_loss_alpha * per_example_loss[new_tokens_mask].mean() if self.subsequent_tokens_loss_alpha > 0.0: subsequent_tokens_mask = torch.zeros_like(original_tokens_mask, dtype=torch.bool) subsequent_tokens_mask[:, 1:][new_tokens_mask[:, :-1]] = True loss += self.subsequent_tokens_loss_alpha * per_example_loss[subsequent_tokens_mask].mean() return {'loss': loss, 'logits': logits} def get_calibrators(self): embedding_calibrator = self.embedding_calibrator if self.calibrate_embedding else None lm_head_calibrator = self.lm_head_calibrator if self.calibrate_lm_head else None return { "embedding_calibrator": embedding_calibrator, "lm_head_calibrator": lm_head_calibrator, "new_tokens_start": self.new_tokens_start, "new_tokens_end": self.new_tokens_end, } def set_calibrators(self, embedding_calibrator=None, lm_head_calibrator=None): self.embedding_calibrator = embedding_calibrator self.lm_head_calibrator = lm_head_calibrator def save_calibrators(self, save_dir): os.makedirs(save_dir, exist_ok=True) if self.calibrate_embedding: torch.save(self.embedding_calibrator, os.path.join(save_dir, "embedding_calibrator.pt")) if self.calibrate_lm_head: torch.save(self.lm_head_calibrator, os.path.join(save_dir, "lm_head_calibrator.pt")) def load_calibrators(self, load_dir, fail_ok=False): """Loads the model's state dictionary from a file.""" try: if self.calibrate_embedding: self.embedding_calibrator = torch.load(os.path.join(load_dir, "embedding_calibrator.pt")) if self.calibrate_lm_head: self.lm_head_calibrator = torch.load(os.path.join(load_dir, "lm_head_calibrator.pt")) return True except: if fail_ok: return False raise FileNotFoundError(f"Loading calibrators from '{load_dir}' failed") def get_calibration_model(model, original_vocab_size, num_new_tokens, target_loss_weight=0.15, subsequent_loss_weight=0.15): calibrated_model = CalibrationModel(model.model, model.lm_head, original_vocab_size, num_new_tokens, target_loss_weight=target_loss_weight, subsequent_loss_weight=subsequent_loss_weight) calibrated_model.base_model.eval() calibrated_model.lm_head.eval() for param in calibrated_model.base_model.parameters(): param.requires_grad = False for param in calibrated_model.lm_head.parameters(): param.requires_grad = False for param in calibrated_model.lm_head_calibrator.parameters(): param.requires_grad = True for param in calibrated_model.embedding_calibrator.parameters(): param.requires_grad = True return calibrated_model def train_calibration_model(calibrated_model: CalibrationModel, tokenizer, dataset, save_dir=None, max_samples=None, filter_examples_without_new_tokens=True, lr=1e-4, lr_schedule="linear", num_epochs=1, batch_size=8, max_length=256, n_warmup_steps=0, text_col_name="text", clip_grad_norm=1.0, mixed_precision=None): accelerator = Accelerator(mixed_precision=mixed_precision) # Optimizer optimizer = optim.AdamW(calibrated_model.parameters(), lr=lr) # Tokenize data if tokenizer.bos_token is not None and max_length: add_start_token = True # leave room for token to be added: max_tokenized_len = max_length - 1 else: add_start_token = False max_tokenized_len = max_length def _add_start_token(batch): bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * batch["input_ids"].size(dim=0)).to(batch["input_ids"].device) batch["input_ids"] = torch.cat([bos_tokens_tensor, batch["input_ids"]], dim=1) batch["attention_mask"] = torch.cat( [torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(batch["attention_mask"].device), batch["attention_mask"]], dim=1) return batch tokenize_function = get_tokenize_func(tokenizer, text_col_name) column_names = dataset.column_names with accelerator.main_process_first(): tokenized_dataset = dataset.map( tokenize_function, batched=True, remove_columns=column_names, load_from_cache_file=False, desc="Running tokenizer on dataset", ) group_texts = get_group_texts_func(block_size=max_tokenized_len) lm_dataset = tokenized_dataset.map( group_texts, batched=True, ) if filter_examples_without_new_tokens: examples_w_new_token = np.arange(len(lm_dataset))[np.any(np.array(lm_dataset['input_ids']) >= calibrated_model.new_tokens_start, axis=1)] lm_dataset = lm_dataset.select(examples_w_new_token) if max_samples is not None: lm_dataset = lm_dataset.select(np.arange(max_samples)) data_collator = default_data_collator # Create data loaders dataloader = DataLoader( lm_dataset, collate_fn=data_collator, batch_size=batch_size, drop_last=True, shuffle=True, ) # Learning rate scheduler if isinstance(n_warmup_steps, float): n_warmup_steps = n_warmup_steps * len(dataloader) scheduler = get_scheduler(lr_schedule, optimizer=optimizer, num_warmup_steps=n_warmup_steps, num_training_steps=len(dataloader) * num_epochs) calibrated_model, dataloader = accelerator.prepare(calibrated_model, dataloader) # Freeze the original lm_head weights for param in calibrated_model.lm_head.parameters(): param.requires_grad = False calibrated_model.train() for epoch in tqdm(range(num_epochs), unit="epochs", desc="Fitting calibration"): total_loss = 0.0 for step, batch in tqdm(enumerate(dataloader), total=len(dataloader), miniters=10, unit="batches"): if add_start_token: batch = _add_start_token(batch) batch["labels"] = batch["input_ids"] optimizer.zero_grad() outputs = calibrated_model(**batch) loss = outputs['loss'] loss.backward() torch.nn.utils.clip_grad_norm_(calibrated_model.parameters(), max_norm=clip_grad_norm) optimizer.step() scheduler.step() total_loss += loss.item() # # Log loss # if step % 10 == 0: # print(f"Epoch {epoch + 1}, Step {step}, Loss: {loss.item()}") avg_loss = total_loss / len(dataloader) print(f"Epoch {epoch + 1} completed. Average Loss: {avg_loss}") if save_dir is not None: calibrated_model.save_calibrators(save_dir) return calibrated_model def merge_calibrators_to_hf_model(hf_model, new_tokens_start, new_tokens_end=None, embedding_calibrator=None, lm_head_calibrator=None): embedding_calibrator.to(hf_model.device) lm_head_calibrator.to(hf_model.device) if embedding_calibrator is not None: embedding_weights = hf_model.get_input_embeddings().weight with torch.no_grad(): calibrated_weights = embedding_calibrator(embedding_weights[new_tokens_start:new_tokens_end]) hf_model.model.embed_tokens.weight.data[ new_tokens_start:new_tokens_end] = calibrated_weights if lm_head_calibrator is not None: lm_head_weights = hf_model.get_output_embeddings().weight with torch.no_grad(): calibrated_weights = lm_head_calibrator(lm_head_weights[new_tokens_start:new_tokens_end]) hf_model.lm_head.weight.data[new_tokens_start:new_tokens_end] = calibrated_weights return hf_model def merge_calibration_model_to_hf_model(hf_model, calibrated_model): calibrated_model.to(hf_model.device) if calibrated_model.calibrate_lm_head: lm_head_weights = calibrated_model.lm_head.weight normed_weights = calibrated_model.lm_head_calibrator(lm_head_weights[calibrated_model.new_tokens_start:calibrated_model.new_tokens_end]) with torch.no_grad(): hf_model.lm_head.weight.data[calibrated_model.new_tokens_start:calibrated_model.new_tokens_end] = normed_weights if calibrated_model.calibrate_embedding: embedding_weights = calibrated_model.base_model.get_input_embeddings().weight normed_weights = calibrated_model.embedding_calibrator(embedding_weights[calibrated_model.new_tokens_start:calibrated_model.new_tokens_end]) with torch.no_grad(): hf_model.model.embed_tokens.weight.data[calibrated_model.new_tokens_start:calibrated_model.new_tokens_end] = normed_weights return hf_model