mebubo commited on
Commit
d3ef10a
·
1 Parent(s): b174bd4
Files changed (1) hide show
  1. app.py +22 -18
app.py CHANGED
@@ -51,15 +51,12 @@ def load_model_and_tokenizer(model_name: str, device: torch.device) -> tuple[Pre
51
  model.to(device)
52
  return model, tokenizer
53
 
54
- def tokenize(input_text: str, tokenizer: Tokenizer, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
55
- inputs: BatchEncoding = tokenizer(input_text, return_tensors="pt").to(device)
56
- input_ids = cast(torch.Tensor, inputs["input_ids"])
57
- attention_mask = cast(torch.Tensor, inputs["attention_mask"])
58
- return input_ids, attention_mask
59
 
60
- def calculate_log_probabilities(model: PreTrainedModel, tokenizer: Tokenizer, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> list[tuple[int, float]]:
61
  with torch.no_grad():
62
- outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
63
  # B x T x V
64
  logits: torch.Tensor = outputs.logits[:, :-1, :]
65
  # B x T x V
@@ -70,16 +67,14 @@ def calculate_log_probabilities(model: PreTrainedModel, tokenizer: Tokenizer, in
70
  tokens: torch.Tensor = input_ids[0][1:]
71
  return list(zip(tokens.tolist(), token_log_probs.tolist()))
72
 
73
- def prepare_inputs(contexts: list[list[int]], tokenizer: Tokenizer, device: torch.device) -> tuple[torch.FloatTensor, torch.FloatTensor]:
74
  texts = [tokenizer.decode(context, skip_special_tokens=True) for context in contexts]
75
- inputs = tokenizer(texts, return_tensors="pt", padding=True).to(device)
76
- input_ids = cast(torch.FloatTensor, inputs["input_ids"])
77
- attention_mask = cast(torch.FloatTensor, inputs["attention_mask"])
78
- return input_ids, attention_mask
79
 
80
- def generate_replacements(model: PreTrainedModel, tokenizer: Tokenizer, contexts: list[list[int]],
81
  device: torch.device, num_samples: int = 5) -> tuple[GenerateOutput | torch.LongTensor, list[list[str]]]:
82
- input_ids, attention_mask = prepare_inputs(contexts, tokenizer, device)
 
83
  with torch.no_grad():
84
  outputs = model.generate(
85
  input_ids=input_ids,
@@ -92,12 +87,13 @@ def generate_replacements(model: PreTrainedModel, tokenizer: Tokenizer, contexts
92
  do_sample=True
93
  )
94
  all_new_words = []
95
- for i in range(len(contexts)):
96
  replacements = []
97
  for j in range(num_samples):
98
  generated_ids = outputs[i * num_samples + j][input_ids.shape[-1]:]
99
- new_word = tokenizer.decode(generated_ids, skip_special_tokens=False).split()[0]
100
- replacements.append(new_word)
 
101
  all_new_words.append(replacements)
102
  return outputs, all_new_words
103
 
@@ -125,11 +121,13 @@ low_prob_words = [word for word in words if word.logprob < log_prob_threshold]
125
 
126
  #%%
127
  contexts = [word.context for word in low_prob_words]
 
 
128
 
129
  #%%
130
 
131
  start_time = time.time()
132
- replacements_batch = generate_replacements(model, tokenizer, contexts, device, num_samples=5)
133
  end_time = time.time()
134
  print(f"Total time taken for replacements: {end_time - start_time:.4f} seconds")
135
 
@@ -140,3 +138,9 @@ for word, replacements in zip(low_prob_words, replacements_batch):
140
  print(f"Proposed replacements: {replacements}")
141
 
142
  # %%
 
 
 
 
 
 
 
51
  model.to(device)
52
  return model, tokenizer
53
 
54
+ def tokenize(input_text: str, tokenizer: Tokenizer, device: torch.device) -> BatchEncoding:
55
+ return tokenizer(input_text, return_tensors="pt").to(device)
 
 
 
56
 
57
+ def calculate_log_probabilities(model: PreTrainedModel, tokenizer: Tokenizer, inputs: BatchEncoding) -> list[tuple[int, float]]:
58
  with torch.no_grad():
59
+ outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], labels=inputs["input_ids"])
60
  # B x T x V
61
  logits: torch.Tensor = outputs.logits[:, :-1, :]
62
  # B x T x V
 
67
  tokens: torch.Tensor = input_ids[0][1:]
68
  return list(zip(tokens.tolist(), token_log_probs.tolist()))
69
 
70
+ def prepare_inputs(contexts: list[list[int]], tokenizer: Tokenizer, device: torch.device) -> BatchEncoding:
71
  texts = [tokenizer.decode(context, skip_special_tokens=True) for context in contexts]
72
+ return tokenizer(texts, return_tensors="pt", padding=True).to(device)
 
 
 
73
 
74
+ def generate_replacements(model: PreTrainedModel, tokenizer: Tokenizer, inputs: BatchEncoding,
75
  device: torch.device, num_samples: int = 5) -> tuple[GenerateOutput | torch.LongTensor, list[list[str]]]:
76
+ input_ids = inputs["input_ids"]
77
+ attention_mask = inputs["attention_mask"]
78
  with torch.no_grad():
79
  outputs = model.generate(
80
  input_ids=input_ids,
 
87
  do_sample=True
88
  )
89
  all_new_words = []
90
+ for i in range(len(input_ids)):
91
  replacements = []
92
  for j in range(num_samples):
93
  generated_ids = outputs[i * num_samples + j][input_ids.shape[-1]:]
94
+ new_word = tokenizer.convert_ids_to_tokens(generated_ids.tolist())[0]
95
+ if new_word.startswith(chr(9601)):
96
+ replacements.append(new_word)
97
  all_new_words.append(replacements)
98
  return outputs, all_new_words
99
 
 
121
 
122
  #%%
123
  contexts = [word.context for word in low_prob_words]
124
+ inputs = prepare_inputs(contexts, tokenizer, device)
125
+ input_ids = inputs["input_ids"]
126
 
127
  #%%
128
 
129
  start_time = time.time()
130
+ outputs, replacements_batch = generate_replacements(model, tokenizer, inputs, device, num_samples=5)
131
  end_time = time.time()
132
  print(f"Total time taken for replacements: {end_time - start_time:.4f} seconds")
133
 
 
138
  print(f"Proposed replacements: {replacements}")
139
 
140
  # %%
141
+
142
+ generated_ids = outputs[:, input_ids.shape[-1]:]
143
+ for g in generated_ids:
144
+ print(tokenizer.convert_ids_to_tokens(g.tolist()))
145
+
146
+ # %%