File size: 6,618 Bytes
7934b29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Script to compute the Word or Character Error Rate of a given ASR model for a given manifest file for some dataset.
The manifest file must conform to standard ASR definition - containing `audio_filepath` and `text` as the ground truth.
Note: This script depends on the `transcribe_speech.py` script, and therefore both scripts should be located in the
same directory during execution.
# Arguments
<< All arguments of `transcribe_speech.py` are inherited by this script, so please refer to `transcribe_speech.py`
for full list of arguments >>
dataset_manifest: Required - path to dataset JSON manifest file (in NeMo format)
output_filename: Optional - output filename where the transcriptions will be written.
use_cer: Bool, whether to compute CER or WER
tolerance: Float, minimum WER/CER required to pass some arbitrary tolerance.
only_score_manifest: Bool, when set will skip audio transcription and just calculate WER of provided manifest.
# Usage
## To score a dataset with a manifest file that does not contain previously transcribed `pred_text`.
python speech_to_text_eval.py \
model_path=null \
pretrained_name=null \
dataset_manifest=<Mandatory: Path to an ASR dataset manifest file> \
output_filename=<Optional: Some output filename which will hold the transcribed text as a manifest> \
batch_size=32 \
amp=True \
use_cer=False
## To score a manifest file which has been previously augmented with transcribed text as `pred_text`
This is useful when one uses `transcribe_speech_parallel.py` to transcribe larger datasets, and results are written
to a manifest which has the two keys `text` (for ground truth) and `pred_text` (for model's transcription)
python speech_to_text_eval.py \
dataset_manifest=<Mandatory: Path to an ASR dataset manifest file> \
use_cer=False \
only_score_manifest=True
"""
import json
import os
from dataclasses import dataclass, is_dataclass
from typing import Optional
import torch
import transcribe_speech
from omegaconf import MISSING, OmegaConf, open_dict
from nemo.collections.asr.metrics.wer import word_error_rate
from nemo.collections.asr.parts.utils.transcribe_utils import PunctuationCapitalization
from nemo.core.config import hydra_runner
from nemo.utils import logging
@dataclass
class EvaluationConfig(transcribe_speech.TranscriptionConfig):
dataset_manifest: str = MISSING
output_filename: Optional[str] = "evaluation_transcripts.json"
use_cer: bool = False
tolerance: Optional[float] = None
only_score_manifest: bool = False
separate_punctuation: bool = True
do_lowercase: bool = False
rm_punctuation: bool = False
@hydra_runner(config_name="EvaluationConfig", schema=EvaluationConfig)
def main(cfg: EvaluationConfig):
torch.set_grad_enabled(False)
if is_dataclass(cfg):
cfg = OmegaConf.structured(cfg)
if cfg.audio_dir is not None:
raise RuntimeError(
"Evaluation script requires ground truth labels to be passed via a manifest file. "
"If manifest file is available, submit it via `dataset_manifest` argument."
)
if not os.path.exists(cfg.dataset_manifest):
raise FileNotFoundError(f"The dataset manifest file could not be found at path : {cfg.dataset_manifest}")
if not cfg.only_score_manifest:
# Transcribe speech into an output directory
transcription_cfg = transcribe_speech.main(cfg) # type: EvaluationConfig
# Release GPU memory if it was used during transcription
if torch.cuda.is_available():
torch.cuda.empty_cache()
logging.info("Finished transcribing speech dataset. Computing ASR metrics..")
else:
cfg.output_filename = cfg.dataset_manifest
transcription_cfg = cfg
ground_truth_text = []
predicted_text = []
invalid_manifest = False
with open(transcription_cfg.output_filename, 'r') as f:
for line in f:
data = json.loads(line)
if 'pred_text' not in data:
invalid_manifest = True
break
ground_truth_text.append(data['text'])
predicted_text.append(data['pred_text'])
pc = PunctuationCapitalization('.,?')
if cfg.separate_punctuation:
ground_truth_text = pc.separate_punctuation(ground_truth_text)
if cfg.do_lowercase:
ground_truth_text = pc.do_lowercase(ground_truth_text)
predicted_text = pc.do_lowercase(predicted_text)
if cfg.rm_punctuation:
ground_truth_text = pc.rm_punctuation(ground_truth_text)
predicted_text = pc.rm_punctuation(predicted_text)
# Test for invalid manifest supplied
if invalid_manifest:
raise ValueError(
f"Invalid manifest provided: {transcription_cfg.output_filename} does not "
f"contain value for `pred_text`."
)
# Compute the WER
cer = word_error_rate(hypotheses=predicted_text, references=ground_truth_text, use_cer=True)
wer = word_error_rate(hypotheses=predicted_text, references=ground_truth_text, use_cer=False)
if cfg.use_cer:
metric_name = 'CER'
metric_value = cer
else:
metric_name = 'WER'
metric_value = wer
if cfg.tolerance is not None:
if metric_value > cfg.tolerance:
raise ValueError(f"Got {metric_name} of {metric_value}, which was higher than tolerance={cfg.tolerance}")
logging.info(f'Got {metric_name} of {metric_value}. Tolerance was {cfg.tolerance}')
else:
logging.info(f'Got {metric_name} of {metric_value}')
logging.info(f'Dataset WER/CER ' + str(round(100 * wer, 2)) + "%/" + str(round(100 * cer, 2)) + "%")
# Inject the metric name and score into the config, and return the entire config
with open_dict(cfg):
cfg.metric_name = metric_name
cfg.metric_value = metric_value
return cfg
if __name__ == '__main__':
main() # noqa pylint: disable=no-value-for-parameter
|