inner_lexicon / eval_utils.py
Guy24's picture
adding application
d844e87
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from accelerate import Accelerator
from transformers import default_data_collator
from collections import defaultdict
from tqdm import tqdm
import numpy as np
def is_not_number(s):
try:
float(s) # Try converting the string to a float
return False # If conversion is successful, it's a number
except ValueError:
return True # If conversion fails, it's not a number
def get_contexts_ending_with_word(word, dataset):
result_contexts = []
word_len = len(word)
# Iterate over the dataset
for example in dataset:
text = example["text"]
# Find all occurrences of the word in the text
start = 0
while True:
idx = text.find(word, start)
if idx == -1:
break
# Ensure that the word is isolated (not a substring of another word)
if (idx == 0 or not text[idx - 1].isalnum()) and (
idx + word_len == len(text) or not text[idx + word_len].isalnum()):
# Text ends with the word
result_contexts.append(text[:idx + word_len].strip())
start = idx + word_len
return result_contexts
def get_texts_containing_word(words, dataset):
result_texts = []
words_set = set(words)
# Iterate over the dataset
for example in dataset:
if words_set.intersection(set(example["text"].split())):
result_texts.append(example["text"])
return result_texts
def compute_topk_token_rank(logits, labels, k=1000):
# Get the top-k predicted logits and their indices
topk_logits, topk_indices = torch.topk(logits, k, dim=-1)
# Expand the labels for comparison
labels_expanded = labels.unsqueeze(-1).expand_as(topk_indices)
# Check if the label token is within the top-k predictions
rank_in_topk = (topk_indices == labels_expanded).nonzero(as_tuple=False)
# Create a rank tensor initialized with k (max rank is k)
ranks = torch.full(labels.shape, k, dtype=torch.long, device=logits.device)
# For labels in top-k, set the rank accordingly
ranks[rank_in_topk[:, 0], rank_in_topk[:, 1]] = rank_in_topk[:, 2] + 1
return ranks
def count_tokens_in_dataset(dataset, tokenizer, text_column='text'):
def tokenize_and_count(examples):
return {'num_tokens': [len(tokenizer(ex).input_ids) for ex in examples[text_column]]}
tokenized_dataset = dataset.map(tokenize_and_count, batched=True, remove_columns=dataset.column_names)
total_tokens = sum(tokenized_dataset['num_tokens'])
return total_tokens
def filter_single_token_words(array, tokenizer, add_space_prefix_for_lower=True):
def _is_multi_token(word):
if add_space_prefix_for_lower and word[0].islower():
word = " " + word
return len(tokenizer.encode(word, add_special_tokens=False))
token_counts = array.apply(_is_multi_token)
mask = token_counts > 1
return array[mask], token_counts
# TODO make clearer what's its use
def get_last_zero_in_every_seq_mask(tensor):
# Find where consecutive zeros end
zero_mask = (tensor == 0)
diff = torch.diff(zero_mask.int(), dim=1)
last_zero_mask = torch.cat([diff, torch.ones(tensor.size(0), 1, dtype=diff.dtype).to(tensor.device)], dim=1) == -1
# Create the output
output = 1 - tensor
output[zero_mask & ~last_zero_mask] = 0
return output
def get_first_zero_in_every_seq_mask(tensor):
# Identify where consecutive zeros begin
zero_mask = (tensor == 0)
diff = torch.diff(zero_mask.int(), dim=1, prepend=torch.zeros(tensor.size(0), 1, dtype=torch.int).to(tensor.device))
first_zero_mask = diff == 1 # Marks the beginning of each sequence of zeros
# Create the output
output = 1 - tensor
output[zero_mask & ~first_zero_mask] = 0
return output
def _add_start_token(batch, tokenizer):
bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * batch["input_ids"].size(dim=0)).to(batch["input_ids"].device)
batch["input_ids"] = torch.cat([bos_tokens_tensor, batch["input_ids"]], dim=1)
batch["attention_mask"] = torch.cat(
[torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(batch["attention_mask"].device), batch["attention_mask"]], dim=1)
return batch
def _ignore_new_words_in_attention_mask(shift_attention_mask_batch, shift_labels, new_token_ids=None, replaced_token_seqs_by_len=None):
# Ignore token_ids of new vocabulary words in shift_labels and shift_logits
if new_token_ids is not None:
ignore_mask = torch.isin(shift_labels, new_token_ids)
shift_attention_mask_batch = shift_attention_mask_batch * (~ignore_mask).long()
# Ignore multi-token sequences of that were replaced with a single token
if replaced_token_seqs_by_len is not None:
# Create a mask that will be updated where sequences match
ignore_mask = shift_attention_mask_batch.clone() # Clone the attention mask to modify it
# Loop over sequences in skip_token_seqs
for seq_len, seqs in replaced_token_seqs_by_len.items():
# Create a sliding window of the same size as the skip_seq and check for matches
for i in range(shift_labels.size(1) - seq_len + 1):
# Check if the sequence matches at position i
window = shift_labels[:, i:i + seq_len]
curr_mask = torch.all(window.unsqueeze(1) == seqs.unsqueeze(0), dim=-1)
if curr_mask.any():
# Zero out the ignore mask for the length of the sequence
ignore_mask[curr_mask.any(dim=-1), i:i + seq_len] = 0
# Apply the ignore mask to the attention mask
shift_attention_mask_batch *= ignore_mask
return shift_attention_mask_batch, ignore_mask
# TODO consider not aggregating results here, to enable metrics for specific words
def compute_metrics(
logits, labels, attention_mask,
compute_target_metrics=True, compute_subsequent_metrics=True, compute_perplexity=False,
return_successful_targets=False,
original_labels=None, original_logits=None,
debug=False):
target_results = dict() # will hold metrics for all the new words we add or their original tokenization
background_results = dict() # will hold metrics for all background tokens, i.e., not the ones we add or replace
overall_results = dict() # will hold metrics for all tokens
successful_targets = None # will hold list of target tokens successfully predicted
if compute_subsequent_metrics:
# prepare labels and attentions masks for computing metrics only for the 1st tokens following the new words
subsequent_labels = labels[:, 1:]
subsequent_attention_mask = get_last_zero_in_every_seq_mask(attention_mask[..., :-1].contiguous())
subsequent_attention_mask_bool = subsequent_attention_mask == 1
attention_mask_bool = attention_mask == 1
overall_mask_bool = attention_mask_bool
if compute_target_metrics:
target_mask = get_first_zero_in_every_seq_mask(attention_mask)
target_mask_bool = target_mask == 1
overall_mask_bool = attention_mask_bool | target_mask_bool
if compute_perplexity:
background_results["perplexity"] = torch.exp(
(F.cross_entropy(logits.transpose(1, 2), labels, reduction="none") * attention_mask).sum(1)
/ attention_mask.sum(1)
).mean().detach().cpu().numpy()
top1 = logits.argmax(dim=-1)
if original_logits is not None:
orig_top1 = original_logits.argmax(dim=-1)
if compute_target_metrics:
target_results["top1_acc"] = ((labels == top1)[target_mask_bool]).detach().cpu().numpy()
if original_labels is not None:
target_results["sum_top1_acc"] = (
((original_labels == top1) | (labels == top1))[target_mask_bool]).detach().cpu().numpy()
if original_logits is not None:
target_results["orig_top1_acc"] = (
(original_labels == orig_top1)[target_mask_bool]).detach().cpu().numpy()
if return_successful_targets:
successful_targets = (labels[(labels == top1) & target_mask_bool]).detach().cpu().numpy()
background_results["top1_acc"] = ((
labels == top1)[attention_mask_bool]).detach().cpu().numpy()
if compute_subsequent_metrics:
background_results["subsequent_top1_acc"] = ((subsequent_labels == top1[:, 1:])[subsequent_attention_mask_bool]).detach().cpu().numpy()
if original_logits is not None:
background_results["orig_top1_acc"] = (
(original_labels == orig_top1)[attention_mask_bool]).detach().cpu().numpy()
if compute_subsequent_metrics:
background_results["orig_subsequent_top1_acc"] = (
(subsequent_labels == orig_top1[:, 1:])[subsequent_attention_mask_bool]).detach().cpu().numpy()
overall_results["top1_acc"] = ((labels == top1))[overall_mask_bool].detach().cpu().numpy()
if original_labels is not None:
overall_results["sum_top1_acc"] = (
((original_labels == top1) | (labels == top1)))[overall_mask_bool].detach().cpu().numpy()
if original_logits is not None:
overall_results["orig_top1_acc"] = (
(original_labels == orig_top1)[overall_mask_bool]).detach().cpu().numpy()
if debug:
import pdb; pdb.set_trace()
return background_results, target_results, overall_results, successful_targets
def eval_next_word_prediction(
model, tokenizer, lm_dataset, accelerator=None,
batch_size: int = 4,
new_token_ids=None, replaced_token_seqs_by_len=None,
new_token_to_original_first_token=None,
max_length: int = 256,
drop_last: bool = True,
eval_max_samples: int = None,
eval_shuffle_samples: bool = False,
reduction="none",
):
if accelerator is None:
accelerator = Accelerator()
model.eval()
if tokenizer.bos_token is not None and max_length:
add_start_token = True
else:
add_start_token = False
data_collator = default_data_collator
if eval_max_samples:
eval_idx = range(len(lm_dataset), min(eval_max_samples, len(lm_dataset)))
if eval_shuffle_samples:
eval_idx = np.random.choice(len(lm_dataset), min(eval_max_samples, len(lm_dataset)))
lm_dataset = lm_dataset.select(eval_idx)
# Create data loaders
eval_dataloader = DataLoader(
lm_dataset, collate_fn=data_collator, batch_size=batch_size, drop_last=drop_last, shuffle=False,
)
eval_dataloader = accelerator.prepare(eval_dataloader)
model.eval()
if new_token_ids is not None:
new_token_ids = torch.tensor(new_token_ids).to(model.device)
if replaced_token_seqs_by_len is not None:
replaced_token_seqs_by_len = {token_length: torch.tensor(skip_token_seqs).to(model.device) for token_length, skip_token_seqs in replaced_token_seqs_by_len.items() if len(skip_token_seqs) > 0}
if new_token_to_original_first_token is not None:
# Convert the mapping into a tensor for efficient indexing, create a mapping tensor that defaults to identity
new_token_to_orig_first_mapping_tensor = torch.arange(len(tokenizer), device=model.device)
new_token_to_orig_first_mapping_tensor[torch.tensor(list(new_token_to_original_first_token.keys()), device=model.device)] = \
torch.tensor(list(new_token_to_original_first_token.values()), device=model.device)
target_metrics = defaultdict(list)
background_metrics = defaultdict(list)
overall_metrics = defaultdict(list)
# run eval and compute metrics
for batch_i, batch in tqdm(enumerate(eval_dataloader), total=len(eval_dataloader), miniters=10, desc="Evaluating vocabulary..."):
if add_start_token:
batch = _add_start_token(batch, tokenizer)
labels = batch["input_ids"]
attn_mask = batch["attention_mask"]
batch.pop("labels")
with torch.no_grad():
outputs = model(**batch)
out_logits = outputs.logits
shift_logits = out_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
shift_attention_mask_batch = attn_mask[..., 1:].contiguous()
shift_attention_mask_batch, ignore_mask = \
_ignore_new_words_in_attention_mask(
shift_attention_mask_batch, shift_labels, new_token_ids, replaced_token_seqs_by_len)
original_labels = None if new_token_to_original_first_token is None \
else new_token_to_orig_first_mapping_tensor[shift_labels]
original_logits = None if new_token_ids is None else torch.cat([shift_logits[:, :, :min(new_token_ids)], shift_logits[:, :, max(new_token_ids)+1:]], dim=-1)
background_results, target_results, overall_results, successful_targets = \
compute_metrics(
shift_logits, shift_labels, shift_attention_mask_batch,
original_labels=original_labels, original_logits=original_logits, compute_perplexity=True)
for metric_name, metric_value in target_results.items():
target_metrics[metric_name].append(np.array(metric_value))
for metric_name, metric_value in background_results.items():
background_metrics[metric_name].append(metric_value)
for metric_name, metric_value in overall_results.items():
overall_metrics[metric_name].append(metric_value)
eval_dataloader = accelerator.free_memory(eval_dataloader)
def _concat_func(x):
if isinstance(x, np.ndarray) and len(x.shape) > 1:
x = np.concat(x)
elif isinstance(x, (list, tuple)) and len(x) > 1:
if isinstance(x[0], np.ndarray) and len(x[0].shape) == 0:
x = np.array(x)
else:
x = np.concat(x)
return x
# apply reduction
reduce_func = _concat_func
if reduction == 'mean':
reduce_func = lambda x: np.mean(_concat_func(x)).item()
for metric_name, metric_value in target_metrics.items():
target_metrics[metric_name] = reduce_func(metric_value)
for metric_name, metric_value in background_metrics.items():
background_metrics[metric_name] = reduce_func(metric_value)
for metric_name, metric_value in overall_metrics.items():
overall_metrics[metric_name] = reduce_func(metric_value)
return background_metrics, target_metrics, overall_metrics