# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import contextlib import os from dataclasses import dataclass, is_dataclass from typing import Optional import pytorch_lightning as pl import torch from omegaconf import OmegaConf from nemo.collections.asr.metrics.rnnt_wer import RNNTDecodingConfig from nemo.collections.asr.metrics.wer import CTCDecodingConfig from nemo.collections.asr.models.ctc_models import EncDecCTCModel from nemo.collections.asr.modules.conformer_encoder import ConformerChangeConfig from nemo.collections.asr.parts.utils.transcribe_utils import ( compute_output_filename, prepare_audio_data, setup_model, transcribe_partial_audio, write_transcription, ) from nemo.collections.common.tokenizers.aggregate_tokenizer import AggregateTokenizer from nemo.core.config import hydra_runner from nemo.utils import logging """ Transcribe audio file on a single CPU/GPU. Useful for transcription of moderate amounts of audio data. # Arguments model_path: path to .nemo ASR checkpoint pretrained_name: name of pretrained ASR model (from NGC registry) audio_dir: path to directory with audio files dataset_manifest: path to dataset JSON manifest file (in NeMo format) compute_timestamps: Bool to request greedy time stamp information (if the model supports it) compute_langs: Bool to request language ID information (if the model supports it) (Optionally: You can limit the type of timestamp computations using below overrides) ctc_decoding.ctc_timestamp_type="all" # (default all, can be [all, char, word]) rnnt_decoding.rnnt_timestamp_type="all" # (default all, can be [all, char, word]) (Optionally: You can limit the type of timestamp computations using below overrides) ctc_decoding.ctc_timestamp_type="all" # (default all, can be [all, char, word]) rnnt_decoding.rnnt_timestamp_type="all" # (default all, can be [all, char, word]) output_filename: Output filename where the transcriptions will be written batch_size: batch size during inference cuda: Optional int to enable or disable execution of model on certain CUDA device. amp: Bool to decide if Automatic Mixed Precision should be used during inference audio_type: Str filetype of the audio. Supported = wav, flac, mp3 overwrite_transcripts: Bool which when set allows repeated transcriptions to overwrite previous results. ctc_decoding: Decoding sub-config for CTC. Refer to documentation for specific values. rnnt_decoding: Decoding sub-config for RNNT. Refer to documentation for specific values. # Usage ASR model can be specified by either "model_path" or "pretrained_name". Data for transcription can be defined with either "audio_dir" or "dataset_manifest". append_pred - optional. Allows you to add more than one prediction to an existing .json pred_name_postfix - optional. The name you want to be written for the current model Results are returned in a JSON manifest file. python transcribe_speech.py \ model_path=null \ pretrained_name=null \ audio_dir="" \ dataset_manifest="" \ output_filename="" \ batch_size=32 \ compute_timestamps=False \ compute_langs=False \ cuda=0 \ amp=True \ append_pred=False \ pred_name_postfix="" """ @dataclass class ModelChangeConfig: # Sub-config for changes specific to the Conformer Encoder conformer: ConformerChangeConfig = ConformerChangeConfig() @dataclass class TranscriptionConfig: # Required configs model_path: Optional[str] = None # Path to a .nemo file pretrained_name: Optional[str] = None # Name of a pretrained model audio_dir: Optional[str] = None # Path to a directory which contains audio files dataset_manifest: Optional[str] = None # Path to dataset's JSON manifest channel_selector: Optional[int] = None # Used to select a single channel from multi-channel files audio_key: str = 'audio_filepath' # Used to override the default audio key in dataset_manifest eval_config_yaml: Optional[str] = None # Path to a yaml file of config of evaluation # General configs output_filename: Optional[str] = None batch_size: int = 32 num_workers: int = 0 append_pred: bool = False # Sets mode of work, if True it will add new field transcriptions. pred_name_postfix: Optional[str] = None # If you need to use another model name, rather than standard one. random_seed: Optional[int] = None # seed number going to be used in seed_everything() # Set to True to output greedy timestamp information (only supported models) compute_timestamps: bool = False # Set to True to output language ID information compute_langs: bool = False # Set `cuda` to int to define CUDA device. If 'None', will look for CUDA # device anyway, and do inference on CPU only if CUDA device is not found. # If `cuda` is a negative number, inference will be on CPU only. cuda: Optional[int] = None amp: bool = False audio_type: str = "wav" # Recompute model transcription, even if the output folder exists with scores. overwrite_transcripts: bool = True # Decoding strategy for CTC models ctc_decoding: CTCDecodingConfig = CTCDecodingConfig() # Decoding strategy for RNNT models rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig(fused_batch_size=-1) # decoder type: ctc or rnnt, can be used to switch between CTC and RNNT decoder for Joint RNNT/CTC models decoder_type: Optional[str] = None # Use this for model-specific changes before transcription model_change: ModelChangeConfig = ModelChangeConfig() @hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig) def main(cfg: TranscriptionConfig) -> TranscriptionConfig: logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') if is_dataclass(cfg): cfg = OmegaConf.structured(cfg) if cfg.random_seed: pl.seed_everything(cfg.random_seed) if cfg.model_path is None and cfg.pretrained_name is None: raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None!") if cfg.audio_dir is None and cfg.dataset_manifest is None: raise ValueError("Both cfg.audio_dir and cfg.dataset_manifest cannot be None!") # Load augmentor from exteranl yaml file which contains eval info, could be extend to other feature such VAD, P&C augmentor = None if cfg.eval_config_yaml: eval_config = OmegaConf.load(cfg.eval_config_yaml) augmentor = eval_config.test_ds.get("augmentor") logging.info(f"Will apply on-the-fly augmentation on samples during transcription: {augmentor} ") # setup GPU if cfg.cuda is None: if torch.cuda.is_available(): device = [0] # use 0th CUDA device accelerator = 'gpu' else: device = 1 accelerator = 'cpu' else: device = [cfg.cuda] accelerator = 'gpu' map_location = torch.device('cuda:{}'.format(device[0]) if accelerator == 'gpu' else 'cpu') logging.info(f"Inference will be done on device : {device}") asr_model, model_name = setup_model(cfg, map_location) trainer = pl.Trainer(devices=device, accelerator=accelerator) asr_model.set_trainer(trainer) asr_model = asr_model.eval() # collect additional transcription information return_hypotheses = True # we will adjust this flag is the model does not support it compute_timestamps = cfg.compute_timestamps compute_langs = cfg.compute_langs # Setup decoding strategy if hasattr(asr_model, 'change_decoding_strategy'): if cfg.decoder_type is not None: # TODO: Support compute_langs in CTC eventually if cfg.compute_langs and cfg.decoder_type == 'ctc': raise ValueError("CTC models do not support `compute_langs` at the moment") decoding_cfg = cfg.rnnt_decoding if cfg.decoder_type == 'rnnt' else cfg.ctc_decoding decoding_cfg.compute_timestamps = cfg.compute_timestamps # both ctc and rnnt support it if 'preserve_alignments' in decoding_cfg: decoding_cfg.preserve_alignments = cfg.compute_timestamps if 'compute_langs' in decoding_cfg: decoding_cfg.compute_langs = cfg.compute_langs asr_model.change_decoding_strategy(decoding_cfg, decoder_type=cfg.decoder_type) # Check if ctc or rnnt model elif hasattr(asr_model, 'joint'): # RNNT model cfg.rnnt_decoding.fused_batch_size = -1 cfg.rnnt_decoding.compute_timestamps = cfg.compute_timestamps cfg.rnnt_decoding.compute_langs = cfg.compute_langs if 'preserve_alignments' in cfg.rnnt_decoding: cfg.rnnt_decoding.preserve_alignments = cfg.compute_timestamps asr_model.change_decoding_strategy(cfg.rnnt_decoding) else: if cfg.compute_langs: raise ValueError("CTC models do not support `compute_langs` at the moment.") cfg.ctc_decoding.compute_timestamps = cfg.compute_timestamps asr_model.change_decoding_strategy(cfg.ctc_decoding) # prepare audio filepaths and decide wether it's partical audio filepaths, partial_audio = prepare_audio_data(cfg) # setup AMP (optional) if cfg.amp and torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): logging.info("AMP enabled!\n") autocast = torch.cuda.amp.autocast else: @contextlib.contextmanager def autocast(): yield # Compute output filename cfg = compute_output_filename(cfg, model_name) # if transcripts should not be overwritten, and already exists, skip re-transcription step and return if not cfg.overwrite_transcripts and os.path.exists(cfg.output_filename): logging.info( f"Previous transcripts found at {cfg.output_filename}, and flag `overwrite_transcripts`" f"is {cfg.overwrite_transcripts}. Returning without re-transcribing text." ) return cfg # transcribe audio with autocast(): with torch.no_grad(): if partial_audio: if isinstance(asr_model, EncDecCTCModel): transcriptions = transcribe_partial_audio( asr_model=asr_model, path2manifest=cfg.dataset_manifest, batch_size=cfg.batch_size, num_workers=cfg.num_workers, return_hypotheses=return_hypotheses, channel_selector=cfg.channel_selector, augmentor=augmentor, ) else: logging.warning( "RNNT models do not support transcribe partial audio for now. Transcribing full audio." ) transcriptions = asr_model.transcribe( paths2audio_files=filepaths, batch_size=cfg.batch_size, num_workers=cfg.num_workers, return_hypotheses=return_hypotheses, channel_selector=cfg.channel_selector, augmentor=augmentor, ) else: transcriptions = asr_model.transcribe( paths2audio_files=filepaths, batch_size=cfg.batch_size, num_workers=cfg.num_workers, return_hypotheses=return_hypotheses, channel_selector=cfg.channel_selector, augmentor=augmentor, ) logging.info(f"Finished transcribing {len(filepaths)} files !") logging.info(f"Writing transcriptions into file: {cfg.output_filename}") # if transcriptions form a tuple (from RNNT), extract just "best" hypothesis if type(transcriptions) == tuple and len(transcriptions) == 2: transcriptions = transcriptions[0] # write audio transcriptions output_filename = write_transcription( transcriptions, cfg, model_name, filepaths=filepaths, compute_langs=compute_langs, compute_timestamps=compute_timestamps, ) logging.info(f"Finished writing predictions to {output_filename}!") return cfg if __name__ == '__main__': main() # noqa pylint: disable=no-value-for-parameter