Spaces:
Running
on
Zero
Running
on
Zero
Correctly reset batch iterator at each arrow create_iter call. (#74)
Browse files
bytelatent/data/iterators/arrow_iterator.py
CHANGED
@@ -197,9 +197,6 @@ class ArrowFileIterator(StatefulIterator):
|
|
197 |
self.dataset = pa.dataset.dataset(
|
198 |
self.dataset_files, format=self.file_format, filesystem=filesystem
|
199 |
)
|
200 |
-
self.batch_iterator = self.dataset.to_batches(
|
201 |
-
batch_size=self.arrow_batch_size
|
202 |
-
)
|
203 |
self.iter_id += 1
|
204 |
if self.batch_to_consume is not None:
|
205 |
batch_columns: dict[str, list] = self.batch_to_consume
|
@@ -229,6 +226,7 @@ class ArrowFileIterator(StatefulIterator):
|
|
229 |
if (self.row_num - 1) % self.num_workers == self.worker_id:
|
230 |
yield out
|
231 |
|
|
|
232 |
for batch in self.batch_iterator:
|
233 |
batch_columns = batch.to_pydict()
|
234 |
if self.file_format == "arrow":
|
|
|
197 |
self.dataset = pa.dataset.dataset(
|
198 |
self.dataset_files, format=self.file_format, filesystem=filesystem
|
199 |
)
|
|
|
|
|
|
|
200 |
self.iter_id += 1
|
201 |
if self.batch_to_consume is not None:
|
202 |
batch_columns: dict[str, list] = self.batch_to_consume
|
|
|
226 |
if (self.row_num - 1) % self.num_workers == self.worker_id:
|
227 |
yield out
|
228 |
|
229 |
+
self.batch_iterator = self.dataset.to_batches(batch_size=self.arrow_batch_size)
|
230 |
for batch in self.batch_iterator:
|
231 |
batch_columns = batch.to_pydict()
|
232 |
if self.file_format == "arrow":
|