Split big batches to avoid CUDA OOM
Browse files- expand_llm.py +23 -9
expand_llm.py
CHANGED
@@ -39,19 +39,33 @@ class LLMBatchExpander(BatchExpander):
|
|
39 |
model: PreTrainedModel
|
40 |
tokenizer: Tokenizer
|
41 |
threshold: float
|
|
|
42 |
|
43 |
def expand(self, batch: Batch) -> BatchCandidates:
|
44 |
-
inputs = prepare_inputs([s.get_all_tokens() for s in batch.items], self.tokenizer, self.model.device)
|
45 |
-
next_tokens = find_next_tokens(self.model, inputs, self.threshold)
|
46 |
start_time = time.time()
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
print(f"Token candidates done, took {time.time() - start_time} seconds")
|
54 |
-
return BatchCandidates(items=
|
55 |
|
56 |
def create_stopping_criterion_llm(tokenizer: Tokenizer) -> Callable[[Series, Expansion], bool]:
|
57 |
def stopping_criterion(series: Series, expansion: Expansion) -> bool:
|
|
|
39 |
model: PreTrainedModel
|
40 |
tokenizer: Tokenizer
|
41 |
threshold: float
|
42 |
+
chunk_size: int = 16 # Default chunk size, can be adjusted as needed
|
43 |
|
44 |
def expand(self, batch: Batch) -> BatchCandidates:
|
|
|
|
|
45 |
start_time = time.time()
|
46 |
+
all_results = []
|
47 |
+
|
48 |
+
# Split batch.items into chunks to avoid CUDA out of memory
|
49 |
+
for i in range(0, len(batch.items), self.chunk_size):
|
50 |
+
chunk_items = batch.items[i:i + self.chunk_size]
|
51 |
+
print(f"Processing chunk {i//self.chunk_size + 1}/{(len(batch.items) + self.chunk_size - 1)//self.chunk_size} with {len(chunk_items)} items")
|
52 |
+
|
53 |
+
# Process this chunk
|
54 |
+
inputs = prepare_inputs([s.get_all_tokens() for s in chunk_items], self.tokenizer, self.model.device)
|
55 |
+
chunk_next_tokens = find_next_tokens(self.model, inputs, self.threshold)
|
56 |
+
|
57 |
+
# Create token candidates for this chunk
|
58 |
+
for s, next_tokens in zip(chunk_items, chunk_next_tokens):
|
59 |
+
expansions = [Expansion(token=token, cost=cost) for token, cost in next_tokens]
|
60 |
+
all_results.append(TokenCandidates(series=s, expansions=expansions))
|
61 |
+
|
62 |
+
# Clear CUDA cache to free up memory
|
63 |
+
if torch.cuda.is_available():
|
64 |
+
torch.cuda.empty_cache()
|
65 |
+
|
66 |
+
print(f"Total batch size: {len(batch.items)}, processed in {(len(batch.items) + self.chunk_size - 1)//self.chunk_size} chunks")
|
67 |
print(f"Token candidates done, took {time.time() - start_time} seconds")
|
68 |
+
return BatchCandidates(items=all_results)
|
69 |
|
70 |
def create_stopping_criterion_llm(tokenizer: Tokenizer) -> Callable[[Series, Expansion], bool]:
|
71 |
def stopping_criterion(series: Series, expansion: Expansion) -> bool:
|