File size: 3,419 Bytes
6fc683c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from fairseq.scoring import BaseScorer, register_scorer
from nltk.metrics.distance import edit_distance
from fairseq.dataclass import FairseqDataclass
import fastwer
from Levenshtein import distance
import string

@register_scorer("cer", dataclass=FairseqDataclass)
class CERScorer(BaseScorer):
    def __init__(self, cfg):
        super().__init__(cfg)
        self.refs = []
        self.preds = []

    def add_string(self, ref, pred):
        self.refs.append(ref)
        self.preds.append(pred)
    
    def score(self):
        return fastwer.score(self.preds, self.refs, char_level=True)

    def result_string(self) -> str:
        return f"CER: {self.score():.2f}"


@register_scorer("wpa", dataclass=FairseqDataclass)
class WPAScorer(BaseScorer):
    def __init__(self, cfg):
        super().__init__(cfg)
        self.refs = []
        self.preds = []
        self.alphabet = string.digits + string.ascii_lowercase

    def filter(self, string):
        string = ''.join([i for i in string if i in self.alphabet])
        return string

    def add_string(self, ref, pred):
        # print(f'[Pred] gt: "{ref}" | pred: "{pred}"')
        self.refs.append(self.filter(ref.lower()))
        self.preds.append(self.filter(pred.lower()))
    
    def score(self):

        length = len(self.refs)
        correct = 0
        for i in range(length):
            if self.refs[i] == self.preds[i]:
                correct += 1
        return round(correct / length * 100, 2)

        # return 100 - fastwer.score(self.preds, self.refs, char_level=False)

    def result_string(self) -> str:
        return f"WPA: {self.score():.2f}"

@register_scorer("acc_ed", dataclass=FairseqDataclass)
class AccEDScorer(BaseScorer):
    def __init__(self, args):
        super(AccEDScorer, self).__init__(args)
        self.n_data = 0
        self.n_correct = 0
        self.ed = 0

    def add_string(self, ref, pred):        
        self.n_data += 1
        if ref == pred:
            self.n_correct += 1
        self.ed += edit_distance(ref, pred) 
        self.ref.append(ref)
        self.pred.append(pred)

    def score(self):
        return self.n_correct / float(self.n_data) * 100, self.ed / float(self.n_data)

    def result_string(self):
        acc, norm_ed = self.score()
        return f"Accuracy: {acc:.3f} Norm ED: {norm_ed:.2f}"

@register_scorer("sroie", dataclass=FairseqDataclass)
class SROIEScorer(BaseScorer):
    def __init__(self, args):
        super(SROIEScorer, self).__init__(args)
        self.n_detected_words = 0
        self.n_gt_words = 0        
        self.n_match_words = 0

    def add_string(self, ref, pred):        
        pred_words = list(pred.split())
        ref_words = list(ref.split())
        self.n_gt_words += len(ref_words)
        self.n_detected_words += len(pred_words)
        for pred_w in pred_words:
            if pred_w in ref_words:
                self.n_match_words += 1
                ref_words.remove(pred_w)
        self.ref.append(ref)
        self.pred.append(pred)

    def score(self):
        prec = self.n_match_words / float(self.n_detected_words) * 100
        recall = self.n_match_words / float(self.n_gt_words) * 100
        f1 = 2 * (prec * recall) / (prec + recall)
        return prec, recall, f1

    def result_string(self):
        prec, recall, f1 = self.score()
        return f"Precision: {prec:.3f} Recall: {recall:.3f} F1: {f1:.3f}"