mebubo commited on
Commit
426b33e
·
1 Parent(s): 98f1760
Files changed (2) hide show
  1. app.py +38 -24
  2. text_processing.py +3 -3
app.py CHANGED
@@ -3,31 +3,36 @@ import time
3
  from tqdm import tqdm
4
  from text_processing import split_into_words, Word
5
  import torch
6
- from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer
7
- from pprint import pprint
 
8
 
9
- def load_model_and_tokenizer(model_name):
10
- tokenizer = AutoTokenizer.from_pretrained(model_name)
11
- model = AutoModelForCausalLM.from_pretrained(model_name)
12
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
13
  model.to(device)
14
- return model, tokenizer, device
15
 
16
- def process_input_text(input_text, tokenizer, device):
17
- """Process input text to obtain input IDs and attention mask."""
18
- inputs = tokenizer(input_text, return_tensors="pt").to(device)
19
- input_ids = inputs["input_ids"]
20
- attention_mask = inputs["attention_mask"]
21
- return inputs, input_ids, attention_mask
22
 
23
- def calculate_log_probabilities(model, tokenizer, inputs, input_ids, attention_mask):
24
  with torch.no_grad():
25
  outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
26
- logits = outputs.logits[0, :-1, :]
27
- log_probs = torch.log_softmax(logits, dim=-1)
28
- token_log_probs = log_probs[range(log_probs.shape[0]), input_ids[0][1:]]
29
- tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
30
- return list(zip(tokens[1:], token_log_probs.tolist()))
 
 
 
 
31
 
32
 
33
  def generate_replacements(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefix_tokens: list[int], device: torch.device, num_samples: int = 5) -> list[str]:
@@ -53,15 +58,24 @@ def generate_replacements(model: PreTrainedModel, tokenizer: PreTrainedTokenizer
53
  return new_words
54
 
55
  #%%
 
 
 
56
  model_name = "mistralai/Mistral-7B-v0.1"
57
- model, tokenizer, device = load_model_and_tokenizer(model_name)
 
 
58
 
59
  input_text = "He asked me to prostrate myself before the king, but I rifused."
60
- inputs, input_ids, attention_mask = process_input_text(input_text, tokenizer, device)
61
 
62
- result = calculate_log_probabilities(model, tokenizer, inputs, input_ids, attention_mask)
 
 
 
 
63
 
64
- words = split_into_words([token for token, _ in result], [logprob for _, logprob in result])
65
  log_prob_threshold = -5.0
66
  low_prob_words = [word for word in words if word.logprob < log_prob_threshold]
67
 
@@ -72,7 +86,7 @@ start_time = time.time()
72
  for word in tqdm(low_prob_words, desc="Processing words"):
73
  iteration_start_time = time.time()
74
  prefix_index = word.first_token_index
75
- prefix_tokens = tokenizer.convert_tokens_to_ids([token for token, _ in result][:prefix_index + 1])
76
  replacements = generate_replacements(model, tokenizer, prefix_tokens, device)
77
  print(f"Original word: {word.text}, Log Probability: {word.logprob:.4f}")
78
  print(f"Proposed replacements: {replacements}")
 
3
  from tqdm import tqdm
4
  from text_processing import split_into_words, Word
5
  import torch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast, BatchEncoding
7
+ from tokenizers import Encoding
8
+ from typing import cast
9
 
10
+ type Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast
11
+
12
+ def load_model_and_tokenizer(model_name: str, device: torch.device) -> tuple[PreTrainedModel, Tokenizer]:
13
+ tokenizer: Tokenizer = AutoTokenizer.from_pretrained(model_name)
14
+ model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(model_name)
15
  model.to(device)
16
+ return model, tokenizer
17
 
18
+ def tokenize(input_text: str, tokenizer: Tokenizer, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
19
+ inputs: BatchEncoding = tokenizer(input_text, return_tensors="pt").to(device)
20
+ input_ids = cast(torch.Tensor, inputs["input_ids"])
21
+ attention_mask = cast(torch.Tensor, inputs["attention_mask"])
22
+ return input_ids, attention_mask
 
23
 
24
+ def calculate_log_probabilities(model: PreTrainedModel, tokenizer: Tokenizer, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> list[tuple[str, float]]:
25
  with torch.no_grad():
26
  outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
27
+ # B x T x V
28
+ logits: torch.Tensor = outputs.logits[:, :-1, :]
29
+ # B x T x V
30
+ log_probs: torch.Tensor = torch.log_softmax(logits, dim=-1)
31
+ # T - 1
32
+ token_log_probs: torch.Tensor = log_probs[0, range(log_probs.shape[1]), input_ids[0][1:]]
33
+ # T - 1
34
+ tokens: list[str] = tokenizer.convert_ids_to_tokens(input_ids[0])[1:]
35
+ return list(zip(tokens, token_log_probs.tolist()))
36
 
37
 
38
  def generate_replacements(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefix_tokens: list[int], device: torch.device, num_samples: int = 5) -> list[str]:
 
58
  return new_words
59
 
60
  #%%
61
+
62
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
63
+
64
  model_name = "mistralai/Mistral-7B-v0.1"
65
+ model, tokenizer = load_model_and_tokenizer(model_name, device)
66
+
67
+ #%%
68
 
69
  input_text = "He asked me to prostrate myself before the king, but I rifused."
70
+ input_ids, attention_mask = tokenize(input_text, tokenizer, device)
71
 
72
+ #%%
73
+
74
+ token_probs: list[tuple[str, float]] = calculate_log_probabilities(model, tokenizer, input_ids, attention_mask)
75
+
76
+ #%%
77
 
78
+ words = split_into_words(token_probs)
79
  log_prob_threshold = -5.0
80
  low_prob_words = [word for word in words if word.logprob < log_prob_threshold]
81
 
 
86
  for word in tqdm(low_prob_words, desc="Processing words"):
87
  iteration_start_time = time.time()
88
  prefix_index = word.first_token_index
89
+ prefix_tokens = tokenizer.convert_tokens_to_ids([token for token, _ in token_probs][:prefix_index + 1])
90
  replacements = generate_replacements(model, tokenizer, prefix_tokens, device)
91
  print(f"Original word: {word.text}, Log Probability: {word.logprob:.4f}")
92
  print(f"Proposed replacements: {replacements}")
text_processing.py CHANGED
@@ -2,18 +2,18 @@ from dataclasses import dataclass
2
 
3
  @dataclass
4
  class Word:
5
- tokens: list[int]
6
  text: str
7
  logprob: float
8
  first_token_index: int
9
 
10
- def split_into_words(tokens, log_probs) -> list[Word]:
11
  words = []
12
  current_word = []
13
  current_log_probs = []
14
  current_word_first_token_index = 0
15
 
16
- for i, (token, logprob) in enumerate(zip(tokens, log_probs)):
17
  if not token.startswith(chr(9601)) and token.isalpha():
18
  current_word.append(token)
19
  current_log_probs.append(logprob)
 
2
 
3
  @dataclass
4
  class Word:
5
+ tokens: list[str]
6
  text: str
7
  logprob: float
8
  first_token_index: int
9
 
10
+ def split_into_words(token_probs: list[tuple[str, float]]) -> list[Word]:
11
  words = []
12
  current_word = []
13
  current_log_probs = []
14
  current_word_first_token_index = 0
15
 
16
+ for i, (token, logprob) in enumerate(token_probs):
17
  if not token.startswith(chr(9601)) and token.isalpha():
18
  current_word.append(token)
19
  current_log_probs.append(logprob)