refactor: Update type hints to use Python 3.12 style in generate_replacements
Browse files
app.py
CHANGED
@@ -28,9 +28,8 @@ def calculate_log_probabilities(model, tokenizer, inputs, input_ids):
|
|
28 |
return list(zip(tokens[1:], token_log_probs.tolist()))
|
29 |
|
30 |
from transformers import PreTrainedModel, PreTrainedTokenizer
|
31 |
-
from typing import List
|
32 |
|
33 |
-
def generate_replacements(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefix: str, device: torch.device, num_samples: int = 5) ->
|
34 |
input_context = tokenizer(prefix, return_tensors="pt").to(device)
|
35 |
input_ids = input_context["input_ids"]
|
36 |
with torch.no_grad():
|
|
|
28 |
return list(zip(tokens[1:], token_log_probs.tolist()))
|
29 |
|
30 |
from transformers import PreTrainedModel, PreTrainedTokenizer
|
|
|
31 |
|
32 |
+
def generate_replacements(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefix: str, device: torch.device, num_samples: int = 5) -> list[str]:
|
33 |
input_context = tokenizer(prefix, return_tensors="pt").to(device)
|
34 |
input_ids = input_context["input_ids"]
|
35 |
with torch.no_grad():
|