mebubo commited on
Commit
2bbbefa
·
1 Parent(s): e0f1806

Fix performance by ignoring tokens with probabilities below a threshold

Browse files
Files changed (1) hide show
  1. expand_llm.py +6 -1
expand_llm.py CHANGED
@@ -21,8 +21,13 @@ def find_next_tokens(model: PreTrainedModel, inputs: BatchEncoding, tokenizer: T
21
  start_time = time.time()
22
  result = []
23
  print(f"Resulting tensor: {log_probs.shape}")
 
24
  for probs in log_probs:
25
- result.append([(i, p.item()) for i, p in enumerate(probs)])
 
 
 
 
26
  print(f"Result done, took {time.time() - start_time} seconds")
27
  return result
28
 
 
21
  start_time = time.time()
22
  result = []
23
  print(f"Resulting tensor: {log_probs.shape}")
24
+ threshold = -10.0
25
  for probs in log_probs:
26
+ # Filter out low probability tokens for efficiency
27
+ above_threshold = torch.where(probs > threshold)
28
+ filtered_indices = above_threshold[0]
29
+ filtered_probs = probs[filtered_indices]
30
+ result.append([(idx.item(), prob.item()) for idx, prob in zip(filtered_indices, filtered_probs)])
31
  print(f"Result done, took {time.time() - start_time} seconds")
32
  return result
33