token_log_probs: simplify
Browse files- completions.py +2 -2
completions.py
CHANGED
@@ -66,9 +66,9 @@ def calculate_log_probabilities(model: PreTrainedModel, tokenizer: Tokenizer, in
|
|
66 |
# B x T x V
|
67 |
log_probs: torch.Tensor = torch.log_softmax(logits, dim=-1)
|
68 |
# T - 1
|
69 |
-
token_log_probs: torch.Tensor = log_probs[0, range(log_probs.shape[1]), input_ids[0][1:]]
|
70 |
-
# T - 1
|
71 |
tokens: torch.Tensor = input_ids[0][1:]
|
|
|
|
|
72 |
return list(zip(tokens.tolist(), token_log_probs.tolist()))
|
73 |
|
74 |
#%%
|
|
|
66 |
# B x T x V
|
67 |
log_probs: torch.Tensor = torch.log_softmax(logits, dim=-1)
|
68 |
# T - 1
|
|
|
|
|
69 |
tokens: torch.Tensor = input_ids[0][1:]
|
70 |
+
# T - 1
|
71 |
+
token_log_probs: torch.Tensor = log_probs[0, range(log_probs.shape[1]), tokens]
|
72 |
return list(zip(tokens.tolist(), token_log_probs.tolist()))
|
73 |
|
74 |
#%%
|