File size: 5,578 Bytes
a9cc853 be53c78 91f2f92 426b33e b174bd4 426b33e b174bd4 91f2f92 426b33e b174bd4 426b33e b174bd4 426b33e 8e36e52 426b33e 8e36e52 426b33e 8e36e52 230a441 8e36e52 9029ade 426b33e 230a441 8e36e52 b174bd4 acbaa45 b174bd4 91515a1 9029ade 91515a1 6641473 b174bd4 6641473 b174bd4 91f2f92 ada166c 426b33e ada166c 426b33e 8e36e52 ada166c 426b33e 8e36e52 426b33e 6641473 426b33e 91f2f92 6641473 ada166c 91f2f92 b174bd4 ada166c 91f2f92 be53c78 b174bd4 be53c78 b174bd4 6641473 ada166c 1f2d72c 4ef971a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
#%%
import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast, BatchEncoding
from transformers.generation.utils import GenerateOutput
from typing import cast
from dataclasses import dataclass
type Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast
@dataclass
class Word:
tokens: list[int]
text: str
logprob: float
context: list[int]
def split_into_words(token_probs: list[tuple[int, float]], tokenizer: Tokenizer) -> list[Word]:
words: list[Word] = []
current_word: list[int] = []
current_log_probs: list[float] = []
current_word_first_token_index: int = 0
all_tokens: list[int] = [token_id for token_id, _ in token_probs]
def append_current_word():
if current_word:
words.append(Word(current_word,
tokenizer.decode(current_word),
sum(current_log_probs),
all_tokens[:current_word_first_token_index]))
for i, (token_id, logprob) in enumerate(token_probs):
token: str = tokenizer.convert_ids_to_tokens([token_id])[0]
if not token.startswith(chr(9601)) and token.isalpha():
current_word.append(token_id)
current_log_probs.append(logprob)
else:
append_current_word()
current_word = [token_id]
current_log_probs = [logprob]
current_word_first_token_index = i
append_current_word()
return words
def load_model_and_tokenizer(model_name: str, device: torch.device) -> tuple[PreTrainedModel, Tokenizer]:
tokenizer: Tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(model_name)
model.to(device)
return model, tokenizer
def tokenize(input_text: str, tokenizer: Tokenizer, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
inputs: BatchEncoding = tokenizer(input_text, return_tensors="pt").to(device)
input_ids = cast(torch.Tensor, inputs["input_ids"])
attention_mask = cast(torch.Tensor, inputs["attention_mask"])
return input_ids, attention_mask
def calculate_log_probabilities(model: PreTrainedModel, tokenizer: Tokenizer, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> list[tuple[int, float]]:
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
# B x T x V
logits: torch.Tensor = outputs.logits[:, :-1, :]
# B x T x V
log_probs: torch.Tensor = torch.log_softmax(logits, dim=-1)
# T - 1
token_log_probs: torch.Tensor = log_probs[0, range(log_probs.shape[1]), input_ids[0][1:]]
# T - 1
tokens: torch.Tensor = input_ids[0][1:]
return list(zip(tokens.tolist(), token_log_probs.tolist()))
def prepare_inputs(contexts: list[list[int]], tokenizer: Tokenizer, device: torch.device) -> tuple[torch.FloatTensor, torch.FloatTensor]:
texts = [tokenizer.decode(context, skip_special_tokens=True) for context in contexts]
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(device)
input_ids = cast(torch.FloatTensor, inputs["input_ids"])
attention_mask = cast(torch.FloatTensor, inputs["attention_mask"])
return input_ids, attention_mask
def generate_replacements(model: PreTrainedModel, tokenizer: Tokenizer, contexts: list[list[int]],
device: torch.device, num_samples: int = 5) -> tuple[GenerateOutput | torch.LongTensor, list[list[str]]]:
input_ids, attention_mask = prepare_inputs(contexts, tokenizer, device)
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=input_ids.shape[-1] + 5,
num_return_sequences=num_samples,
temperature=1.0,
top_k=50,
top_p=0.95,
do_sample=True
)
all_new_words = []
for i in range(len(contexts)):
replacements = []
for j in range(num_samples):
generated_ids = outputs[i * num_samples + j][input_ids.shape[-1]:]
new_word = tokenizer.decode(generated_ids, skip_special_tokens=False).split()[0]
replacements.append(new_word)
all_new_words.append(replacements)
return outputs, all_new_words
#%%
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "mistralai/Mistral-7B-v0.1"
model, tokenizer = load_model_and_tokenizer(model_name, device)
#%%
input_text = "He asked me to prostrate myself before the king, but I rifused."
input_ids, attention_mask = tokenize(input_text, tokenizer, device)
#%%
token_probs: list[tuple[int, float]] = calculate_log_probabilities(model, tokenizer, input_ids, attention_mask)
#%%
words = split_into_words(token_probs, tokenizer)
log_prob_threshold = -5.0
low_prob_words = [word for word in words if word.logprob < log_prob_threshold]
#%%
contexts = [word.context for word in low_prob_words]
#%%
start_time = time.time()
replacements_batch = generate_replacements(model, tokenizer, contexts, device, num_samples=5)
end_time = time.time()
print(f"Total time taken for replacements: {end_time - start_time:.4f} seconds")
#%%
for word, replacements in zip(low_prob_words, replacements_batch):
print(f"Original word: {word.text}, Log Probability: {word.logprob:.4f}")
print(f"Proposed replacements: {replacements}")
# %%
|