feat: Update generate_replacements to accept prefix as token list
Browse files
app.py
CHANGED
@@ -30,8 +30,8 @@ def calculate_log_probabilities(model, tokenizer, inputs, input_ids, attention_m
|
|
30 |
return list(zip(tokens[1:], token_log_probs.tolist()))
|
31 |
|
32 |
|
33 |
-
def generate_replacements(model: PreTrainedModel, tokenizer: PreTrainedTokenizer,
|
34 |
-
input_context =
|
35 |
input_ids = input_context["input_ids"]
|
36 |
attention_mask = input_context["attention_mask"]
|
37 |
with torch.no_grad():
|
@@ -73,8 +73,7 @@ for word in tqdm(low_prob_words, desc="Processing words"):
|
|
73 |
iteration_start_time = time.time()
|
74 |
prefix_index = word.first_token_index
|
75 |
prefix_tokens = [token for token, _ in result][:prefix_index + 1]
|
76 |
-
|
77 |
-
replacements = generate_replacements(model, tokenizer, prefix, device)
|
78 |
print(f"Original word: {word.text}, Log Probability: {word.logprob:.4f}")
|
79 |
print(f"Proposed replacements: {replacements}")
|
80 |
print()
|
|
|
30 |
return list(zip(tokens[1:], token_log_probs.tolist()))
|
31 |
|
32 |
|
33 |
+
def generate_replacements(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefix_tokens: list[int], device: torch.device, num_samples: int = 5) -> list[str]:
|
34 |
+
input_context = {"input_ids": torch.tensor([prefix_tokens]).to(device)}
|
35 |
input_ids = input_context["input_ids"]
|
36 |
attention_mask = input_context["attention_mask"]
|
37 |
with torch.no_grad():
|
|
|
73 |
iteration_start_time = time.time()
|
74 |
prefix_index = word.first_token_index
|
75 |
prefix_tokens = [token for token, _ in result][:prefix_index + 1]
|
76 |
+
replacements = generate_replacements(model, tokenizer, prefix_tokens, device)
|
|
|
77 |
print(f"Original word: {word.text}, Log Probability: {word.logprob:.4f}")
|
78 |
print(f"Proposed replacements: {replacements}")
|
79 |
print()
|