|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Script to compute metrics for a given audio-to-audio model for a given manifest file for some dataset. |
|
The manifest file must include path to input audio and path to target (ground truth) audio. |
|
|
|
Note: This scripts depends on the `process_audio.py` script, and therefore both scripts should be |
|
located in the same directory during execution. |
|
|
|
# Arguments |
|
|
|
<< All arguments of `process_audio.py` are inherited by this script, so please refer to `process_audio.py` |
|
for full list of arguments >> |
|
|
|
dataset_manifest: Required - path to dataset JSON manifest file (in NeMo format) |
|
output_dir: Optional - output directory where the processed audio will be saved |
|
metrics: Optional - list of metrics to evaluate. Defaults to [sdr,estoi] |
|
sample_rate: Optional - sample rate for loaded audio. Defaults to 16kHz. |
|
only_score_manifest: Optional - If set, processing will be skipped and it is assumed the processed audio is available in dataset_manifest |
|
|
|
# Usage |
|
|
|
## To score a dataset with a manifest file that contains the input audio which needs to be processed and target audio |
|
|
|
python audio_to_audio_eval.py \ |
|
model_path=null \ |
|
pretrained_model=null \ |
|
dataset_manifest=<Mandatory: path to a dataset manifest file> \ |
|
output_dir=<Optional: Directory where processed audio will be saved> \ |
|
processed_channel_selector=<Optional: list of channels to select from the processed audio file> \ |
|
target_key=<Optional: key for the target audio in the dataset manifest. Default: target_audio_filepath> \ |
|
target_channel_selector=<Optional: list of channels to select from the target audio file> \ |
|
metrics=<Optional: list of metrics to evaluate. Defaults to [sdr,estoi]> |
|
batch_size=32 \ |
|
amp=True |
|
|
|
## To score a manifest file which has been previously processed and contains both processed audio and target audio |
|
|
|
python audio_to_audio_eval.py \ |
|
dataset_manifest=<Mandatory: path to a dataset manifest file> \ |
|
processed_key=<Optional: key for the target audio in the dataset manifest. Default: processed_audio_filepath> |
|
processed_channel_selector=<Optional: list of channels to select from the processed audio file> \ |
|
target_key=<Optional: key for the target audio in the dataset manifest. Default: target_audio_filepath> \ |
|
target_channel_selector=<Optional: list of channels to select from the target audio file> \ |
|
metrics=<Optional: list of metrics to evaluate. Defaults to [sdr,estoi]> |
|
batch_size=32 \ |
|
amp=True |
|
""" |
|
import json |
|
import os |
|
import tempfile |
|
from dataclasses import dataclass, field, is_dataclass |
|
from typing import List, Optional |
|
|
|
import process_audio |
|
import torch |
|
from omegaconf import OmegaConf, open_dict |
|
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality |
|
from torchmetrics.audio.sdr import ScaleInvariantSignalDistortionRatio, SignalDistortionRatio |
|
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility |
|
from tqdm import tqdm |
|
|
|
from nemo.collections.asr.data import audio_to_audio_dataset |
|
from nemo.collections.asr.metrics.audio import AudioMetricWrapper |
|
from nemo.collections.common.parts.preprocessing import manifest |
|
from nemo.core.config import hydra_runner |
|
from nemo.utils import logging |
|
|
|
|
|
@dataclass |
|
class AudioEvaluationConfig(process_audio.ProcessConfig): |
|
|
|
processed_channel_selector: Optional[List] = None |
|
processed_key: str = 'processed_audio_filepath' |
|
|
|
|
|
target_dataset_dir: Optional[str] = None |
|
target_channel_selector: Optional[List] = None |
|
target_key: str = 'target_audio_filepath' |
|
|
|
|
|
sample_rate: int = 16000 |
|
|
|
|
|
only_score_manifest: bool = False |
|
|
|
|
|
metrics: List[str] = field(default_factory=lambda: ['sdr', 'estoi']) |
|
|
|
|
|
def get_evaluation_dataloader(config): |
|
"""Prepare a dataloader for evaluation. |
|
""" |
|
dataset = audio_to_audio_dataset.get_audio_to_target_dataset(config=config) |
|
|
|
return torch.utils.data.DataLoader( |
|
dataset=dataset, |
|
batch_size=config['batch_size'], |
|
collate_fn=dataset.collate_fn, |
|
drop_last=config.get('drop_last', False), |
|
shuffle=False, |
|
num_workers=config.get('num_workers', min(config['batch_size'], os.cpu_count() - 1)), |
|
pin_memory=True, |
|
) |
|
|
|
|
|
def get_metrics(cfg: AudioEvaluationConfig): |
|
"""Prepare a dictionary with metrics. |
|
""" |
|
available_metrics = ['sdr', 'sisdr', 'stoi', 'estoi', 'pesq'] |
|
|
|
metrics = dict() |
|
for name in sorted(set(cfg.metrics)): |
|
name = name.lower() |
|
if name == 'sdr': |
|
metric = AudioMetricWrapper(metric=SignalDistortionRatio()) |
|
elif name == 'sisdr': |
|
metric = AudioMetricWrapper(metric=ScaleInvariantSignalDistortionRatio()) |
|
elif name == 'stoi': |
|
metric = AudioMetricWrapper(metric=ShortTimeObjectiveIntelligibility(fs=cfg.sample_rate, extended=False)) |
|
elif name == 'estoi': |
|
metric = AudioMetricWrapper(metric=ShortTimeObjectiveIntelligibility(fs=cfg.sample_rate, extended=True)) |
|
elif name == 'pesq': |
|
metric = AudioMetricWrapper(metric=PerceptualEvaluationSpeechQuality(fs=cfg.sample_rate, mode='wb')) |
|
else: |
|
raise ValueError(f'Unexpected metric: {name}. Currently available metrics: {available_metrics}') |
|
|
|
metrics[name] = metric |
|
|
|
return metrics |
|
|
|
|
|
@hydra_runner(config_name="AudioEvaluationConfig", schema=AudioEvaluationConfig) |
|
def main(cfg: AudioEvaluationConfig): |
|
torch.set_grad_enabled(False) |
|
|
|
if is_dataclass(cfg): |
|
cfg = OmegaConf.structured(cfg) |
|
|
|
if cfg.audio_dir is not None: |
|
raise RuntimeError( |
|
"Evaluation script requires ground truth audio to be passed via a manifest file. " |
|
"If manifest file is available, submit it via `dataset_manifest` argument." |
|
) |
|
|
|
if not os.path.exists(cfg.dataset_manifest): |
|
raise FileNotFoundError(f'The dataset manifest file could not be found at path : {cfg.dataset_manifest}') |
|
|
|
if cfg.target_dataset_dir is None: |
|
|
|
cfg.target_dataset_dir = os.path.dirname(cfg.dataset_manifest) |
|
elif not os.path.isdir(cfg.target_dataset_dir): |
|
raise FileNotFoundError(f'Target dataset dir could not be found at path : {cfg.target_dataset_dir}') |
|
|
|
|
|
metrics = get_metrics(cfg) |
|
|
|
|
|
if not cfg.only_score_manifest: |
|
|
|
process_cfg = process_audio.main(cfg) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
logging.info('Finished processing audio.') |
|
else: |
|
|
|
cfg.output_filename = cfg.dataset_manifest |
|
process_cfg = cfg |
|
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
|
|
temporary_manifest_filepath = os.path.join(tmp_dir, 'manifest.json') |
|
|
|
num_files = 0 |
|
|
|
with open(process_cfg.output_filename, 'r') as f_processed, open( |
|
temporary_manifest_filepath, 'w', encoding='utf-8' |
|
) as f_tmp: |
|
for line_processed in f_processed: |
|
data_processed = json.loads(line_processed) |
|
|
|
if cfg.processed_key not in data_processed: |
|
raise ValueError( |
|
f'Processed key {cfg.processed_key} not found in manifest: {process_cfg.output_filename}.' |
|
) |
|
|
|
if cfg.target_key not in data_processed: |
|
raise ValueError( |
|
f'Target key {cfg.target_key} not found in manifest: {process_cfg.output_filename}.' |
|
) |
|
|
|
item = { |
|
'processed': manifest.get_full_path( |
|
audio_file=data_processed[cfg.processed_key], manifest_file=process_cfg.output_filename |
|
), |
|
'target': manifest.get_full_path( |
|
audio_file=data_processed[cfg.target_key], data_dir=cfg.target_dataset_dir |
|
), |
|
'duration': data_processed.get('duration'), |
|
} |
|
|
|
|
|
for key in ['processed', 'target']: |
|
if not os.path.isfile(item[key]): |
|
raise ValueError(f'File for key "{key}" not found at: {item[key]}.\nCurrent item: {item}') |
|
|
|
|
|
if item['target'] == item['processed']: |
|
logging.warning('Using the same file as processed and target: %s', item['target']) |
|
|
|
|
|
f_tmp.write(json.dumps(item) + '\n') |
|
|
|
num_files += 1 |
|
|
|
|
|
config = { |
|
'manifest_filepath': temporary_manifest_filepath, |
|
'sample_rate': cfg.sample_rate, |
|
'input_key': 'processed', |
|
'input_channel_selector': cfg.processed_channel_selector, |
|
'target_key': 'target', |
|
'target_channel_selector': cfg.target_channel_selector, |
|
'batch_size': min(cfg.batch_size, num_files), |
|
'num_workers': cfg.num_workers, |
|
} |
|
temporary_dataloader = get_evaluation_dataloader(config) |
|
|
|
|
|
for eval_batch in tqdm(temporary_dataloader, desc='Evaluating'): |
|
processed_signal, processed_length, target_signal, target_length = eval_batch |
|
|
|
if not torch.equal(processed_length, target_length): |
|
raise RuntimeError(f'Length mismatch.') |
|
|
|
for name, metric in metrics.items(): |
|
metric.update(preds=processed_signal, target=target_signal, input_length=target_length) |
|
|
|
|
|
metrics_value = {name: metric.compute().item() for name, metric in metrics.items()} |
|
|
|
logging.info('Finished running evaluation.') |
|
|
|
|
|
logging.info('Summary\n') |
|
logging.info('Data') |
|
logging.info('\tmanifest: %s', cfg.output_filename) |
|
logging.info('\ttarget_dataset_dir: %s', cfg.target_dataset_dir) |
|
logging.info('\tnum_files: %s', num_files) |
|
logging.info('Metrics') |
|
for name, value in metrics_value.items(): |
|
logging.info('\t%10s: \t%6.2f', name, value) |
|
|
|
|
|
with open_dict(cfg): |
|
cfg.metrics_value = metrics_value |
|
|
|
return cfg |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|