Unwrap the single-element tensor when returning logprob
Browse files- completions.py +1 -1
completions.py
CHANGED
@@ -118,7 +118,7 @@ def find_next_tokens(model: PreTrainedModel, inputs: BatchEncoding, tokenizer: T
|
|
118 |
log_probs: torch.Tensor = torch.log_softmax(logits, dim=-1)
|
119 |
result = []
|
120 |
for probs in log_probs:
|
121 |
-
result.append([(i, p) for i, p in enumerate(probs)])
|
122 |
return result
|
123 |
|
124 |
def extract_replacements(outputs: GenerateOutput | torch.LongTensor, tokenizer: Tokenizer, num_inputs: int, input_len: int, num_samples: int = 5) -> list[list[str]]:
|
|
|
118 |
log_probs: torch.Tensor = torch.log_softmax(logits, dim=-1)
|
119 |
result = []
|
120 |
for probs in log_probs:
|
121 |
+
result.append([(i, p.item()) for i, p in enumerate(probs)])
|
122 |
return result
|
123 |
|
124 |
def extract_replacements(outputs: GenerateOutput | torch.LongTensor, tokenizer: Tokenizer, num_inputs: int, input_len: int, num_samples: int = 5) -> list[list[str]]:
|