|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
This file provides the ASR+VAD inference pipeline, with the option to perform only ASR or VAD alone. |
|
|
|
There are two types of input, the first one is a manifest passed to `manifest_filepath`, |
|
and the other one is to pass a directory containing audios to `audio_dir` and specify `audio_type`. |
|
|
|
The input manifest must be a manifest json file, where each line is a Python dictionary. The fields ["audio_filepath", "offset", "duration", "text"] are required. An example of a manifest file is: |
|
``` |
|
{"audio_filepath": "/path/to/audio_file1", "offset": 0, "duration": 10000, "text": "a b c d e"} |
|
{"audio_filepath": "/path/to/audio_file2", "offset": 0, "duration": 10000, "text": "f g h i j"} |
|
``` |
|
|
|
To run the code with ASR+VAD default settings: |
|
|
|
```bash |
|
python speech_to_text_with_vad.py \ |
|
manifest_filepath=/PATH/TO/MANIFEST.json \ |
|
vad_model=vad_multilingual_marblenet \ |
|
asr_model=stt_en_conformer_ctc_large \ |
|
vad_config=../conf/vad/vad_inference_postprocess.yaml |
|
``` |
|
|
|
To use only ASR and disable VAD, set `vad_model=None` and `use_rttm=False`. |
|
|
|
To use only VAD, set `asr_model=None` and specify both `vad_model` and `vad_config`. |
|
|
|
To enable profiling, set `profiling=True`, but this will significantly slow down the program. |
|
|
|
To use or disable feature masking, set `use_rttm` to `True` or `False`. |
|
|
|
To normalize feature before masking, set `normalize=pre_norm`, |
|
and set `normalize=post_norm` for masking before normalization. |
|
|
|
To use a specific value for feature masking, set `feat_mask_val` to the desired value. |
|
Default is `feat_mask_val=None`, where -16.530 will be used for `post_norm` and 0 will be used for `pre_norm`. |
|
|
|
See more options in the `InferenceConfig` class. |
|
""" |
|
|
|
|
|
import contextlib |
|
import json |
|
import os |
|
import time |
|
from dataclasses import dataclass, is_dataclass |
|
from pathlib import Path |
|
from typing import Callable, Optional |
|
|
|
import torch |
|
import yaml |
|
from omegaconf import DictConfig, OmegaConf |
|
from torch.profiler import ProfilerActivity, profile, record_function |
|
from tqdm import tqdm |
|
|
|
from nemo.collections.asr.data import feature_to_text_dataset |
|
from nemo.collections.asr.metrics.rnnt_wer import RNNTDecodingConfig |
|
from nemo.collections.asr.metrics.wer import CTCDecodingConfig, word_error_rate |
|
from nemo.collections.asr.models import ASRModel, EncDecClassificationModel |
|
from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest |
|
from nemo.collections.asr.parts.utils.vad_utils import ( |
|
extract_audio_features, |
|
generate_overlap_vad_seq, |
|
generate_vad_segment_table, |
|
get_vad_stream_status, |
|
init_vad_model, |
|
) |
|
from nemo.core.config import hydra_runner |
|
from nemo.utils import logging |
|
|
|
try: |
|
from torch.cuda.amp import autocast |
|
except ImportError: |
|
|
|
@contextlib.contextmanager |
|
def autocast(enabled=None): |
|
yield |
|
|
|
|
|
@dataclass |
|
class InferenceConfig: |
|
|
|
asr_model: Optional[str] = None |
|
vad_model: Optional[str] = None |
|
vad_config: Optional[str] = None |
|
manifest_filepath: Optional[str] = None |
|
audio_dir: Optional[str] = None |
|
|
|
use_rttm: bool = True |
|
feat_mask_val: Optional[float] = None |
|
normalize: Optional[ |
|
str |
|
] = "post_norm" |
|
normalize_type: str = "per_feature" |
|
use_pure_noise: bool = False |
|
|
|
profiling: bool = False |
|
|
|
|
|
batch_size: int = 1 |
|
num_workers: int = 8 |
|
sample_rate: int = 16000 |
|
frame_unit_time_secs: float = 0.01 |
|
audio_type: str = "wav" |
|
|
|
|
|
output_dir: Optional[str] = None |
|
output_filename: Optional[str] = None |
|
pred_name_postfix: Optional[str] = None |
|
|
|
|
|
compute_langs: bool = False |
|
|
|
|
|
ctc_decoding: CTCDecodingConfig = CTCDecodingConfig() |
|
|
|
|
|
rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig(fused_batch_size=-1) |
|
|
|
|
|
@hydra_runner(config_name="InferenceConfig", schema=InferenceConfig) |
|
def main(cfg): |
|
|
|
if is_dataclass(cfg): |
|
cfg = OmegaConf.structured(cfg) |
|
|
|
if cfg.output_dir is None: |
|
cfg.output_dir = "./outputs" |
|
output_dir = Path(cfg.output_dir) |
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') |
|
|
|
|
|
if cfg.profiling: |
|
logging.info("Profiling enabled") |
|
profile_fn = profile |
|
record_fn = record_function |
|
else: |
|
logging.info("Profiling disabled") |
|
|
|
@contextlib.contextmanager |
|
def profile_fn(*args, **kwargs): |
|
yield |
|
|
|
@contextlib.contextmanager |
|
def record_fn(*args, **kwargs): |
|
yield |
|
|
|
input_manifest_file = prepare_inference_manifest(cfg) |
|
|
|
if cfg.manifest_filepath is None: |
|
cfg.manifest_filepath = str(input_manifest_file) |
|
|
|
with profile_fn( |
|
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True |
|
) as prof: |
|
|
|
input_manifest_file = extract_audio_features(input_manifest_file, cfg, record_fn) |
|
|
|
if cfg.vad_model is not None: |
|
logging.info(f"Running VAD with model: {cfg.vad_model}") |
|
input_manifest_file = run_vad_inference(input_manifest_file, cfg, record_fn) |
|
|
|
if cfg.asr_model is not None: |
|
logging.info(f"Running ASR with model: {cfg.asr_model}") |
|
run_asr_inference(input_manifest_file, cfg, record_fn) |
|
|
|
if cfg.profiling: |
|
print("--------------------------------------------------------------------\n") |
|
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=15)) |
|
print("--------------------------------------------------------------------\n") |
|
logging.info("Done.") |
|
|
|
|
|
def prepare_inference_manifest(cfg: DictConfig) -> str: |
|
|
|
if cfg.audio_dir is not None and cfg.manifest_filepath is None: |
|
manifest_data = [] |
|
for audio_file in Path(cfg.audio_dir).glob(f"**/*.{cfg.audio_type}"): |
|
item = {"audio_filepath": str(audio_file.absolute()), "duration": 1000000, "offset": 0} |
|
manifest_data.append(item) |
|
parent_dir = Path(cfg.audio_dir) |
|
else: |
|
manifest_data = read_manifest(cfg.manifest_filepath) |
|
parent_dir = Path(cfg.manifest_filepath).parent |
|
|
|
new_manifest_data = [] |
|
|
|
for item in manifest_data: |
|
audio_file = Path(item["audio_filepath"]) |
|
if len(str(audio_file)) < 255 and not audio_file.is_file() and not audio_file.is_absolute(): |
|
new_audio_file = parent_dir / audio_file |
|
if new_audio_file.is_file(): |
|
item["audio_filepath"] = str(new_audio_file.absolute()) |
|
else: |
|
item["audio_filepath"] = os.path.expanduser(str(audio_file)) |
|
else: |
|
item["audio_filepath"] = os.path.expanduser(str(audio_file)) |
|
item["label"] = "infer" |
|
item["text"] = "-" |
|
new_manifest_data.append(item) |
|
|
|
new_manifest_filepath = str(Path(cfg.output_dir) / Path("temp_manifest_input.json")) |
|
write_manifest(new_manifest_filepath, new_manifest_data) |
|
return new_manifest_filepath |
|
|
|
|
|
def extract_audio_features(manifest_filepath: str, cfg: DictConfig, record_fn: Callable) -> str: |
|
file_list = [] |
|
manifest_data = [] |
|
out_dir = Path(cfg.output_dir) / Path("features") |
|
new_manifest_filepath = str(Path(cfg.output_dir) / Path("temp_manifest_input_feature.json")) |
|
|
|
if Path(new_manifest_filepath).is_file(): |
|
logging.info("Features already exist in output_dir, skipping feature extraction.") |
|
return new_manifest_filepath |
|
|
|
has_feat = False |
|
with open(manifest_filepath, 'r', encoding='utf-8') as fin: |
|
for line in fin.readlines(): |
|
item = json.loads(line.strip()) |
|
manifest_data.append(item) |
|
file_list.append(Path(item['audio_filepath']).stem) |
|
if "feature_file" in item: |
|
has_feat = True |
|
if has_feat: |
|
logging.info("Features already exist in manifest, skipping feature extraction.") |
|
return manifest_filepath |
|
|
|
out_dir.mkdir(parents=True, exist_ok=True) |
|
torch.set_grad_enabled(False) |
|
vad_model = EncDecClassificationModel.from_pretrained("vad_multilingual_marblenet") |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
vad_model = vad_model.to(device) |
|
vad_model.eval() |
|
vad_model.setup_test_data( |
|
test_data_config={ |
|
'batch_size': 1, |
|
'vad_stream': False, |
|
'sample_rate': cfg.sample_rate, |
|
'manifest_filepath': manifest_filepath, |
|
'labels': ['infer',], |
|
'num_workers': cfg.num_workers, |
|
'shuffle': False, |
|
} |
|
) |
|
|
|
logging.info(f"Extracting features on {len(file_list)} audio files...") |
|
with record_fn("feat_extract_loop"): |
|
for i, test_batch in enumerate(tqdm(vad_model.test_dataloader(), total=len(vad_model.test_dataloader()))): |
|
test_batch = [x.to(vad_model.device) for x in test_batch] |
|
with autocast(): |
|
with record_fn("feat_extract_infer"): |
|
processed_signal, processed_signal_length = vad_model.preprocessor( |
|
input_signal=test_batch[0], length=test_batch[1], |
|
) |
|
with record_fn("feat_extract_other"): |
|
processed_signal = processed_signal.squeeze(0)[:, :processed_signal_length] |
|
processed_signal = processed_signal.cpu() |
|
outpath = os.path.join(out_dir, file_list[i] + ".pt") |
|
outpath = str(Path(outpath).absolute()) |
|
torch.save(processed_signal, outpath) |
|
manifest_data[i]["feature_file"] = outpath |
|
del test_batch |
|
|
|
logging.info(f"Features saved at: {out_dir}") |
|
write_manifest(new_manifest_filepath, manifest_data) |
|
return new_manifest_filepath |
|
|
|
|
|
def run_vad_inference(manifest_filepath: str, cfg: DictConfig, record_fn: Callable) -> str: |
|
logging.info("Start VAD inference pipeline...") |
|
vad_model = init_vad_model(cfg.vad_model) |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
vad_model = vad_model.to(device) |
|
vad_model.eval() |
|
|
|
vad_yaml = Path(cfg.vad_config) |
|
if not vad_yaml.is_file(): |
|
raise ValueError(f"VAD config file not found: {cfg.vad_config}") |
|
|
|
with vad_yaml.open("r") as fp: |
|
vad_cfg = yaml.safe_load(fp) |
|
vad_cfg = DictConfig(vad_cfg) |
|
|
|
test_data_config = { |
|
'vad_stream': True, |
|
'manifest_filepath': manifest_filepath, |
|
'labels': ['infer',], |
|
'num_workers': cfg.num_workers, |
|
'shuffle': False, |
|
'window_length_in_sec': vad_cfg.vad.parameters.window_length_in_sec, |
|
'shift_length_in_sec': vad_cfg.vad.parameters.shift_length_in_sec, |
|
} |
|
vad_model.setup_test_data(test_data_config=test_data_config, use_feat=True) |
|
|
|
pred_dir = Path(cfg.output_dir) / Path("vad_frame_pred") |
|
if pred_dir.is_dir(): |
|
logging.info(f"VAD frame-level prediction already exists: {pred_dir}, skipped") |
|
else: |
|
logging.info("Generating VAD frame-level prediction") |
|
pred_dir.mkdir(parents=True) |
|
t0 = time.time() |
|
pred_dir = generate_vad_frame_pred( |
|
vad_model=vad_model, |
|
window_length_in_sec=vad_cfg.vad.parameters.window_length_in_sec, |
|
shift_length_in_sec=vad_cfg.vad.parameters.shift_length_in_sec, |
|
manifest_vad_input=manifest_filepath, |
|
out_dir=str(pred_dir), |
|
use_feat=True, |
|
record_fn=record_fn, |
|
) |
|
t1 = time.time() |
|
logging.info(f"Time elapsed: {t1 - t0: .2f} seconds") |
|
logging.info( |
|
f"Finished generating VAD frame level prediction with window_length_in_sec={vad_cfg.vad.parameters.window_length_in_sec} and shift_length_in_sec={vad_cfg.vad.parameters.shift_length_in_sec}" |
|
) |
|
|
|
frame_length_in_sec = vad_cfg.vad.parameters.shift_length_in_sec |
|
|
|
if vad_cfg.vad.parameters.smoothing: |
|
|
|
|
|
logging.info("Generating predictions with overlapping input segments") |
|
t0 = time.time() |
|
smoothing_pred_dir = generate_overlap_vad_seq( |
|
frame_pred_dir=pred_dir, |
|
smoothing_method=vad_cfg.vad.parameters.smoothing, |
|
overlap=vad_cfg.vad.parameters.overlap, |
|
window_length_in_sec=vad_cfg.vad.parameters.window_length_in_sec, |
|
shift_length_in_sec=vad_cfg.vad.parameters.shift_length_in_sec, |
|
num_workers=cfg.num_workers, |
|
out_dir=vad_cfg.smoothing_out_dir, |
|
) |
|
logging.info( |
|
f"Finish generating predictions with overlapping input segments with smoothing_method={vad_cfg.vad.parameters.smoothing} and overlap={vad_cfg.vad.parameters.overlap}" |
|
) |
|
t1 = time.time() |
|
logging.info(f"Time elapsed: {t1 - t0: .2f} seconds") |
|
pred_dir = smoothing_pred_dir |
|
frame_length_in_sec = 0.01 |
|
|
|
|
|
logging.info(f"Generating segment tables with postprocessing params: {vad_cfg.vad.parameters.postprocessing}") |
|
segment_dir_name = "vad_rttm" |
|
for key, val in vad_cfg.vad.parameters.postprocessing.items(): |
|
if key == "use_rttm": |
|
continue |
|
segment_dir_name = segment_dir_name + "-" + str(key) + str(val) |
|
|
|
segment_dir = Path(cfg.output_dir) / Path(segment_dir_name) |
|
if segment_dir.is_dir(): |
|
logging.info(f"VAD speech segments already exists: {segment_dir}, skipped") |
|
else: |
|
segment_dir.mkdir(parents=True) |
|
t0 = time.time() |
|
vad_cfg.vad.parameters.postprocessing.use_rttm = True |
|
segment_dir = generate_vad_segment_table( |
|
vad_pred_dir=pred_dir, |
|
postprocessing_params=vad_cfg.vad.parameters.postprocessing, |
|
frame_length_in_sec=frame_length_in_sec, |
|
num_workers=cfg.num_workers, |
|
out_dir=segment_dir, |
|
) |
|
t1 = time.time() |
|
logging.info(f"Time elapsed: {t1 - t0: .2f} seconds") |
|
logging.info("Finished generating RTTM files from VAD predictions.") |
|
|
|
rttm_map = {} |
|
for filepath in Path(segment_dir).glob("*.rttm"): |
|
rttm_map[filepath.stem] = str(filepath.absolute()) |
|
|
|
manifest_data = read_manifest(manifest_filepath) |
|
for i in range(len(manifest_data)): |
|
key = Path(manifest_data[i]["audio_filepath"]).stem |
|
manifest_data[i]["rttm_file"] = rttm_map[key] |
|
|
|
new_manifest_filepath = str(Path(cfg.output_dir) / Path(f"temp_manifest_{segment_dir_name}.json")) |
|
write_manifest(new_manifest_filepath, manifest_data) |
|
return new_manifest_filepath |
|
|
|
|
|
def generate_vad_frame_pred( |
|
vad_model: EncDecClassificationModel, |
|
window_length_in_sec: float, |
|
shift_length_in_sec: float, |
|
manifest_vad_input: str, |
|
out_dir: str, |
|
use_feat: bool = False, |
|
record_fn: Callable = None, |
|
) -> str: |
|
""" |
|
Generate VAD frame level prediction and write to out_dir |
|
""" |
|
time_unit = int(window_length_in_sec / shift_length_in_sec) |
|
trunc = int(time_unit / 2) |
|
trunc_l = time_unit - trunc |
|
all_len = 0 |
|
|
|
data = [] |
|
with open(manifest_vad_input, 'r', encoding='utf-8') as fin: |
|
for line in fin.readlines(): |
|
file = json.loads(line)['audio_filepath'].split("/")[-1] |
|
data.append(file.split(".wav")[0]) |
|
logging.info(f"Inference on {len(data)} audio files/json lines!") |
|
|
|
status = get_vad_stream_status(data) |
|
|
|
with record_fn("vad_infer_loop"): |
|
for i, test_batch in enumerate(tqdm(vad_model.test_dataloader(), total=len(vad_model.test_dataloader()))): |
|
test_batch = [x.to(vad_model.device) for x in test_batch] |
|
with autocast(): |
|
with record_fn("vad_infer_model"): |
|
if use_feat: |
|
log_probs = vad_model(processed_signal=test_batch[0], processed_signal_length=test_batch[1]) |
|
else: |
|
log_probs = vad_model(input_signal=test_batch[0], input_signal_length=test_batch[1]) |
|
|
|
with record_fn("vad_infer_other"): |
|
probs = torch.softmax(log_probs, dim=-1) |
|
pred = probs[:, 1] |
|
|
|
if status[i] == 'start': |
|
to_save = pred[:-trunc] |
|
elif status[i] == 'next': |
|
to_save = pred[trunc:-trunc_l] |
|
elif status[i] == 'end': |
|
to_save = pred[trunc_l:] |
|
else: |
|
to_save = pred |
|
|
|
all_len += len(to_save) |
|
outpath = os.path.join(out_dir, data[i] + ".frame") |
|
with open(outpath, "a", encoding='utf-8') as fout: |
|
for f in range(len(to_save)): |
|
fout.write('{0:0.4f}\n'.format(to_save[f])) |
|
|
|
del test_batch |
|
if status[i] == 'end' or status[i] == 'single': |
|
all_len = 0 |
|
return out_dir |
|
|
|
|
|
def init_asr_model(model_path: str) -> ASRModel: |
|
if model_path.endswith('.nemo'): |
|
logging.info(f"Using local ASR model from {model_path}") |
|
asr_model = ASRModel.restore_from(restore_path=model_path) |
|
elif model_path.endswith('.ckpt'): |
|
asr_model = ASRModel.load_from_checkpoint(checkpoint_path=model_path) |
|
else: |
|
logging.info(f"Using NGC ASR model {model_path}") |
|
asr_model = ASRModel.from_pretrained(model_name=model_path) |
|
return asr_model |
|
|
|
|
|
def run_asr_inference(manifest_filepath, cfg, record_fn) -> str: |
|
logging.info("Start ASR inference pipeline...") |
|
asr_model = init_asr_model(cfg.asr_model) |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
asr_model = asr_model.to(device) |
|
asr_model.eval() |
|
|
|
|
|
decode_function = None |
|
if hasattr(asr_model, 'change_decoding_strategy'): |
|
|
|
if hasattr(asr_model, 'joint'): |
|
cfg.rnnt_decoding.fused_batch_size = -1 |
|
cfg.rnnt_decoding.compute_langs = cfg.compute_langs |
|
asr_model.change_decoding_strategy(cfg.rnnt_decoding) |
|
decode_function = asr_model.decoding.rnnt_decoder_predictions_tensor |
|
else: |
|
asr_model.change_decoding_strategy(cfg.ctc_decoding) |
|
decode_function = asr_model.decoding.ctc_decoder_predictions_tensor |
|
else: |
|
raise ValueError(f"Only support CTC or RNNT models that have `change_decoding_strategy()` implemented.") |
|
|
|
|
|
if cfg.output_filename is None: |
|
|
|
if cfg.pred_name_postfix is not None: |
|
cfg.output_filename = cfg.manifest_filepath.replace('.json', f'_{cfg.pred_name_postfix}.json') |
|
else: |
|
tag = f"{cfg.normalize}_{cfg.normalize_type}" |
|
if cfg.use_rttm: |
|
vad_tag = Path(manifest_filepath).stem |
|
vad_tag = vad_tag[len("temp_manifest_vad_rttm_") :] |
|
tag += f"-mask{cfg.feat_mask_val}-{vad_tag}" |
|
cfg.output_filename = cfg.manifest_filepath.replace('.json', f'-{Path(cfg.asr_model).stem}-{tag}.json') |
|
cfg.output_filename = Path(cfg.output_dir) / Path(cfg.output_filename).name |
|
|
|
logging.info("Setting up dataloader for ASR...") |
|
data_config = { |
|
"manifest_filepath": manifest_filepath, |
|
"normalize": cfg.normalize, |
|
"normalize_type": cfg.normalize_type, |
|
"use_rttm": cfg.use_rttm, |
|
"feat_mask_val": cfg.feat_mask_val, |
|
"frame_unit_time_secs": cfg.frame_unit_time_secs, |
|
} |
|
logging.info(f"use_rttm = {cfg.use_rttm}") |
|
if hasattr(asr_model, "tokenizer"): |
|
dataset = feature_to_text_dataset.get_bpe_dataset(config=data_config, tokenizer=asr_model.tokenizer) |
|
else: |
|
data_config["labels"] = asr_model.decoder.vocabulary |
|
dataset = feature_to_text_dataset.get_char_dataset(config=data_config) |
|
|
|
dataloader = torch.utils.data.DataLoader( |
|
dataset=dataset, |
|
batch_size=cfg.batch_size, |
|
collate_fn=dataset._collate_fn, |
|
drop_last=False, |
|
shuffle=False, |
|
num_workers=cfg.get('num_workers', 0), |
|
pin_memory=cfg.get('pin_memory', False), |
|
) |
|
|
|
logging.info("Start transcribing...") |
|
hypotheses = [] |
|
all_hypotheses = [] |
|
t0 = time.time() |
|
with autocast(): |
|
with torch.no_grad(): |
|
with record_fn("asr_infer_loop"): |
|
for test_batch in tqdm(dataloader, desc="Transcribing"): |
|
with record_fn("asr_infer_model"): |
|
outputs = asr_model.forward( |
|
processed_signal=test_batch[0].to(device), |
|
processed_signal_length=test_batch[1].to(device), |
|
) |
|
with record_fn("asr_infer_other"): |
|
logits, logits_len = outputs[0], outputs[1] |
|
|
|
current_hypotheses, all_hyp = decode_function(logits, logits_len, return_hypotheses=False,) |
|
|
|
hypotheses += current_hypotheses |
|
if all_hyp is not None: |
|
all_hypotheses += all_hyp |
|
else: |
|
all_hypotheses += current_hypotheses |
|
|
|
del logits |
|
del test_batch |
|
t1 = time.time() |
|
logging.info(f"Time elapsed: {t1 - t0: .2f} seconds") |
|
|
|
logging.info("Finished transcribing.") |
|
|
|
input_manifest_data = read_manifest(manifest_filepath) |
|
manifest_data = read_manifest(cfg.manifest_filepath) |
|
groundtruth = [] |
|
for i in range(len(manifest_data)): |
|
groundtruth.append(manifest_data[i]["text"]) |
|
manifest_data[i]["pred_text"] = hypotheses[i] |
|
manifest_data[i]["feature_file"] = input_manifest_data[i]["feature_file"] |
|
if "rttm_file" in input_manifest_data[i]: |
|
manifest_data[i]["feature_file"] = input_manifest_data[i]["feature_file"] |
|
|
|
write_manifest(cfg.output_filename, manifest_data) |
|
|
|
if cfg.use_pure_noise: |
|
hypotheses = " ".join(hypotheses) |
|
words = hypotheses.split() |
|
chars = "".join(words) |
|
logging.info("-----------------------------------------") |
|
logging.info(f"Number of hallucinated characters={len(chars)}") |
|
logging.info(f"Number of hallucinated words={len(words)}") |
|
logging.info(f"Concatenated predictions: {hypotheses}") |
|
logging.info("-----------------------------------------") |
|
else: |
|
wer_score = word_error_rate(hypotheses=hypotheses, references=groundtruth) |
|
logging.info("-----------------------------------------") |
|
logging.info(f"WER={wer_score*100:.2f}") |
|
logging.info("-----------------------------------------") |
|
|
|
logging.info(f"ASR output saved at {cfg.output_filename}") |
|
return cfg.output_filename |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|