NeMo / scripts /nemo_legacy_import /nlp_checkpoint_port.py
camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import sys
import pytorch_lightning as pl
from omegaconf import OmegaConf, open_dict
from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector
from nemo.core import ModelPT
from nemo.core.config import TrainerConfig
def get_args(argv):
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description=f"Update NLP models trained on previous versions to current version ",
)
parser.add_argument("source", help="Source .nemo file")
parser.add_argument("out", help="Location to write result to")
parser.add_argument("--megatron-legacy", help="If the source model is megatron-bert trained on NeMo < 1.5")
parser.add_argument(
"--megatron-checkpoint",
type=str,
help="Path of the MegatronBert nemo checkpoint converted from MegatronLM using megatron_lm_ckpt_to_nemo.py file (Not NLP model checkpoint)",
)
parser.add_argument("--verbose", default=None, help="Verbose level for logging, numeric")
args = parser.parse_args(argv)
return args
def nemo_convert(argv):
args = get_args(argv)
loglevel = logging.INFO
# assuming loglevel is bound to the string value obtained from the
# command line argument. Convert to upper case to allow the user to
# specify --log=DEBUG or --log=debug
if args.verbose is not None:
numeric_level = getattr(logging, args.verbose.upper(), None)
if not isinstance(numeric_level, int):
raise ValueError('Invalid log level: %s' % numeric_level)
loglevel = numeric_level
logger = logging.getLogger(__name__)
if logger.handlers:
for handler in logger.handlers:
logger.removeHandler(handler)
logging.basicConfig(level=loglevel, format='%(asctime)s [%(levelname)s] %(message)s')
logging.info("Logging level set to {}".format(loglevel))
"""Convert a .nemo saved model trained on previous versions of nemo into a nemo fie with current version."""
nemo_in = args.source
out = args.out
# Create a PL trainer object which is required for restoring Megatron models
cfg_trainer = TrainerConfig(
gpus=1,
accelerator="ddp",
num_nodes=1,
# Need to set the following two to False as ExpManager will take care of them differently.
logger=False,
enable_checkpointing=False,
)
trainer = pl.Trainer(cfg_trainer)
logging.info("Restoring NeMo model from '{}'".format(nemo_in))
try:
# If the megatron based NLP model was trained on NeMo < 1.5, then we need to update the lm_checkpoint on the model config
if args.megatron_legacy:
if args.megatron_checkpoint:
connector = NLPSaveRestoreConnector()
model_cfg = ModelPT.restore_from(
restore_path=nemo_in, save_restore_connector=connector, trainer=trainer, return_config=True
)
OmegaConf.set_struct(model_cfg, True)
with open_dict(model_cfg):
model_cfg.language_model.lm_checkpoint = args.megatron_checkpoint
model_cfg['megatron_legacy'] = True
model_cfg['masked_softmax_fusion'] = False
model_cfg['bias_gelu_fusion'] = False
model = ModelPT.restore_from(
restore_path=nemo_in,
save_restore_connector=connector,
trainer=trainer,
override_config_path=model_cfg,
)
else:
logging.error("Megatron Checkpoint must be provided if Megatron legacy is chosen")
else:
model = ModelPT.restore_from(restore_path=nemo_in, trainer=trainer)
logging.info("Model {} restored from '{}'".format(model.cfg.target, nemo_in))
# Save the model
model.save_to(out)
logging.info("Successfully converted to {}".format(out))
del model
except Exception as e:
logging.error(
"Failed to restore model from NeMo file : {}. Please make sure you have the latest NeMo package installed with [all] dependencies.".format(
nemo_in
)
)
raise e
if __name__ == '__main__':
nemo_convert(sys.argv[1:])