mebubo commited on
Commit
76e131a
·
1 Parent(s): b47d499

feat: Update generate_replacements to accept prefix as token list

Browse files
Files changed (1) hide show
  1. app.py +3 -4
app.py CHANGED
@@ -30,8 +30,8 @@ def calculate_log_probabilities(model, tokenizer, inputs, input_ids, attention_m
30
  return list(zip(tokens[1:], token_log_probs.tolist()))
31
 
32
 
33
- def generate_replacements(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefix: str, device: torch.device, num_samples: int = 5) -> list[str]:
34
- input_context = tokenizer(prefix, return_tensors="pt").to(device)
35
  input_ids = input_context["input_ids"]
36
  attention_mask = input_context["attention_mask"]
37
  with torch.no_grad():
@@ -73,8 +73,7 @@ for word in tqdm(low_prob_words, desc="Processing words"):
73
  iteration_start_time = time.time()
74
  prefix_index = word.first_token_index
75
  prefix_tokens = [token for token, _ in result][:prefix_index + 1]
76
- prefix = tokenizer.convert_tokens_to_string(prefix_tokens)
77
- replacements = generate_replacements(model, tokenizer, prefix, device)
78
  print(f"Original word: {word.text}, Log Probability: {word.logprob:.4f}")
79
  print(f"Proposed replacements: {replacements}")
80
  print()
 
30
  return list(zip(tokens[1:], token_log_probs.tolist()))
31
 
32
 
33
+ def generate_replacements(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefix_tokens: list[int], device: torch.device, num_samples: int = 5) -> list[str]:
34
+ input_context = {"input_ids": torch.tensor([prefix_tokens]).to(device)}
35
  input_ids = input_context["input_ids"]
36
  attention_mask = input_context["attention_mask"]
37
  with torch.no_grad():
 
73
  iteration_start_time = time.time()
74
  prefix_index = word.first_token_index
75
  prefix_tokens = [token for token, _ in result][:prefix_index + 1]
76
+ replacements = generate_replacements(model, tokenizer, prefix_tokens, device)
 
77
  print(f"Original word: {word.text}, Log Probability: {word.logprob:.4f}")
78
  print(f"Proposed replacements: {replacements}")
79
  print()