Spaces:
Running
Running
import json | |
import os | |
import csv | |
#用来计算数据集中不同问题种类对应的pass@k的平均值 | |
input_dir = 'E:\python-testn\pythonProject3\hh_1\evaluate_result' | |
# 获取目录中的所有文件 | |
files = os.listdir(input_dir) | |
with open("cata_result.csv", "w", newline='') as csvfile: | |
writer = csv.writer(csvfile) | |
writer.writerow(["Model", "String", "Math","Array","Sorting","Hash Table","Stack","Search","Matrix"]) | |
for file_name in files: | |
# 构建完整的文件路径 | |
input_file_path = os.path.join(input_dir, file_name) | |
first_underscore_index = file_name.find('_') | |
# 找到最后一个 - 的位置 | |
last_dash_index = file_name.rfind('-') | |
model_name = file_name[first_underscore_index + 1:last_dash_index] | |
print(model_name) | |
with open(input_file_path, "r", encoding="utf-8") as file: | |
data1 = json.load(file) | |
with open("humaneval_with_cata.json","r",encoding="utf-8") as file: | |
data2=json.load(file) | |
sum0=0 | |
count0=0 | |
sum1=0 | |
count1=0 | |
sum2=0 | |
count2=0 | |
sum3=0 | |
count3=0 | |
sum4=0 | |
count4=0 | |
sum5=0 | |
count5=0 | |
sum6=0 | |
count6=0 | |
sum7=0 | |
count7=0 | |
for (item1,item2) in zip(data1["humaneval"]["pass@1"],data2): | |
if "String" in item2["answer"]: | |
index, value = item1 | |
sum0=sum0+value | |
count0=count0+1 | |
if "Math" in item2["answer"]: | |
index, value = item1 | |
sum1=sum1+value | |
count1=count1+1 | |
if "Array" in item2["answer"]: | |
index, value = item1 | |
sum2=sum2+value | |
count2=count2+1 | |
if "Sorting" in item2["answer"]: | |
index, value = item1 | |
sum3=sum3+value | |
count3=count3+1 | |
if "Hash table" in item2["answer"]: | |
index, value = item1 | |
sum4 = sum4 + value | |
count4 = count4 + 1 | |
if "Stack" in item2["answer"]: | |
index, value = item1 | |
sum5=sum5+value | |
count5=count5+1 | |
if "Search" in item2["answer"]: | |
index, value = item1 | |
sum6=sum6+value | |
count6=count6+1 | |
if "Matrix" in item2["answer"]: | |
index, value = item1 | |
sum7=sum7+value | |
count7=count7+1 | |
mean0=round(sum0/count0*100,2) | |
mean1=round(sum1/count1*100,2) | |
mean2=round(sum2/count2*100,2) | |
mean3=round(sum3/count3*100,2) | |
mean4=round(sum4/count4*100,2) | |
mean5=round(sum5/count5*100,2) | |
mean6=round(sum6/count6*100,2) | |
mean7=round(sum7/count7*100,2) | |
print(count0,count1,count2,count3,count4,count5,count6,count7) | |
print(mean0,mean1,mean2,mean3,mean4,mean5,mean6,mean7) | |
with open("cata_result.csv", mode='a', newline='', encoding='utf-8') as file: | |
writer = csv.writer(file) | |
writer.writerow([model_name,mean0,mean1,mean2,mean3,mean4,mean5,mean6,mean7]) | |