par-meta commited on
Commit
8c61ab5
·
unverified ·
1 Parent(s): 85c2f28

Fix multiprocessing dataloader checkpointing and use it in the train script (#50)

Browse files
bytelatent/args.py CHANGED
@@ -1,10 +1,8 @@
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- import json
3
  import logging
4
  import os
5
  from typing import Any
6
 
7
- import fsspec
8
  import numpy as np
9
  import yaml
10
  from omegaconf import OmegaConf
 
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
 
2
  import logging
3
  import os
4
  from typing import Any
5
 
 
6
  import numpy as np
7
  import yaml
8
  from omegaconf import OmegaConf
bytelatent/data/iterators/abstract_iterator.py CHANGED
@@ -21,3 +21,13 @@ class IteratorState(Generic[C]):
21
  @abc.abstractmethod
22
  def build(self) -> StatefulIterator[T, C]:
23
  pass
 
 
 
 
 
 
 
 
 
 
 
21
  @abc.abstractmethod
22
  def build(self) -> StatefulIterator[T, C]:
23
  pass
24
+
25
+
26
+ def get_state_and_refresh(iterator: StatefulIterator):
27
+ # Re-init dataloader and iterator is necessary since get_state()
28
+ # on mp iterator shuts down MP to correctly persist state and it needs
29
+ # to be restarted.
30
+ state = iterator.get_state()
31
+ data_loader = state.build()
32
+ py_iterator = data_loader.create_iter()
33
+ return state, data_loader, py_iterator
bytelatent/data/iterators/arrow_iterator.py CHANGED
@@ -60,6 +60,13 @@ def shard_sort_key(file: str):
60
  return shard_number
61
 
62
 
 
 
 
 
 
 
 
63
  class ArrowFileIterator(StatefulIterator):
64
  def __init__(
65
  self,
@@ -235,9 +242,8 @@ class ArrowFileIterator(StatefulIterator):
235
  yield out
236
 
237
  def _set_row_num(self, target_row_num: int):
238
- logger.info(
239
- f"Setting arrow position to {target_row_num} for {self.dataset_files}"
240
- )
241
  if target_row_num is None or target_row_num == 0:
242
  self.row_num = 0
243
  self.dataset = None
@@ -285,6 +291,7 @@ class ArrowFileIterator(StatefulIterator):
285
  else:
286
  curr_remaining -= len(batch)
287
  self.row_num = target_row_num
 
288
  logger.info(
289
- f"Finished setting arrow position to {target_row_num} for {self.dataset_files}"
290
  )
 
60
  return shard_number
61
 
62
 
63
+ def maybe_truncate_string(text: str, max_length: int):
64
+ if len(text) <= max_length:
65
+ return text
66
+ else:
67
+ return text[:max_length] + "..."
68
+
69
+
70
  class ArrowFileIterator(StatefulIterator):
71
  def __init__(
72
  self,
 
242
  yield out
243
 
244
  def _set_row_num(self, target_row_num: int):
245
+ data_str = maybe_truncate_string(str(self.dataset_files), 200)
246
+ logger.info(f"Setting arrow position to {target_row_num} for {data_str}")
 
247
  if target_row_num is None or target_row_num == 0:
248
  self.row_num = 0
249
  self.dataset = None
 
291
  else:
292
  curr_remaining -= len(batch)
293
  self.row_num = target_row_num
294
+ data_str = maybe_truncate_string(str(self.dataset_files), 200)
295
  logger.info(
296
+ f"Finished setting arrow position to {target_row_num} for {data_str}"
297
  )
bytelatent/data/iterators/multiprocess_iterator.py CHANGED
@@ -54,9 +54,10 @@ def start_work_from_state(
54
  if stop_event.is_set():
55
  # Signal the end of output, this ensures that even if the queue takes a while to
56
  # buffer, that the main thread receives everything (and tosses this fake batch)
57
- logging.info(
58
  "Worker thread: Stop event detected, outputting is_final=True batch"
59
  )
 
60
  batch_queue.put(
61
  Batch(
62
  x=np.zeros((1, 1)),
@@ -67,14 +68,17 @@ def start_work_from_state(
67
  ngram_ids=None,
68
  )
69
  )
 
 
 
70
  break
71
 
72
  try:
73
- logging.info("Worker thread: outputting state")
74
- state_queue.put(iterator.get_state(), timeout=1)
75
- logging.info("Worker thread: state dump complete")
76
  state_dumped_event.set()
77
- logging.info("Worker thread: set state_dump_event")
78
  except Full:
79
  raise ValueError(
80
  "Attempted to dump state into the state queue, but it was full"
@@ -156,16 +160,20 @@ class MultiprocessIterator(StatefulIterator):
156
  serialized_prefetch_buffer=serialized_prefetch_buffer,
157
  )
158
  else:
159
- logging.info("Main thread: Sending stop iteration event")
160
  self.stop_iterating_event.set()
161
- logging.info("Main thread: Waiting for state_dumped event")
162
- self.state_dumped_event.wait()
 
163
  self.prefetch_buffer = []
164
  final_batch_received = False
165
  while True:
166
  try:
167
  batch = self.batch_queue.get(timeout=1)
168
  if batch.is_final:
 
 
 
169
  final_batch_received = True
170
  break
171
  self.prefetch_buffer.append(batch)
@@ -173,6 +181,9 @@ class MultiprocessIterator(StatefulIterator):
173
  logging.warning("Main thread: batch_queue is abnormally empty")
174
  assert final_batch_received
175
 
 
 
 
176
  try:
177
  base_iterator_state = self.state_queue.get(timeout=1)
178
  assert isinstance(base_iterator_state, IteratorState)
 
54
  if stop_event.is_set():
55
  # Signal the end of output, this ensures that even if the queue takes a while to
56
  # buffer, that the main thread receives everything (and tosses this fake batch)
57
+ logging.debug(
58
  "Worker thread: Stop event detected, outputting is_final=True batch"
59
  )
60
+ logging.debug("Worker thread: batch_queue full=%s", batch_queue.full())
61
  batch_queue.put(
62
  Batch(
63
  x=np.zeros((1, 1)),
 
68
  ngram_ids=None,
69
  )
70
  )
71
+ logging.debug(
72
+ "Worker thread: is_final=True batch put in queue, breaking from loop."
73
+ )
74
  break
75
 
76
  try:
77
+ logging.debug("Worker thread: outputting state")
78
+ state_queue.put(stateful_iterator.get_state(), timeout=1)
79
+ logging.debug("Worker thread: state dump complete")
80
  state_dumped_event.set()
81
+ logging.debug("Worker thread: set state_dump_event")
82
  except Full:
83
  raise ValueError(
84
  "Attempted to dump state into the state queue, but it was full"
 
160
  serialized_prefetch_buffer=serialized_prefetch_buffer,
161
  )
162
  else:
163
+ logging.debug("Main thread: Sending stop iteration event")
164
  self.stop_iterating_event.set()
165
+ logging.debug(
166
+ "Main thread: Emptying the batch_queue until batch.is_final=True is found."
167
+ )
168
  self.prefetch_buffer = []
169
  final_batch_received = False
170
  while True:
171
  try:
172
  batch = self.batch_queue.get(timeout=1)
173
  if batch.is_final:
174
+ logging.debug(
175
+ "Main thread: is_final=True batch found, stopping fetch from batch_queue"
176
+ )
177
  final_batch_received = True
178
  break
179
  self.prefetch_buffer.append(batch)
 
181
  logging.warning("Main thread: batch_queue is abnormally empty")
182
  assert final_batch_received
183
 
184
+ logging.debug("Main thread: Waiting for state_dumped event")
185
+ self.state_dumped_event.wait()
186
+
187
  try:
188
  base_iterator_state = self.state_queue.get(timeout=1)
189
  assert isinstance(base_iterator_state, IteratorState)
bytelatent/train.py CHANGED
@@ -26,6 +26,7 @@ from torch.optim import lr_scheduler
26
  from bytelatent.args import TrainArgs, parse_args
27
  from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint
28
  from bytelatent.data.file_util import get_fs
 
29
  from bytelatent.data.iterators.multiprocess_iterator import (
30
  MultiprocessIterator,
31
  MultiprocessIteratorState,
@@ -35,7 +36,6 @@ from bytelatent.distributed import (
35
  check_model_value_range,
36
  clean_env,
37
  dist_mean,
38
- dist_mean_dict,
39
  dist_sum,
40
  get_device_mesh,
41
  get_is_master,
@@ -88,6 +88,13 @@ def get_iterator_state_name(iterator_state):
88
  raise ValueError(f"Unsupported iterator to get name from: {iterator_state}")
89
 
90
 
 
 
 
 
 
 
 
91
  # TODO: Make this pydantic based instead of data class based
92
  # TODO: Generalize this to any iterator state
93
  @dataclass
@@ -603,20 +610,20 @@ def train(args: TrainArgs):
603
  # step: Metric at a step
604
  # interval: Metric averaged/summed across all steps since the last log interval.
605
  # Typically, this is 10
606
- step_loss_per_gpu = loss.item()
607
- step_loss_across_gpus = dist_mean(step_loss_per_gpu).item()
608
- interval_loss_per_gpu = np.mean(step_losses).item()
609
- interval_loss_across_gpus = dist_mean(interval_loss_per_gpu).item()
610
 
611
  stacked_tok_loss = torch.cat(step_tok_losses, dim=0)
612
- interval_total_tok_loss_per_gpu = stacked_tok_loss.sum().item()
613
  interval_total_tok_loss_across_gpus = dist_sum(
614
  interval_total_tok_loss_per_gpu, reduce_dtype=torch.bfloat16
615
- ).item()
616
- interval_total_n_bytes_per_gpu = n_bytes.item()
617
  interval_total_n_bytes_across_gpus = dist_sum(
618
  n_bytes, reduce_dtype=torch.bfloat16
619
- ).item()
620
 
621
  interval_bpb_per_gpu = (
622
  interval_total_tok_loss_per_gpu
@@ -645,18 +652,20 @@ def train(args: TrainArgs):
645
  },
646
  "memory": gpu_mem_stats._asdict(),
647
  "loss": {
648
- "step_per_gpu": step_loss_per_gpu,
649
- "step_across_gpu": step_loss_across_gpus,
650
- "interval_per_gpu": interval_loss_per_gpu,
651
- "interval_across_gpu": interval_loss_across_gpus,
652
  },
653
  "bpb": {
654
- "interval_per_gpu": interval_bpb_per_gpu,
655
- "interval_across_gpus": interval_bpb_across_gpus,
656
  },
657
  "n_bytes": {
658
- "interval_per_gpu": interval_total_n_bytes_per_gpu,
659
- "interval_across_gpus": interval_total_n_bytes_across_gpus,
 
 
660
  },
661
  }
662
 
@@ -676,8 +685,8 @@ def train(args: TrainArgs):
676
  logger.info(
677
  f"step: {train_state.step}"
678
  f" acc: {train_state.acc_step}"
679
- f" loss_gpu: {round(interval_loss_per_gpu, 4):>7}"
680
- f" loss_avg: {round(interval_loss_across_gpus, 4):>7}"
681
  f" bpb_gpu: {interval_bpb_per_gpu:3f}"
682
  f" bpb_avg: {interval_bpb_across_gpus:3f}"
683
  f" grad: {grad_norm:.2e}"
@@ -702,6 +711,9 @@ def train(args: TrainArgs):
702
  if every_n_steps(
703
  train_state, args.checkpoint.dump.every, acc_step=0
704
  ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
 
 
 
705
  saved = checkpoint.save(
706
  model,
707
  optimizer,
@@ -743,6 +755,9 @@ def train(args: TrainArgs):
743
 
744
  if preemption_flag["flag"]:
745
  if not saved:
 
 
 
746
  checkpoint.save(
747
  model,
748
  optimizer,
@@ -754,6 +769,9 @@ def train(args: TrainArgs):
754
  sys.exit(0)
755
 
756
  if not saved:
 
 
 
757
  checkpoint.save(
758
  model,
759
  optimizer,
 
26
  from bytelatent.args import TrainArgs, parse_args
27
  from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint
28
  from bytelatent.data.file_util import get_fs
29
+ from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh
30
  from bytelatent.data.iterators.multiprocess_iterator import (
31
  MultiprocessIterator,
32
  MultiprocessIteratorState,
 
36
  check_model_value_range,
37
  clean_env,
38
  dist_mean,
 
39
  dist_sum,
40
  get_device_mesh,
41
  get_is_master,
 
88
  raise ValueError(f"Unsupported iterator to get name from: {iterator_state}")
89
 
90
 
91
+ def to_py_num(num: int | float | torch.Tensor | np.ndarray) -> int | float:
92
+ if isinstance(num, (torch.Tensor, np.ndarray)):
93
+ return num.item()
94
+ else:
95
+ return num
96
+
97
+
98
  # TODO: Make this pydantic based instead of data class based
99
  # TODO: Generalize this to any iterator state
100
  @dataclass
 
610
  # step: Metric at a step
611
  # interval: Metric averaged/summed across all steps since the last log interval.
612
  # Typically, this is 10
613
+ step_loss_per_gpu = loss
614
+ step_loss_across_gpus = dist_mean(step_loss_per_gpu)
615
+ interval_loss_per_gpu = np.mean(step_losses)
616
+ interval_loss_across_gpus = dist_mean(interval_loss_per_gpu)
617
 
618
  stacked_tok_loss = torch.cat(step_tok_losses, dim=0)
619
+ interval_total_tok_loss_per_gpu = stacked_tok_loss.sum()
620
  interval_total_tok_loss_across_gpus = dist_sum(
621
  interval_total_tok_loss_per_gpu, reduce_dtype=torch.bfloat16
622
+ )
623
+ interval_total_n_bytes_per_gpu = n_bytes
624
  interval_total_n_bytes_across_gpus = dist_sum(
625
  n_bytes, reduce_dtype=torch.bfloat16
626
+ )
627
 
628
  interval_bpb_per_gpu = (
629
  interval_total_tok_loss_per_gpu
 
652
  },
653
  "memory": gpu_mem_stats._asdict(),
654
  "loss": {
655
+ "step_per_gpu": to_py_num(step_loss_per_gpu),
656
+ "step_across_gpu": to_py_num(step_loss_across_gpus),
657
+ "interval_per_gpu": to_py_num(interval_loss_per_gpu),
658
+ "interval_across_gpu": to_py_num(interval_loss_across_gpus),
659
  },
660
  "bpb": {
661
+ "interval_per_gpu": to_py_num(interval_bpb_per_gpu),
662
+ "interval_across_gpus": to_py_num(interval_bpb_across_gpus),
663
  },
664
  "n_bytes": {
665
+ "interval_per_gpu": to_py_num(interval_total_n_bytes_per_gpu),
666
+ "interval_across_gpus": to_py_num(
667
+ interval_total_n_bytes_across_gpus
668
+ ),
669
  },
670
  }
671
 
 
685
  logger.info(
686
  f"step: {train_state.step}"
687
  f" acc: {train_state.acc_step}"
688
+ f" loss_gpu: {round(to_py_num(interval_loss_per_gpu), 4):>7}"
689
+ f" loss_avg: {round(to_py_num(interval_loss_across_gpus), 4):>7}"
690
  f" bpb_gpu: {interval_bpb_per_gpu:3f}"
691
  f" bpb_avg: {interval_bpb_across_gpus:3f}"
692
  f" grad: {grad_norm:.2e}"
 
711
  if every_n_steps(
712
  train_state, args.checkpoint.dump.every, acc_step=0
713
  ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
714
+ train_state.data_loader_state, data_loader, batch_iterator = (
715
+ get_state_and_refresh(data_loader)
716
+ )
717
  saved = checkpoint.save(
718
  model,
719
  optimizer,
 
755
 
756
  if preemption_flag["flag"]:
757
  if not saved:
758
+ train_state.data_loader_state, data_loader, batch_iterator = (
759
+ get_state_and_refresh(data_loader)
760
+ )
761
  checkpoint.save(
762
  model,
763
  optimizer,
 
769
  sys.exit(0)
770
 
771
  if not saved:
772
+ train_state.data_loader_state, data_loader, batch_iterator = (
773
+ get_state_and_refresh(data_loader)
774
+ )
775
  checkpoint.save(
776
  model,
777
  optimizer,