|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from omegaconf.omegaconf import OmegaConf, open_dict |
|
from pytorch_lightning import Trainer |
|
from pytorch_lightning.callbacks import ModelSummary |
|
from pytorch_lightning.plugins.environments import TorchElasticEnvironment |
|
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector |
|
|
|
from nemo.collections.nlp.models.language_modeling.megatron_bart_model import MegatronBARTModel |
|
from nemo.collections.nlp.models.language_modeling.megatron_t5_model import MegatronT5Model |
|
from nemo.collections.nlp.models.machine_translation.megatron_nmt_model import MegatronNMTModel |
|
from nemo.collections.nlp.parts.nlp_overrides import ( |
|
GradScaler, |
|
MegatronHalfPrecisionPlugin, |
|
NLPDDPStrategy, |
|
NLPSaveRestoreConnector, |
|
PipelineMixedPrecisionPlugin, |
|
) |
|
from nemo.core.config import hydra_runner |
|
from nemo.utils import logging |
|
from nemo.utils.exp_manager import exp_manager |
|
|
|
|
|
@hydra_runner(config_path="conf", config_name="aayn_base_megatron") |
|
def main(cfg) -> None: |
|
logging.info("\n\n************** Experiment configuration ***********") |
|
logging.info(f'\n{OmegaConf.to_yaml(cfg)}') |
|
|
|
megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False) |
|
plugins = [] |
|
strategy = NLPDDPStrategy( |
|
no_ddp_communication_hook=True, |
|
gradient_as_bucket_view=cfg.model.gradient_as_bucket_view, |
|
find_unused_parameters=False, |
|
) |
|
if cfg.trainer.precision in [16, 'bf16']: |
|
scaler = None |
|
if cfg.trainer.precision == 16: |
|
scaler = GradScaler( |
|
init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), |
|
growth_interval=cfg.model.get('native_amp_growth_interval', 1000), |
|
hysteresis=cfg.model.get('hysteresis', 2), |
|
) |
|
if megatron_amp_o2: |
|
plugins.append(MegatronHalfPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler)) |
|
else: |
|
plugins.append(PipelineMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler)) |
|
|
|
if cfg.get('cluster_type', None) == 'BCP': |
|
plugins.append(TorchElasticEnvironment()) |
|
|
|
trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer, callbacks=[ModelSummary(max_depth=3)]) |
|
|
|
exp_manager(trainer, cfg.exp_manager) |
|
|
|
|
|
if cfg.model.resume_from_checkpoint is not None: |
|
resume_from_checkpoint = cfg.model.resume_from_checkpoint |
|
else: |
|
resume_from_checkpoint = trainer._checkpoint_connector.resume_from_checkpoint_fit_path |
|
logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}') |
|
|
|
trainer._checkpoint_connector = CheckpointConnector(trainer, resume_from_checkpoint=resume_from_checkpoint) |
|
|
|
|
|
with open_dict(cfg): |
|
cfg.model.precision = cfg.trainer.precision |
|
|
|
if hasattr(cfg.model, 'pretrained_model_path') and cfg.model.pretrained_model_path is not None: |
|
if not hasattr(cfg.model, 'pretrained_model_type'): |
|
raise ValueError(f"Pretrained model type must be in [T5, BART].") |
|
|
|
assert cfg.model.pretrained_model_type in ['T5', 'BART'] |
|
if cfg.model.pretrained_model_type == 'T5': |
|
pretrained_cfg = MegatronT5Model.restore_from( |
|
cfg.model.pretrained_model_path, trainer=trainer, return_config=True |
|
) |
|
else: |
|
pretrained_cfg = MegatronBARTModel.restore_from( |
|
cfg.model.pretrained_model_path, trainer=trainer, return_config=True |
|
) |
|
OmegaConf.set_struct(pretrained_cfg, True) |
|
with open_dict(pretrained_cfg): |
|
pretrained_cfg.masked_softmax_fusion = False |
|
|
|
pretrained_cfg.src_language = cfg.model.src_language |
|
pretrained_cfg.tgt_language = cfg.model.tgt_language |
|
pretrained_cfg.multilingual = cfg.model.multilingual |
|
pretrained_cfg.shared_tokenizer = True |
|
|
|
|
|
pretrained_cfg.max_generation_delta = cfg.model.max_generation_delta |
|
|
|
|
|
pretrained_cfg.label_smoothing = cfg.model.label_smoothing |
|
|
|
|
|
pretrained_cfg.encoder_tokenizer = pretrained_cfg.tokenizer |
|
pretrained_cfg.decoder_tokenizer = pretrained_cfg.tokenizer |
|
|
|
|
|
pretrained_cfg.encoder_tokenizer.sentencepiece_legacy = True |
|
pretrained_cfg.decoder_tokenizer.sentencepiece_legacy = True |
|
|
|
|
|
|
|
|
|
if not hasattr(pretrained_cfg, 'encoder'): |
|
assert not hasattr(pretrained_cfg, 'decoder') |
|
logging.warning( |
|
"No separate configuration for encoder, found in pretrained model, using encoder dropout settings everywhere." |
|
) |
|
pretrained_cfg.hidden_dropout = cfg.model.encoder.hidden_dropout |
|
pretrained_cfg.attention_dropout = cfg.model.encoder.attention_dropout |
|
else: |
|
assert hasattr(pretrained_cfg, 'decoder') and hasattr(pretrained_cfg, 'encoder') |
|
pretrained_cfg.encoder.hidden_dropout = cfg.model.encoder.hidden_dropout |
|
pretrained_cfg.encoder.attention_dropout = cfg.model.encoder.attention_dropout |
|
pretrained_cfg.decoder.hidden_dropout = cfg.model.decoder.hidden_dropout |
|
pretrained_cfg.decoder.attention_dropout = cfg.model.decoder.attention_dropout |
|
|
|
|
|
pretrained_cfg.precision = trainer.precision |
|
|
|
|
|
pretrained_cfg.micro_batch_size = cfg.model.micro_batch_size |
|
pretrained_cfg.global_batch_size = cfg.model.global_batch_size |
|
|
|
|
|
pretrained_cfg.megatron_amp_O2 = cfg.model.get('megatron_amp_O2', False) |
|
|
|
|
|
pretrained_cfg.train_ds = cfg.model.train_ds |
|
pretrained_cfg.train_ds.micro_batch_size = cfg.model.micro_batch_size |
|
pretrained_cfg.train_ds.global_batch_size = cfg.model.global_batch_size |
|
if hasattr(cfg.model, 'validation_ds'): |
|
pretrained_cfg.validation_ds = cfg.model.validation_ds |
|
else: |
|
raise AttributeError(f"No validation dataset found in config.") |
|
if hasattr(cfg.model, 'test_ds'): |
|
pretrained_cfg.test_ds = cfg.model.test_ds |
|
|
|
|
|
pretrained_cfg.target = ( |
|
"nemo.collections.nlp.models.machine_translation.megatron_nmt_model.MegatronNMTModel" |
|
) |
|
|
|
|
|
pretrained_cfg.optim = cfg.model.optim |
|
|
|
model = MegatronNMTModel.restore_from( |
|
cfg.model.pretrained_model_path, |
|
trainer=trainer, |
|
override_config_path=pretrained_cfg, |
|
save_restore_connector=NLPSaveRestoreConnector(), |
|
) |
|
else: |
|
model = MegatronNMTModel(cfg.model, trainer) |
|
|
|
trainer.fit(model) |
|
trainer.validate(model) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|