File size: 3,579 Bytes
a9cc853 be53c78 1f2d72c 6735ae4 91f2f92 19904de 91f2f92 8e36e52 b47d499 8e36e52 9029ade b47d499 8e36e52 b47d499 8e36e52 9029ade 8e36e52 acbaa45 15b7594 91f2f92 9029ade 91515a1 9029ade 91515a1 91f2f92 91515a1 91f2f92 ada166c 8e36e52 ada166c b47d499 8e36e52 b47d499 91f2f92 ada166c 91f2f92 ada166c 91f2f92 be53c78 1f2d72c ada166c 1f2d72c 4ef971a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
#%%
import time
from tqdm import tqdm
from text_processing import split_into_words, Word
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer
from pprint import pprint
def load_model_and_tokenizer(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
return model, tokenizer, device
def process_input_text(input_text, tokenizer, device):
"""Process input text to obtain input IDs and attention mask."""
inputs = tokenizer(input_text, return_tensors="pt").to(device)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
return inputs, input_ids, attention_mask
def calculate_log_probabilities(model, tokenizer, inputs, input_ids, attention_mask):
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
logits = outputs.logits[0, :-1, :]
log_probs = torch.log_softmax(logits, dim=-1)
token_log_probs = log_probs[range(log_probs.shape[0]), input_ids[0][1:]]
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
return list(zip(tokens[1:], token_log_probs.tolist()))
def generate_replacements(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefix: str, device: torch.device, num_samples: int = 5) -> list[str]:
input_context = tokenizer(prefix, return_tensors="pt").to(device)
input_ids = input_context["input_ids"]
attention_mask = input_context["attention_mask"]
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=input_ids.shape[-1] + 5,
num_return_sequences=num_samples,
temperature=1.0,
top_k=50,
top_p=0.95,
do_sample=True
)
new_words = []
for i in range(num_samples):
generated_ids = outputs[i][input_ids.shape[-1]:]
new_word = tokenizer.decode(generated_ids, skip_special_tokens=True).split()[0]
new_words.append(new_word)
return new_words
#%%
model_name = "mistralai/Mistral-7B-v0.1"
model, tokenizer, device = load_model_and_tokenizer(model_name)
input_text = "He asked me to prostrate myself before the king, but I rifused."
inputs, input_ids, attention_mask = process_input_text(input_text, tokenizer, device)
result = calculate_log_probabilities(model, tokenizer, inputs, input_ids, attention_mask)
words = split_into_words([token for token, _ in result], [logprob for _, logprob in result])
log_prob_threshold = -5.0
low_prob_words = [word for word in words if word.logprob < log_prob_threshold]
#%%
start_time = time.time()
for word in tqdm(low_prob_words, desc="Processing words"):
iteration_start_time = time.time()
prefix_index = word.first_token_index
prefix_tokens = [token for token, _ in result][:prefix_index + 1]
prefix = tokenizer.convert_tokens_to_string(prefix_tokens)
replacements = generate_replacements(model, tokenizer, prefix, device)
print(f"Original word: {word.text}, Log Probability: {word.logprob:.4f}")
print(f"Proposed replacements: {replacements}")
print()
iteration_end_time = time.time()
print(f"Time taken for this iteration: {iteration_end_time - iteration_start_time:.4f} seconds")
end_time = time.time()
print(f"Total time taken for the loop: {end_time - start_time:.4f} seconds")
# %%
|