Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from dataclasses import dataclass, field | |
from fairseq.dataclass import FairseqDataclass | |
from fairseq.scoring import BaseScorer, register_scorer | |
from fairseq.scoring.tokenizer import EvaluationTokenizer | |
class WerScorerConfig(FairseqDataclass): | |
wer_tokenizer: EvaluationTokenizer.ALL_TOKENIZER_TYPES = field( | |
default="none", metadata={"help": "sacreBLEU tokenizer to use for evaluation"} | |
) | |
wer_remove_punct: bool = field( | |
default=False, metadata={"help": "remove punctuation"} | |
) | |
wer_char_level: bool = field( | |
default=False, metadata={"help": "evaluate at character level"} | |
) | |
wer_lowercase: bool = field(default=False, metadata={"help": "lowercasing"}) | |
class WerScorer(BaseScorer): | |
def __init__(self, cfg): | |
super().__init__(cfg) | |
self.reset() | |
try: | |
import editdistance as ed | |
except ImportError: | |
raise ImportError("Please install editdistance to use WER scorer") | |
self.ed = ed | |
self.tokenizer = EvaluationTokenizer( | |
tokenizer_type=self.cfg.wer_tokenizer, | |
lowercase=self.cfg.wer_lowercase, | |
punctuation_removal=self.cfg.wer_remove_punct, | |
character_tokenization=self.cfg.wer_char_level, | |
) | |
def reset(self): | |
self.distance = 0 | |
self.ref_length = 0 | |
def add_string(self, ref, pred): | |
ref_items = self.tokenizer.tokenize(ref).split() | |
pred_items = self.tokenizer.tokenize(pred).split() | |
self.distance += self.ed.eval(ref_items, pred_items) | |
self.ref_length += len(ref_items) | |
def result_string(self): | |
return f"WER: {self.score():.2f}" | |
def score(self): | |
return 100.0 * self.distance / self.ref_length if self.ref_length > 0 else 0 | |