File size: 3,961 Bytes
a9cc853 be53c78 1f2d72c 6735ae4 91f2f92 426b33e 91f2f92 426b33e 8e36e52 426b33e 8e36e52 426b33e 8e36e52 230a441 8e36e52 9029ade 426b33e 230a441 8e36e52 acbaa45 76e131a 91f2f92 9029ade 91515a1 9029ade 91515a1 91f2f92 91515a1 91f2f92 ada166c 426b33e ada166c 426b33e 8e36e52 ada166c 426b33e 8e36e52 426b33e 91f2f92 426b33e ada166c 91f2f92 ada166c 91f2f92 be53c78 1f2d72c ada166c 426b33e 76e131a ada166c 1f2d72c 4ef971a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
#%%
import time
from tqdm import tqdm
from text_processing import split_into_words, Word
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast, BatchEncoding
from tokenizers import Encoding
from typing import cast
type Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast
def load_model_and_tokenizer(model_name: str, device: torch.device) -> tuple[PreTrainedModel, Tokenizer]:
tokenizer: Tokenizer = AutoTokenizer.from_pretrained(model_name)
model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(model_name)
model.to(device)
return model, tokenizer
def tokenize(input_text: str, tokenizer: Tokenizer, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
inputs: BatchEncoding = tokenizer(input_text, return_tensors="pt").to(device)
input_ids = cast(torch.Tensor, inputs["input_ids"])
attention_mask = cast(torch.Tensor, inputs["attention_mask"])
return input_ids, attention_mask
def calculate_log_probabilities(model: PreTrainedModel, tokenizer: Tokenizer, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> list[tuple[int, float]]:
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
# B x T x V
logits: torch.Tensor = outputs.logits[:, :-1, :]
# B x T x V
log_probs: torch.Tensor = torch.log_softmax(logits, dim=-1)
# T - 1
token_log_probs: torch.Tensor = log_probs[0, range(log_probs.shape[1]), input_ids[0][1:]]
# T - 1
tokens: torch.Tensor = input_ids[0][1:]
return list(zip(tokens.tolist(), token_log_probs.tolist()))
def generate_replacements(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefix_tokens: list[int], device: torch.device, num_samples: int = 5) -> list[str]:
input_context = {"input_ids": torch.tensor([prefix_tokens]).to(device)}
input_ids = input_context["input_ids"]
attention_mask = input_context["attention_mask"]
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=input_ids.shape[-1] + 5,
num_return_sequences=num_samples,
temperature=1.0,
top_k=50,
top_p=0.95,
do_sample=True
)
new_words = []
for i in range(num_samples):
generated_ids = outputs[i][input_ids.shape[-1]:]
new_word = tokenizer.decode(generated_ids, skip_special_tokens=True).split()[0]
new_words.append(new_word)
return new_words
#%%
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "mistralai/Mistral-7B-v0.1"
model, tokenizer = load_model_and_tokenizer(model_name, device)
#%%
input_text = "He asked me to prostrate myself before the king, but I rifused."
input_ids, attention_mask = tokenize(input_text, tokenizer, device)
#%%
token_probs: list[tuple[str, float]] = calculate_log_probabilities(model, tokenizer, input_ids, attention_mask)
#%%
words = split_into_words(token_probs)
log_prob_threshold = -5.0
low_prob_words = [word for word in words if word.logprob < log_prob_threshold]
#%%
start_time = time.time()
for word in tqdm(low_prob_words, desc="Processing words"):
iteration_start_time = time.time()
prefix_index = word.first_token_index
prefix_tokens = tokenizer.convert_tokens_to_ids([token for token, _ in token_probs][:prefix_index + 1])
replacements = generate_replacements(model, tokenizer, prefix_tokens, device)
print(f"Original word: {word.text}, Log Probability: {word.logprob:.4f}")
print(f"Proposed replacements: {replacements}")
print()
iteration_end_time = time.time()
print(f"Time taken for this iteration: {iteration_end_time - iteration_start_time:.4f} seconds")
end_time = time.time()
print(f"Total time taken for the loop: {end_time - start_time:.4f} seconds")
# %%
|