File size: 6,047 Bytes
7934b29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import os
from omegaconf.omegaconf import OmegaConf, open_dict
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.timer import Timer
from pytorch_lightning.plugins.environments import TorchElasticEnvironment
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
from nemo.collections.nlp.models.language_modeling.megatron_retro_fine_tune_model import MegatronRetroFinetuneModel
from nemo.collections.nlp.parts.nlp_overrides import (
GradScaler,
MegatronHalfPrecisionPlugin,
NLPDDPStrategy,
NLPSaveRestoreConnector,
)
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import StatelessTimer, exp_manager
def _modify_config(retro_cfg, cfg, add_cfg_to_tree=False):
"""
This function modifies the original retro pre-training config with attributes from the finetuning config (cfg).
The `add_cfg_to_tree` arg adds `cfg` to the top of the yaml tree which is needed for all `hparams.yaml` files when passed as an arg to `load_from_checkpoint()`.
"""
OmegaConf.set_struct(retro_cfg, True)
with open_dict(retro_cfg):
retro_cfg.megatron_amp_O2 = cfg.model.get('megatron_amp_O2', False)
retro_cfg.data = cfg.model.data
retro_cfg.precision = cfg.trainer.precision
retro_cfg.optim = cfg.model.optim
retro_cfg.micro_batch_size = cfg.model.micro_batch_size
# This is needed when modifying a hparam file directly to load `.ckpt` files.
# This is not needed to modify the cfg in `.nemo` files.
if add_cfg_to_tree:
OmegaConf.resolve(retro_cfg)
retro_cfg.cfg = retro_cfg
return retro_cfg
def load_from_nemo(cls, cfg, trainer, retro_cfg, modify_confg_fn, save_restore_connector):
retro_cfg = modify_confg_fn(retro_cfg, cfg, add_cfg_to_tree=False)
model = cls.restore_from(
restore_path=cfg.model.restore_path,
trainer=trainer,
override_config_path=retro_cfg,
save_restore_connector=save_restore_connector,
)
return model
@hydra_runner(config_path="conf", config_name="megatron_retro_finetune_config")
def main(cfg) -> None:
logging.info("\n\n************** Experiment configuration ***********")
logging.info(f'\n{OmegaConf.to_yaml(cfg)}')
###### following is the workaround for num_workers=0 issue #####
# import torch.multiprocessing as mp
# mp.set_start_method("spawn", force=True)
#####################################################
megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False)
plugins = []
strategy = NLPDDPStrategy(
no_ddp_communication_hook=True if megatron_amp_o2 else False,
gradient_as_bucket_view=cfg.model.gradient_as_bucket_view,
find_unused_parameters=False,
timeout=datetime.timedelta(seconds=18000),
)
if cfg.trainer.precision in [16, 'bf16']:
scaler = None
if cfg.trainer.precision == 16:
scaler = GradScaler(
init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32),
growth_interval=cfg.model.get('native_amp_growth_interval', 1000),
hysteresis=cfg.model.get('hysteresis', 2),
)
if megatron_amp_o2:
plugins.append(MegatronHalfPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))
else:
plugins.append(NativeMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))
if cfg.get('cluster_type', None) == 'BCP':
plugins.append(TorchElasticEnvironment())
trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer)
exp_manager(trainer, cfg.exp_manager)
# update resume from checkpoint found by exp_manager
resume_from_checkpoint = trainer._checkpoint_connector.resume_from_checkpoint_fit_path
logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}')
trainer._checkpoint_connector = CheckpointConnector(trainer, resume_from_checkpoint=resume_from_checkpoint)
# Override timer callback to a stateless one
for idx, callback in enumerate(trainer.callbacks):
if isinstance(callback, Timer):
trainer.callbacks[idx] = StatelessTimer(cfg.trainer.max_time,)
# load existing or init new soft prompt GPT model
if cfg.model.get("restore_path", None):
save_restore_connector = NLPSaveRestoreConnector()
if os.path.isdir(cfg.model.restore_path):
save_restore_connector.model_extracted_dir = cfg.model.restore_path
model_cfg = MegatronRetroFinetuneModel.restore_from(
restore_path=cfg.model.restore_path,
trainer=trainer,
return_config=True,
save_restore_connector=save_restore_connector,
)
# hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams
model = load_from_nemo(
MegatronRetroFinetuneModel,
cfg,
trainer,
model_cfg,
modify_confg_fn=_modify_config,
save_restore_connector=save_restore_connector,
)
else:
model = MegatronRetroFinetuneModel(cfg.model, trainer=trainer)
trainer.fit(model)
if __name__ == '__main__':
main()
|