Spaces:
Running
on
Zero
Running
on
Zero
Fix multiprocessing dataloader checkpointing and use it in the train script (#50)
Browse files
bytelatent/args.py
CHANGED
@@ -1,10 +1,8 @@
|
|
1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
import json
|
3 |
import logging
|
4 |
import os
|
5 |
from typing import Any
|
6 |
|
7 |
-
import fsspec
|
8 |
import numpy as np
|
9 |
import yaml
|
10 |
from omegaconf import OmegaConf
|
|
|
1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
|
2 |
import logging
|
3 |
import os
|
4 |
from typing import Any
|
5 |
|
|
|
6 |
import numpy as np
|
7 |
import yaml
|
8 |
from omegaconf import OmegaConf
|
bytelatent/data/iterators/abstract_iterator.py
CHANGED
@@ -21,3 +21,13 @@ class IteratorState(Generic[C]):
|
|
21 |
@abc.abstractmethod
|
22 |
def build(self) -> StatefulIterator[T, C]:
|
23 |
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
@abc.abstractmethod
|
22 |
def build(self) -> StatefulIterator[T, C]:
|
23 |
pass
|
24 |
+
|
25 |
+
|
26 |
+
def get_state_and_refresh(iterator: StatefulIterator):
|
27 |
+
# Re-init dataloader and iterator is necessary since get_state()
|
28 |
+
# on mp iterator shuts down MP to correctly persist state and it needs
|
29 |
+
# to be restarted.
|
30 |
+
state = iterator.get_state()
|
31 |
+
data_loader = state.build()
|
32 |
+
py_iterator = data_loader.create_iter()
|
33 |
+
return state, data_loader, py_iterator
|
bytelatent/data/iterators/arrow_iterator.py
CHANGED
@@ -60,6 +60,13 @@ def shard_sort_key(file: str):
|
|
60 |
return shard_number
|
61 |
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
class ArrowFileIterator(StatefulIterator):
|
64 |
def __init__(
|
65 |
self,
|
@@ -235,9 +242,8 @@ class ArrowFileIterator(StatefulIterator):
|
|
235 |
yield out
|
236 |
|
237 |
def _set_row_num(self, target_row_num: int):
|
238 |
-
|
239 |
-
|
240 |
-
)
|
241 |
if target_row_num is None or target_row_num == 0:
|
242 |
self.row_num = 0
|
243 |
self.dataset = None
|
@@ -285,6 +291,7 @@ class ArrowFileIterator(StatefulIterator):
|
|
285 |
else:
|
286 |
curr_remaining -= len(batch)
|
287 |
self.row_num = target_row_num
|
|
|
288 |
logger.info(
|
289 |
-
f"Finished setting arrow position to {target_row_num} for {
|
290 |
)
|
|
|
60 |
return shard_number
|
61 |
|
62 |
|
63 |
+
def maybe_truncate_string(text: str, max_length: int):
|
64 |
+
if len(text) <= max_length:
|
65 |
+
return text
|
66 |
+
else:
|
67 |
+
return text[:max_length] + "..."
|
68 |
+
|
69 |
+
|
70 |
class ArrowFileIterator(StatefulIterator):
|
71 |
def __init__(
|
72 |
self,
|
|
|
242 |
yield out
|
243 |
|
244 |
def _set_row_num(self, target_row_num: int):
|
245 |
+
data_str = maybe_truncate_string(str(self.dataset_files), 200)
|
246 |
+
logger.info(f"Setting arrow position to {target_row_num} for {data_str}")
|
|
|
247 |
if target_row_num is None or target_row_num == 0:
|
248 |
self.row_num = 0
|
249 |
self.dataset = None
|
|
|
291 |
else:
|
292 |
curr_remaining -= len(batch)
|
293 |
self.row_num = target_row_num
|
294 |
+
data_str = maybe_truncate_string(str(self.dataset_files), 200)
|
295 |
logger.info(
|
296 |
+
f"Finished setting arrow position to {target_row_num} for {data_str}"
|
297 |
)
|
bytelatent/data/iterators/multiprocess_iterator.py
CHANGED
@@ -54,9 +54,10 @@ def start_work_from_state(
|
|
54 |
if stop_event.is_set():
|
55 |
# Signal the end of output, this ensures that even if the queue takes a while to
|
56 |
# buffer, that the main thread receives everything (and tosses this fake batch)
|
57 |
-
logging.
|
58 |
"Worker thread: Stop event detected, outputting is_final=True batch"
|
59 |
)
|
|
|
60 |
batch_queue.put(
|
61 |
Batch(
|
62 |
x=np.zeros((1, 1)),
|
@@ -67,14 +68,17 @@ def start_work_from_state(
|
|
67 |
ngram_ids=None,
|
68 |
)
|
69 |
)
|
|
|
|
|
|
|
70 |
break
|
71 |
|
72 |
try:
|
73 |
-
logging.
|
74 |
-
state_queue.put(
|
75 |
-
logging.
|
76 |
state_dumped_event.set()
|
77 |
-
logging.
|
78 |
except Full:
|
79 |
raise ValueError(
|
80 |
"Attempted to dump state into the state queue, but it was full"
|
@@ -156,16 +160,20 @@ class MultiprocessIterator(StatefulIterator):
|
|
156 |
serialized_prefetch_buffer=serialized_prefetch_buffer,
|
157 |
)
|
158 |
else:
|
159 |
-
logging.
|
160 |
self.stop_iterating_event.set()
|
161 |
-
logging.
|
162 |
-
|
|
|
163 |
self.prefetch_buffer = []
|
164 |
final_batch_received = False
|
165 |
while True:
|
166 |
try:
|
167 |
batch = self.batch_queue.get(timeout=1)
|
168 |
if batch.is_final:
|
|
|
|
|
|
|
169 |
final_batch_received = True
|
170 |
break
|
171 |
self.prefetch_buffer.append(batch)
|
@@ -173,6 +181,9 @@ class MultiprocessIterator(StatefulIterator):
|
|
173 |
logging.warning("Main thread: batch_queue is abnormally empty")
|
174 |
assert final_batch_received
|
175 |
|
|
|
|
|
|
|
176 |
try:
|
177 |
base_iterator_state = self.state_queue.get(timeout=1)
|
178 |
assert isinstance(base_iterator_state, IteratorState)
|
|
|
54 |
if stop_event.is_set():
|
55 |
# Signal the end of output, this ensures that even if the queue takes a while to
|
56 |
# buffer, that the main thread receives everything (and tosses this fake batch)
|
57 |
+
logging.debug(
|
58 |
"Worker thread: Stop event detected, outputting is_final=True batch"
|
59 |
)
|
60 |
+
logging.debug("Worker thread: batch_queue full=%s", batch_queue.full())
|
61 |
batch_queue.put(
|
62 |
Batch(
|
63 |
x=np.zeros((1, 1)),
|
|
|
68 |
ngram_ids=None,
|
69 |
)
|
70 |
)
|
71 |
+
logging.debug(
|
72 |
+
"Worker thread: is_final=True batch put in queue, breaking from loop."
|
73 |
+
)
|
74 |
break
|
75 |
|
76 |
try:
|
77 |
+
logging.debug("Worker thread: outputting state")
|
78 |
+
state_queue.put(stateful_iterator.get_state(), timeout=1)
|
79 |
+
logging.debug("Worker thread: state dump complete")
|
80 |
state_dumped_event.set()
|
81 |
+
logging.debug("Worker thread: set state_dump_event")
|
82 |
except Full:
|
83 |
raise ValueError(
|
84 |
"Attempted to dump state into the state queue, but it was full"
|
|
|
160 |
serialized_prefetch_buffer=serialized_prefetch_buffer,
|
161 |
)
|
162 |
else:
|
163 |
+
logging.debug("Main thread: Sending stop iteration event")
|
164 |
self.stop_iterating_event.set()
|
165 |
+
logging.debug(
|
166 |
+
"Main thread: Emptying the batch_queue until batch.is_final=True is found."
|
167 |
+
)
|
168 |
self.prefetch_buffer = []
|
169 |
final_batch_received = False
|
170 |
while True:
|
171 |
try:
|
172 |
batch = self.batch_queue.get(timeout=1)
|
173 |
if batch.is_final:
|
174 |
+
logging.debug(
|
175 |
+
"Main thread: is_final=True batch found, stopping fetch from batch_queue"
|
176 |
+
)
|
177 |
final_batch_received = True
|
178 |
break
|
179 |
self.prefetch_buffer.append(batch)
|
|
|
181 |
logging.warning("Main thread: batch_queue is abnormally empty")
|
182 |
assert final_batch_received
|
183 |
|
184 |
+
logging.debug("Main thread: Waiting for state_dumped event")
|
185 |
+
self.state_dumped_event.wait()
|
186 |
+
|
187 |
try:
|
188 |
base_iterator_state = self.state_queue.get(timeout=1)
|
189 |
assert isinstance(base_iterator_state, IteratorState)
|
bytelatent/train.py
CHANGED
@@ -26,6 +26,7 @@ from torch.optim import lr_scheduler
|
|
26 |
from bytelatent.args import TrainArgs, parse_args
|
27 |
from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint
|
28 |
from bytelatent.data.file_util import get_fs
|
|
|
29 |
from bytelatent.data.iterators.multiprocess_iterator import (
|
30 |
MultiprocessIterator,
|
31 |
MultiprocessIteratorState,
|
@@ -35,7 +36,6 @@ from bytelatent.distributed import (
|
|
35 |
check_model_value_range,
|
36 |
clean_env,
|
37 |
dist_mean,
|
38 |
-
dist_mean_dict,
|
39 |
dist_sum,
|
40 |
get_device_mesh,
|
41 |
get_is_master,
|
@@ -88,6 +88,13 @@ def get_iterator_state_name(iterator_state):
|
|
88 |
raise ValueError(f"Unsupported iterator to get name from: {iterator_state}")
|
89 |
|
90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
# TODO: Make this pydantic based instead of data class based
|
92 |
# TODO: Generalize this to any iterator state
|
93 |
@dataclass
|
@@ -603,20 +610,20 @@ def train(args: TrainArgs):
|
|
603 |
# step: Metric at a step
|
604 |
# interval: Metric averaged/summed across all steps since the last log interval.
|
605 |
# Typically, this is 10
|
606 |
-
step_loss_per_gpu = loss
|
607 |
-
step_loss_across_gpus = dist_mean(step_loss_per_gpu)
|
608 |
-
interval_loss_per_gpu = np.mean(step_losses)
|
609 |
-
interval_loss_across_gpus = dist_mean(interval_loss_per_gpu)
|
610 |
|
611 |
stacked_tok_loss = torch.cat(step_tok_losses, dim=0)
|
612 |
-
interval_total_tok_loss_per_gpu = stacked_tok_loss.sum()
|
613 |
interval_total_tok_loss_across_gpus = dist_sum(
|
614 |
interval_total_tok_loss_per_gpu, reduce_dtype=torch.bfloat16
|
615 |
-
)
|
616 |
-
interval_total_n_bytes_per_gpu = n_bytes
|
617 |
interval_total_n_bytes_across_gpus = dist_sum(
|
618 |
n_bytes, reduce_dtype=torch.bfloat16
|
619 |
-
)
|
620 |
|
621 |
interval_bpb_per_gpu = (
|
622 |
interval_total_tok_loss_per_gpu
|
@@ -645,18 +652,20 @@ def train(args: TrainArgs):
|
|
645 |
},
|
646 |
"memory": gpu_mem_stats._asdict(),
|
647 |
"loss": {
|
648 |
-
"step_per_gpu": step_loss_per_gpu,
|
649 |
-
"step_across_gpu": step_loss_across_gpus,
|
650 |
-
"interval_per_gpu": interval_loss_per_gpu,
|
651 |
-
"interval_across_gpu": interval_loss_across_gpus,
|
652 |
},
|
653 |
"bpb": {
|
654 |
-
"interval_per_gpu": interval_bpb_per_gpu,
|
655 |
-
"interval_across_gpus": interval_bpb_across_gpus,
|
656 |
},
|
657 |
"n_bytes": {
|
658 |
-
"interval_per_gpu": interval_total_n_bytes_per_gpu,
|
659 |
-
"interval_across_gpus":
|
|
|
|
|
660 |
},
|
661 |
}
|
662 |
|
@@ -676,8 +685,8 @@ def train(args: TrainArgs):
|
|
676 |
logger.info(
|
677 |
f"step: {train_state.step}"
|
678 |
f" acc: {train_state.acc_step}"
|
679 |
-
f" loss_gpu: {round(interval_loss_per_gpu, 4):>7}"
|
680 |
-
f" loss_avg: {round(interval_loss_across_gpus, 4):>7}"
|
681 |
f" bpb_gpu: {interval_bpb_per_gpu:3f}"
|
682 |
f" bpb_avg: {interval_bpb_across_gpus:3f}"
|
683 |
f" grad: {grad_norm:.2e}"
|
@@ -702,6 +711,9 @@ def train(args: TrainArgs):
|
|
702 |
if every_n_steps(
|
703 |
train_state, args.checkpoint.dump.every, acc_step=0
|
704 |
) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
|
|
|
|
|
|
|
705 |
saved = checkpoint.save(
|
706 |
model,
|
707 |
optimizer,
|
@@ -743,6 +755,9 @@ def train(args: TrainArgs):
|
|
743 |
|
744 |
if preemption_flag["flag"]:
|
745 |
if not saved:
|
|
|
|
|
|
|
746 |
checkpoint.save(
|
747 |
model,
|
748 |
optimizer,
|
@@ -754,6 +769,9 @@ def train(args: TrainArgs):
|
|
754 |
sys.exit(0)
|
755 |
|
756 |
if not saved:
|
|
|
|
|
|
|
757 |
checkpoint.save(
|
758 |
model,
|
759 |
optimizer,
|
|
|
26 |
from bytelatent.args import TrainArgs, parse_args
|
27 |
from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint
|
28 |
from bytelatent.data.file_util import get_fs
|
29 |
+
from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh
|
30 |
from bytelatent.data.iterators.multiprocess_iterator import (
|
31 |
MultiprocessIterator,
|
32 |
MultiprocessIteratorState,
|
|
|
36 |
check_model_value_range,
|
37 |
clean_env,
|
38 |
dist_mean,
|
|
|
39 |
dist_sum,
|
40 |
get_device_mesh,
|
41 |
get_is_master,
|
|
|
88 |
raise ValueError(f"Unsupported iterator to get name from: {iterator_state}")
|
89 |
|
90 |
|
91 |
+
def to_py_num(num: int | float | torch.Tensor | np.ndarray) -> int | float:
|
92 |
+
if isinstance(num, (torch.Tensor, np.ndarray)):
|
93 |
+
return num.item()
|
94 |
+
else:
|
95 |
+
return num
|
96 |
+
|
97 |
+
|
98 |
# TODO: Make this pydantic based instead of data class based
|
99 |
# TODO: Generalize this to any iterator state
|
100 |
@dataclass
|
|
|
610 |
# step: Metric at a step
|
611 |
# interval: Metric averaged/summed across all steps since the last log interval.
|
612 |
# Typically, this is 10
|
613 |
+
step_loss_per_gpu = loss
|
614 |
+
step_loss_across_gpus = dist_mean(step_loss_per_gpu)
|
615 |
+
interval_loss_per_gpu = np.mean(step_losses)
|
616 |
+
interval_loss_across_gpus = dist_mean(interval_loss_per_gpu)
|
617 |
|
618 |
stacked_tok_loss = torch.cat(step_tok_losses, dim=0)
|
619 |
+
interval_total_tok_loss_per_gpu = stacked_tok_loss.sum()
|
620 |
interval_total_tok_loss_across_gpus = dist_sum(
|
621 |
interval_total_tok_loss_per_gpu, reduce_dtype=torch.bfloat16
|
622 |
+
)
|
623 |
+
interval_total_n_bytes_per_gpu = n_bytes
|
624 |
interval_total_n_bytes_across_gpus = dist_sum(
|
625 |
n_bytes, reduce_dtype=torch.bfloat16
|
626 |
+
)
|
627 |
|
628 |
interval_bpb_per_gpu = (
|
629 |
interval_total_tok_loss_per_gpu
|
|
|
652 |
},
|
653 |
"memory": gpu_mem_stats._asdict(),
|
654 |
"loss": {
|
655 |
+
"step_per_gpu": to_py_num(step_loss_per_gpu),
|
656 |
+
"step_across_gpu": to_py_num(step_loss_across_gpus),
|
657 |
+
"interval_per_gpu": to_py_num(interval_loss_per_gpu),
|
658 |
+
"interval_across_gpu": to_py_num(interval_loss_across_gpus),
|
659 |
},
|
660 |
"bpb": {
|
661 |
+
"interval_per_gpu": to_py_num(interval_bpb_per_gpu),
|
662 |
+
"interval_across_gpus": to_py_num(interval_bpb_across_gpus),
|
663 |
},
|
664 |
"n_bytes": {
|
665 |
+
"interval_per_gpu": to_py_num(interval_total_n_bytes_per_gpu),
|
666 |
+
"interval_across_gpus": to_py_num(
|
667 |
+
interval_total_n_bytes_across_gpus
|
668 |
+
),
|
669 |
},
|
670 |
}
|
671 |
|
|
|
685 |
logger.info(
|
686 |
f"step: {train_state.step}"
|
687 |
f" acc: {train_state.acc_step}"
|
688 |
+
f" loss_gpu: {round(to_py_num(interval_loss_per_gpu), 4):>7}"
|
689 |
+
f" loss_avg: {round(to_py_num(interval_loss_across_gpus), 4):>7}"
|
690 |
f" bpb_gpu: {interval_bpb_per_gpu:3f}"
|
691 |
f" bpb_avg: {interval_bpb_across_gpus:3f}"
|
692 |
f" grad: {grad_norm:.2e}"
|
|
|
711 |
if every_n_steps(
|
712 |
train_state, args.checkpoint.dump.every, acc_step=0
|
713 |
) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
|
714 |
+
train_state.data_loader_state, data_loader, batch_iterator = (
|
715 |
+
get_state_and_refresh(data_loader)
|
716 |
+
)
|
717 |
saved = checkpoint.save(
|
718 |
model,
|
719 |
optimizer,
|
|
|
755 |
|
756 |
if preemption_flag["flag"]:
|
757 |
if not saved:
|
758 |
+
train_state.data_loader_state, data_loader, batch_iterator = (
|
759 |
+
get_state_and_refresh(data_loader)
|
760 |
+
)
|
761 |
checkpoint.save(
|
762 |
model,
|
763 |
optimizer,
|
|
|
769 |
sys.exit(0)
|
770 |
|
771 |
if not saved:
|
772 |
+
train_state.data_loader_state, data_loader, batch_iterator = (
|
773 |
+
get_state_and_refresh(data_loader)
|
774 |
+
)
|
775 |
checkpoint.save(
|
776 |
model,
|
777 |
optimizer,
|