File size: 10,102 Bytes
7934b29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import torch.multiprocessing as mp
from omegaconf.omegaconf import OmegaConf, open_dict
from pytorch_lightning import Trainer
from pytorch_lightning.plugins.environments import TorchElasticEnvironment
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
from nemo.collections.nlp.models.language_modeling.megatron_finetune_model import MegatronT5FinetuneModel
from nemo.collections.nlp.models.language_modeling.megatron_glue_model import MegatronT5GLUEModel
from nemo.collections.nlp.models.language_modeling.megatron_t0_model import MegatronT0Model
from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel
from nemo.collections.nlp.parts.nlp_overrides import (
GradScaler,
MegatronHalfPrecisionPlugin,
NLPDDPStrategy,
NLPSaveRestoreConnector,
PipelineMixedPrecisionPlugin,
)
from nemo.core.config import hydra_runner
from nemo.utils import AppState, logging
from nemo.utils.exp_manager import exp_manager
from nemo.utils.model_utils import inject_model_parallel_rank
mp.set_start_method("spawn", force=True)
def _modify_config(t5_cfg, cfg, add_cfg_to_tree=False):
"""
This function modifies the original t5 pre-training config (t5_cfg) with attributes from the finetuning config (cfg).
The `add_cfg_to_tree` arg adds `cfg` to the top of the yaml tree which is needed for all `hparams.yaml` files when passed as an arg to `load_from_checkpoint()`.
"""
OmegaConf.set_struct(t5_cfg, True)
with open_dict(t5_cfg):
t5_cfg.megatron_amp_O2 = cfg.model.get('megatron_amp_O2', False)
if hasattr(t5_cfg, 'encoder') and hasattr(t5_cfg, 'decoder'):
t5_cfg.encoder.masked_softmax_fusion = False
t5_cfg.decoder.masked_softmax_fusion = False
t5_cfg.encoder.hidden_dropout = cfg.model.get('hidden_dropout', 0.1)
t5_cfg.decoder.hidden_dropout = cfg.model.get('hidden_dropout', 0.1)
if hasattr(t5_cfg.encoder, 'ffn_dropout'):
t5_cfg.encoder.ffn_dropout = cfg.model.get('ffn_dropout', 0.1)
if hasattr(t5_cfg.decoder, 'ffn_dropout'):
t5_cfg.decoder.ffn_dropout = cfg.model.get('ffn_dropout', 0.1)
else:
t5_cfg.hidden_dropout = cfg.model.get('hidden_dropout', 0.1)
t5_cfg.attention_dropout = cfg.model.get('attention_dropout', 0.1)
t5_cfg.masked_softmax_fusion = False
t5_cfg.data = cfg.model.data
t5_cfg.precision = cfg.trainer.precision
t5_cfg.optim = cfg.model.optim
t5_cfg.micro_batch_size = cfg.model.data.train_ds.micro_batch_size
t5_cfg.global_batch_size = cfg.model.data.train_ds.global_batch_size
# XNLI has eval languages in the yaml config.
if hasattr(cfg.model, 'eval_languages'):
t5_cfg.eval_languages = cfg.model.eval_languages
# This is needed when modifying a hparam file directly to load `.ckpt` files.
# This is not needed to modify the cfg in `.nemo` files.
if add_cfg_to_tree:
OmegaConf.resolve(t5_cfg)
t5_cfg.cfg = t5_cfg
return t5_cfg
def load_from_nemo(cls, cfg, trainer, t5_cfg, modify_confg_fn):
t5_cfg = modify_confg_fn(t5_cfg, cfg, add_cfg_to_tree=False)
model = cls.restore_from(
restore_path=cfg.model.restore_from_path,
trainer=trainer,
override_config_path=t5_cfg,
save_restore_connector=NLPSaveRestoreConnector(),
)
return model
def load_from_checkpoint_dir(cls, cfg, trainer, modify_confg_fn):
app_state = AppState()
if cfg.model.tensor_model_parallel_size > 1 or cfg.model.pipeline_model_parallel_size > 1:
app_state.model_parallel_size = cfg.model.tensor_model_parallel_size * cfg.model.pipeline_model_parallel_size
app_state.tensor_model_parallel_size = cfg.model.tensor_model_parallel_size
app_state.pipeline_model_parallel_size = cfg.model.pipeline_model_parallel_size
(
app_state.tensor_model_parallel_rank,
app_state.pipeline_model_parallel_rank,
app_state.model_parallel_size,
app_state.data_parallel_size,
app_state.pipeline_model_parallel_split_rank,
app_state.virtual_pipeline_model_parallel_rank,
) = fake_initialize_model_parallel(
world_size=app_state.model_parallel_size,
rank=trainer.global_rank,
tensor_model_parallel_size_=cfg.model.tensor_model_parallel_size,
pipeline_model_parallel_size_=cfg.model.pipeline_model_parallel_size,
pipeline_model_parallel_split_rank_=cfg.model.pipeline_model_parallel_split_rank,
)
checkpoint_path = inject_model_parallel_rank(
os.path.join(cfg.model.pretrained_checkpoint.checkpoint_dir, cfg.model.pretrained_checkpoint.checkpoint_name)
)
hparams_file = OmegaConf.load(cfg.model.pretrained_checkpoint.hparams_file)
t5_cfg = modify_confg_fn(hparams_file.cfg, cfg, add_cfg_to_tree=True)
with tempfile.NamedTemporaryFile(suffix='.yaml') as f:
OmegaConf.save(config=t5_cfg, f=f.name)
model = cls.load_from_checkpoint(checkpoint_path=checkpoint_path, trainer=trainer, hparams_file=f.name,)
return model
def validate_checkpoint_loading_args(cfg):
if cfg.checkpoint_dir is None or not os.path.isdir(cfg.checkpoint_dir):
raise ValueError(f'Checkpoint directory {cfg.checkpoint_dir} does not exist or is not a directory.')
if cfg.checkpoint_name is None:
raise ValueError(f'Checkpoint name {cfg.checkpoint_name} is not valid.')
if cfg.hparams_file is None or not os.path.isfile(cfg.hparams_file):
raise ValueError(f'Hparams file {cfg.hparams_file} does not exist or is not a file.')
@hydra_runner(config_path="conf", config_name="megatron_t5_config_finetune_glue_mnli")
def main(cfg) -> None:
logging.info("\n\n************** Experiment configuration ***********")
logging.info(f'\n{OmegaConf.to_yaml(cfg)}')
megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False)
plugins = []
strategy = NLPDDPStrategy(
no_ddp_communication_hook=True,
gradient_as_bucket_view=cfg.model.gradient_as_bucket_view,
find_unused_parameters=False,
)
if cfg.trainer.precision in [16, 'bf16']:
scaler = None
if cfg.trainer.precision == 16:
scaler = GradScaler(
init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32),
growth_interval=cfg.model.get('native_amp_growth_interval', 1000),
hysteresis=cfg.model.get('hysteresis', 2),
)
if megatron_amp_o2:
plugins.append(MegatronHalfPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))
else:
plugins.append(PipelineMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))
if cfg.get('cluster_type', None) == 'BCP':
plugins.append(TorchElasticEnvironment())
trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer)
exp_manager(trainer, cfg.exp_manager)
# update resume from checkpoint found by exp_manager
if cfg.model.resume_from_checkpoint is not None:
resume_from_checkpoint = cfg.model.resume_from_checkpoint
else:
resume_from_checkpoint = trainer._checkpoint_connector.resume_from_checkpoint_fit_path
logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}')
trainer._checkpoint_connector = CheckpointConnector(trainer, resume_from_checkpoint=resume_from_checkpoint)
if hasattr(cfg.model.data.train_ds, 'task_name'):
if cfg.model.restore_from_path:
t5_cfg = MegatronT5GLUEModel.restore_from(
restore_path=cfg.model.restore_from_path, trainer=trainer, return_config=True
)
model = load_from_nemo(MegatronT5GLUEModel, cfg, trainer, t5_cfg, modify_confg_fn=_modify_config)
else:
validate_checkpoint_loading_args(cfg.model.pretrained_checkpoint)
model = load_from_checkpoint_dir(MegatronT5GLUEModel, cfg, trainer, modify_confg_fn=_modify_config)
elif hasattr(cfg.model.data.train_ds, 'file_names'):
if cfg.model.restore_from_path:
t5_cfg = MegatronT0Model.restore_from(
restore_path=cfg.model.restore_from_path, trainer=trainer, return_config=True
)
model = load_from_nemo(MegatronT0Model, cfg, trainer, t5_cfg, modify_confg_fn=_modify_config)
else:
validate_checkpoint_loading_args(cfg.model.pretrained_checkpoint)
model = load_from_checkpoint_dir(MegatronT0Model, cfg, trainer, modify_confg_fn=_modify_config)
else:
if cfg.model.restore_from_path:
t5_cfg = MegatronT5FinetuneModel.restore_from(
restore_path=cfg.model.restore_from_path, trainer=trainer, return_config=True
)
model = load_from_nemo(MegatronT5FinetuneModel, cfg, trainer, t5_cfg, modify_confg_fn=_modify_config)
else:
validate_checkpoint_loading_args(cfg.model.pretrained_checkpoint)
model = load_from_checkpoint_dir(MegatronT5FinetuneModel, cfg, trainer, modify_confg_fn=_modify_config)
trainer.fit(model)
trainer.validate(model)
if hasattr(cfg.model.data, 'test_ds'):
trainer.test(model)
if __name__ == '__main__':
main()
|