Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from torch.utils.data import DataLoader | |
from accelerate import Accelerator | |
from transformers import default_data_collator | |
from collections import defaultdict | |
from tqdm import tqdm | |
import numpy as np | |
def is_not_number(s): | |
try: | |
float(s) # Try converting the string to a float | |
return False # If conversion is successful, it's a number | |
except ValueError: | |
return True # If conversion fails, it's not a number | |
def get_contexts_ending_with_word(word, dataset): | |
result_contexts = [] | |
word_len = len(word) | |
# Iterate over the dataset | |
for example in dataset: | |
text = example["text"] | |
# Find all occurrences of the word in the text | |
start = 0 | |
while True: | |
idx = text.find(word, start) | |
if idx == -1: | |
break | |
# Ensure that the word is isolated (not a substring of another word) | |
if (idx == 0 or not text[idx - 1].isalnum()) and ( | |
idx + word_len == len(text) or not text[idx + word_len].isalnum()): | |
# Text ends with the word | |
result_contexts.append(text[:idx + word_len].strip()) | |
start = idx + word_len | |
return result_contexts | |
def get_texts_containing_word(words, dataset): | |
result_texts = [] | |
words_set = set(words) | |
# Iterate over the dataset | |
for example in dataset: | |
if words_set.intersection(set(example["text"].split())): | |
result_texts.append(example["text"]) | |
return result_texts | |
def compute_topk_token_rank(logits, labels, k=1000): | |
# Get the top-k predicted logits and their indices | |
topk_logits, topk_indices = torch.topk(logits, k, dim=-1) | |
# Expand the labels for comparison | |
labels_expanded = labels.unsqueeze(-1).expand_as(topk_indices) | |
# Check if the label token is within the top-k predictions | |
rank_in_topk = (topk_indices == labels_expanded).nonzero(as_tuple=False) | |
# Create a rank tensor initialized with k (max rank is k) | |
ranks = torch.full(labels.shape, k, dtype=torch.long, device=logits.device) | |
# For labels in top-k, set the rank accordingly | |
ranks[rank_in_topk[:, 0], rank_in_topk[:, 1]] = rank_in_topk[:, 2] + 1 | |
return ranks | |
def count_tokens_in_dataset(dataset, tokenizer, text_column='text'): | |
def tokenize_and_count(examples): | |
return {'num_tokens': [len(tokenizer(ex).input_ids) for ex in examples[text_column]]} | |
tokenized_dataset = dataset.map(tokenize_and_count, batched=True, remove_columns=dataset.column_names) | |
total_tokens = sum(tokenized_dataset['num_tokens']) | |
return total_tokens | |
def filter_single_token_words(array, tokenizer, add_space_prefix_for_lower=True): | |
def _is_multi_token(word): | |
if add_space_prefix_for_lower and word[0].islower(): | |
word = " " + word | |
return len(tokenizer.encode(word, add_special_tokens=False)) | |
token_counts = array.apply(_is_multi_token) | |
mask = token_counts > 1 | |
return array[mask], token_counts | |
# TODO make clearer what's its use | |
def get_last_zero_in_every_seq_mask(tensor): | |
# Find where consecutive zeros end | |
zero_mask = (tensor == 0) | |
diff = torch.diff(zero_mask.int(), dim=1) | |
last_zero_mask = torch.cat([diff, torch.ones(tensor.size(0), 1, dtype=diff.dtype).to(tensor.device)], dim=1) == -1 | |
# Create the output | |
output = 1 - tensor | |
output[zero_mask & ~last_zero_mask] = 0 | |
return output | |
def get_first_zero_in_every_seq_mask(tensor): | |
# Identify where consecutive zeros begin | |
zero_mask = (tensor == 0) | |
diff = torch.diff(zero_mask.int(), dim=1, prepend=torch.zeros(tensor.size(0), 1, dtype=torch.int).to(tensor.device)) | |
first_zero_mask = diff == 1 # Marks the beginning of each sequence of zeros | |
# Create the output | |
output = 1 - tensor | |
output[zero_mask & ~first_zero_mask] = 0 | |
return output | |
def _add_start_token(batch, tokenizer): | |
bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * batch["input_ids"].size(dim=0)).to(batch["input_ids"].device) | |
batch["input_ids"] = torch.cat([bos_tokens_tensor, batch["input_ids"]], dim=1) | |
batch["attention_mask"] = torch.cat( | |
[torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(batch["attention_mask"].device), batch["attention_mask"]], dim=1) | |
return batch | |
def _ignore_new_words_in_attention_mask(shift_attention_mask_batch, shift_labels, new_token_ids=None, replaced_token_seqs_by_len=None): | |
# Ignore token_ids of new vocabulary words in shift_labels and shift_logits | |
if new_token_ids is not None: | |
ignore_mask = torch.isin(shift_labels, new_token_ids) | |
shift_attention_mask_batch = shift_attention_mask_batch * (~ignore_mask).long() | |
# Ignore multi-token sequences of that were replaced with a single token | |
if replaced_token_seqs_by_len is not None: | |
# Create a mask that will be updated where sequences match | |
ignore_mask = shift_attention_mask_batch.clone() # Clone the attention mask to modify it | |
# Loop over sequences in skip_token_seqs | |
for seq_len, seqs in replaced_token_seqs_by_len.items(): | |
# Create a sliding window of the same size as the skip_seq and check for matches | |
for i in range(shift_labels.size(1) - seq_len + 1): | |
# Check if the sequence matches at position i | |
window = shift_labels[:, i:i + seq_len] | |
curr_mask = torch.all(window.unsqueeze(1) == seqs.unsqueeze(0), dim=-1) | |
if curr_mask.any(): | |
# Zero out the ignore mask for the length of the sequence | |
ignore_mask[curr_mask.any(dim=-1), i:i + seq_len] = 0 | |
# Apply the ignore mask to the attention mask | |
shift_attention_mask_batch *= ignore_mask | |
return shift_attention_mask_batch, ignore_mask | |
# TODO consider not aggregating results here, to enable metrics for specific words | |
def compute_metrics( | |
logits, labels, attention_mask, | |
compute_target_metrics=True, compute_subsequent_metrics=True, compute_perplexity=False, | |
return_successful_targets=False, | |
original_labels=None, original_logits=None, | |
debug=False): | |
target_results = dict() # will hold metrics for all the new words we add or their original tokenization | |
background_results = dict() # will hold metrics for all background tokens, i.e., not the ones we add or replace | |
overall_results = dict() # will hold metrics for all tokens | |
successful_targets = None # will hold list of target tokens successfully predicted | |
if compute_subsequent_metrics: | |
# prepare labels and attentions masks for computing metrics only for the 1st tokens following the new words | |
subsequent_labels = labels[:, 1:] | |
subsequent_attention_mask = get_last_zero_in_every_seq_mask(attention_mask[..., :-1].contiguous()) | |
subsequent_attention_mask_bool = subsequent_attention_mask == 1 | |
attention_mask_bool = attention_mask == 1 | |
overall_mask_bool = attention_mask_bool | |
if compute_target_metrics: | |
target_mask = get_first_zero_in_every_seq_mask(attention_mask) | |
target_mask_bool = target_mask == 1 | |
overall_mask_bool = attention_mask_bool | target_mask_bool | |
if compute_perplexity: | |
background_results["perplexity"] = torch.exp( | |
(F.cross_entropy(logits.transpose(1, 2), labels, reduction="none") * attention_mask).sum(1) | |
/ attention_mask.sum(1) | |
).mean().detach().cpu().numpy() | |
top1 = logits.argmax(dim=-1) | |
if original_logits is not None: | |
orig_top1 = original_logits.argmax(dim=-1) | |
if compute_target_metrics: | |
target_results["top1_acc"] = ((labels == top1)[target_mask_bool]).detach().cpu().numpy() | |
if original_labels is not None: | |
target_results["sum_top1_acc"] = ( | |
((original_labels == top1) | (labels == top1))[target_mask_bool]).detach().cpu().numpy() | |
if original_logits is not None: | |
target_results["orig_top1_acc"] = ( | |
(original_labels == orig_top1)[target_mask_bool]).detach().cpu().numpy() | |
if return_successful_targets: | |
successful_targets = (labels[(labels == top1) & target_mask_bool]).detach().cpu().numpy() | |
background_results["top1_acc"] = (( | |
labels == top1)[attention_mask_bool]).detach().cpu().numpy() | |
if compute_subsequent_metrics: | |
background_results["subsequent_top1_acc"] = ((subsequent_labels == top1[:, 1:])[subsequent_attention_mask_bool]).detach().cpu().numpy() | |
if original_logits is not None: | |
background_results["orig_top1_acc"] = ( | |
(original_labels == orig_top1)[attention_mask_bool]).detach().cpu().numpy() | |
if compute_subsequent_metrics: | |
background_results["orig_subsequent_top1_acc"] = ( | |
(subsequent_labels == orig_top1[:, 1:])[subsequent_attention_mask_bool]).detach().cpu().numpy() | |
overall_results["top1_acc"] = ((labels == top1))[overall_mask_bool].detach().cpu().numpy() | |
if original_labels is not None: | |
overall_results["sum_top1_acc"] = ( | |
((original_labels == top1) | (labels == top1)))[overall_mask_bool].detach().cpu().numpy() | |
if original_logits is not None: | |
overall_results["orig_top1_acc"] = ( | |
(original_labels == orig_top1)[overall_mask_bool]).detach().cpu().numpy() | |
if debug: | |
import pdb; pdb.set_trace() | |
return background_results, target_results, overall_results, successful_targets | |
def eval_next_word_prediction( | |
model, tokenizer, lm_dataset, accelerator=None, | |
batch_size: int = 4, | |
new_token_ids=None, replaced_token_seqs_by_len=None, | |
new_token_to_original_first_token=None, | |
max_length: int = 256, | |
drop_last: bool = True, | |
eval_max_samples: int = None, | |
eval_shuffle_samples: bool = False, | |
reduction="none", | |
): | |
if accelerator is None: | |
accelerator = Accelerator() | |
model.eval() | |
if tokenizer.bos_token is not None and max_length: | |
add_start_token = True | |
else: | |
add_start_token = False | |
data_collator = default_data_collator | |
if eval_max_samples: | |
eval_idx = range(len(lm_dataset), min(eval_max_samples, len(lm_dataset))) | |
if eval_shuffle_samples: | |
eval_idx = np.random.choice(len(lm_dataset), min(eval_max_samples, len(lm_dataset))) | |
lm_dataset = lm_dataset.select(eval_idx) | |
# Create data loaders | |
eval_dataloader = DataLoader( | |
lm_dataset, collate_fn=data_collator, batch_size=batch_size, drop_last=drop_last, shuffle=False, | |
) | |
eval_dataloader = accelerator.prepare(eval_dataloader) | |
model.eval() | |
if new_token_ids is not None: | |
new_token_ids = torch.tensor(new_token_ids).to(model.device) | |
if replaced_token_seqs_by_len is not None: | |
replaced_token_seqs_by_len = {token_length: torch.tensor(skip_token_seqs).to(model.device) for token_length, skip_token_seqs in replaced_token_seqs_by_len.items() if len(skip_token_seqs) > 0} | |
if new_token_to_original_first_token is not None: | |
# Convert the mapping into a tensor for efficient indexing, create a mapping tensor that defaults to identity | |
new_token_to_orig_first_mapping_tensor = torch.arange(len(tokenizer), device=model.device) | |
new_token_to_orig_first_mapping_tensor[torch.tensor(list(new_token_to_original_first_token.keys()), device=model.device)] = \ | |
torch.tensor(list(new_token_to_original_first_token.values()), device=model.device) | |
target_metrics = defaultdict(list) | |
background_metrics = defaultdict(list) | |
overall_metrics = defaultdict(list) | |
# run eval and compute metrics | |
for batch_i, batch in tqdm(enumerate(eval_dataloader), total=len(eval_dataloader), miniters=10, desc="Evaluating vocabulary..."): | |
if add_start_token: | |
batch = _add_start_token(batch, tokenizer) | |
labels = batch["input_ids"] | |
attn_mask = batch["attention_mask"] | |
batch.pop("labels") | |
with torch.no_grad(): | |
outputs = model(**batch) | |
out_logits = outputs.logits | |
shift_logits = out_logits[..., :-1, :].contiguous() | |
shift_labels = labels[..., 1:].contiguous() | |
shift_attention_mask_batch = attn_mask[..., 1:].contiguous() | |
shift_attention_mask_batch, ignore_mask = \ | |
_ignore_new_words_in_attention_mask( | |
shift_attention_mask_batch, shift_labels, new_token_ids, replaced_token_seqs_by_len) | |
original_labels = None if new_token_to_original_first_token is None \ | |
else new_token_to_orig_first_mapping_tensor[shift_labels] | |
original_logits = None if new_token_ids is None else torch.cat([shift_logits[:, :, :min(new_token_ids)], shift_logits[:, :, max(new_token_ids)+1:]], dim=-1) | |
background_results, target_results, overall_results, successful_targets = \ | |
compute_metrics( | |
shift_logits, shift_labels, shift_attention_mask_batch, | |
original_labels=original_labels, original_logits=original_logits, compute_perplexity=True) | |
for metric_name, metric_value in target_results.items(): | |
target_metrics[metric_name].append(np.array(metric_value)) | |
for metric_name, metric_value in background_results.items(): | |
background_metrics[metric_name].append(metric_value) | |
for metric_name, metric_value in overall_results.items(): | |
overall_metrics[metric_name].append(metric_value) | |
eval_dataloader = accelerator.free_memory(eval_dataloader) | |
def _concat_func(x): | |
if isinstance(x, np.ndarray) and len(x.shape) > 1: | |
x = np.concat(x) | |
elif isinstance(x, (list, tuple)) and len(x) > 1: | |
if isinstance(x[0], np.ndarray) and len(x[0].shape) == 0: | |
x = np.array(x) | |
else: | |
x = np.concat(x) | |
return x | |
# apply reduction | |
reduce_func = _concat_func | |
if reduction == 'mean': | |
reduce_func = lambda x: np.mean(_concat_func(x)).item() | |
for metric_name, metric_value in target_metrics.items(): | |
target_metrics[metric_name] = reduce_func(metric_value) | |
for metric_name, metric_value in background_metrics.items(): | |
background_metrics[metric_name] = reduce_func(metric_value) | |
for metric_name, metric_value in overall_metrics.items(): | |
overall_metrics[metric_name] = reduce_func(metric_value) | |
return background_metrics, target_metrics, overall_metrics | |