File size: 11,652 Bytes
7934b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Script to compute metrics for a given audio-to-audio model for a given manifest file for some dataset.
The manifest file must include path to input audio and path to target (ground truth) audio.

Note: This scripts depends on the `process_audio.py` script, and therefore both scripts should be
located in the same directory during execution.

# Arguments

<< All arguments of `process_audio.py` are inherited by this script, so please refer to `process_audio.py`
for full list of arguments >>

    dataset_manifest: Required - path to dataset JSON manifest file (in NeMo format)
    output_dir: Optional - output directory where the processed audio will be saved
    metrics: Optional - list of metrics to evaluate. Defaults to [sdr,estoi]
    sample_rate: Optional - sample rate for loaded audio. Defaults to 16kHz.
    only_score_manifest: Optional - If set, processing will be skipped and it is assumed the processed audio is available in dataset_manifest

# Usage

## To score a dataset with a manifest file that contains the input audio which needs to be processed and target audio

python audio_to_audio_eval.py \
    model_path=null \
    pretrained_model=null \
    dataset_manifest=<Mandatory: path to a dataset manifest file> \
    output_dir=<Optional: Directory where processed audio will be saved> \
    processed_channel_selector=<Optional: list of channels to select from the processed audio file> \
    target_key=<Optional: key for the target audio in the dataset manifest. Default: target_audio_filepath> \
    target_channel_selector=<Optional: list of channels to select from the target audio file> \
    metrics=<Optional: list of metrics to evaluate. Defaults to [sdr,estoi]>
    batch_size=32 \
    amp=True

## To score a manifest file which has been previously processed and contains both processed audio and target audio

python audio_to_audio_eval.py \
    dataset_manifest=<Mandatory: path to a dataset manifest file> \
    processed_key=<Optional: key for the target audio in the dataset manifest. Default: processed_audio_filepath>
    processed_channel_selector=<Optional: list of channels to select from the processed audio file> \
    target_key=<Optional: key for the target audio in the dataset manifest. Default: target_audio_filepath> \
    target_channel_selector=<Optional: list of channels to select from the target audio file> \
    metrics=<Optional: list of metrics to evaluate. Defaults to [sdr,estoi]>
    batch_size=32 \
    amp=True
