Fix performance by ignoring tokens with probabilities below a threshold
Browse files- expand_llm.py +6 -1
expand_llm.py
CHANGED
@@ -21,8 +21,13 @@ def find_next_tokens(model: PreTrainedModel, inputs: BatchEncoding, tokenizer: T
|
|
21 |
start_time = time.time()
|
22 |
result = []
|
23 |
print(f"Resulting tensor: {log_probs.shape}")
|
|
|
24 |
for probs in log_probs:
|
25 |
-
|
|
|
|
|
|
|
|
|
26 |
print(f"Result done, took {time.time() - start_time} seconds")
|
27 |
return result
|
28 |
|
|
|
21 |
start_time = time.time()
|
22 |
result = []
|
23 |
print(f"Resulting tensor: {log_probs.shape}")
|
24 |
+
threshold = -10.0
|
25 |
for probs in log_probs:
|
26 |
+
# Filter out low probability tokens for efficiency
|
27 |
+
above_threshold = torch.where(probs > threshold)
|
28 |
+
filtered_indices = above_threshold[0]
|
29 |
+
filtered_probs = probs[filtered_indices]
|
30 |
+
result.append([(idx.item(), prob.item()) for idx, prob in zip(filtered_indices, filtered_probs)])
|
31 |
print(f"Result done, took {time.time() - start_time} seconds")
|
32 |
return result
|
33 |
|