mebubo commited on
Commit
c980c58
·
1 Parent(s): 6849fc4

token_log_probs: simplify

Browse files
Files changed (1) hide show
  1. completions.py +2 -2
completions.py CHANGED
@@ -66,9 +66,9 @@ def calculate_log_probabilities(model: PreTrainedModel, tokenizer: Tokenizer, in
66
  # B x T x V
67
  log_probs: torch.Tensor = torch.log_softmax(logits, dim=-1)
68
  # T - 1
69
- token_log_probs: torch.Tensor = log_probs[0, range(log_probs.shape[1]), input_ids[0][1:]]
70
- # T - 1
71
  tokens: torch.Tensor = input_ids[0][1:]
 
 
72
  return list(zip(tokens.tolist(), token_log_probs.tolist()))
73
 
74
  #%%
 
66
  # B x T x V
67
  log_probs: torch.Tensor = torch.log_softmax(logits, dim=-1)
68
  # T - 1
 
 
69
  tokens: torch.Tensor = input_ids[0][1:]
70
+ # T - 1
71
+ token_log_probs: torch.Tensor = log_probs[0, range(log_probs.shape[1]), tokens]
72
  return list(zip(tokens.tolist(), token_log_probs.tolist()))
73
 
74
  #%%