mebubo commited on
Commit
8e36e52
·
1 Parent(s): b2b409d

refactor: Modularize code by creating functions for model loading and processing

Browse files
Files changed (1) hide show
  1. app.py +47 -96
app.py CHANGED
@@ -6,117 +6,68 @@ from pprint import pprint
6
 
7
  #%%
8
 
9
- model_name="mistralai/Mistral-7B-v0.1"
10
-
11
- tokenizer = AutoTokenizer.from_pretrained(model_name)
12
- # model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
13
- model = AutoModelForCausalLM.from_pretrained(model_name)
14
-
15
- # Move the model to GPU if available
16
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
- model.to(device)
18
-
19
- #%%
20
-
21
- input_text = "I just drive to the store to but eggs, but they had some."
22
- input_text = "He asked me to prostate myself before the king, but I rifused."
23
- input_text = "He asked me to prostrate myself before the king, but I rifused."
24
-
25
- #%%
26
- inputs = tokenizer(input_text, return_tensors="pt").to(device)
27
- input_ids = inputs["input_ids"]
28
- labels = input_ids
29
-
30
- #%%
31
- with torch.no_grad():
32
- outputs = model(**inputs, labels=labels)
33
-
34
- #%%
35
-
36
- # Get logits and shift them
37
- logits = outputs.logits[0, :-1, :]
38
-
39
- # Calculate log probabilities
40
- log_probs = torch.log_softmax(logits, dim=-1)
41
-
42
- # Get the log probability of each token in the sequence
43
- token_log_probs = log_probs[range(log_probs.shape[0]), input_ids[0][1:]]
44
-
45
- # Decode tokens and pair them with their log probabilities
46
- tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
47
- result = list(zip(tokens[1:], token_log_probs.tolist()))
48
-
49
- #%%
50
- for token, logprob in result:
51
- print(f"Token: {token}, Log Probability: {logprob:.4f}")
52
-
53
- # %%
54
- words = []
55
- current_word = []
56
- current_log_probs = []
57
-
58
- for token, logprob in result:
59
- if not token.startswith(chr(9601)) and token.isalpha():
60
- current_word.append(token)
61
- current_log_probs.append(logprob)
62
- else:
63
- if current_word:
64
- words.append(("".join(current_word), sum(current_log_probs)))
65
- current_word = [token]
66
- current_log_probs = [logprob]
67
-
68
- if current_word:
69
- words.append(("".join(current_word), sum(current_log_probs)))
70
-
71
- for word, avg_logprob in words:
72
- print(f"Word: {word}, Log Probability: {avg_logprob:.4f}")
73
-
74
- # %%
75
-
76
-
77
- words = split_into_words(tokens[1:], token_log_probs)
78
-
79
- # Define a threshold for low probability words
80
- log_prob_threshold = -5.0
81
-
82
- # Filter words with log probability below the threshold
83
- low_prob_words = [word for word in words if word.logprob < log_prob_threshold]
84
-
85
-
86
- #%%
87
- def generate_replacements(model, tokenizer, prefix, num_samples=5):
88
  input_context = tokenizer(prefix, return_tensors="pt").to(device)
89
  input_ids = input_context["input_ids"]
90
-
91
  new_words = []
92
  for _ in range(num_samples):
93
  with torch.no_grad():
94
  outputs = model.generate(
95
  input_ids=input_ids,
96
- max_length=input_ids.shape[-1] + 5, # generate a few tokens beyond the prefix
97
  num_return_sequences=1,
98
  temperature=1.0,
99
- top_k=50, # use top-k sampling
100
- top_p=0.95, # use nucleus sampling
101
  do_sample=True
102
  )
103
-
104
- generated_ids = outputs[0][input_ids.shape[-1]:] # extract the newly generated part
105
  new_word = tokenizer.decode(generated_ids, skip_special_tokens=True).split()[0]
106
  new_words.append(new_word)
107
-
108
  return new_words
109
 
110
- # Generate new words for low probability words
111
- for word in low_prob_words:
112
- prefix_index = word.first_token_index
113
- prefix_tokens = tokens[:prefix_index + 1] # include the word itself
114
- prefix = tokenizer.convert_tokens_to_string(prefix_tokens)
 
 
 
115
 
116
- replacements = generate_replacements(model, tokenizer, prefix)
 
 
117
 
118
- print(f"Original word: {word.text}, Log Probability: {word.logprob:.4f}")
119
- print(f"Proposed replacements: {replacements}")
120
- print()
 
 
 
 
 
121
 
122
- # %%
 
 
6
 
7
  #%%
8
 
9
+ def load_model_and_tokenizer(model_name):
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+ model = AutoModelForCausalLM.from_pretrained(model_name)
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ model.to(device)
14
+ return model, tokenizer, device
15
+
16
+ def process_input_text(input_text, tokenizer, device):
17
+ inputs = tokenizer(input_text, return_tensors="pt").to(device)
18
+ input_ids = inputs["input_ids"]
19
+ return inputs, input_ids
20
+
21
+ def calculate_log_probabilities(model, inputs, input_ids):
22
+ with torch.no_grad():
23
+ outputs = model(**inputs, labels=input_ids)
24
+ logits = outputs.logits[0, :-1, :]
25
+ log_probs = torch.log_softmax(logits, dim=-1)
26
+ token_log_probs = log_probs[range(log_probs.shape[0]), input_ids[0][1:]]
27
+ tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
28
+ return list(zip(tokens[1:], token_log_probs.tolist()))
29
+
30
+ def generate_replacements(model, tokenizer, prefix, device, num_samples=5):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  input_context = tokenizer(prefix, return_tensors="pt").to(device)
32
  input_ids = input_context["input_ids"]
 
33
  new_words = []
34
  for _ in range(num_samples):
35
  with torch.no_grad():
36
  outputs = model.generate(
37
  input_ids=input_ids,
38
+ max_length=input_ids.shape[-1] + 5,
39
  num_return_sequences=1,
40
  temperature=1.0,
41
+ top_k=50,
42
+ top_p=0.95,
43
  do_sample=True
44
  )
45
+ generated_ids = outputs[0][input_ids.shape[-1]:]
 
46
  new_word = tokenizer.decode(generated_ids, skip_special_tokens=True).split()[0]
47
  new_words.append(new_word)
 
48
  return new_words
49
 
50
+ def main():
51
+ model_name = "mistralai/Mistral-7B-v0.1"
52
+ model, tokenizer, device = load_model_and_tokenizer(model_name)
53
+
54
+ input_text = "He asked me to prostrate myself before the king, but I rifused."
55
+ inputs, input_ids = process_input_text(input_text, tokenizer, device)
56
+
57
+ result = calculate_log_probabilities(model, inputs, input_ids)
58
 
59
+ words = split_into_words([token for token, _ in result], [logprob for _, logprob in result])
60
+ log_prob_threshold = -5.0
61
+ low_prob_words = [word for word in words if word.logprob < log_prob_threshold]
62
 
63
+ for word in low_prob_words:
64
+ prefix_index = word.first_token_index
65
+ prefix_tokens = [token for token, _ in result][:prefix_index + 1]
66
+ prefix = tokenizer.convert_tokens_to_string(prefix_tokens)
67
+ replacements = generate_replacements(model, tokenizer, prefix, device)
68
+ print(f"Original word: {word.text}, Log Probability: {word.logprob:.4f}")
69
+ print(f"Proposed replacements: {replacements}")
70
+ print()
71
 
72
+ if __name__ == "__main__":
73
+ main()