feat: Add type annotations to generate_replacements function
Browse files
app.py
CHANGED
@@ -27,7 +27,10 @@ def calculate_log_probabilities(model, tokenizer, inputs, input_ids):
|
|
27 |
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
|
28 |
return list(zip(tokens[1:], token_log_probs.tolist()))
|
29 |
|
30 |
-
|
|
|
|
|
|
|
31 |
input_context = tokenizer(prefix, return_tensors="pt").to(device)
|
32 |
input_ids = input_context["input_ids"]
|
33 |
with torch.no_grad():
|
|
|
27 |
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
|
28 |
return list(zip(tokens[1:], token_log_probs.tolist()))
|
29 |
|
30 |
+
from transformers import PreTrainedModel, PreTrainedTokenizer
|
31 |
+
from typing import List
|
32 |
+
|
33 |
+
def generate_replacements(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefix: str, device: torch.device, num_samples: int = 5) -> List[str]:
|
34 |
input_context = tokenizer(prefix, return_tensors="pt").to(device)
|
35 |
input_ids = input_context["input_ids"]
|
36 |
with torch.no_grad():
|