mebubo commited on
Commit
b174bd4
·
1 Parent(s): 6641473
Files changed (2) hide show
  1. app.py +58 -18
  2. text_processing.py +0 -40
app.py CHANGED
@@ -1,14 +1,52 @@
1
  #%%
2
  import time
3
- from text_processing import split_into_words, Word
4
  import torch
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast, BatchEncoding
 
6
  from typing import cast
 
7
 
8
  type Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def load_model_and_tokenizer(model_name: str, device: torch.device) -> tuple[PreTrainedModel, Tokenizer]:
11
- tokenizer: Tokenizer = AutoTokenizer.from_pretrained(model_name)
 
12
  model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(model_name)
13
  model.to(device)
14
  return model, tokenizer
@@ -32,10 +70,16 @@ def calculate_log_probabilities(model: PreTrainedModel, tokenizer: Tokenizer, in
32
  tokens: torch.Tensor = input_ids[0][1:]
33
  return list(zip(tokens.tolist(), token_log_probs.tolist()))
34
 
 
 
 
 
 
 
35
 
36
- def generate_replacements(model: PreTrainedModel, tokenizer: Tokenizer, contexts: list[list[int]], device: torch.device, num_samples: int = 5) -> list[list[str]]:
37
- input_ids = torch.tensor(contexts).to(device)
38
- attention_mask = torch.ones_like(input_ids)
39
  with torch.no_grad():
40
  outputs = model.generate(
41
  input_ids=input_ids,
@@ -52,10 +96,10 @@ def generate_replacements(model: PreTrainedModel, tokenizer: Tokenizer, contexts
52
  replacements = []
53
  for j in range(num_samples):
54
  generated_ids = outputs[i * num_samples + j][input_ids.shape[-1]:]
55
- new_word = tokenizer.decode(generated_ids, skip_special_tokens=True).split()[0]
56
  replacements.append(new_word)
57
  all_new_words.append(replacements)
58
- return all_new_words
59
 
60
  #%%
61
 
@@ -75,28 +119,24 @@ token_probs: list[tuple[int, float]] = calculate_log_probabilities(model, tokeni
75
 
76
  #%%
77
 
78
- import importlib
79
- import text_processing
80
-
81
- importlib.reload(text_processing)
82
- from text_processing import split_into_words, Word
83
-
84
  words = split_into_words(token_probs, tokenizer)
85
  log_prob_threshold = -5.0
86
  low_prob_words = [word for word in words if word.logprob < log_prob_threshold]
87
 
 
 
 
88
  #%%
89
 
90
  start_time = time.time()
 
 
 
91
 
92
- contexts = [word.context for word in low_prob_words]
93
- replacements_batch = generate_replacements(model, tokenizer, contexts, device)
94
 
95
  for word, replacements in zip(low_prob_words, replacements_batch):
96
  print(f"Original word: {word.text}, Log Probability: {word.logprob:.4f}")
97
  print(f"Proposed replacements: {replacements}")
98
 
99
- end_time = time.time()
100
- print(f"Total time taken for replacements: {end_time - start_time:.4f} seconds")
101
-
102
  # %%
 
1
  #%%
2
  import time
 
3
  import torch
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast, BatchEncoding
5
+ from transformers.generation.utils import GenerateOutput
6
  from typing import cast
7
+ from dataclasses import dataclass
8
 
9
  type Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast
10
 
11
+ @dataclass
12
+ class Word:
13
+ tokens: list[int]
14
+ text: str
15
+ logprob: float
16
+ context: list[int]
17
+
18
+ def split_into_words(token_probs: list[tuple[int, float]], tokenizer: Tokenizer) -> list[Word]:
19
+ words: list[Word] = []
20
+ current_word: list[int] = []
21
+ current_log_probs: list[float] = []
22
+ current_word_first_token_index: int = 0
23
+ all_tokens: list[int] = [token_id for token_id, _ in token_probs]
24
+
25
+ def append_current_word():
26
+ if current_word:
27
+ words.append(Word(current_word,
28
+ tokenizer.decode(current_word),
29
+ sum(current_log_probs),
30
+ all_tokens[:current_word_first_token_index]))
31
+
32
+ for i, (token_id, logprob) in enumerate(token_probs):
33
+ token: str = tokenizer.convert_ids_to_tokens([token_id])[0]
34
+ if not token.startswith(chr(9601)) and token.isalpha():
35
+ current_word.append(token_id)
36
+ current_log_probs.append(logprob)
37
+ else:
38
+ append_current_word()
39
+ current_word = [token_id]
40
+ current_log_probs = [logprob]
41
+ current_word_first_token_index = i
42
+
43
+ append_current_word()
44
+
45
+ return words
46
+
47
  def load_model_and_tokenizer(model_name: str, device: torch.device) -> tuple[PreTrainedModel, Tokenizer]:
48
+ tokenizer: Tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
49
+ tokenizer.pad_token = tokenizer.eos_token
50
  model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(model_name)
51
  model.to(device)
52
  return model, tokenizer
 
70
  tokens: torch.Tensor = input_ids[0][1:]
71
  return list(zip(tokens.tolist(), token_log_probs.tolist()))
72
 
73
+ def prepare_inputs(contexts: list[list[int]], tokenizer: Tokenizer, device: torch.device) -> tuple[torch.FloatTensor, torch.FloatTensor]:
74
+ texts = [tokenizer.decode(context, skip_special_tokens=True) for context in contexts]
75
+ inputs = tokenizer(texts, return_tensors="pt", padding=True).to(device)
76
+ input_ids = cast(torch.FloatTensor, inputs["input_ids"])
77
+ attention_mask = cast(torch.FloatTensor, inputs["attention_mask"])
78
+ return input_ids, attention_mask
79
 
80
+ def generate_replacements(model: PreTrainedModel, tokenizer: Tokenizer, contexts: list[list[int]],
81
+ device: torch.device, num_samples: int = 5) -> tuple[GenerateOutput | torch.LongTensor, list[list[str]]]:
82
+ input_ids, attention_mask = prepare_inputs(contexts, tokenizer, device)
83
  with torch.no_grad():
84
  outputs = model.generate(
85
  input_ids=input_ids,
 
96
  replacements = []
97
  for j in range(num_samples):
98
  generated_ids = outputs[i * num_samples + j][input_ids.shape[-1]:]
99
+ new_word = tokenizer.decode(generated_ids, skip_special_tokens=False).split()[0]
100
  replacements.append(new_word)
101
  all_new_words.append(replacements)
102
+ return outputs, all_new_words
103
 
104
  #%%
105
 
 
119
 
120
  #%%
121
 
 
 
 
 
 
 
122
  words = split_into_words(token_probs, tokenizer)
123
  log_prob_threshold = -5.0
124
  low_prob_words = [word for word in words if word.logprob < log_prob_threshold]
125
 
126
+ #%%
127
+ contexts = [word.context for word in low_prob_words]
128
+
129
  #%%
130
 
131
  start_time = time.time()
132
+ replacements_batch = generate_replacements(model, tokenizer, contexts, device, num_samples=5)
133
+ end_time = time.time()
134
+ print(f"Total time taken for replacements: {end_time - start_time:.4f} seconds")
135
 
136
+ #%%
 
137
 
138
  for word, replacements in zip(low_prob_words, replacements_batch):
139
  print(f"Original word: {word.text}, Log Probability: {word.logprob:.4f}")
140
  print(f"Proposed replacements: {replacements}")
141
 
 
 
 
142
  # %%
text_processing.py DELETED
@@ -1,40 +0,0 @@
1
- from dataclasses import dataclass
2
- from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
3
-
4
- type Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast
5
-
6
- @dataclass
7
- class Word:
8
- tokens: list[int]
9
- text: str
10
- logprob: float
11
- context: list[int]
12
-
13
- def split_into_words(token_probs: list[tuple[int, float]], tokenizer: Tokenizer) -> list[Word]:
14
- words: list[Word] = []
15
- current_word: list[int] = []
16
- current_log_probs: list[float] = []
17
- current_word_first_token_index: int = 0
18
- all_tokens: list[int] = [token_id for token_id, _ in token_probs]
19
-
20
- def append_current_word():
21
- if current_word:
22
- words.append(Word(current_word,
23
- tokenizer.decode(current_word),
24
- sum(current_log_probs),
25
- all_tokens[:current_word_first_token_index]))
26
-
27
- for i, (token_id, logprob) in enumerate(token_probs):
28
- token: str = tokenizer.convert_ids_to_tokens([token_id])[0]
29
- if not token.startswith(chr(9601)) and token.isalpha():
30
- current_word.append(token_id)
31
- current_log_probs.append(logprob)
32
- else:
33
- append_current_word()
34
- current_word = [token_id]
35
- current_log_probs = [logprob]
36
- current_word_first_token_index = i
37
-
38
- append_current_word()
39
-
40
- return words