"""
import json
import os
import tempfile
from dataclasses import dataclass, field, is_dataclass
from typing import List, Optional

import process_audio
import torch
from omegaconf import OmegaConf, open_dict
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
from torchmetrics.audio.sdr import ScaleInvariantSignalDistortionRatio, SignalDistortionRatio
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
from tqdm import tqdm

from nemo.collections.asr.data import audio_to_audio_dataset
from nemo.collections.asr.metrics.audio import AudioMetricWrapper
from nemo.collections.common.parts.preprocessing import manifest
from nemo.core.config import hydra_runner
from nemo.utils import logging


@dataclass
class AudioEvaluationConfig(process_audio.ProcessConfig):
    # Processed audio config
    processed_channel_selector: Optional[List] = None
    processed_key: str = 'processed_audio_filepath'

    # Target audio configs
    target_dataset_dir: Optional[str] = None  # If not provided, defaults to dirname(cfg.dataset_manifest)
    target_channel_selector: Optional[List] = None
    target_key: str = 'target_audio_filepath'

    # Sample rate for audio evaluation
    sample_rate: int = 16000

    # Score an existing manifest without running processing
    only_score_manifest: bool = False

    # Metrics to calculate
    metrics: List[str] = field(default_factory=lambda: ['sdr', 'estoi'])


def get_evaluation_dataloader(config):
    """Prepare a dataloader for evaluation.
    """
    dataset = audio_to_audio_dataset.get_audio_to_target_dataset(config=config)

    return torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=config['batch_size'],
        collate_fn=dataset.collate_fn,
        drop_last=config.get('drop_last', False),
        shuffle=False,
        num_workers=config.get('num_workers', min(config['batch_size'], os.cpu_count() - 1)),
        pin_memory=True,
    )


def get_metrics(cfg: AudioEvaluationConfig):
    """Prepare a dictionary with metrics.
    """
    available_metrics = ['sdr', 'sisdr', 'stoi', 'estoi', 'pesq']

    metrics = dict()
    for name in sorted(set(cfg.metrics)):
        name = name.lower()
        if name == 'sdr':
            metric = AudioMetricWrapper(metric=SignalDistortionRatio())
        elif name == 'sisdr':
            metric = AudioMetricWrapper(metric=ScaleInvariantSignalDistortionRatio())
        elif name == 'stoi':
            metric = AudioMetricWrapper(metric=ShortTimeObjectiveIntelligibility(fs=cfg.sample_rate, extended=False))
        elif name == 'estoi':
            metric = AudioMetricWrapper(metric=ShortTimeObjectiveIntelligibility(fs=cfg.sample_rate, extended=True))
        elif name == 'pesq':
            metric = AudioMetricWrapper(metric=PerceptualEvaluationSpeechQuality(fs=cfg.sample_rate, mode='wb'))
        else:
            raise ValueError(f'Unexpected metric: {name}. Currently available metrics: {available_metrics}')

        metrics[name] = metric

    return metrics


@hydra_runner(config_name="AudioEvaluationConfig", schema=AudioEvaluationConfig)
def main(cfg: AudioEvaluationConfig):
    torch.set_grad_enabled(False)

    if is_dataclass(cfg):
        cfg = OmegaConf.structured(cfg)

    if cfg.audio_dir is not None:
        raise RuntimeError(
            "Evaluation script requires ground truth audio to be passed via a manifest file. "
            "If manifest file is available, submit it via `dataset_manifest` argument."
        )

    if not os.path.exists(cfg.dataset_manifest):
        raise FileNotFoundError(f'The dataset manifest file could not be found at path : {cfg.dataset_manifest}')

    if cfg.target_dataset_dir is None:
        # Assume the target data is available in the same directory as the input data
        cfg.target_dataset_dir = os.path.dirname(cfg.dataset_manifest)
    elif not os.path.isdir(cfg.target_dataset_dir):
        raise FileNotFoundError(f'Target dataset dir could not be found at path : {cfg.target_dataset_dir}')

    # Setup metrics
    metrics = get_metrics(cfg)

    # Processing
    if not cfg.only_score_manifest:
        # Process audio using the configured model and save in the output directory
        process_cfg = process_audio.main(cfg)  # type: ProcessConfig

        # Release GPU memory if it was used during transcription
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        logging.info('Finished processing audio.')
    else:
        # Score the input manifest, no need to run a model
        cfg.output_filename = cfg.dataset_manifest
        process_cfg = cfg

    # Evaluation
    with tempfile.TemporaryDirectory() as tmp_dir:
        # Prepare a temporary manifest with processed audio and target
        temporary_manifest_filepath = os.path.join(tmp_dir, 'manifest.json')

        num_files = 0

        with open(process_cfg.output_filename, 'r') as f_processed, open(
            temporary_manifest_filepath, 'w', encoding='utf-8'
        ) as f_tmp:
            for line_processed in f_processed:
                data_processed = json.loads(line_processed)

                if cfg.processed_key not in data_processed:
                    raise ValueError(
                        f'Processed key {cfg.processed_key} not found in manifest: {process_cfg.output_filename}.'
                    )

                if cfg.target_key not in data_processed:
                    raise ValueError(
                        f'Target key {cfg.target_key} not found in manifest: {process_cfg.output_filename}.'
                    )

                item = {
                    'processed': manifest.get_full_path(
                        audio_file=data_processed[cfg.processed_key], manifest_file=process_cfg.output_filename
                    ),
                    'target': manifest.get_full_path(
                        audio_file=data_processed[cfg.target_key], data_dir=cfg.target_dataset_dir
                    ),
                    'duration': data_processed.get('duration'),
                }

                # Double-check files exist
                for key in ['processed', 'target']:
                    if not os.path.isfile(item[key]):
                        raise ValueError(f'File for key "{key}" not found at: {item[key]}.\nCurrent item: {item}')

                # Warn if we're comparing the same files
                if item['target'] == item['processed']:
                    logging.warning('Using the same file as processed and target: %s', item['target'])

                # Write the entry in the temporary manifest file
                f_tmp.write(json.dumps(item) + '\n')

                num_files += 1

        # Prepare dataloader
        config = {
            'manifest_filepath': temporary_manifest_filepath,
            'sample_rate': cfg.sample_rate,
            'input_key': 'processed',
            'input_channel_selector': cfg.processed_channel_selector,
            'target_key': 'target',
            'target_channel_selector': cfg.target_channel_selector,
            'batch_size': min(cfg.batch_size, num_files),
            'num_workers': cfg.num_workers,
        }
        temporary_dataloader = get_evaluation_dataloader(config)

        # Calculate metrics
        for eval_batch in tqdm(temporary_dataloader, desc='Evaluating'):
            processed_signal, processed_length, target_signal, target_length = eval_batch

            if not torch.equal(processed_length, target_length):
                raise RuntimeError(f'Length mismatch.')

            for name, metric in metrics.items():
                metric.update(preds=processed_signal, target=target_signal, input_length=target_length)

    # Convert to a dictionary with name: value
    metrics_value = {name: metric.compute().item() for name, metric in metrics.items()}

    logging.info('Finished running evaluation.')

    # Show results
    logging.info('Summary\n')
    logging.info('Data')
    logging.info('\tmanifest:           %s', cfg.output_filename)
    logging.info('\ttarget_dataset_dir: %s', cfg.target_dataset_dir)
    logging.info('\tnum_files:          %s', num_files)
    logging.info('Metrics')
    for name, value in metrics_value.items():
        logging.info('\t%10s: \t%6.2f', name, value)

    # Inject the metric name and score into the config, and return the entire config
    with open_dict(cfg):
        cfg.metrics_value = metrics_value

    return cfg


if __name__ == '__main__':
    main()  # noqa pylint: disable=no-value-for-parameter