NeMo / examples /asr /asr_vad /speech_to_text_with_vad.py
camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file provides the ASR+VAD inference pipeline, with the option to perform only ASR or VAD alone.
There are two types of input, the first one is a manifest passed to `manifest_filepath`,
and the other one is to pass a directory containing audios to `audio_dir` and specify `audio_type`.
The input manifest must be a manifest json file, where each line is a Python dictionary. The fields ["audio_filepath", "offset", "duration", "text"] are required. An example of a manifest file is:
```
{"audio_filepath": "/path/to/audio_file1", "offset": 0, "duration": 10000, "text": "a b c d e"}
{"audio_filepath": "/path/to/audio_file2", "offset": 0, "duration": 10000, "text": "f g h i j"}
```
To run the code with ASR+VAD default settings:
```bash
python speech_to_text_with_vad.py \
manifest_filepath=/PATH/TO/MANIFEST.json \
vad_model=vad_multilingual_marblenet \
asr_model=stt_en_conformer_ctc_large \
vad_config=../conf/vad/vad_inference_postprocess.yaml
```
To use only ASR and disable VAD, set `vad_model=None` and `use_rttm=False`.
To use only VAD, set `asr_model=None` and specify both `vad_model` and `vad_config`.
To enable profiling, set `profiling=True`, but this will significantly slow down the program.
To use or disable feature masking, set `use_rttm` to `True` or `False`.
To normalize feature before masking, set `normalize=pre_norm`,
and set `normalize=post_norm` for masking before normalization.
To use a specific value for feature masking, set `feat_mask_val` to the desired value.
Default is `feat_mask_val=None`, where -16.530 will be used for `post_norm` and 0 will be used for `pre_norm`.
See more options in the `InferenceConfig` class.
"""
import contextlib
import json
import os
import time
from dataclasses import dataclass, is_dataclass
from pathlib import Path
from typing import Callable, Optional
import torch
import yaml
from omegaconf import DictConfig, OmegaConf
from torch.profiler import ProfilerActivity, profile, record_function
from tqdm import tqdm
from nemo.collections.asr.data import feature_to_text_dataset
from nemo.collections.asr.metrics.rnnt_wer import RNNTDecodingConfig
from nemo.collections.asr.metrics.wer import CTCDecodingConfig, word_error_rate
from nemo.collections.asr.models import ASRModel, EncDecClassificationModel
from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest
from nemo.collections.asr.parts.utils.vad_utils import (
extract_audio_features,
generate_overlap_vad_seq,
generate_vad_segment_table,
get_vad_stream_status,
init_vad_model,
)
from nemo.core.config import hydra_runner
from nemo.utils import logging
try:
from torch.cuda.amp import autocast
except ImportError:
@contextlib.contextmanager
def autocast(enabled=None):
yield
@dataclass
class InferenceConfig:
# Required configs
asr_model: Optional[str] = None # Path to a .nemo file or a pretrained NeMo model on NGC
vad_model: Optional[str] = None # Path to a .nemo file or a pretrained NeMo model on NGC
vad_config: Optional[str] = None # Path to a yaml file containing VAD post-processing configs
manifest_filepath: Optional[str] = None # Path to dataset's JSON manifest
audio_dir: Optional[str] = None
use_rttm: bool = True # whether to use RTTM
feat_mask_val: Optional[float] = None # value used to mask features based on RTTM, set None to use defaults
normalize: Optional[
str
] = "post_norm" # whether and where to normalize feature, choices=[None, `pre_norm`, `post_norm`]
normalize_type: str = "per_feature" # how to determine mean and std used for normalization
use_pure_noise: bool = False # whether input is pure noise or not.
profiling: bool = False # whether to enable pytorch profiling
# General configs
batch_size: int = 1 # batch size for ASR. Feature extraction and VAD only support single sample per batch.
num_workers: int = 8
sample_rate: int = 16000
frame_unit_time_secs: float = 0.01 # unit time per frame in seconds, equal to `window_stride` in ASR configs.
audio_type: str = "wav"
# Output settings, no need to change
output_dir: Optional[str] = None # will be automatically set by the program
output_filename: Optional[str] = None # will be automatically set by the program
pred_name_postfix: Optional[str] = None # If you need to use another model name, rather than standard one.
# Set to True to output language ID information
compute_langs: bool = False
# Decoding strategy for CTC models
ctc_decoding: CTCDecodingConfig = CTCDecodingConfig()
# Decoding strategy for RNNT models
rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig(fused_batch_size=-1)
@hydra_runner(config_name="InferenceConfig", schema=InferenceConfig)
def main(cfg):
if is_dataclass(cfg):
cfg = OmegaConf.structured(cfg)
if cfg.output_dir is None:
cfg.output_dir = "./outputs"
output_dir = Path(cfg.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
# setup profiling, note that profiling will significantly increast the total runtime
if cfg.profiling:
logging.info("Profiling enabled")
profile_fn = profile
record_fn = record_function
else:
logging.info("Profiling disabled")
@contextlib.contextmanager
def profile_fn(*args, **kwargs):
yield
@contextlib.contextmanager
def record_fn(*args, **kwargs):
yield
input_manifest_file = prepare_inference_manifest(cfg)
if cfg.manifest_filepath is None:
cfg.manifest_filepath = str(input_manifest_file)
with profile_fn(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True
) as prof:
input_manifest_file = extract_audio_features(input_manifest_file, cfg, record_fn)
if cfg.vad_model is not None:
logging.info(f"Running VAD with model: {cfg.vad_model}")
input_manifest_file = run_vad_inference(input_manifest_file, cfg, record_fn)
if cfg.asr_model is not None:
logging.info(f"Running ASR with model: {cfg.asr_model}")
run_asr_inference(input_manifest_file, cfg, record_fn)
if cfg.profiling:
print("--------------------------------------------------------------------\n")
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=15))
print("--------------------------------------------------------------------\n")
logging.info("Done.")
def prepare_inference_manifest(cfg: DictConfig) -> str:
if cfg.audio_dir is not None and cfg.manifest_filepath is None:
manifest_data = []
for audio_file in Path(cfg.audio_dir).glob(f"**/*.{cfg.audio_type}"):
item = {"audio_filepath": str(audio_file.absolute()), "duration": 1000000, "offset": 0}
manifest_data.append(item)
parent_dir = Path(cfg.audio_dir)
else:
manifest_data = read_manifest(cfg.manifest_filepath)
parent_dir = Path(cfg.manifest_filepath).parent
new_manifest_data = []
for item in manifest_data:
audio_file = Path(item["audio_filepath"])
if len(str(audio_file)) < 255 and not audio_file.is_file() and not audio_file.is_absolute():
new_audio_file = parent_dir / audio_file
if new_audio_file.is_file():
item["audio_filepath"] = str(new_audio_file.absolute())
else:
item["audio_filepath"] = os.path.expanduser(str(audio_file))
else:
item["audio_filepath"] = os.path.expanduser(str(audio_file))
item["label"] = "infer"
item["text"] = "-"
new_manifest_data.append(item)
new_manifest_filepath = str(Path(cfg.output_dir) / Path("temp_manifest_input.json"))
write_manifest(new_manifest_filepath, new_manifest_data)
return new_manifest_filepath
def extract_audio_features(manifest_filepath: str, cfg: DictConfig, record_fn: Callable) -> str:
file_list = []
manifest_data = []
out_dir = Path(cfg.output_dir) / Path("features")
new_manifest_filepath = str(Path(cfg.output_dir) / Path("temp_manifest_input_feature.json"))
if Path(new_manifest_filepath).is_file():
logging.info("Features already exist in output_dir, skipping feature extraction.")
return new_manifest_filepath
has_feat = False
with open(manifest_filepath, 'r', encoding='utf-8') as fin:
for line in fin.readlines():
item = json.loads(line.strip())
manifest_data.append(item)
file_list.append(Path(item['audio_filepath']).stem)
if "feature_file" in item:
has_feat = True
if has_feat:
logging.info("Features already exist in manifest, skipping feature extraction.")
return manifest_filepath
out_dir.mkdir(parents=True, exist_ok=True)
torch.set_grad_enabled(False)
vad_model = EncDecClassificationModel.from_pretrained("vad_multilingual_marblenet")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vad_model = vad_model.to(device)
vad_model.eval()
vad_model.setup_test_data(
test_data_config={
'batch_size': 1,
'vad_stream': False,
'sample_rate': cfg.sample_rate,
'manifest_filepath': manifest_filepath,
'labels': ['infer',],
'num_workers': cfg.num_workers,
'shuffle': False,
}
)
logging.info(f"Extracting features on {len(file_list)} audio files...")
with record_fn("feat_extract_loop"):
for i, test_batch in enumerate(tqdm(vad_model.test_dataloader(), total=len(vad_model.test_dataloader()))):
test_batch = [x.to(vad_model.device) for x in test_batch]
with autocast():
with record_fn("feat_extract_infer"):
processed_signal, processed_signal_length = vad_model.preprocessor(
input_signal=test_batch[0], length=test_batch[1],
)
with record_fn("feat_extract_other"):
processed_signal = processed_signal.squeeze(0)[:, :processed_signal_length]
processed_signal = processed_signal.cpu()
outpath = os.path.join(out_dir, file_list[i] + ".pt")
outpath = str(Path(outpath).absolute())
torch.save(processed_signal, outpath)
manifest_data[i]["feature_file"] = outpath
del test_batch
logging.info(f"Features saved at: {out_dir}")
write_manifest(new_manifest_filepath, manifest_data)
return new_manifest_filepath
def run_vad_inference(manifest_filepath: str, cfg: DictConfig, record_fn: Callable) -> str:
logging.info("Start VAD inference pipeline...")
vad_model = init_vad_model(cfg.vad_model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vad_model = vad_model.to(device)
vad_model.eval()
vad_yaml = Path(cfg.vad_config)
if not vad_yaml.is_file():
raise ValueError(f"VAD config file not found: {cfg.vad_config}")
with vad_yaml.open("r") as fp:
vad_cfg = yaml.safe_load(fp)
vad_cfg = DictConfig(vad_cfg)
test_data_config = {
'vad_stream': True,
'manifest_filepath': manifest_filepath,
'labels': ['infer',],
'num_workers': cfg.num_workers,
'shuffle': False,
'window_length_in_sec': vad_cfg.vad.parameters.window_length_in_sec,
'shift_length_in_sec': vad_cfg.vad.parameters.shift_length_in_sec,
}
vad_model.setup_test_data(test_data_config=test_data_config, use_feat=True)
pred_dir = Path(cfg.output_dir) / Path("vad_frame_pred")
if pred_dir.is_dir():
logging.info(f"VAD frame-level prediction already exists: {pred_dir}, skipped")
else:
logging.info("Generating VAD frame-level prediction")
pred_dir.mkdir(parents=True)
t0 = time.time()
pred_dir = generate_vad_frame_pred(
vad_model=vad_model,
window_length_in_sec=vad_cfg.vad.parameters.window_length_in_sec,
shift_length_in_sec=vad_cfg.vad.parameters.shift_length_in_sec,
manifest_vad_input=manifest_filepath,
out_dir=str(pred_dir),
use_feat=True,
record_fn=record_fn,
)
t1 = time.time()
logging.info(f"Time elapsed: {t1 - t0: .2f} seconds")
logging.info(
f"Finished generating VAD frame level prediction with window_length_in_sec={vad_cfg.vad.parameters.window_length_in_sec} and shift_length_in_sec={vad_cfg.vad.parameters.shift_length_in_sec}"
)
frame_length_in_sec = vad_cfg.vad.parameters.shift_length_in_sec
# overlap smoothing filter
if vad_cfg.vad.parameters.smoothing:
# Generate predictions with overlapping input segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple segments.
# smoothing_method would be either in majority vote (median) or average (mean)
logging.info("Generating predictions with overlapping input segments")
t0 = time.time()
smoothing_pred_dir = generate_overlap_vad_seq(
frame_pred_dir=pred_dir,
smoothing_method=vad_cfg.vad.parameters.smoothing,
overlap=vad_cfg.vad.parameters.overlap,
window_length_in_sec=vad_cfg.vad.parameters.window_length_in_sec,
shift_length_in_sec=vad_cfg.vad.parameters.shift_length_in_sec,
num_workers=cfg.num_workers,
out_dir=vad_cfg.smoothing_out_dir,
)
logging.info(
f"Finish generating predictions with overlapping input segments with smoothing_method={vad_cfg.vad.parameters.smoothing} and overlap={vad_cfg.vad.parameters.overlap}"
)
t1 = time.time()
logging.info(f"Time elapsed: {t1 - t0: .2f} seconds")
pred_dir = smoothing_pred_dir
frame_length_in_sec = 0.01
# Turn frame-wise prediction into speech intervals
logging.info(f"Generating segment tables with postprocessing params: {vad_cfg.vad.parameters.postprocessing}")
segment_dir_name = "vad_rttm"
for key, val in vad_cfg.vad.parameters.postprocessing.items():
if key == "use_rttm":
continue
segment_dir_name = segment_dir_name + "-" + str(key) + str(val)
segment_dir = Path(cfg.output_dir) / Path(segment_dir_name)
if segment_dir.is_dir():
logging.info(f"VAD speech segments already exists: {segment_dir}, skipped")
else:
segment_dir.mkdir(parents=True)
t0 = time.time()
vad_cfg.vad.parameters.postprocessing.use_rttm = True
segment_dir = generate_vad_segment_table(
vad_pred_dir=pred_dir,
postprocessing_params=vad_cfg.vad.parameters.postprocessing,
frame_length_in_sec=frame_length_in_sec,
num_workers=cfg.num_workers,
out_dir=segment_dir,
)
t1 = time.time()
logging.info(f"Time elapsed: {t1 - t0: .2f} seconds")
logging.info("Finished generating RTTM files from VAD predictions.")
rttm_map = {}
for filepath in Path(segment_dir).glob("*.rttm"):
rttm_map[filepath.stem] = str(filepath.absolute())
manifest_data = read_manifest(manifest_filepath)
for i in range(len(manifest_data)):
key = Path(manifest_data[i]["audio_filepath"]).stem
manifest_data[i]["rttm_file"] = rttm_map[key]
new_manifest_filepath = str(Path(cfg.output_dir) / Path(f"temp_manifest_{segment_dir_name}.json"))
write_manifest(new_manifest_filepath, manifest_data)
return new_manifest_filepath
def generate_vad_frame_pred(
vad_model: EncDecClassificationModel,
window_length_in_sec: float,
shift_length_in_sec: float,
manifest_vad_input: str,
out_dir: str,
use_feat: bool = False,
record_fn: Callable = None,
) -> str:
"""
Generate VAD frame level prediction and write to out_dir
"""
time_unit = int(window_length_in_sec / shift_length_in_sec)
trunc = int(time_unit / 2)
trunc_l = time_unit - trunc
all_len = 0
data = []
with open(manifest_vad_input, 'r', encoding='utf-8') as fin:
for line in fin.readlines():
file = json.loads(line)['audio_filepath'].split("/")[-1]
data.append(file.split(".wav")[0])
logging.info(f"Inference on {len(data)} audio files/json lines!")
status = get_vad_stream_status(data)
with record_fn("vad_infer_loop"):
for i, test_batch in enumerate(tqdm(vad_model.test_dataloader(), total=len(vad_model.test_dataloader()))):
test_batch = [x.to(vad_model.device) for x in test_batch]
with autocast():
with record_fn("vad_infer_model"):
if use_feat:
log_probs = vad_model(processed_signal=test_batch[0], processed_signal_length=test_batch[1])
else:
log_probs = vad_model(input_signal=test_batch[0], input_signal_length=test_batch[1])
with record_fn("vad_infer_other"):
probs = torch.softmax(log_probs, dim=-1)
pred = probs[:, 1]
if status[i] == 'start':
to_save = pred[:-trunc]
elif status[i] == 'next':
to_save = pred[trunc:-trunc_l]
elif status[i] == 'end':
to_save = pred[trunc_l:]
else:
to_save = pred
all_len += len(to_save)
outpath = os.path.join(out_dir, data[i] + ".frame")
with open(outpath, "a", encoding='utf-8') as fout:
for f in range(len(to_save)):
fout.write('{0:0.4f}\n'.format(to_save[f]))
del test_batch
if status[i] == 'end' or status[i] == 'single':
all_len = 0
return out_dir
def init_asr_model(model_path: str) -> ASRModel:
if model_path.endswith('.nemo'):
logging.info(f"Using local ASR model from {model_path}")
asr_model = ASRModel.restore_from(restore_path=model_path)
elif model_path.endswith('.ckpt'):
asr_model = ASRModel.load_from_checkpoint(checkpoint_path=model_path)
else:
logging.info(f"Using NGC ASR model {model_path}")
asr_model = ASRModel.from_pretrained(model_name=model_path)
return asr_model
def run_asr_inference(manifest_filepath, cfg, record_fn) -> str:
logging.info("Start ASR inference pipeline...")
asr_model = init_asr_model(cfg.asr_model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
asr_model = asr_model.to(device)
asr_model.eval()
# Setup decoding strategy
decode_function = None
if hasattr(asr_model, 'change_decoding_strategy'):
# Check if ctc or rnnt model
if hasattr(asr_model, 'joint'): # RNNT model
cfg.rnnt_decoding.fused_batch_size = -1
cfg.rnnt_decoding.compute_langs = cfg.compute_langs
asr_model.change_decoding_strategy(cfg.rnnt_decoding)
decode_function = asr_model.decoding.rnnt_decoder_predictions_tensor
else:
asr_model.change_decoding_strategy(cfg.ctc_decoding)
decode_function = asr_model.decoding.ctc_decoder_predictions_tensor
else:
raise ValueError(f"Only support CTC or RNNT models that have `change_decoding_strategy()` implemented.")
# Compute output filename
if cfg.output_filename is None:
# create default output filename
if cfg.pred_name_postfix is not None:
cfg.output_filename = cfg.manifest_filepath.replace('.json', f'_{cfg.pred_name_postfix}.json')
else:
tag = f"{cfg.normalize}_{cfg.normalize_type}"
if cfg.use_rttm:
vad_tag = Path(manifest_filepath).stem
vad_tag = vad_tag[len("temp_manifest_vad_rttm_") :]
tag += f"-mask{cfg.feat_mask_val}-{vad_tag}"
cfg.output_filename = cfg.manifest_filepath.replace('.json', f'-{Path(cfg.asr_model).stem}-{tag}.json')
cfg.output_filename = Path(cfg.output_dir) / Path(cfg.output_filename).name
logging.info("Setting up dataloader for ASR...")
data_config = {
"manifest_filepath": manifest_filepath,
"normalize": cfg.normalize,
"normalize_type": cfg.normalize_type,
"use_rttm": cfg.use_rttm,
"feat_mask_val": cfg.feat_mask_val,
"frame_unit_time_secs": cfg.frame_unit_time_secs,
}
logging.info(f"use_rttm = {cfg.use_rttm}")
if hasattr(asr_model, "tokenizer"):
dataset = feature_to_text_dataset.get_bpe_dataset(config=data_config, tokenizer=asr_model.tokenizer)
else:
data_config["labels"] = asr_model.decoder.vocabulary
dataset = feature_to_text_dataset.get_char_dataset(config=data_config)
dataloader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=cfg.batch_size,
collate_fn=dataset._collate_fn,
drop_last=False,
shuffle=False,
num_workers=cfg.get('num_workers', 0),
pin_memory=cfg.get('pin_memory', False),
)
logging.info("Start transcribing...")
hypotheses = []
all_hypotheses = []
t0 = time.time()
with autocast():
with torch.no_grad():
with record_fn("asr_infer_loop"):
for test_batch in tqdm(dataloader, desc="Transcribing"):
with record_fn("asr_infer_model"):
outputs = asr_model.forward(
processed_signal=test_batch[0].to(device),
processed_signal_length=test_batch[1].to(device),
)
with record_fn("asr_infer_other"):
logits, logits_len = outputs[0], outputs[1]
current_hypotheses, all_hyp = decode_function(logits, logits_len, return_hypotheses=False,)
hypotheses += current_hypotheses
if all_hyp is not None:
all_hypotheses += all_hyp
else:
all_hypotheses += current_hypotheses
del logits
del test_batch
t1 = time.time()
logging.info(f"Time elapsed: {t1 - t0: .2f} seconds")
logging.info("Finished transcribing.")
# Save output to manifest
input_manifest_data = read_manifest(manifest_filepath)
manifest_data = read_manifest(cfg.manifest_filepath)
groundtruth = []
for i in range(len(manifest_data)):
groundtruth.append(manifest_data[i]["text"])
manifest_data[i]["pred_text"] = hypotheses[i]
manifest_data[i]["feature_file"] = input_manifest_data[i]["feature_file"]
if "rttm_file" in input_manifest_data[i]:
manifest_data[i]["feature_file"] = input_manifest_data[i]["feature_file"]
write_manifest(cfg.output_filename, manifest_data)
if cfg.use_pure_noise:
hypotheses = " ".join(hypotheses)
words = hypotheses.split()
chars = "".join(words)
logging.info("-----------------------------------------")
logging.info(f"Number of hallucinated characters={len(chars)}")
logging.info(f"Number of hallucinated words={len(words)}")
logging.info(f"Concatenated predictions: {hypotheses}")
logging.info("-----------------------------------------")
else:
wer_score = word_error_rate(hypotheses=hypotheses, references=groundtruth)
logging.info("-----------------------------------------")
logging.info(f"WER={wer_score*100:.2f}")
logging.info("-----------------------------------------")
logging.info(f"ASR output saved at {cfg.output_filename}")
return cfg.output_filename
if __name__ == "__main__":
main()