par-meta commited on
Commit
c727844
·
unverified ·
1 Parent(s): 08b8c7c

Correctly reset batch iterator at each arrow create_iter call. (#74)

Browse files
bytelatent/data/iterators/arrow_iterator.py CHANGED
@@ -197,9 +197,6 @@ class ArrowFileIterator(StatefulIterator):
197
  self.dataset = pa.dataset.dataset(
198
  self.dataset_files, format=self.file_format, filesystem=filesystem
199
  )
200
- self.batch_iterator = self.dataset.to_batches(
201
- batch_size=self.arrow_batch_size
202
- )
203
  self.iter_id += 1
204
  if self.batch_to_consume is not None:
205
  batch_columns: dict[str, list] = self.batch_to_consume
@@ -229,6 +226,7 @@ class ArrowFileIterator(StatefulIterator):
229
  if (self.row_num - 1) % self.num_workers == self.worker_id:
230
  yield out
231
 
 
232
  for batch in self.batch_iterator:
233
  batch_columns = batch.to_pydict()
234
  if self.file_format == "arrow":
 
197
  self.dataset = pa.dataset.dataset(
198
  self.dataset_files, format=self.file_format, filesystem=filesystem
199
  )
 
 
 
200
  self.iter_id += 1
201
  if self.batch_to_consume is not None:
202
  batch_columns: dict[str, list] = self.batch_to_consume
 
226
  if (self.row_num - 1) % self.num_workers == self.worker_id:
227
  yield out
228
 
229
+ self.batch_iterator = self.dataset.to_batches(batch_size=self.arrow_batch_size)
230
  for batch in self.batch_iterator:
231
  batch_columns = batch.to_pydict()
232
  if self.file_format == "arrow":