File size: 5,793 Bytes
a9cc853 be53c78 91f2f92 426b33e b174bd4 426b33e b174bd4 91f2f92 426b33e b174bd4 87af3eb b174bd4 87af3eb b174bd4 426b33e b174bd4 426b33e 8e36e52 426b33e 8e36e52 d3ef10a 8e36e52 d3ef10a da342d0 8e36e52 da342d0 426b33e 230a441 8e36e52 d3ef10a b174bd4 d3ef10a acbaa45 da342d0 d3ef10a 91515a1 9029ade 87af3eb 91515a1 87af3eb 91515a1 da342d0 6641473 da342d0 6641473 da342d0 d3ef10a 87af3eb 6641473 da342d0 91f2f92 ada166c 426b33e 87af3eb 426b33e ada166c f8b38c6 8e36e52 426b33e f8b38c6 426b33e 6641473 ada166c 91f2f92 b174bd4 d3ef10a b174bd4 ada166c da342d0 be53c78 da342d0 b174bd4 be53c78 da342d0 b174bd4 6641473 ada166c 1f2d72c 4ef971a d3ef10a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
#%%
import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast, BatchEncoding
from transformers.generation.utils import GenerateOutput
from typing import cast
from dataclasses import dataclass
type Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast
@dataclass
class Word:
tokens: list[int]
text: str
logprob: float
context: list[int]
def starts_with_space(token: str) -> bool:
return token.startswith(chr(9601)) or token.startswith(chr(288))
def split_into_words(token_probs: list[tuple[int, float]], tokenizer: Tokenizer) -> list[Word]:
words: list[Word] = []
current_word: list[int] = []
current_log_probs: list[float] = []
current_word_first_token_index: int = 0
all_tokens: list[int] = [token_id for token_id, _ in token_probs]
def append_current_word():
if current_word:
words.append(Word(current_word,
tokenizer.decode(current_word),
sum(current_log_probs),
all_tokens[:current_word_first_token_index]))
for i, (token_id, logprob) in enumerate(token_probs):
token: str = tokenizer.convert_ids_to_tokens([token_id])[0]
if not starts_with_space(token) and token.isalpha():
current_word.append(token_id)
current_log_probs.append(logprob)
else:
append_current_word()
current_word = [token_id]
current_log_probs = [logprob]
current_word_first_token_index = i
append_current_word()
return words
def load_model_and_tokenizer(model_name: str, device: torch.device) -> tuple[PreTrainedModel, Tokenizer]:
tokenizer: Tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(model_name)
model.to(device)
return model, tokenizer
def tokenize(input_text: str, tokenizer: Tokenizer, device: torch.device) -> BatchEncoding:
return tokenizer(input_text, return_tensors="pt").to(device)
def calculate_log_probabilities(model: PreTrainedModel, tokenizer: Tokenizer, inputs: BatchEncoding) -> list[tuple[int, float]]:
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
# B x T x V
logits: torch.Tensor = outputs.logits[:, :-1, :]
# B x T x V
log_probs: torch.Tensor = torch.log_softmax(logits, dim=-1)
# T - 1
token_log_probs: torch.Tensor = log_probs[0, range(log_probs.shape[1]), input_ids[0][1:]]
# T - 1
tokens: torch.Tensor = input_ids[0][1:]
return list(zip(tokens.tolist(), token_log_probs.tolist()))
def prepare_inputs(contexts: list[list[int]], tokenizer: Tokenizer, device: torch.device) -> BatchEncoding:
texts = [tokenizer.decode(context, skip_special_tokens=True) for context in contexts]
return tokenizer(texts, return_tensors="pt", padding=True).to(device)
def generate_outputs(model: PreTrainedModel, inputs: BatchEncoding, num_samples: int = 5) -> GenerateOutput | torch.LongTensor:
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=4,
num_return_sequences=num_samples,
temperature=1.0,
top_k=50,
top_p=0.95,
do_sample=True
# num_beams=num_samples
)
return outputs
def extract_replacements(outputs: GenerateOutput | torch.LongTensor, tokenizer: Tokenizer, num_inputs: int, input_len: int, num_samples: int = 5) -> list[list[str]]:
all_new_words = []
for i in range(num_inputs):
replacements = []
for j in range(num_samples):
generated_ids = outputs[i * num_samples + j][input_len:]
new_word = tokenizer.convert_ids_to_tokens(generated_ids.tolist())[0]
if starts_with_space(new_word):
replacements.append(new_word[1:])
all_new_words.append(replacements)
return all_new_words
#%%
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model_name = "mistralai/Mistral-7B-v0.1"
model_name = "unsloth/Llama-3.2-1B"
model, tokenizer = load_model_and_tokenizer(model_name, device)
#%%
input_text = "He asked me to prostrate myself before the king, but I rifused."
inputs: BatchEncoding = tokenize(input_text, tokenizer, device)
#%%
token_probs: list[tuple[int, float]] = calculate_log_probabilities(model, tokenizer, inputs)
#%%
words = split_into_words(token_probs, tokenizer)
log_prob_threshold = -5.0
low_prob_words = [word for word in words if word.logprob < log_prob_threshold]
#%%
contexts = [word.context for word in low_prob_words]
inputs = prepare_inputs(contexts, tokenizer, device)
input_ids = inputs["input_ids"]
#%%
num_samples = 5
start_time = time.time()
outputs = generate_outputs(model, inputs, num_samples)
end_time = time.time()
print(f"Total time taken for replacements: {end_time - start_time:.4f} seconds")
#%%
replacements_batch = extract_replacements(outputs, tokenizer, input_ids.shape[0], input_ids.shape[1], num_samples)
#%%
for word, replacements in zip(low_prob_words, replacements_batch):
print(f"Original word: {word.text}, Log Probability: {word.logprob:.4f}")
print(f"Proposed replacements: {replacements}")
# %%
generated_ids = outputs[:, input_ids.shape[-1]:]
for g in generated_ids:
print(tokenizer.convert_ids_to_tokens(g.tolist()))
# %%
|