File size: 1,567 Bytes
256a159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import re

"""
number prediction
metric: accuracy
金钝提取
"""
def compute_jetq(data_dict):
    """
    Compute the Accuracy
    we extract the total amount of cost involved in the crime from the prediction and compare it with the reference
    The prediction is correct if
    the total amount of cost provided in the reference, appears in the prediction.
    """
    score_list, abstentions = [], 0

    for example in data_dict:
        question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
        assert answer.startswith("δΈŠζ–‡ζΆ‰εŠεˆ°ηš„ηŠ―η½ͺ金钝:"), f"answer: {answer}, question: {question}"
        assert answer.endswith("元。"), f"answer: {answer}, question: {question}"
        answer = answer.replace("δΈŠζ–‡ζΆ‰εŠεˆ°ηš„ηŠ―η½ͺ金钝:", "")

        assert "千元" not in answer, f"answer: {answer}, question: {question}"
        assert "δΈ‡" not in answer, f"answer: {answer}, question: {question}"

        # remove "ε…ƒ"
        answer = answer.replace("元。", "")
        answer = float(answer)

        prediction_digits = re.findall(r"\d+\.?\d*", prediction)
        prediction_digits = [float(digit) for digit in prediction_digits]

        if len(prediction_digits) == 0:
            abstentions += 1
        if answer in prediction_digits:
            score_list.append(1)
        else:
            score_list.append(0)


    # compute the accuracy of score_list
    accuracy = sum(score_list) / len(score_list)
    return {"score": accuracy, "abstention_rate": abstentions/len(data_dict)}