File size: 6,112 Bytes
a9cc853 be53c78 91f2f92 426b33e b174bd4 426b33e b174bd4 91f2f92 83ec4f2 426b33e 83ec4f2 b174bd4 87af3eb b174bd4 87af3eb b174bd4 426b33e b174bd4 426b33e 8e36e52 426b33e 8e36e52 d3ef10a 8e36e52 d3ef10a da342d0 8e36e52 da342d0 426b33e 230a441 8e36e52 d3ef10a b174bd4 d3ef10a acbaa45 da342d0 d3ef10a 91515a1 9029ade 87af3eb 91515a1 87af3eb 91515a1 da342d0 6641473 da342d0 bbae7a9 6641473 da342d0 d3ef10a 87af3eb bbae7a9 da342d0 91f2f92 ada166c 426b33e 83ec4f2 426b33e 83ec4f2 8e36e52 83ec4f2 426b33e 83ec4f2 426b33e 83ec4f2 91f2f92 83ec4f2 c4d5641 b174bd4 83ec4f2 c4d5641 83ec4f2 da342d0 83ec4f2 bbae7a9 83ec4f2 1f2d72c 83ec4f2 d3ef10a c4d5641 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
#%%
import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast, BatchEncoding
from transformers.generation.utils import GenerateOutput
from typing import cast
from dataclasses import dataclass
from models import ApiWord, Word
type Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast
def starts_with_space(token: str) -> bool:
return token.startswith(chr(9601)) or token.startswith(chr(288))
def split_into_words(token_probs: list[tuple[int, float]], tokenizer: Tokenizer) -> list[Word]:
words: list[Word] = []
current_word: list[int] = []
current_log_probs: list[float] = []
current_word_first_token_index: int = 0
all_tokens: list[int] = [token_id for token_id, _ in token_probs]
def append_current_word():
if current_word:
words.append(Word(current_word,
tokenizer.decode(current_word),
sum(current_log_probs),
all_tokens[:current_word_first_token_index]))
for i, (token_id, logprob) in enumerate(token_probs):
token: str = tokenizer.convert_ids_to_tokens([token_id])[0]
if not starts_with_space(token) and token.isalpha():
current_word.append(token_id)
current_log_probs.append(logprob)
else:
append_current_word()
current_word = [token_id]
current_log_probs = [logprob]
current_word_first_token_index = i
append_current_word()
return words
def load_model_and_tokenizer(model_name: str, device: torch.device) -> tuple[PreTrainedModel, Tokenizer]:
tokenizer: Tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(model_name)
model.to(device)
return model, tokenizer
def tokenize(input_text: str, tokenizer: Tokenizer, device: torch.device) -> BatchEncoding:
return tokenizer(input_text, return_tensors="pt").to(device)
def calculate_log_probabilities(model: PreTrainedModel, tokenizer: Tokenizer, inputs: BatchEncoding) -> list[tuple[int, float]]:
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
# B x T x V
logits: torch.Tensor = outputs.logits[:, :-1, :]
# B x T x V
log_probs: torch.Tensor = torch.log_softmax(logits, dim=-1)
# T - 1
token_log_probs: torch.Tensor = log_probs[0, range(log_probs.shape[1]), input_ids[0][1:]]
# T - 1
tokens: torch.Tensor = input_ids[0][1:]
return list(zip(tokens.tolist(), token_log_probs.tolist()))
def prepare_inputs(contexts: list[list[int]], tokenizer: Tokenizer, device: torch.device) -> BatchEncoding:
texts = [tokenizer.decode(context, skip_special_tokens=True) for context in contexts]
return tokenizer(texts, return_tensors="pt", padding=True).to(device)
def generate_outputs(model: PreTrainedModel, inputs: BatchEncoding, num_samples: int = 5) -> GenerateOutput | torch.LongTensor:
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=4,
num_return_sequences=num_samples,
temperature=1.0,
top_k=50,
top_p=0.95,
do_sample=True
# num_beams=num_samples
)
return outputs
def extract_replacements(outputs: GenerateOutput | torch.LongTensor, tokenizer: Tokenizer, num_inputs: int, input_len: int, num_samples: int = 5) -> list[list[str]]:
all_new_words = []
for i in range(num_inputs):
replacements = set()
for j in range(num_samples):
generated_ids = outputs[i * num_samples + j][input_len:]
new_word = tokenizer.convert_ids_to_tokens(generated_ids.tolist())[0]
if starts_with_space(new_word):
replacements.add(" " +new_word[1:])
all_new_words.append(sorted(list(replacements)))
return all_new_words
#%%
def load_model() -> tuple[PreTrainedModel, Tokenizer, torch.device]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model_name = "mistralai/Mistral-7B-v0.1"
model_name = "unsloth/Llama-3.2-1B"
model, tokenizer = load_model_and_tokenizer(model_name, device)
return model, tokenizer, device
def check_text(input_text: str, model: PreTrainedModel, tokenizer: Tokenizer, device: torch.device) -> list[ApiWord]:
#%%
inputs: BatchEncoding = tokenize(input_text, tokenizer, device)
#%%
token_probs: list[tuple[int, float]] = calculate_log_probabilities(model, tokenizer, inputs)
#%%
words = split_into_words(token_probs, tokenizer)
log_prob_threshold = -5.0
low_prob_words = [(i, word) for i, word in enumerate(words) if word.logprob < log_prob_threshold]
#%%
contexts = [word.context for _, word in low_prob_words]
inputs = prepare_inputs(contexts, tokenizer, device)
input_ids = inputs["input_ids"]
#%%
num_samples = 10
start_time = time.time()
outputs = generate_outputs(model, inputs, num_samples)
end_time = time.time()
print(f"Total time taken for replacements: {end_time - start_time:.4f} seconds")
#%%
replacements = extract_replacements(outputs, tokenizer, input_ids.shape[0], input_ids.shape[1], num_samples)
low_prob_words_with_replacements = { i: (w, r) for (i, w), r in zip(low_prob_words, replacements) }
result = []
for i, word in enumerate(words):
if i in low_prob_words_with_replacements:
result.append(ApiWord(text=word.text, logprob=word.logprob, replacements=low_prob_words_with_replacements[i][1]))
else:
result.append(ApiWord(text=word.text, logprob=word.logprob, replacements=[]))
return result
|