NeMo / examples /nlp /language_modeling /megatron_retro_fine_tune.py
camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import os
from omegaconf.omegaconf import OmegaConf, open_dict
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.timer import Timer
from pytorch_lightning.plugins.environments import TorchElasticEnvironment
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
from nemo.collections.nlp.models.language_modeling.megatron_retro_fine_tune_model import MegatronRetroFinetuneModel
from nemo.collections.nlp.parts.nlp_overrides import (
GradScaler,
MegatronHalfPrecisionPlugin,
NLPDDPStrategy,
NLPSaveRestoreConnector,
)
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import StatelessTimer, exp_manager
def _modify_config(retro_cfg, cfg, add_cfg_to_tree=False):
"""
This function modifies the original retro pre-training config with attributes from the finetuning config (cfg).
The `add_cfg_to_tree` arg adds `cfg` to the top of the yaml tree which is needed for all `hparams.yaml` files when passed as an arg to `load_from_checkpoint()`.
"""
OmegaConf.set_struct(retro_cfg, True)
with open_dict(retro_cfg):
retro_cfg.megatron_amp_O2 = cfg.model.get('megatron_amp_O2', False)
retro_cfg.data = cfg.model.data
retro_cfg.precision = cfg.trainer.precision
retro_cfg.optim = cfg.model.optim
retro_cfg.micro_batch_size = cfg.model.micro_batch_size
# This is needed when modifying a hparam file directly to load `.ckpt` files.
# This is not needed to modify the cfg in `.nemo` files.
if add_cfg_to_tree:
OmegaConf.resolve(retro_cfg)
retro_cfg.cfg = retro_cfg
return retro_cfg
def load_from_nemo(cls, cfg, trainer, retro_cfg, modify_confg_fn, save_restore_connector):
retro_cfg = modify_confg_fn(retro_cfg, cfg, add_cfg_to_tree=False)
model = cls.restore_from(
restore_path=cfg.model.restore_path,
trainer=trainer,
override_config_path=retro_cfg,
save_restore_connector=save_restore_connector,
)
return model
@hydra_runner(config_path="conf", config_name="megatron_retro_finetune_config")
def main(cfg) -> None:
logging.info("\n\n************** Experiment configuration ***********")
logging.info(f'\n{OmegaConf.to_yaml(cfg)}')
###### following is the workaround for num_workers=0 issue #####
# import torch.multiprocessing as mp
# mp.set_start_method("spawn", force=True)
#####################################################
megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False)
plugins = []
strategy = NLPDDPStrategy(
no_ddp_communication_hook=True if megatron_amp_o2 else False,
gradient_as_bucket_view=cfg.model.gradient_as_bucket_view,
find_unused_parameters=False,
timeout=datetime.timedelta(seconds=18000),
)
if cfg.trainer.precision in [16, 'bf16']:
scaler = None
if cfg.trainer.precision == 16:
scaler = GradScaler(
init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32),
growth_interval=cfg.model.get('native_amp_growth_interval', 1000),
hysteresis=cfg.model.get('hysteresis', 2),
)
if megatron_amp_o2:
plugins.append(MegatronHalfPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))
else:
plugins.append(NativeMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))
if cfg.get('cluster_type', None) == 'BCP':
plugins.append(TorchElasticEnvironment())
trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer)
exp_manager(trainer, cfg.exp_manager)
# update resume from checkpoint found by exp_manager
resume_from_checkpoint = trainer._checkpoint_connector.resume_from_checkpoint_fit_path
logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}')
trainer._checkpoint_connector = CheckpointConnector(trainer, resume_from_checkpoint=resume_from_checkpoint)
# Override timer callback to a stateless one
for idx, callback in enumerate(trainer.callbacks):
if isinstance(callback, Timer):
trainer.callbacks[idx] = StatelessTimer(cfg.trainer.max_time,)
# load existing or init new soft prompt GPT model
if cfg.model.get("restore_path", None):
save_restore_connector = NLPSaveRestoreConnector()
if os.path.isdir(cfg.model.restore_path):
save_restore_connector.model_extracted_dir = cfg.model.restore_path
model_cfg = MegatronRetroFinetuneModel.restore_from(
restore_path=cfg.model.restore_path,
trainer=trainer,
return_config=True,
save_restore_connector=save_restore_connector,
)
# hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams
model = load_from_nemo(
MegatronRetroFinetuneModel,
cfg,
trainer,
model_cfg,
modify_confg_fn=_modify_config,
save_restore_connector=save_restore_connector,
)
else:
model = MegatronRetroFinetuneModel(cfg.model, trainer=trainer)
trainer.fit(model)
if __name__ == '__main__':
main()