File size: 8,021 Bytes
7934b29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
Conversion script to convert PTL checkpoints into nemo checkpoint.
Example to run this conversion script:
python -m torch.distributed.launch --nproc_per_node=<tensor_model_parallel_size> * <pipeline_model_parallel_size> \
megatron_ckpt_to_nemo.py \
--checkpoint_folder <path_to_PTL_checkpoints_folder> \
--checkpoint_name <checkpoint_name> \
--nemo_file_path <path_to_output_nemo_file> \
--tensor_model_parallel_size <tensor_model_parallel_size> \
--pipeline_model_parallel_size <pipeline_model_parallel_size>
"""
import os
from argparse import ArgumentParser
import torch
from apex.transformer import parallel_state
from pytorch_lightning.plugins.environments import TorchElasticEnvironment
from pytorch_lightning.trainer.trainer import Trainer
from nemo.collections.nlp.models.language_modeling.megatron_bart_model import MegatronBARTModel
from nemo.collections.nlp.models.language_modeling.megatron_bert_model import MegatronBertModel
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel
from nemo.collections.nlp.models.language_modeling.megatron_t5_model import MegatronT5Model
from nemo.collections.nlp.models.machine_translation.megatron_nmt_model import MegatronNMTModel
from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector
from nemo.utils import AppState, logging
from nemo.utils.distributed import initialize_distributed
from nemo.utils.model_utils import inject_model_parallel_rank
def get_args():
parser = ArgumentParser()
parser.add_argument(
"--checkpoint_folder",
type=str,
default=None,
required=True,
help="Path to PTL checkpoints saved during training. Ex: /raid/nemo_experiments/megatron_gpt/checkpoints",
)
parser.add_argument(
"--checkpoint_name",
type=str,
default=None,
required=True,
help="Name of checkpoint to be used. Ex: megatron_gpt--val_loss=6.34-step=649-last.ckpt",
)
parser.add_argument(
"--hparams_file",
type=str,
default=None,
required=False,
help="Path config for restoring. It's created during training and may need to be modified during restore if restore environment is different than training. Ex: /raid/nemo_experiments/megatron_gpt/hparams.yaml",
)
parser.add_argument("--nemo_file_path", type=str, default=None, required=True, help="Path to output .nemo file.")
parser.add_argument("--gpus_per_node", type=int, required=True, default=None)
parser.add_argument("--tensor_model_parallel_size", type=int, required=True, default=None)
parser.add_argument("--pipeline_model_parallel_size", type=int, required=True, default=None)
parser.add_argument(
"--pipeline_model_parallel_split_rank",
type=int,
required=False,
default=None,
help="If pipeline parallel size > 1, this is the rank at which the encoder ends and the decoder begins.",
)
parser.add_argument(
"--model_type", type=str, required=True, default="gpt", choices=["gpt", "t5", "bert", "nmt", "bart", "retro"]
)
parser.add_argument("--local_rank", type=int, required=False, default=os.getenv('LOCAL_RANK', -1))
parser.add_argument("--bcp", action="store_true", help="Whether on BCP platform")
args = parser.parse_args()
return args
def convert(local_rank, rank, world_size, args):
app_state = AppState()
app_state.data_parallel_rank = 0
num_nodes = world_size // args.gpus_per_node
if args.bcp:
trainer = Trainer(
devices=args.gpus_per_node, num_nodes=num_nodes, accelerator='gpu', plugins=[TorchElasticEnvironment()]
)
else:
trainer = Trainer(devices=args.gpus_per_node, num_nodes=num_nodes, accelerator='gpu')
app_state.pipeline_model_parallel_size = args.pipeline_model_parallel_size
app_state.tensor_model_parallel_size = args.tensor_model_parallel_size
# Auto set split rank for T5, BART, NMT if split rank is None.
if args.pipeline_model_parallel_size > 1 and args.model_type in ['t5', 'bart', 'nmt']:
if args.pipeline_model_parallel_split_rank is not None:
app_state.pipeline_model_parallel_split_rank = args.pipeline_model_parallel_split_rank
else:
if args.pipeline_model_parallel_size % 2 != 0:
raise ValueError(
f"Pipeline model parallel size {args.pipeline_model_parallel_size} must be even if split rank is not specified."
)
else:
# If split rank is not set, then we set it to be pipeline_model_parallel_size // 2 - this is because in most cases we have the same number of enc/dec layers.
app_state.pipeline_model_parallel_split_rank = args.pipeline_model_parallel_size // 2
else:
app_state.pipeline_model_parallel_split_rank = None
app_state.model_parallel_size = app_state.tensor_model_parallel_size * app_state.pipeline_model_parallel_size
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=app_state.tensor_model_parallel_size,
pipeline_model_parallel_size_=app_state.pipeline_model_parallel_size,
pipeline_model_parallel_split_rank_=app_state.pipeline_model_parallel_split_rank,
)
app_state.pipeline_model_parallel_rank = parallel_state.get_pipeline_model_parallel_rank()
app_state.tensor_model_parallel_rank = parallel_state.get_tensor_model_parallel_rank()
# inject model parallel rank
checkpoint_path = inject_model_parallel_rank(os.path.join(args.checkpoint_folder, args.checkpoint_name))
logging.info(
f'rank: {rank}, local_rank: {local_rank}, is loading checkpoint: {checkpoint_path} for tp_rank: {app_state.tensor_model_parallel_rank} and pp_rank: {app_state.pipeline_model_parallel_rank}'
)
if args.model_type == 'gpt':
model = MegatronGPTModel.load_from_checkpoint(checkpoint_path, hparams_file=args.hparams_file, trainer=trainer)
elif args.model_type == 'bert':
model = MegatronBertModel.load_from_checkpoint(
checkpoint_path, hparams_file=args.hparams_file, trainer=trainer
)
elif args.model_type == 't5':
model = MegatronT5Model.load_from_checkpoint(checkpoint_path, hparams_file=args.hparams_file, trainer=trainer)
elif args.model_type == 'bart':
model = MegatronBARTModel.load_from_checkpoint(
checkpoint_path, hparams_file=args.hparams_file, trainer=trainer
)
elif args.model_type == 'nmt':
model = MegatronNMTModel.load_from_checkpoint(checkpoint_path, hparams_file=args.hparams_file, trainer=trainer)
elif args.model_type == 'retro':
model = MegatronRetrievalModel.load_from_checkpoint(
checkpoint_path, hparams_file=args.hparams_file, trainer=trainer
)
model._save_restore_connector = NLPSaveRestoreConnector()
if torch.distributed.is_initialized():
torch.distributed.barrier()
model.save_to(args.nemo_file_path)
logging.info(f'NeMo model saved to: {args.nemo_file_path}')
if __name__ == '__main__':
args = get_args()
local_rank, rank, world_size = initialize_distributed(args)
convert(local_rank, rank, world_size, args)
|