File size: 6,047 Bytes
7934b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import os

from omegaconf.omegaconf import OmegaConf, open_dict
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.timer import Timer
from pytorch_lightning.plugins.environments import TorchElasticEnvironment
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector

from nemo.collections.nlp.models.language_modeling.megatron_retro_fine_tune_model import MegatronRetroFinetuneModel
from nemo.collections.nlp.parts.nlp_overrides import (
    GradScaler,
    MegatronHalfPrecisionPlugin,
    NLPDDPStrategy,
    NLPSaveRestoreConnector,
)
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import StatelessTimer, exp_manager


def _modify_config(retro_cfg, cfg, add_cfg_to_tree=False):
    """
    This function modifies the original retro pre-training config with attributes from the finetuning config (cfg).
    The `add_cfg_to_tree` arg adds `cfg` to the top of the yaml tree which is needed for all `hparams.yaml` files when passed as an arg to `load_from_checkpoint()`.
    """
    OmegaConf.set_struct(retro_cfg, True)
    with open_dict(retro_cfg):
        retro_cfg.megatron_amp_O2 = cfg.model.get('megatron_amp_O2', False)
        retro_cfg.data = cfg.model.data
        retro_cfg.precision = cfg.trainer.precision
        retro_cfg.optim = cfg.model.optim
        retro_cfg.micro_batch_size = cfg.model.micro_batch_size
        # This is needed when modifying a hparam file directly to load `.ckpt` files.
        # This is not needed to modify the cfg in `.nemo` files.
        if add_cfg_to_tree:
            OmegaConf.resolve(retro_cfg)
            retro_cfg.cfg = retro_cfg
    return retro_cfg


def load_from_nemo(cls, cfg, trainer, retro_cfg, modify_confg_fn, save_restore_connector):
    retro_cfg = modify_confg_fn(retro_cfg, cfg, add_cfg_to_tree=False)
    model = cls.restore_from(
        restore_path=cfg.model.restore_path,
        trainer=trainer,
        override_config_path=retro_cfg,
        save_restore_connector=save_restore_connector,
    )
    return model


@hydra_runner(config_path="conf", config_name="megatron_retro_finetune_config")
def main(cfg) -> None:
    logging.info("\n\n************** Experiment configuration ***********")
    logging.info(f'\n{OmegaConf.to_yaml(cfg)}')
    ###### following is the workaround for num_workers=0 issue #####
    # import torch.multiprocessing as mp
    # mp.set_start_method("spawn", force=True)
    #####################################################
    megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False)
    plugins = []
    strategy = NLPDDPStrategy(
        no_ddp_communication_hook=True if megatron_amp_o2 else False,
        gradient_as_bucket_view=cfg.model.gradient_as_bucket_view,
        find_unused_parameters=False,
        timeout=datetime.timedelta(seconds=18000),
    )

    if cfg.trainer.precision in [16, 'bf16']:
        scaler = None
        if cfg.trainer.precision == 16:
            scaler = GradScaler(
                init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32),
                growth_interval=cfg.model.get('native_amp_growth_interval', 1000),
                hysteresis=cfg.model.get('hysteresis', 2),
            )
        if megatron_amp_o2:
            plugins.append(MegatronHalfPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))
        else:
            plugins.append(NativeMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))

    if cfg.get('cluster_type', None) == 'BCP':
        plugins.append(TorchElasticEnvironment())

    trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer)
    exp_manager(trainer, cfg.exp_manager)

    # update resume from checkpoint found by exp_manager
    resume_from_checkpoint = trainer._checkpoint_connector.resume_from_checkpoint_fit_path
    logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}')

    trainer._checkpoint_connector = CheckpointConnector(trainer, resume_from_checkpoint=resume_from_checkpoint)

    # Override timer callback to a stateless one
    for idx, callback in enumerate(trainer.callbacks):
        if isinstance(callback, Timer):
            trainer.callbacks[idx] = StatelessTimer(cfg.trainer.max_time,)

    # load existing or init new soft prompt GPT model
    if cfg.model.get("restore_path", None):
        save_restore_connector = NLPSaveRestoreConnector()
        if os.path.isdir(cfg.model.restore_path):
            save_restore_connector.model_extracted_dir = cfg.model.restore_path

        model_cfg = MegatronRetroFinetuneModel.restore_from(
            restore_path=cfg.model.restore_path,
            trainer=trainer,
            return_config=True,
            save_restore_connector=save_restore_connector,
        )
        # hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams
        model = load_from_nemo(
            MegatronRetroFinetuneModel,
            cfg,
            trainer,
            model_cfg,
            modify_confg_fn=_modify_config,
            save_restore_connector=save_restore_connector,
        )
    else:
        model = MegatronRetroFinetuneModel(cfg.model, trainer=trainer)

    trainer.fit(model)


if __name__ == '__main__':
    main()