|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import logging |
|
import sys |
|
|
|
import pytorch_lightning as pl |
|
from omegaconf import OmegaConf, open_dict |
|
|
|
from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector |
|
from nemo.core import ModelPT |
|
from nemo.core.config import TrainerConfig |
|
|
|
|
|
def get_args(argv): |
|
parser = argparse.ArgumentParser( |
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
|
description=f"Update NLP models trained on previous versions to current version ", |
|
) |
|
parser.add_argument("source", help="Source .nemo file") |
|
parser.add_argument("out", help="Location to write result to") |
|
parser.add_argument("--megatron-legacy", help="If the source model is megatron-bert trained on NeMo < 1.5") |
|
parser.add_argument( |
|
"--megatron-checkpoint", |
|
type=str, |
|
help="Path of the MegatronBert nemo checkpoint converted from MegatronLM using megatron_lm_ckpt_to_nemo.py file (Not NLP model checkpoint)", |
|
) |
|
parser.add_argument("--verbose", default=None, help="Verbose level for logging, numeric") |
|
args = parser.parse_args(argv) |
|
return args |
|
|
|
|
|
def nemo_convert(argv): |
|
args = get_args(argv) |
|
loglevel = logging.INFO |
|
|
|
|
|
|
|
if args.verbose is not None: |
|
numeric_level = getattr(logging, args.verbose.upper(), None) |
|
if not isinstance(numeric_level, int): |
|
raise ValueError('Invalid log level: %s' % numeric_level) |
|
loglevel = numeric_level |
|
|
|
logger = logging.getLogger(__name__) |
|
if logger.handlers: |
|
for handler in logger.handlers: |
|
logger.removeHandler(handler) |
|
logging.basicConfig(level=loglevel, format='%(asctime)s [%(levelname)s] %(message)s') |
|
logging.info("Logging level set to {}".format(loglevel)) |
|
|
|
"""Convert a .nemo saved model trained on previous versions of nemo into a nemo fie with current version.""" |
|
nemo_in = args.source |
|
out = args.out |
|
|
|
|
|
cfg_trainer = TrainerConfig( |
|
gpus=1, |
|
accelerator="ddp", |
|
num_nodes=1, |
|
|
|
logger=False, |
|
enable_checkpointing=False, |
|
) |
|
trainer = pl.Trainer(cfg_trainer) |
|
|
|
logging.info("Restoring NeMo model from '{}'".format(nemo_in)) |
|
try: |
|
|
|
if args.megatron_legacy: |
|
if args.megatron_checkpoint: |
|
connector = NLPSaveRestoreConnector() |
|
model_cfg = ModelPT.restore_from( |
|
restore_path=nemo_in, save_restore_connector=connector, trainer=trainer, return_config=True |
|
) |
|
OmegaConf.set_struct(model_cfg, True) |
|
with open_dict(model_cfg): |
|
model_cfg.language_model.lm_checkpoint = args.megatron_checkpoint |
|
model_cfg['megatron_legacy'] = True |
|
model_cfg['masked_softmax_fusion'] = False |
|
model_cfg['bias_gelu_fusion'] = False |
|
model = ModelPT.restore_from( |
|
restore_path=nemo_in, |
|
save_restore_connector=connector, |
|
trainer=trainer, |
|
override_config_path=model_cfg, |
|
) |
|
else: |
|
logging.error("Megatron Checkpoint must be provided if Megatron legacy is chosen") |
|
else: |
|
model = ModelPT.restore_from(restore_path=nemo_in, trainer=trainer) |
|
logging.info("Model {} restored from '{}'".format(model.cfg.target, nemo_in)) |
|
|
|
|
|
model.save_to(out) |
|
logging.info("Successfully converted to {}".format(out)) |
|
|
|
del model |
|
except Exception as e: |
|
logging.error( |
|
"Failed to restore model from NeMo file : {}. Please make sure you have the latest NeMo package installed with [all] dependencies.".format( |
|
nemo_in |
|
) |
|
) |
|
raise e |
|
|
|
|
|
if __name__ == '__main__': |
|
nemo_convert(sys.argv[1:]) |
|
|