|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""This script converts old Jasper/QuartzNet models from NeMo 0.11.* to NeMo v1.0.0* |
|
""" |
|
|
|
import argparse |
|
|
|
import torch |
|
from omegaconf import DictConfig |
|
from ruamel.yaml import YAML |
|
|
|
import nemo.collections.asr as nemo_asr |
|
from nemo.utils import logging |
|
|
|
|
|
def get_parser(): |
|
parser = argparse.ArgumentParser(description="Converts old Jasper/QuartzNet models to NeMo v1.0beta") |
|
parser.add_argument("--config_path", default=None, required=True, help="Path to model config (NeMo v1.0beta)") |
|
parser.add_argument("--encoder_ckpt", default=None, required=True, help="Encoder checkpoint path") |
|
parser.add_argument("--decoder_ckpt", default=None, required=True, help="Decoder checkpoint path") |
|
parser.add_argument("--output_path", default=None, required=True, help="Output checkpoint path (should be .nemo)") |
|
parser.add_argument( |
|
"--model_type", |
|
default='asr', |
|
type=str, |
|
choices=['asr', 'speech_label', 'speaker'], |
|
help="Type of decoder used by the model.", |
|
) |
|
|
|
return parser |
|
|
|
|
|
def main(config_path, encoder_ckpt, decoder_ckpt, output_path, model_type): |
|
|
|
yaml = YAML(typ='safe') |
|
with open(config_path) as f: |
|
params = yaml.load(f) |
|
|
|
model = None |
|
if model_type == 'asr': |
|
logging.info("Creating ASR NeMo 1.0 model") |
|
model = nemo_asr.models.EncDecCTCModel(cfg=DictConfig(params['model'])) |
|
elif model_type == 'speech_label': |
|
logging.info("Creating speech label NeMo 1.0 model") |
|
model = nemo_asr.models.EncDecClassificationModel(cfg=DictConfig(params['model'])) |
|
else: |
|
logging.info("Creating Speaker Recognition NeMo 1.0 model") |
|
model = nemo_asr.models.EncDecSpeakerLabelModel(cfg=DictConfig(params['model'])) |
|
|
|
model.encoder.load_state_dict(torch.load(encoder_ckpt)) |
|
model.decoder.load_state_dict(torch.load(decoder_ckpt)) |
|
logging.info("Succesfully ported old checkpoint") |
|
|
|
model.save_to(output_path) |
|
logging.info("new model saved at {}".format(output_path)) |
|
|
|
|
|
if __name__ == "__main__": |
|
args = get_parser().parse_args() |
|
main(args.config_path, args.encoder_ckpt, args.decoder_ckpt, args.output_path, args.model_type) |
|
|