mebubo commited on
Commit
6641473
·
1 Parent(s): e72ea09
Files changed (2) hide show
  1. app.py +25 -23
  2. text_processing.py +13 -12
app.py CHANGED
@@ -1,10 +1,8 @@
1
  #%%
2
  import time
3
- from tqdm import tqdm
4
  from text_processing import split_into_words, Word
5
  import torch
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast, BatchEncoding
7
- from tokenizers import Encoding
8
  from typing import cast
9
 
10
  type Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast
@@ -35,10 +33,9 @@ def calculate_log_probabilities(model: PreTrainedModel, tokenizer: Tokenizer, in
35
  return list(zip(tokens.tolist(), token_log_probs.tolist()))
36
 
37
 
38
- def generate_replacements(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefix_tokens: list[int], device: torch.device, num_samples: int = 5) -> list[str]:
39
- input_context = {"input_ids": torch.tensor([prefix_tokens]).to(device)}
40
- input_ids = input_context["input_ids"]
41
- attention_mask = input_context["attention_mask"]
42
  with torch.no_grad():
43
  outputs = model.generate(
44
  input_ids=input_ids,
@@ -50,12 +47,15 @@ def generate_replacements(model: PreTrainedModel, tokenizer: PreTrainedTokenizer
50
  top_p=0.95,
51
  do_sample=True
52
  )
53
- new_words = []
54
- for i in range(num_samples):
55
- generated_ids = outputs[i][input_ids.shape[-1]:]
56
- new_word = tokenizer.decode(generated_ids, skip_special_tokens=True).split()[0]
57
- new_words.append(new_word)
58
- return new_words
 
 
 
59
 
60
  #%%
61
 
@@ -71,11 +71,17 @@ input_ids, attention_mask = tokenize(input_text, tokenizer, device)
71
 
72
  #%%
73
 
74
- token_probs: list[tuple[str, float]] = calculate_log_probabilities(model, tokenizer, input_ids, attention_mask)
75
 
76
  #%%
77
 
78
- words = split_into_words(token_probs)
 
 
 
 
 
 
79
  log_prob_threshold = -5.0
80
  low_prob_words = [word for word in words if word.logprob < log_prob_threshold]
81
 
@@ -83,18 +89,14 @@ low_prob_words = [word for word in words if word.logprob < log_prob_threshold]
83
 
84
  start_time = time.time()
85
 
86
- for word in tqdm(low_prob_words, desc="Processing words"):
87
- iteration_start_time = time.time()
88
- prefix_index = word.first_token_index
89
- prefix_tokens = tokenizer.convert_tokens_to_ids([token for token, _ in token_probs][:prefix_index + 1])
90
- replacements = generate_replacements(model, tokenizer, prefix_tokens, device)
91
  print(f"Original word: {word.text}, Log Probability: {word.logprob:.4f}")
92
  print(f"Proposed replacements: {replacements}")
93
- print()
94
- iteration_end_time = time.time()
95
- print(f"Time taken for this iteration: {iteration_end_time - iteration_start_time:.4f} seconds")
96
 
97
  end_time = time.time()
98
- print(f"Total time taken for the loop: {end_time - start_time:.4f} seconds")
99
 
100
  # %%
 
1
  #%%
2
  import time
 
3
  from text_processing import split_into_words, Word
4
  import torch
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast, BatchEncoding
 
6
  from typing import cast
7
 
8
  type Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast
 
33
  return list(zip(tokens.tolist(), token_log_probs.tolist()))
34
 
35
 
36
+ def generate_replacements(model: PreTrainedModel, tokenizer: Tokenizer, contexts: list[list[int]], device: torch.device, num_samples: int = 5) -> list[list[str]]:
37
+ input_ids = torch.tensor(contexts).to(device)
38
+ attention_mask = torch.ones_like(input_ids)
 
39
  with torch.no_grad():
40
  outputs = model.generate(
41
  input_ids=input_ids,
 
47
  top_p=0.95,
48
  do_sample=True
49
  )
50
+ all_new_words = []
51
+ for i in range(len(contexts)):
52
+ replacements = []
53
+ for j in range(num_samples):
54
+ generated_ids = outputs[i * num_samples + j][input_ids.shape[-1]:]
55
+ new_word = tokenizer.decode(generated_ids, skip_special_tokens=True).split()[0]
56
+ replacements.append(new_word)
57
+ all_new_words.append(replacements)
58
+ return all_new_words
59
 
60
  #%%
61
 
 
71
 
72
  #%%
73
 
74
+ token_probs: list[tuple[int, float]] = calculate_log_probabilities(model, tokenizer, input_ids, attention_mask)
75
 
76
  #%%
77
 
78
+ import importlib
79
+ import text_processing
80
+
81
+ importlib.reload(text_processing)
82
+ from text_processing import split_into_words, Word
83
+
84
+ words = split_into_words(token_probs, tokenizer)
85
  log_prob_threshold = -5.0
86
  low_prob_words = [word for word in words if word.logprob < log_prob_threshold]
87
 
 
89
 
90
  start_time = time.time()
91
 
92
+ contexts = [word.context for word in low_prob_words]
93
+ replacements_batch = generate_replacements(model, tokenizer, contexts, device)
94
+
95
+ for word, replacements in zip(low_prob_words, replacements_batch):
 
96
  print(f"Original word: {word.text}, Log Probability: {word.logprob:.4f}")
97
  print(f"Proposed replacements: {replacements}")
 
 
 
98
 
99
  end_time = time.time()
100
+ print(f"Total time taken for replacements: {end_time - start_time:.4f} seconds")
101
 
102
  # %%
text_processing.py CHANGED
@@ -1,5 +1,7 @@
1
  from dataclasses import dataclass
2
- from tokenizers import Tokenizer
 
 
3
 
4
  @dataclass
5
  class Word:
@@ -15,25 +17,24 @@ def split_into_words(token_probs: list[tuple[int, float]], tokenizer: Tokenizer)
15
  current_word_first_token_index: int = 0
16
  all_tokens: list[int] = [token_id for token_id, _ in token_probs]
17
 
 
 
 
 
 
 
 
18
  for i, (token_id, logprob) in enumerate(token_probs):
19
- token: str = tokenizer.decode([token_id])
20
  if not token.startswith(chr(9601)) and token.isalpha():
21
  current_word.append(token_id)
22
  current_log_probs.append(logprob)
23
  else:
24
- if current_word:
25
- words.append(Word(current_word,
26
- tokenizer.decode(current_word),
27
- sum(current_log_probs),
28
- all_tokens[:current_word_first_token_index]))
29
  current_word = [token_id]
30
  current_log_probs = [logprob]
31
  current_word_first_token_index = i
32
 
33
- if current_word:
34
- words.append(Word(current_word,
35
- tokenizer.decode(current_word),
36
- sum(current_log_probs),
37
- all_tokens[:current_word_first_token_index]))
38
 
39
  return words
 
1
  from dataclasses import dataclass
2
+ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
3
+
4
+ type Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast
5
 
6
  @dataclass
7
  class Word:
 
17
  current_word_first_token_index: int = 0
18
  all_tokens: list[int] = [token_id for token_id, _ in token_probs]
19
 
20
+ def append_current_word():
21
+ if current_word:
22
+ words.append(Word(current_word,
23
+ tokenizer.decode(current_word),
24
+ sum(current_log_probs),
25
+ all_tokens[:current_word_first_token_index]))
26
+
27
  for i, (token_id, logprob) in enumerate(token_probs):
28
+ token: str = tokenizer.convert_ids_to_tokens([token_id])[0]
29
  if not token.startswith(chr(9601)) and token.isalpha():
30
  current_word.append(token_id)
31
  current_log_probs.append(logprob)
32
  else:
33
+ append_current_word()
 
 
 
 
34
  current_word = [token_id]
35
  current_log_probs = [logprob]
36
  current_word_first_token_index = i
37
 
38
+ append_current_word()
 
 
 
 
39
 
40
  return words