Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. | |
import gc | |
import logging | |
import os | |
import sys | |
from contextlib import ExitStack | |
from copy import deepcopy | |
from dataclasses import asdict, dataclass, field | |
from pathlib import Path | |
from timeit import default_timer as timer | |
from typing import Any, Dict, Optional | |
import torch | |
import torch.distributed | |
import wandb | |
import xformers.profiler | |
from lingua.args import dataclass_from_dict, dump_config, flatten_dict | |
from lingua.data import ( | |
DataArgs, | |
PackTokensState, | |
build_dataloader_from_args, | |
init_dataloader_state_from_args, | |
) | |
from lingua.tokenizers.build_tokenizer import TokenizerArgs | |
from omegaconf import OmegaConf | |
from pydantic import BaseModel | |
from torch.distributed._tensor import DTensor | |
from torch.distributed.checkpoint.stateful import Stateful | |
from torch.optim import lr_scheduler | |
from bytelatent.checkpoint import ( | |
CheckpointArgs, | |
CheckpointManager, | |
load_from_checkpoint, | |
) | |
from bytelatent.distributed import ( | |
DistributedArgs, | |
EnvironmentArgs, | |
check_model_value_range, | |
clean_env, | |
dist_mean_dict, | |
get_device_mesh, | |
get_is_master, | |
get_world_size, | |
init_signal_handler, | |
parallelize_model, | |
requeue_slurm_job, | |
setup_env, | |
setup_torch_distributed, | |
) | |
from bytelatent.logger import init_logger | |
from bytelatent.metrics import ( | |
GPUMemoryMonitor, | |
LoggingArgs, | |
MetricLogger, | |
get_num_params, | |
) | |
from bytelatent.optim import OptimArgs, build_optimizer | |
from bytelatent.probe import AutoProbeD | |
from bytelatent.profiling import ProfilerArgs, maybe_run_profiler | |
from bytelatent.stool import StoolArgs, launch_job | |
from bytelatent.transformer import ( | |
LMTransformer, | |
LMTransformerArgs, | |
build_fsdp_grouping_plan, | |
get_no_recompute_ops, | |
get_num_flop_per_token, | |
tp_parallelize, | |
) | |
logger = logging.getLogger() | |
class TrainArgs(BaseModel): | |
name: str = "lingua" | |
dump_dir: str = "" | |
seed: int = 42 | |
# Number of gradient accumulation steps | |
# Total batch size is batch_size*grad_acc_steps | |
grad_acc_steps: int = 1 | |
gc_collect_freq: int = 1000 | |
probe_freq: int | None = None | |
# Nb optimizer steps to take | |
steps: int = 1000 | |
data: DataArgs | |
optim: OptimArgs | |
model: LMTransformerArgs | |
distributed: DistributedArgs | |
env: EnvironmentArgs | |
checkpoint: CheckpointArgs | |
profiling: ProfilerArgs | |
logging: LoggingArgs | |
# If set to None, eval is run locally otherwise it launches a new job with the given number of gpus | |
async_eval_gpus: int | None = None | |
eval: Any | None = None | |
class TrainState(Stateful): | |
step: int # Nb of steps taken by the optimizer | |
acc_step: int # Nb of accumulation steps done since last optimizer step | |
scheduler: lr_scheduler.LambdaLR | |
data_loader_state: PackTokensState | |
def state_dict(self) -> Dict[str, Any]: | |
return { | |
"step": self.step, | |
"acc_step": self.acc_step, | |
"data_loader_state": self.data_loader_state, | |
"scheduler": self.scheduler.state_dict(), | |
} | |
def load_state_dict(self, state_dict): | |
self.step = state_dict["step"] | |
self.acc_step = state_dict["acc_step"] | |
self.data_loader_state = PackTokensState(**state_dict["data_loader_state"]) | |
self.scheduler.load_state_dict(state_dict["scheduler"]) | |
def validate_train_args(args: TrainArgs, output_size: int): | |
if args.model.vocab_size < 0: | |
logger.info(f"Setting model output size to {args.model.vocab_size}") | |
args.model.vocab_size = output_size | |
assert ( | |
args.model.vocab_size == output_size | |
), "Vocab size should be the same as output size" | |
assert args.dump_dir, "Dump dir not set" | |
if args.checkpoint.path is None: | |
logger.info(f"Setting checkpoint path to {args.checkpoint.path}") | |
args.checkpoint.path = str(Path(args.dump_dir) / "checkpoints") | |
for source in args.data.sources: | |
data_path = os.path.join(args.data.root_dir, source) | |
assert os.path.exists(data_path), f"{data_path} doesn't exist" | |
if ( | |
args.distributed.dp_replicate | |
* args.distributed.dp_shard | |
* args.distributed.tp_size | |
!= get_world_size() | |
): | |
assert get_world_size() % args.distributed.dp_shard == 0 | |
args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard | |
assert args.distributed.dp_replicate % args.distributed.tp_size == 0 | |
args.distributed.dp_replicate = ( | |
args.distributed.dp_replicate // args.distributed.tp_size | |
) | |
logger.warning( | |
f"Setting Data Parallel size to {args.distributed.dp_replicate * args.distributed.dp_shard}" | |
) | |
assert ( | |
args.distributed.dp_replicate | |
* args.distributed.dp_shard | |
* args.distributed.tp_size | |
== get_world_size() | |
) | |
if args.distributed.fsdp_type == "no_shard": | |
assert ( | |
args.distributed.dp_shard == 1 | |
and args.distributed.dp_replicate == get_world_size() | |
) | |
args.model.max_seqlen = args.data.seq_len | |
if args.distributed.tp_size == 1: | |
logger.warning( | |
"Tensor parallelism has not been tested for a while, use at your own risk" | |
) | |
assert ( | |
args.probe_freq != args.profiling.mem_steps | |
), "Don't profile during probe step" | |
assert ( | |
args.probe_freq != args.profiling.profile_steps | |
), "Don't profile during probe step" | |
if args.logging.wandb is not None: | |
args.logging.wandb.name = args.name | |
if args.probe_freq is not None: | |
assert ( | |
args.distributed.tp_size == 1 | |
), "Probing not supported with tensor parallelism" | |
assert ( | |
args.distributed.selective_activation_checkpointing is False | |
), "Probing not supported with selective activation checkpointing" | |
preemption_flag = dict(flag=False) | |
def set_preemption_flag(signum, frame): | |
logger.warning("Signal handler called with signal " + str(signum)) | |
logger.warning("Preemption ! checkpointing asap and exiting.") | |
preemption_flag["flag"] = True | |
def every_n_steps(train_state, freq, acc_step=None, acc_freq=None): | |
test = train_state.step % freq == 0 | |
if acc_step is not None: | |
test = test and (train_state.acc_step == acc_step) | |
elif acc_freq is not None: | |
test = test and ((train_state.acc_step % acc_freq) == 0) | |
return test | |
def train(args: TrainArgs): | |
with ExitStack() as context_stack: | |
tokenizer_args = TokenizerArgs( | |
name=args.data.name, | |
init_kwargs=args.data.tokenizer.init_kwargs, | |
) | |
tokenizer = tokenizer_args.build() | |
validate_train_args( | |
args, | |
tokenizer.n_words, | |
) | |
if get_is_master(): | |
os.makedirs(args.dump_dir, exist_ok=True) | |
dump_config(args, Path(args.dump_dir) / "config.yaml") | |
init_logger(Path(args.dump_dir) / "train.log") | |
init_signal_handler(set_preemption_flag) # For handling preemption signals. | |
setup_env(args.env) | |
setup_torch_distributed(args.distributed) | |
world_mesh = get_device_mesh(args.distributed) | |
logger.info(f"Starting job: {args.name}") | |
# build dataloader | |
# need dp world size and rank | |
dp_mesh = world_mesh["dp_replicate"] | |
dp_degree = dp_mesh.size() | |
dp_rank = dp_mesh.get_local_rank() | |
if args.distributed.dp_shard > 1: | |
dp_rank = dp_rank * dp_degree + world_mesh["dp_shard"].get_local_rank() | |
dp_degree *= world_mesh["dp_shard"].size() | |
logger.info(f"Running on dp rank : {dp_rank}") | |
logger.info(f"Running on dp size : {dp_degree}") | |
torch.manual_seed(args.seed) | |
logger.info("Building model") | |
# Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory | |
with torch.device("meta"): | |
model = LMTransformer(args.model) | |
logger.info("Model is built !") | |
model_param_count = get_num_params(model) | |
model = parallelize_model( | |
model, | |
world_mesh, | |
args.model, | |
args.distributed, | |
fsdp_grouping_plan=build_fsdp_grouping_plan(args.model), | |
tp_parallelize=tp_parallelize, | |
no_recompute_ops=get_no_recompute_ops(), | |
) | |
# Once we shard the model on different gpus we can actually initialize the model | |
# First we create empty tensors of the correct shapes | |
model = model.to_empty(device="cuda") | |
# Then we init the model. Please make sure this function initializes *ALL* parameters | |
# and buffers, otherwise you will have random values in the unitialized tensors | |
# which will silently fail (give nan gradients for example) | |
if args.checkpoint.init_ckpt_path: | |
logger.info(f"Loading initial model from {args.checkpoint.init_ckpt_path}") | |
load_from_checkpoint( | |
args.checkpoint.init_ckpt_path, model, model_key="model" | |
) # Put model_key="" if its directly the model checkpoint | |
model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded | |
else: | |
with torch.random.fork_rng(devices=[torch.cuda.current_device()]): | |
torch.manual_seed(args.model.seed) | |
model.init_weights() | |
check_model_value_range(model, range=10.0, std=1.0) | |
# log model size | |
logger.info(f"Model size: {model_param_count:,} total parameters") | |
gpu_memory_monitor = GPUMemoryMonitor("cuda") | |
logger.info( | |
f"GPU capacity: {gpu_memory_monitor.device_name} ({gpu_memory_monitor.device_index}) " | |
f"with {gpu_memory_monitor.device_capacity_gib:.2f}GiB memory" | |
) | |
logger.info(f"GPU memory usage: {gpu_memory_monitor}") | |
# build optimizer after apply parallelisms to the model | |
optimizer, scheduler = build_optimizer(model, args.optim, args.steps) | |
data_loader_state = init_dataloader_state_from_args( | |
args.data, dp_rank, dp_degree | |
) | |
train_state = TrainState( | |
step=0, | |
acc_step=0, | |
data_loader_state=data_loader_state, | |
scheduler=scheduler, | |
) | |
checkpoint = CheckpointManager.instantiate_and_make_dir(args.checkpoint) | |
checkpoint.load(model, optimizer, train_state, world_mesh) | |
# Either load from latest checkpoint or start from scratch | |
if args.probe_freq is not None: | |
if get_is_master(): | |
os.makedirs(Path(args.dump_dir) / "probe", exist_ok=True) | |
torch.distributed.barrier() | |
probe = AutoProbeD( | |
model, | |
( | |
Path(args.dump_dir) / "probe" / f"probe.{dp_rank}.jsonl" | |
if (dp_rank % 128 == 0) | |
else None | |
), | |
) | |
probe_mod = model._orig_mod if args.distributed.compile else model | |
gc.disable() | |
# train loop | |
model.train() | |
metric_logger = context_stack.enter_context( | |
MetricLogger(Path(args.dump_dir) / "metrics.jsonl", args) | |
) | |
data_loader = context_stack.enter_context( | |
build_dataloader_from_args( | |
args.data, | |
state=train_state.data_loader_state, | |
) | |
) | |
torch_profiler = context_stack.enter_context( | |
maybe_run_profiler(args.dump_dir, model, args.profiling) | |
) | |
nwords_since_last_log = 0 | |
time_last_log = timer() | |
gc.collect() | |
while train_state.step < args.steps: | |
# We constrain train_state.acc_step to be in range 0 to args.grad_acc_steps - 1 | |
train_state.acc_step += 1 | |
train_state.acc_step = train_state.acc_step % args.grad_acc_steps | |
# get batch | |
curr_lr = float(optimizer.param_groups[0]["lr"]) | |
data_load_start = timer() | |
batch, train_state.data_loader_state = next(data_loader) | |
batch = torch.tensor( | |
batch, | |
dtype=torch.long, | |
) | |
if every_n_steps(train_state, args.gc_collect_freq, acc_step=0): | |
logger.info("garbage collection") | |
# we do garbage collection manually otherwise different processes | |
# run the GC at different times so they slow down the whole pipeline | |
gc.collect() | |
input_ids = batch[:, :, 0].cuda() | |
labels = batch[:, :, 1].cuda() | |
data_load_time = round(timer() - data_load_start, 4) | |
nwords_since_last_log += input_ids.numel() | |
bsz, seqlen = labels.shape | |
# forward | |
start_timer = torch.cuda.Event(enable_timing=True) | |
end_timer = torch.cuda.Event(enable_timing=True) | |
start_timer.record() | |
# This is an automatic probe that will compute statistics | |
# of all linears' inputs, weights and outputs | |
# along with attention logits and entropy | |
# both in forward and backward pass | |
if (args.probe_freq is not None) and every_n_steps( | |
train_state, args.probe_freq, acc_step=1 % args.grad_acc_steps | |
): | |
# Here we do a fake forward and backward pass on a smaller | |
# batch size to avoid OOM | |
# This assumes the model has no stateful layers (batch norm..) | |
assert ( | |
next(probe_mod.parameters()).grad is None | |
), "Can't probe model if grads are not reset" | |
with probe: | |
probe.metadata = { | |
"it": train_state.step, | |
"global_step": train_state.step, | |
"loop": "lingua", | |
} | |
# Non compiled model uses roughly 2x memory in our exps | |
# So we divide bsz by 2 or seqlen by 2 | |
probe_bsz = max(1, bsz // 2) | |
probe_seq = seqlen if (bsz // 2 >= 1) else (seqlen // 2) | |
probe_loss = probe_mod( | |
input_ids[:probe_bsz, :probe_seq], | |
labels[:probe_bsz, :probe_seq], | |
) | |
probe_loss.backward() | |
# We zero grads to cancel this fake step | |
optimizer.zero_grad() | |
assert ( | |
next(probe_mod.parameters()).grad is None | |
), "Probe model shouldn't have grads at this point" | |
loss = model(input_ids, labels) | |
# We scale loss with grad_acc_steps so the gradient is the same | |
# regardless of grad_acc_steps | |
loss = loss / args.grad_acc_steps | |
# backward on scaled loss to create scaled gradients | |
loss.backward() | |
# For logging we undo that scaling | |
loss = loss.detach() * args.grad_acc_steps | |
grad_norm = torch.nn.utils.clip_grad_norm_( | |
model.parameters(), max_norm=args.optim.clip, foreach=True | |
) | |
grad_norm = ( | |
grad_norm.full_tensor() if isinstance(grad_norm, DTensor) else grad_norm | |
).item() | |
# optimizer step | |
if train_state.acc_step == 0: | |
optimizer.step() | |
scheduler.step() | |
optimizer.zero_grad() | |
train_state.step += 1 | |
# updates the scale for next iteration | |
# training iteration complete | |
end_timer.record() | |
torch.cuda.synchronize() | |
curr_iter_time = round(start_timer.elapsed_time(end_timer) * 1e-3, 4) | |
# if profiler is active | |
if torch_profiler: | |
xformers.profiler.step() | |
# log metrics | |
if every_n_steps( | |
train_state, | |
args.logging.freq, | |
acc_step=None if args.logging.acc_freq else 0, | |
acc_freq=args.logging.acc_freq, | |
): | |
time_delta = timer() - time_last_log | |
wps = nwords_since_last_log / (time_delta * args.distributed.tp_size) | |
gpu_mem_stats = gpu_memory_monitor.get_peak_stats() | |
total_acc_steps = ( | |
args.grad_acc_steps * train_state.step + train_state.acc_step | |
) | |
tokens_per_gpu = ( | |
total_acc_steps * args.data.batch_size * args.data.seq_len | |
) | |
total_tokens = dp_degree * tokens_per_gpu | |
# This is an estimate and the correct values may change | |
# if you change the architecture | |
# Use xformer's analyze profile trace to get actual measurement | |
FLOPS = ( | |
get_num_flop_per_token( | |
model_param_count - args.model.vocab_size * args.model.dim, | |
args.model.n_layers, | |
args.model.dim, | |
args.data.seq_len, | |
) | |
* wps | |
) | |
metrics = flatten_dict( | |
{ | |
"global_step": train_state.step, | |
"acc_step": train_state.acc_step, | |
"speed": { | |
"wps": wps, | |
"FLOPS": FLOPS, | |
"curr_iter_time": curr_iter_time, | |
"data_load_time": data_load_time, | |
}, | |
"optim": { | |
"grad_norm": grad_norm, | |
"lr": curr_lr, | |
"total_tokens": total_tokens, | |
}, | |
"memory": gpu_mem_stats._asdict(), | |
}, | |
sep="/", | |
) | |
to_sync = {} | |
to_sync["loss/out"] = loss.item() | |
metrics.update(dist_mean_dict(to_sync)) | |
if get_is_master(): | |
metric_logger.log(metrics) | |
gpu_memory_monitor.reset_peak_stats() | |
nwords_since_last_log = 0 | |
time_last_log = timer() | |
logger.info( | |
f"step: {train_state.step}" | |
f" acc: {train_state.acc_step}" | |
f" loss: {round(loss.item(),4):>7}" | |
f" grad: {grad_norm:.2e}" | |
f" flops: {FLOPS:.2e}" | |
f" wps: {wps:.2e}" | |
f" iter: {curr_iter_time:>7}" | |
f" data: {data_load_time:>5}" | |
f" lr: {curr_lr:.2e}" | |
f" mem: {gpu_mem_stats.max_active_pct:.0f}%" | |
f" pow: {gpu_mem_stats.power_draw/1000} W" | |
) | |
saved = False | |
if every_n_steps( | |
train_state, args.checkpoint.dump.every, acc_step=0 | |
) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0): | |
saved = checkpoint.save( | |
model, | |
optimizer, | |
train_state, | |
args, | |
device_mesh=world_mesh, | |
) | |
if args.eval is not None and every_n_steps( | |
train_state, args.checkpoint.eval.every, acc_step=0 | |
): | |
from bytelatent.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval | |
eval_args = dataclass_from_dict(EvalArgs, args.eval) | |
eval_args.global_step = train_state.step | |
eval_args.ckpt_dir = str(checkpoint.existing_saves[-1]) | |
eval_args.dump_dir = str( | |
os.path.join( | |
args.dump_dir, | |
"evals", | |
EVAL_FOLDER_NAME.format(train_state.step), | |
) | |
) | |
eval_args.metric_log_dir = args.dump_dir | |
if args.async_eval_gpus is None: | |
launch_eval(eval_args) | |
elif get_is_master(): | |
if wandb.run is not None and args.logging.wandb is not None: | |
eval_args.wandb = deepcopy(args.logging.wandb) | |
assert args.async_eval_gpus > 0 | |
logger.info(f"Launching evals on {args.async_eval_gpus} gpus") | |
with clean_env(): | |
launch_job( | |
StoolArgs( | |
asdict(eval_args), | |
script="apps.main.eval", | |
copy_code=False, | |
nodes=args.async_eval_gpus // 8, | |
qos="lowest", | |
) | |
) | |
if preemption_flag["flag"]: | |
if not saved: | |
checkpoint.save( | |
model, | |
optimizer, | |
train_state, | |
args, | |
device_mesh=world_mesh, | |
) | |
requeue_slurm_job() | |
sys.exit(0) | |
if not saved: | |
checkpoint.save( | |
model, | |
optimizer, | |
train_state, | |
args, | |
device_mesh=world_mesh, | |
) | |
gc.collect() | |
def main(): | |
""" | |
The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments | |
This accepts arguments as a dot list | |
So if the dataclass looks like | |
@dataclass | |
class DummyArgs: | |
name: str | |
model: LMTransformerArgsgs | |
@dataclass | |
class LMTransformerArgsgs: | |
dim: int | |
Then you can pass model.dim=32 to change values in LMTransformerArgsgs | |
or just name=tictac for top level attributes. | |
The behavior here is as follows: | |
1. We instantiate TrainArgs with its default values | |
2. We override those default values with the ones in the provided config file | |
3. We override the result with the additional arguments provided through command line | |
For example, if the config is the following | |
model: | |
dim: 128 | |
n_layers: 4 | |
and you call train.py with train.py model.dim=64 | |
Then the final TrainArgs will have | |
model: | |
dim: 64 | |
n_layers: 4 | |
Plus all the default values in TrainArgs dataclass. | |
""" | |
cli_args = OmegaConf.from_cli() | |
file_cfg = OmegaConf.load(cli_args.config) | |
# We remove 'config' attribute from config as the underlying DataClass does not have it | |
del cli_args.config | |
default_cfg = OmegaConf.structured(TrainArgs()) | |
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) | |
cfg = OmegaConf.to_object(cfg) | |
train(cfg) | |
if __name__ == "__main__": | |
main() | |