File size: 2,806 Bytes
91f2f92
6735ae4
91f2f92
 
 
 
 
 
8e36e52
 
 
 
 
 
 
 
 
 
 
 
c6407ad
8e36e52
 
 
 
 
 
 
 
 
91f2f92
 
 
 
 
 
 
8e36e52
91f2f92
 
8e36e52
 
91f2f92
 
8e36e52
91f2f92
 
 
 
ada166c
 
 
8e36e52
ada166c
 
8e36e52
ada166c
91f2f92
ada166c
 
 
91f2f92
ada166c
91f2f92
ada166c
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#%%
from text_processing import split_into_words, Word
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from pprint import pprint

#%%

def load_model_and_tokenizer(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    return model, tokenizer, device

def process_input_text(input_text, tokenizer, device):
    inputs = tokenizer(input_text, return_tensors="pt").to(device)
    input_ids = inputs["input_ids"]
    return inputs, input_ids

def calculate_log_probabilities(model, tokenizer, inputs, input_ids):
    with torch.no_grad():
        outputs = model(**inputs, labels=input_ids)
    logits = outputs.logits[0, :-1, :]
    log_probs = torch.log_softmax(logits, dim=-1)
    token_log_probs = log_probs[range(log_probs.shape[0]), input_ids[0][1:]]
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    return list(zip(tokens[1:], token_log_probs.tolist()))

def generate_replacements(model, tokenizer, prefix, device, num_samples=5):
    input_context = tokenizer(prefix, return_tensors="pt").to(device)
    input_ids = input_context["input_ids"]
    new_words = []
    for _ in range(num_samples):
        with torch.no_grad():
            outputs = model.generate(
                input_ids=input_ids,
                max_length=input_ids.shape[-1] + 5,
                num_return_sequences=1,
                temperature=1.0,
                top_k=50,
                top_p=0.95,
                do_sample=True
            )
        generated_ids = outputs[0][input_ids.shape[-1]:]
        new_word = tokenizer.decode(generated_ids, skip_special_tokens=True).split()[0]
        new_words.append(new_word)
    return new_words

#%%
model_name = "mistralai/Mistral-7B-v0.1"
model, tokenizer, device = load_model_and_tokenizer(model_name)

input_text = "He asked me to prostrate myself before the king, but I rifused."
inputs, input_ids = process_input_text(input_text, tokenizer, device)

result = calculate_log_probabilities(model, tokenizer, inputs, input_ids)

words = split_into_words([token for token, _ in result], [logprob for _, logprob in result])
log_prob_threshold = -5.0
low_prob_words = [word for word in words if word.logprob < log_prob_threshold]

#%%

for word in low_prob_words:
    prefix_index = word.first_token_index
    prefix_tokens = [token for token, _ in result][:prefix_index + 1]
    prefix = tokenizer.convert_tokens_to_string(prefix_tokens)
    replacements = generate_replacements(model, tokenizer, prefix, device)
    print(f"Original word: {word.text}, Log Probability: {word.logprob:.4f}")
    print(f"Proposed replacements: {replacements}")
    print()