File size: 6,853 Bytes
7934b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from megatron_t5_seq2seq_finetune import load_from_checkpoint_dir, load_from_nemo, validate_checkpoint_loading_args
from omegaconf.omegaconf import OmegaConf, open_dict
from pytorch_lightning import Trainer
from pytorch_lightning.plugins.environments import TorchElasticEnvironment
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin

from nemo.collections.nlp.models.language_modeling.megatron_finetune_model import MegatronT5FinetuneModel
from nemo.collections.nlp.models.language_modeling.megatron_glue_model import MegatronT5GLUEModel
from nemo.collections.nlp.models.language_modeling.megatron_t0_model import MegatronT0Model
from nemo.collections.nlp.parts.nlp_overrides import GradScaler, MegatronHalfPrecisionPlugin, NLPDDPStrategy
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager


def _modify_config(t5_cfg, cfg, add_cfg_to_tree=False):
    """
    This function modifies the original t5 pre-training config (t5_cfg) with attributes from the finetuning config (cfg).
    The `add_cfg_to_tree` arg adds `cfg` to the top of the yaml tree which is needed for all `hparams.yaml` files when passed as an arg to `load_from_checkpoint()`.
    """
    OmegaConf.set_struct(t5_cfg, True)
    with open_dict(t5_cfg):
        t5_cfg.precision = cfg.trainer.precision
        # Overwrite data configs
        if cfg.model.data.validation_ds.get('src_file_name', None) is not None:
            logging.info(
                'Found validation_ds.src_file_name in the config file. Overriding the finetuned model config file with the values from the new config file.'
            )
            t5_cfg.data.validation_ds.src_file_name = cfg.model.data.validation_ds.src_file_name
        if cfg.model.data.validation_ds.get('tgt_file_name', None) is not None:
            logging.info(
                'Found validation_ds.tgt_file_name in the config file. Overriding the finetuned model config file with the values from the new config file.'
            )
            t5_cfg.data.validation_ds.tgt_file_name = cfg.model.data.validation_ds.tgt_file_name

        if "write_predictions_to_file" in cfg.model.data.validation_ds:
            t5_cfg.data.validation_ds.write_predictions_to_file = (
                cfg.model.data.validation_ds.write_predictions_to_file
            )
        if "output_file_path_prefix" in cfg.model.data.validation_ds:
            t5_cfg.data.validation_ds.output_file_path_prefix = cfg.model.data.validation_ds.output_file_path_prefix

        t5_cfg.data.validation_ds.micro_batch_size = cfg.model.data.validation_ds.micro_batch_size
        t5_cfg.data.validation_ds.global_batch_size = cfg.model.data.validation_ds.global_batch_size

        # This is needed when modifying a hparam file directly to load `.ckpt` files.
        # This is not needed to modify the cfg in `.nemo` files.
        if add_cfg_to_tree:
            OmegaConf.resolve(t5_cfg)
            t5_cfg.cfg = t5_cfg

    return t5_cfg


@hydra_runner(config_path="conf", config_name="megatron_t5_config_finetune_glue_eval")
def main(cfg) -> None:
    logging.info("\n\n************** Experiment configuration ***********")
    logging.info(f'\n{OmegaConf.to_yaml(cfg)}')

    megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False)
    plugins = []
    strategy = NLPDDPStrategy(
        no_ddp_communication_hook=True,
        gradient_as_bucket_view=cfg.model.gradient_as_bucket_view,
        find_unused_parameters=False,
    )
    if cfg.trainer.precision in [16, 'bf16']:
        scaler = None
        if cfg.trainer.precision == 16:
            scaler = GradScaler(
                init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32),
                growth_interval=cfg.model.get('native_amp_growth_interval', 1000),
                hysteresis=cfg.model.get('hysteresis', 2),
            )
        if megatron_amp_o2:
            plugins.append(MegatronHalfPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))
        else:
            plugins.append(NativeMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))

    if cfg.get('cluster_type', None) == 'BCP':
        plugins.append(TorchElasticEnvironment())

    trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer)

    exp_manager(trainer, cfg.exp_manager)

    if hasattr(cfg.model.data.validation_ds, 'task_name'):
        if cfg.model.restore_from_path:
            t5_cfg = MegatronT5GLUEModel.restore_from(
                restore_path=cfg.model.restore_from_path, trainer=trainer, return_config=True
            )
            model = load_from_nemo(MegatronT5GLUEModel, cfg, trainer, t5_cfg, modify_confg_fn=_modify_config)
        else:
            validate_checkpoint_loading_args(cfg.model.pretrained_checkpoint)
            model = load_from_checkpoint_dir(MegatronT5GLUEModel, cfg, trainer, modify_confg_fn=_modify_config)
    elif hasattr(cfg.model.data.validation_ds, 'file_names'):
        if cfg.model.restore_from_path:
            t5_cfg = MegatronT0Model.restore_from(
                restore_path=cfg.model.restore_from_path, trainer=trainer, return_config=True
            )
            model = load_from_nemo(MegatronT0Model, cfg, trainer, t5_cfg, modify_confg_fn=_modify_config)
        else:
            validate_checkpoint_loading_args(cfg.model.pretrained_checkpoint)
            model = load_from_checkpoint_dir(MegatronT0Model, cfg, trainer, modify_confg_fn=_modify_config)
    else:
        if cfg.model.restore_from_path:
            t5_cfg = MegatronT5FinetuneModel.restore_from(
                restore_path=cfg.model.restore_from_path, trainer=trainer, return_config=True
            )
            model = load_from_nemo(MegatronT5FinetuneModel, cfg, trainer, t5_cfg, modify_confg_fn=_modify_config)
        else:
            validate_checkpoint_loading_args(cfg.model.pretrained_checkpoint)
            model = load_from_checkpoint_dir(MegatronT5FinetuneModel, cfg, trainer, modify_confg_fn=_modify_config)

    model.freeze()
    trainer.validate(model)
    if hasattr(cfg.model.data, 'test_ds'):
        trainer.test(model)


if __name__ == '__main__':
    main()