|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Script to compute the Word or Character Error Rate of a given ASR model for a given manifest file for some dataset. |
|
The manifest file must conform to standard ASR definition - containing `audio_filepath` and `text` as the ground truth. |
|
|
|
Note: This script depends on the `transcribe_speech.py` script, and therefore both scripts should be located in the |
|
same directory during execution. |
|
|
|
# Arguments |
|
|
|
<< All arguments of `transcribe_speech.py` are inherited by this script, so please refer to `transcribe_speech.py` |
|
for full list of arguments >> |
|
|
|
dataset_manifest: Required - path to dataset JSON manifest file (in NeMo format) |
|
output_filename: Optional - output filename where the transcriptions will be written. |
|
|
|
use_cer: Bool, whether to compute CER or WER |
|
tolerance: Float, minimum WER/CER required to pass some arbitrary tolerance. |
|
|
|
only_score_manifest: Bool, when set will skip audio transcription and just calculate WER of provided manifest. |
|
|
|
# Usage |
|
|
|
## To score a dataset with a manifest file that does not contain previously transcribed `pred_text`. |
|
|
|
python speech_to_text_eval.py \ |
|
model_path=null \ |
|
pretrained_name=null \ |
|
dataset_manifest=<Mandatory: Path to an ASR dataset manifest file> \ |
|
output_filename=<Optional: Some output filename which will hold the transcribed text as a manifest> \ |
|
batch_size=32 \ |
|
amp=True \ |
|
use_cer=False |
|
|
|
## To score a manifest file which has been previously augmented with transcribed text as `pred_text` |
|
This is useful when one uses `transcribe_speech_parallel.py` to transcribe larger datasets, and results are written |
|
to a manifest which has the two keys `text` (for ground truth) and `pred_text` (for model's transcription) |
|
|
|
python speech_to_text_eval.py \ |
|
dataset_manifest=<Mandatory: Path to an ASR dataset manifest file> \ |
|
use_cer=False \ |
|
only_score_manifest=True |
|
|
|
""" |
|
|
|
import json |
|
import os |
|
from dataclasses import dataclass, is_dataclass |
|
from typing import Optional |
|
|
|
import torch |
|
import transcribe_speech |
|
from omegaconf import MISSING, OmegaConf, open_dict |
|
|
|
from nemo.collections.asr.metrics.wer import word_error_rate |
|
from nemo.collections.asr.parts.utils.transcribe_utils import PunctuationCapitalization |
|
from nemo.core.config import hydra_runner |
|
from nemo.utils import logging |
|
|
|
|
|
@dataclass |
|
class EvaluationConfig(transcribe_speech.TranscriptionConfig): |
|
dataset_manifest: str = MISSING |
|
output_filename: Optional[str] = "evaluation_transcripts.json" |
|
|
|
use_cer: bool = False |
|
tolerance: Optional[float] = None |
|
|
|
only_score_manifest: bool = False |
|
|
|
separate_punctuation: bool = True |
|
do_lowercase: bool = False |
|
rm_punctuation: bool = False |
|
|
|
|
|
@hydra_runner(config_name="EvaluationConfig", schema=EvaluationConfig) |
|
def main(cfg: EvaluationConfig): |
|
torch.set_grad_enabled(False) |
|
|
|
if is_dataclass(cfg): |
|
cfg = OmegaConf.structured(cfg) |
|
|
|
if cfg.audio_dir is not None: |
|
raise RuntimeError( |
|
"Evaluation script requires ground truth labels to be passed via a manifest file. " |
|
"If manifest file is available, submit it via `dataset_manifest` argument." |
|
) |
|
|
|
if not os.path.exists(cfg.dataset_manifest): |
|
raise FileNotFoundError(f"The dataset manifest file could not be found at path : {cfg.dataset_manifest}") |
|
|
|
if not cfg.only_score_manifest: |
|
|
|
transcription_cfg = transcribe_speech.main(cfg) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
logging.info("Finished transcribing speech dataset. Computing ASR metrics..") |
|
|
|
else: |
|
cfg.output_filename = cfg.dataset_manifest |
|
transcription_cfg = cfg |
|
|
|
ground_truth_text = [] |
|
predicted_text = [] |
|
invalid_manifest = False |
|
with open(transcription_cfg.output_filename, 'r') as f: |
|
for line in f: |
|
data = json.loads(line) |
|
|
|
if 'pred_text' not in data: |
|
invalid_manifest = True |
|
break |
|
|
|
ground_truth_text.append(data['text']) |
|
|
|
predicted_text.append(data['pred_text']) |
|
|
|
pc = PunctuationCapitalization('.,?') |
|
if cfg.separate_punctuation: |
|
ground_truth_text = pc.separate_punctuation(ground_truth_text) |
|
if cfg.do_lowercase: |
|
ground_truth_text = pc.do_lowercase(ground_truth_text) |
|
predicted_text = pc.do_lowercase(predicted_text) |
|
if cfg.rm_punctuation: |
|
ground_truth_text = pc.rm_punctuation(ground_truth_text) |
|
predicted_text = pc.rm_punctuation(predicted_text) |
|
|
|
|
|
if invalid_manifest: |
|
raise ValueError( |
|
f"Invalid manifest provided: {transcription_cfg.output_filename} does not " |
|
f"contain value for `pred_text`." |
|
) |
|
|
|
|
|
cer = word_error_rate(hypotheses=predicted_text, references=ground_truth_text, use_cer=True) |
|
wer = word_error_rate(hypotheses=predicted_text, references=ground_truth_text, use_cer=False) |
|
|
|
if cfg.use_cer: |
|
metric_name = 'CER' |
|
metric_value = cer |
|
else: |
|
metric_name = 'WER' |
|
metric_value = wer |
|
|
|
if cfg.tolerance is not None: |
|
if metric_value > cfg.tolerance: |
|
raise ValueError(f"Got {metric_name} of {metric_value}, which was higher than tolerance={cfg.tolerance}") |
|
|
|
logging.info(f'Got {metric_name} of {metric_value}. Tolerance was {cfg.tolerance}') |
|
else: |
|
logging.info(f'Got {metric_name} of {metric_value}') |
|
|
|
logging.info(f'Dataset WER/CER ' + str(round(100 * wer, 2)) + "%/" + str(round(100 * cer, 2)) + "%") |
|
|
|
|
|
with open_dict(cfg): |
|
cfg.metric_name = metric_name |
|
cfg.metric_value = metric_value |
|
|
|
return cfg |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|