import torch from torch import nn from torch.nn import functional as F from torch.utils.data import DataLoader from accelerate import Accelerator from transformers import default_data_collator from collections import defaultdict from tqdm import tqdm import numpy as np def is_not_number(s): try: float(s) # Try converting the string to a float return False # If conversion is successful, it's a number except ValueError: return True # If conversion fails, it's not a number def get_contexts_ending_with_word(word, dataset): result_contexts = [] word_len = len(word) # Iterate over the dataset for example in dataset: text = example["text"] # Find all occurrences of the word in the text start = 0 while True: idx = text.find(word, start) if idx == -1: break # Ensure that the word is isolated (not a substring of another word) if (idx == 0 or not text[idx - 1].isalnum()) and ( idx + word_len == len(text) or not text[idx + word_len].isalnum()): # Text ends with the word result_contexts.append(text[:idx + word_len].strip()) start = idx + word_len return result_contexts def get_texts_containing_word(words, dataset): result_texts = [] words_set = set(words) # Iterate over the dataset for example in dataset: if words_set.intersection(set(example["text"].split())): result_texts.append(example["text"]) return result_texts def compute_topk_token_rank(logits, labels, k=1000): # Get the top-k predicted logits and their indices topk_logits, topk_indices = torch.topk(logits, k, dim=-1) # Expand the labels for comparison labels_expanded = labels.unsqueeze(-1).expand_as(topk_indices) # Check if the label token is within the top-k predictions rank_in_topk = (topk_indices == labels_expanded).nonzero(as_tuple=False) # Create a rank tensor initialized with k (max rank is k) ranks = torch.full(labels.shape, k, dtype=torch.long, device=logits.device) # For labels in top-k, set the rank accordingly ranks[rank_in_topk[:, 0], rank_in_topk[:, 1]] = rank_in_topk[:, 2] + 1 return ranks def count_tokens_in_dataset(dataset, tokenizer, text_column='text'): def tokenize_and_count(examples): return {'num_tokens': [len(tokenizer(ex).input_ids) for ex in examples[text_column]]} tokenized_dataset = dataset.map(tokenize_and_count, batched=True, remove_columns=dataset.column_names) total_tokens = sum(tokenized_dataset['num_tokens']) return total_tokens def filter_single_token_words(array, tokenizer, add_space_prefix_for_lower=True): def _is_multi_token(word): if add_space_prefix_for_lower and word[0].islower(): word = " " + word return len(tokenizer.encode(word, add_special_tokens=False)) token_counts = array.apply(_is_multi_token) mask = token_counts > 1 return array[mask], token_counts # TODO make clearer what's its use def get_last_zero_in_every_seq_mask(tensor): # Find where consecutive zeros end zero_mask = (tensor == 0) diff = torch.diff(zero_mask.int(), dim=1) last_zero_mask = torch.cat([diff, torch.ones(tensor.size(0), 1, dtype=diff.dtype).to(tensor.device)], dim=1) == -1 # Create the output output = 1 - tensor output[zero_mask & ~last_zero_mask] = 0 return output def get_first_zero_in_every_seq_mask(tensor): # Identify where consecutive zeros begin zero_mask = (tensor == 0) diff = torch.diff(zero_mask.int(), dim=1, prepend=torch.zeros(tensor.size(0), 1, dtype=torch.int).to(tensor.device)) first_zero_mask = diff == 1 # Marks the beginning of each sequence of zeros # Create the output output = 1 - tensor output[zero_mask & ~first_zero_mask] = 0 return output def _add_start_token(batch, tokenizer): bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * batch["input_ids"].size(dim=0)).to(batch["input_ids"].device) batch["input_ids"] = torch.cat([bos_tokens_tensor, batch["input_ids"]], dim=1) batch["attention_mask"] = torch.cat( [torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(batch["attention_mask"].device), batch["attention_mask"]], dim=1) return batch def _ignore_new_words_in_attention_mask(shift_attention_mask_batch, shift_labels, new_token_ids=None, replaced_token_seqs_by_len=None): # Ignore token_ids of new vocabulary words in shift_labels and shift_logits if new_token_ids is not None: ignore_mask = torch.isin(shift_labels, new_token_ids) shift_attention_mask_batch = shift_attention_mask_batch * (~ignore_mask).long() # Ignore multi-token sequences of that were replaced with a single token if replaced_token_seqs_by_len is not None: # Create a mask that will be updated where sequences match ignore_mask = shift_attention_mask_batch.clone() # Clone the attention mask to modify it # Loop over sequences in skip_token_seqs for seq_len, seqs in replaced_token_seqs_by_len.items(): # Create a sliding window of the same size as the skip_seq and check for matches for i in range(shift_labels.size(1) - seq_len + 1): # Check if the sequence matches at position i window = shift_labels[:, i:i + seq_len] curr_mask = torch.all(window.unsqueeze(1) == seqs.unsqueeze(0), dim=-1) if curr_mask.any(): # Zero out the ignore mask for the length of the sequence ignore_mask[curr_mask.any(dim=-1), i:i + seq_len] = 0 # Apply the ignore mask to the attention mask shift_attention_mask_batch *= ignore_mask return shift_attention_mask_batch, ignore_mask # TODO consider not aggregating results here, to enable metrics for specific words def compute_metrics( logits, labels, attention_mask, compute_target_metrics=True, compute_subsequent_metrics=True, compute_perplexity=False, return_successful_targets=False, original_labels=None, original_logits=None, debug=False): target_results = dict() # will hold metrics for all the new words we add or their original tokenization background_results = dict() # will hold metrics for all background tokens, i.e., not the ones we add or replace overall_results = dict() # will hold metrics for all tokens successful_targets = None # will hold list of target tokens successfully predicted if compute_subsequent_metrics: # prepare labels and attentions masks for computing metrics only for the 1st tokens following the new words subsequent_labels = labels[:, 1:] subsequent_attention_mask = get_last_zero_in_every_seq_mask(attention_mask[..., :-1].contiguous()) subsequent_attention_mask_bool = subsequent_attention_mask == 1 attention_mask_bool = attention_mask == 1 overall_mask_bool = attention_mask_bool if compute_target_metrics: target_mask = get_first_zero_in_every_seq_mask(attention_mask) target_mask_bool = target_mask == 1 overall_mask_bool = attention_mask_bool | target_mask_bool if compute_perplexity: background_results["perplexity"] = torch.exp( (F.cross_entropy(logits.transpose(1, 2), labels, reduction="none") * attention_mask).sum(1) / attention_mask.sum(1) ).mean().detach().cpu().numpy() top1 = logits.argmax(dim=-1) if original_logits is not None: orig_top1 = original_logits.argmax(dim=-1) if compute_target_metrics: target_results["top1_acc"] = ((labels == top1)[target_mask_bool]).detach().cpu().numpy() if original_labels is not None: target_results["sum_top1_acc"] = ( ((original_labels == top1) | (labels == top1))[target_mask_bool]).detach().cpu().numpy() if original_logits is not None: target_results["orig_top1_acc"] = ( (original_labels == orig_top1)[target_mask_bool]).detach().cpu().numpy() if return_successful_targets: successful_targets = (labels[(labels == top1) & target_mask_bool]).detach().cpu().numpy() background_results["top1_acc"] = (( labels == top1)[attention_mask_bool]).detach().cpu().numpy() if compute_subsequent_metrics: background_results["subsequent_top1_acc"] = ((subsequent_labels == top1[:, 1:])[subsequent_attention_mask_bool]).detach().cpu().numpy() if original_logits is not None: background_results["orig_top1_acc"] = ( (original_labels == orig_top1)[attention_mask_bool]).detach().cpu().numpy() if compute_subsequent_metrics: background_results["orig_subsequent_top1_acc"] = ( (subsequent_labels == orig_top1[:, 1:])[subsequent_attention_mask_bool]).detach().cpu().numpy() overall_results["top1_acc"] = ((labels == top1))[overall_mask_bool].detach().cpu().numpy() if original_labels is not None: overall_results["sum_top1_acc"] = ( ((original_labels == top1) | (labels == top1)))[overall_mask_bool].detach().cpu().numpy() if original_logits is not None: overall_results["orig_top1_acc"] = ( (original_labels == orig_top1)[overall_mask_bool]).detach().cpu().numpy() if debug: import pdb; pdb.set_trace() return background_results, target_results, overall_results, successful_targets def eval_next_word_prediction( model, tokenizer, lm_dataset, accelerator=None, batch_size: int = 4, new_token_ids=None, replaced_token_seqs_by_len=None, new_token_to_original_first_token=None, max_length: int = 256, drop_last: bool = True, eval_max_samples: int = None, eval_shuffle_samples: bool = False, reduction="none", ): if accelerator is None: accelerator = Accelerator() model.eval() if tokenizer.bos_token is not None and max_length: add_start_token = True else: add_start_token = False data_collator = default_data_collator if eval_max_samples: eval_idx = range(len(lm_dataset), min(eval_max_samples, len(lm_dataset))) if eval_shuffle_samples: eval_idx = np.random.choice(len(lm_dataset), min(eval_max_samples, len(lm_dataset))) lm_dataset = lm_dataset.select(eval_idx) # Create data loaders eval_dataloader = DataLoader( lm_dataset, collate_fn=data_collator, batch_size=batch_size, drop_last=drop_last, shuffle=False, ) eval_dataloader = accelerator.prepare(eval_dataloader) model.eval() if new_token_ids is not None: new_token_ids = torch.tensor(new_token_ids).to(model.device) if replaced_token_seqs_by_len is not None: replaced_token_seqs_by_len = {token_length: torch.tensor(skip_token_seqs).to(model.device) for token_length, skip_token_seqs in replaced_token_seqs_by_len.items() if len(skip_token_seqs) > 0} if new_token_to_original_first_token is not None: # Convert the mapping into a tensor for efficient indexing, create a mapping tensor that defaults to identity new_token_to_orig_first_mapping_tensor = torch.arange(len(tokenizer), device=model.device) new_token_to_orig_first_mapping_tensor[torch.tensor(list(new_token_to_original_first_token.keys()), device=model.device)] = \ torch.tensor(list(new_token_to_original_first_token.values()), device=model.device) target_metrics = defaultdict(list) background_metrics = defaultdict(list) overall_metrics = defaultdict(list) # run eval and compute metrics for batch_i, batch in tqdm(enumerate(eval_dataloader), total=len(eval_dataloader), miniters=10, desc="Evaluating vocabulary..."): if add_start_token: batch = _add_start_token(batch, tokenizer) labels = batch["input_ids"] attn_mask = batch["attention_mask"] batch.pop("labels") with torch.no_grad(): outputs = model(**batch) out_logits = outputs.logits shift_logits = out_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() shift_attention_mask_batch = attn_mask[..., 1:].contiguous() shift_attention_mask_batch, ignore_mask = \ _ignore_new_words_in_attention_mask( shift_attention_mask_batch, shift_labels, new_token_ids, replaced_token_seqs_by_len) original_labels = None if new_token_to_original_first_token is None \ else new_token_to_orig_first_mapping_tensor[shift_labels] original_logits = None if new_token_ids is None else torch.cat([shift_logits[:, :, :min(new_token_ids)], shift_logits[:, :, max(new_token_ids)+1:]], dim=-1) background_results, target_results, overall_results, successful_targets = \ compute_metrics( shift_logits, shift_labels, shift_attention_mask_batch, original_labels=original_labels, original_logits=original_logits, compute_perplexity=True) for metric_name, metric_value in target_results.items(): target_metrics[metric_name].append(np.array(metric_value)) for metric_name, metric_value in background_results.items(): background_metrics[metric_name].append(metric_value) for metric_name, metric_value in overall_results.items(): overall_metrics[metric_name].append(metric_value) eval_dataloader = accelerator.free_memory(eval_dataloader) def _concat_func(x): if isinstance(x, np.ndarray) and len(x.shape) > 1: x = np.concat(x) elif isinstance(x, (list, tuple)) and len(x) > 1: if isinstance(x[0], np.ndarray) and len(x[0].shape) == 0: x = np.array(x) else: x = np.concat(x) return x # apply reduction reduce_func = _concat_func if reduction == 'mean': reduce_func = lambda x: np.mean(_concat_func(x)).item() for metric_name, metric_value in target_metrics.items(): target_metrics[metric_name] = reduce_func(metric_value) for metric_name, metric_value in background_metrics.items(): background_metrics[metric_name] = reduce_func(metric_value) for metric_name, metric_value in overall_metrics.items(): overall_metrics[metric_name] = reduce_func(metric_value) return background_metrics, target_metrics, overall_metrics