File size: 6,178 Bytes
256a159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""Official evaluation script for CAIL-2021.

The code is based partially on CoQA evaluation script.
"""
import json
import sys

from collections import Counter


class CJRCEvaluator:
    def __init__(self, gold_file):
        self.gold_data = CJRCEvaluator.gold_answers_to_dict(gold_file)

    @staticmethod
    def gold_answers_to_dict(gold_file):
        dataset = json.load(open(gold_file, mode="r", encoding="utf-8"))
        gold_dict = {}
        # id_to_domain = {}
        for story in dataset['data']:
            qas = story["paragraphs"][0]["qas"]
            for qa in qas:
                qid = qa['id']
                gold_answers = []
                answers = qa["answers"]
                if len(answers) == 0:
                    gold_answers = ['']
                else:
                    for answer in qa["answers"]:
                        if type(answer) == dict:
                            gold_answers.append(answer["text"])
                        elif type(answer) == list:
                            gold_answers.append("".join([a["text"] for a in answer]))
                if qid in gold_dict:
                    sys.stderr.write("Gold file has duplicate stories: {}".format(qid))
                gold_dict[qid] = gold_answers
        return gold_dict

    @staticmethod
    def preds_to_dict(pred_file):
        preds = json.load(open(pred_file, mode="r", encoding="utf-8"))
        pred_dict = {}
        for pred in preds:
            pred_dict[pred['id']] = "".join(pred['answer'])
        return pred_dict

    @staticmethod
    def normalize_answer(s):
        """Lower text and remove punctuation, storys and extra whitespace."""

        def remove_punc(text):
            return "".join(ch for ch in text if ch.isdigit() or ch.isalpha())

        def lower(text):
            return text.lower()
    
        return remove_punc(lower(s))

    @staticmethod
    def get_tokens(s):
        if not s: return []
        return list(CJRCEvaluator.normalize_answer(s))

    @staticmethod
    def compute_exact(a_gold, a_pred):
        return int(CJRCEvaluator.normalize_answer(a_gold) == CJRCEvaluator.normalize_answer(a_pred))

    @staticmethod
    def compute_f1(a_gold, a_pred):
        gold_toks = CJRCEvaluator.get_tokens(a_gold)
        pred_toks = CJRCEvaluator.get_tokens(a_pred)
        common = Counter(gold_toks) & Counter(pred_toks)
        num_same = sum(common.values())
        if len(gold_toks) == 0 or len(pred_toks) == 0:
            # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
            return int(gold_toks == pred_toks)
        if num_same == 0:
            return 0
        precision = 1.0 * num_same / len(pred_toks)
        recall = 1.0 * num_same / len(gold_toks)
        f1 = (2 * precision * recall) / (precision + recall)
        return f1

    @staticmethod
    def _compute_turn_score(a_gold_list, a_pred):
        f1_sum = 0.0
        em_sum = 0.0
        if len(a_gold_list) > 1:
            for i in range(len(a_gold_list)):
                # exclude the current answer
                gold_answers = a_gold_list[0:i] + a_gold_list[i + 1:]
                em_sum += max(CJRCEvaluator.compute_exact(a, a_pred) for a in gold_answers)
                f1_sum += max(CJRCEvaluator.compute_f1(a, a_pred) for a in gold_answers)
        else:
            em_sum += max(CJRCEvaluator.compute_exact(a, a_pred) for a in a_gold_list)
            f1_sum += max(CJRCEvaluator.compute_f1(a, a_pred) for a in a_gold_list)
        if f1_sum != 1:
            a = 1 + 1
        return {'em': em_sum / max(1, len(a_gold_list)), 'f1': f1_sum / max(1, len(a_gold_list))}

    def compute_turn_score(self, qid, a_pred):
        ''' This is the function what you are probably looking for. a_pred is the answer string your model predicted. '''
        a_gold_list = self.gold_data[qid]
        return CJRCEvaluator._compute_turn_score(a_gold_list, a_pred)

    def get_raw_scores(self, pred_data):
        ''''Returns a dict with score'''
        exact_scores = {}
        f1_scores = {}
        for qid in self.gold_data:
            if qid not in pred_data:
                sys.stderr.write('Missing prediction for {}\n'.format(qid))
                continue
            a_pred = pred_data[qid]
            scores = self.compute_turn_score(qid, a_pred)
            # Take max over all gold answers
            exact_scores[qid] = scores['em']
            f1_scores[qid] = scores['f1']
        return exact_scores, f1_scores

    def get_raw_scores_human(self):
        '''
        Returns a dict with score
        '''
        exact_scores = {}
        f1_scores = {}
        for qid in self.gold_data:
            f1_sum = 0.0
            em_sum = 0.0
            if len(self.gold_data[qid]) > 1:
                for i in range(len(self.gold_data[qid])):
                    # exclude the current answer
                    gold_answers = self.gold_data[qid][0:i] + self.gold_data[qid][i + 1:]
                    em_sum += max(CJRCEvaluator.compute_exact(a, self.gold_data[qid][i]) for a in gold_answers)
                    f1_sum += max(CJRCEvaluator.compute_f1(a, self.gold_data[qid][i]) for a in gold_answers)
            else:
                exit("Gold answers should be multiple: {}={}".format(qid, self.gold_data[qid]))
            exact_scores[qid] = em_sum / len(self.gold_data[qid])
            f1_scores[qid] = f1_sum / len(self.gold_data[qid])
        return exact_scores, f1_scores

    def human_performance(self):
        exact_scores, f1_scores = self.get_raw_scores_human()
        return self.get_total_scores(exact_scores, f1_scores)

    def model_performance(self, pred_data):
        exact_scores, f1_scores = self.get_raw_scores(pred_data)
        return self.get_total_scores(exact_scores, f1_scores)

    def get_total_scores(self, exact_scores, f1_scores):
        em_total, f1_total, turn_count = 0, 0, 0
        scores = {}
        for qid in self.gold_data:
            em_total += exact_scores.get(qid, 0)
            f1_total += f1_scores.get(qid, 0)
            turn_count += 1
        scores["F1"] = round(f1_total / max(1, turn_count) * 100, 1)
        return scores