|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r""" |
|
Conversion script to convert PTL checkpoints into nemo checkpoint. |
|
Example to run this conversion script: |
|
python -m torch.distributed.launch --nproc_per_node=<tensor_model_parallel_size> * <pipeline_model_parallel_size> \ |
|
megatron_ckpt_to_nemo.py \ |
|
--checkpoint_folder <path_to_PTL_checkpoints_folder> \ |
|
--checkpoint_name <checkpoint_name> \ |
|
--nemo_file_path <path_to_output_nemo_file> \ |
|
--tensor_model_parallel_size <tensor_model_parallel_size> \ |
|
--pipeline_model_parallel_size <pipeline_model_parallel_size> |
|
""" |
|
|
|
import os |
|
from argparse import ArgumentParser |
|
|
|
import torch |
|
from apex.transformer import parallel_state |
|
from pytorch_lightning.plugins.environments import TorchElasticEnvironment |
|
from pytorch_lightning.trainer.trainer import Trainer |
|
|
|
from nemo.collections.nlp.models.language_modeling.megatron_bart_model import MegatronBARTModel |
|
from nemo.collections.nlp.models.language_modeling.megatron_bert_model import MegatronBertModel |
|
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel |
|
from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel |
|
from nemo.collections.nlp.models.language_modeling.megatron_t5_model import MegatronT5Model |
|
from nemo.collections.nlp.models.machine_translation.megatron_nmt_model import MegatronNMTModel |
|
from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector |
|
from nemo.utils import AppState, logging |
|
from nemo.utils.distributed import initialize_distributed |
|
from nemo.utils.model_utils import inject_model_parallel_rank |
|
|
|
|
|
def get_args(): |
|
parser = ArgumentParser() |
|
parser.add_argument( |
|
"--checkpoint_folder", |
|
type=str, |
|
default=None, |
|
required=True, |
|
help="Path to PTL checkpoints saved during training. Ex: /raid/nemo_experiments/megatron_gpt/checkpoints", |
|
) |
|
parser.add_argument( |
|
"--checkpoint_name", |
|
type=str, |
|
default=None, |
|
required=True, |
|
help="Name of checkpoint to be used. Ex: megatron_gpt--val_loss=6.34-step=649-last.ckpt", |
|
) |
|
|
|
parser.add_argument( |
|
"--hparams_file", |
|
type=str, |
|
default=None, |
|
required=False, |
|
help="Path config for restoring. It's created during training and may need to be modified during restore if restore environment is different than training. Ex: /raid/nemo_experiments/megatron_gpt/hparams.yaml", |
|
) |
|
parser.add_argument("--nemo_file_path", type=str, default=None, required=True, help="Path to output .nemo file.") |
|
parser.add_argument("--gpus_per_node", type=int, required=True, default=None) |
|
parser.add_argument("--tensor_model_parallel_size", type=int, required=True, default=None) |
|
parser.add_argument("--pipeline_model_parallel_size", type=int, required=True, default=None) |
|
parser.add_argument( |
|
"--pipeline_model_parallel_split_rank", |
|
type=int, |
|
required=False, |
|
default=None, |
|
help="If pipeline parallel size > 1, this is the rank at which the encoder ends and the decoder begins.", |
|
) |
|
parser.add_argument( |
|
"--model_type", type=str, required=True, default="gpt", choices=["gpt", "t5", "bert", "nmt", "bart", "retro"] |
|
) |
|
parser.add_argument("--local_rank", type=int, required=False, default=os.getenv('LOCAL_RANK', -1)) |
|
parser.add_argument("--bcp", action="store_true", help="Whether on BCP platform") |
|
|
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def convert(local_rank, rank, world_size, args): |
|
|
|
app_state = AppState() |
|
app_state.data_parallel_rank = 0 |
|
num_nodes = world_size // args.gpus_per_node |
|
if args.bcp: |
|
trainer = Trainer( |
|
devices=args.gpus_per_node, num_nodes=num_nodes, accelerator='gpu', plugins=[TorchElasticEnvironment()] |
|
) |
|
else: |
|
trainer = Trainer(devices=args.gpus_per_node, num_nodes=num_nodes, accelerator='gpu') |
|
|
|
app_state.pipeline_model_parallel_size = args.pipeline_model_parallel_size |
|
app_state.tensor_model_parallel_size = args.tensor_model_parallel_size |
|
|
|
if args.pipeline_model_parallel_size > 1 and args.model_type in ['t5', 'bart', 'nmt']: |
|
if args.pipeline_model_parallel_split_rank is not None: |
|
app_state.pipeline_model_parallel_split_rank = args.pipeline_model_parallel_split_rank |
|
else: |
|
if args.pipeline_model_parallel_size % 2 != 0: |
|
raise ValueError( |
|
f"Pipeline model parallel size {args.pipeline_model_parallel_size} must be even if split rank is not specified." |
|
) |
|
else: |
|
|
|
app_state.pipeline_model_parallel_split_rank = args.pipeline_model_parallel_size // 2 |
|
else: |
|
app_state.pipeline_model_parallel_split_rank = None |
|
|
|
app_state.model_parallel_size = app_state.tensor_model_parallel_size * app_state.pipeline_model_parallel_size |
|
|
|
parallel_state.initialize_model_parallel( |
|
tensor_model_parallel_size_=app_state.tensor_model_parallel_size, |
|
pipeline_model_parallel_size_=app_state.pipeline_model_parallel_size, |
|
pipeline_model_parallel_split_rank_=app_state.pipeline_model_parallel_split_rank, |
|
) |
|
|
|
app_state.pipeline_model_parallel_rank = parallel_state.get_pipeline_model_parallel_rank() |
|
app_state.tensor_model_parallel_rank = parallel_state.get_tensor_model_parallel_rank() |
|
|
|
|
|
checkpoint_path = inject_model_parallel_rank(os.path.join(args.checkpoint_folder, args.checkpoint_name)) |
|
|
|
logging.info( |
|
f'rank: {rank}, local_rank: {local_rank}, is loading checkpoint: {checkpoint_path} for tp_rank: {app_state.tensor_model_parallel_rank} and pp_rank: {app_state.pipeline_model_parallel_rank}' |
|
) |
|
|
|
if args.model_type == 'gpt': |
|
model = MegatronGPTModel.load_from_checkpoint(checkpoint_path, hparams_file=args.hparams_file, trainer=trainer) |
|
elif args.model_type == 'bert': |
|
model = MegatronBertModel.load_from_checkpoint( |
|
checkpoint_path, hparams_file=args.hparams_file, trainer=trainer |
|
) |
|
elif args.model_type == 't5': |
|
model = MegatronT5Model.load_from_checkpoint(checkpoint_path, hparams_file=args.hparams_file, trainer=trainer) |
|
elif args.model_type == 'bart': |
|
model = MegatronBARTModel.load_from_checkpoint( |
|
checkpoint_path, hparams_file=args.hparams_file, trainer=trainer |
|
) |
|
elif args.model_type == 'nmt': |
|
model = MegatronNMTModel.load_from_checkpoint(checkpoint_path, hparams_file=args.hparams_file, trainer=trainer) |
|
elif args.model_type == 'retro': |
|
model = MegatronRetrievalModel.load_from_checkpoint( |
|
checkpoint_path, hparams_file=args.hparams_file, trainer=trainer |
|
) |
|
model._save_restore_connector = NLPSaveRestoreConnector() |
|
|
|
if torch.distributed.is_initialized(): |
|
torch.distributed.barrier() |
|
|
|
model.save_to(args.nemo_file_path) |
|
|
|
logging.info(f'NeMo model saved to: {args.nemo_file_path}') |
|
|
|
|
|
if __name__ == '__main__': |
|
args = get_args() |
|
|
|
local_rank, rank, world_size = initialize_distributed(args) |
|
|
|
convert(local_rank, rank, world_size, args) |
|
|