File size: 2,577 Bytes
4daa863
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import json
import os
import csv
# 定义文件所在的目录
input_dir = 'E:/python-testn/pythonProject3/hh_2/evaluate_result_mbpp'

# 获取目录中的所有文件
files = os.listdir(input_dir)

with open("cata_result.csv","w", newline='') as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(["Model", "Array", "String","Math","Other"])

for file_name in files:
    # 构建完整的文件路径
    input_file_path = os.path.join(input_dir, file_name)
    first_underscore_index = file_name.find('_')

    # 找到最后一个 - 的位置
    last_dash_index = file_name.rfind('-')
    model_name = file_name[first_underscore_index + 1:last_dash_index]
    print(model_name)
    with open(input_file_path,"r",encoding="utf-8") as file:
        data1=json.load(file)

    with open("mbpp_with_cata.json", "r", encoding="utf-8") as file:
        data2=json.load(file)
    sum0=0
    count0=0
    sum1=0
    count1=0
    sum2=0
    count2=0
    sum3 = 0
    count3 = 0



    for item1 in data1:
        task_id = item1["task_id"]  # 假设 task_id 是 item1 中的一个属性
        value = item1["pass@1"]  # 假设 value 是 item1 中的一个属性

        # 在 data2 中找到与 task_id 相同的对象
        item2 = next((item for item in data2 if item["task_id"] == task_id), None)

        if item2 is not None:
            #按照token个数划分后的评估结果
            if item2["cata"] == "Array":
                index=item2["task_id"]

                sum0=sum0+value
                count0=count0+1
            if item2["cata"] == "String":
                index=item2["task_id"]

                sum1=sum1+value
                count1=count1+1
            if item2["cata"] == "Math":
                index=item2["task_id"]

                sum2=sum2+value
                count2=count2+1
            if item2["cata"] == "Other":
                index=item2["task_id"]

                sum3=sum3+value
                count3=count3+1



    mean0 = round(sum0 / count0 * 100, 2)

    mean1 = round(sum1 / count1 * 100, 2)
    mean2 = round(sum2 / count2 * 100, 2)
    if count3==0:
        mean3=0
    else:
        mean3 = round(sum3 / count3 * 100, 2)

    print("count_result!!")
    print(count0, count1, count2, count3)
    print(mean0, mean1, mean2, mean3)
    with open("cata_result.csv", mode='a', newline='', encoding='utf-8') as file:
        writer = csv.writer(file)
        writer.writerow([model_name, mean0, mean1, mean2, mean3])