# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Script for calibrating a pretrained ASR model for quantization """ from argparse import ArgumentParser import torch from omegaconf import open_dict from nemo.collections.asr.models import EncDecCTCModel from nemo.utils import logging try: from pytorch_quantization import calib from pytorch_quantization import nn as quant_nn from pytorch_quantization import quant_modules from pytorch_quantization.tensor_quant import QuantDescriptor except ImportError: raise ImportError( "pytorch-quantization is not installed. Install from " "https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization." ) try: from torch.cuda.amp import autocast except ImportError: from contextlib import contextmanager @contextmanager def autocast(enabled=None): yield can_gpu = torch.cuda.is_available() def main(): parser = ArgumentParser() parser.add_argument( "--asr_model", type=str, default="QuartzNet15x5Base-En", required=True, help="Pass: 'QuartzNet15x5Base-En'", ) parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data") parser.add_argument("--batch_size", type=int, default=256) parser.add_argument( "--dont_normalize_text", default=False, action='store_false', help="Turn off trasnscript normalization. Recommended for non-English.", ) parser.add_argument('--num_calib_batch', default=1, type=int, help="Number of batches for calibration.") parser.add_argument('--calibrator', type=str, choices=["max", "histogram"], default="max") parser.add_argument('--percentile', nargs='+', type=float, default=[99.9, 99.99, 99.999, 99.9999]) parser.add_argument("--amp", action="store_true", help="Use AMP in calibration.") parser.set_defaults(amp=False) args = parser.parse_args() torch.set_grad_enabled(False) # Initialize quantization quant_desc_input = QuantDescriptor(calib_method=args.calibrator) quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input) quant_nn.QuantConvTranspose2d.set_default_quant_desc_input(quant_desc_input) quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input) if args.asr_model.endswith('.nemo'): logging.info(f"Using local ASR model from {args.asr_model}") asr_model_cfg = EncDecCTCModel.restore_from(restore_path=args.asr_model, return_config=True) with open_dict(asr_model_cfg): asr_model_cfg.encoder.quantize = True asr_model = EncDecCTCModel.restore_from(restore_path=args.asr_model, override_config_path=asr_model_cfg) else: logging.info(f"Using NGC cloud ASR model {args.asr_model}") asr_model_cfg = EncDecCTCModel.from_pretrained(model_name=args.asr_model, return_config=True) with open_dict(asr_model_cfg): asr_model_cfg.encoder.quantize = True asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model, override_config_path=asr_model_cfg) asr_model.setup_test_data( test_data_config={ 'sample_rate': 16000, 'manifest_filepath': args.dataset, 'labels': asr_model.decoder.vocabulary, 'batch_size': args.batch_size, 'normalize_transcripts': args.dont_normalize_text, 'shuffle': True, } ) asr_model.preprocessor.featurizer.dither = 0.0 asr_model.preprocessor.featurizer.pad_to = 0 if can_gpu: asr_model = asr_model.cuda() asr_model.eval() # Enable calibrators for name, module in asr_model.named_modules(): if isinstance(module, quant_nn.TensorQuantizer): if module._calibrator is not None: module.disable_quant() module.enable_calib() else: module.disable() for i, test_batch in enumerate(asr_model.test_dataloader()): if can_gpu: test_batch = [x.cuda() for x in test_batch] if args.amp: with autocast(): _ = asr_model(input_signal=test_batch[0], input_signal_length=test_batch[1]) else: _ = asr_model(input_signal=test_batch[0], input_signal_length=test_batch[1]) if i >= args.num_calib_batch: break # Save calibrated model(s) model_name = args.asr_model.replace(".nemo", "") if args.asr_model.endswith(".nemo") else args.asr_model if not args.calibrator == "histogram": compute_amax(asr_model, method="max") asr_model.save_to(F"{model_name}-max-{args.num_calib_batch*args.batch_size}.nemo") else: for percentile in args.percentile: print(F"{percentile} percentile calibration") compute_amax(asr_model, method="percentile") asr_model.save_to(F"{model_name}-percentile-{percentile}-{args.num_calib_batch*args.batch_size}.nemo") for method in ["mse", "entropy"]: print(F"{method} calibration") compute_amax(asr_model, method=method) asr_model.save_to(F"{model_name}-{method}-{args.num_calib_batch*args.batch_size}.nemo") def compute_amax(model, **kwargs): for name, module in model.named_modules(): if isinstance(module, quant_nn.TensorQuantizer): if module._calibrator is not None: if isinstance(module._calibrator, calib.MaxCalibrator): module.load_calib_amax() else: module.load_calib_amax(**kwargs) print(F"{name:40}: {module}") if can_gpu: model.cuda() if __name__ == '__main__': main() # noqa pylint: disable=no-value-for-parameter