Tzktz's picture
Upload 7664 files
6fc683c verified
import argparse
import json
import glob
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--log_path",
default=None,
type=str,
required=True,
help="path to evaluation log file.",
)
parser.add_argument(
"--task_name",
default=None,
type=str,
required=True,
help="task name",
)
args = parser.parse_args()
all_res_terms = []
all_best = 0.0
all_final_res = None
for path in sorted(glob.glob(args.log_path), key=lambda x: (len(x), x)):
if args.task_name != "squad" and path.find("squad") != -1:
continue
# print(path)
best = 0.0
final_res = None
lines = open(path).readlines()
for line in lines:
line = line.strip()
if args.task_name == "squad":
json_str = line.split('\t')[0]
else:
json_str = line.split('\t')[1]
# print(json_str)
res = json.loads(json_str)
if args.task_name.lower() == "xnli" or args.task_name.lower() == "pawsx":
if res["valid_avg"]["acc"] > best:
# print(res)
best = res["valid_avg"]["acc"]
final_res = res
elif args.task_name.lower() in ["panx", "udpos"]:
if "dev_avg" not in res or res["dev_avg"]["f1"] > best:
# print(res)
if "dev_avg" in res:
best = res["dev_avg"]["f1"]
final_res = res
elif args.task_name.lower() in ["xquad", "mlqa", "tydiqa", "squad"]:
final_res = res
if args.task_name in ["xnli", "pawsx", "udpos", "panx"]:
if best > all_best:
all_best = best
all_final_res = final_res
elif args.task_name == "tydiqa":
if final_res["dev_en"]["f1"] > all_best:
all_best = final_res["dev_en"]["f1"]
all_final_res = final_res
elif args.task_name == "mlqa":
if final_res["dev_avg"]["f1"] > all_best:
all_best = final_res["dev_avg"]["f1"]
all_final_res = final_res
else:
pass
res_terms = []
try:
if args.task_name == "xnli":
# order = ["en", "fr", "es", "de", "el", "bg", "ru", "tr", "ar", "vi", "th", "zh", "hi", "sw", "ur"]
order = ["en", "ar", "bg", "de", "el", "es", "fr", "hi", "ru", "sw", "th", "tr", "ur", "vi", "zh"]
acc = {}
final_output = ""
for k in final_res:
if k.find("-") == -1:
continue
mode, lang = k.split("-")
if mode == "test":
# final_output += lang + " " + str(final_res[k]["acc"]) + '\t'
acc[lang] = round(final_res[k]["acc"] * 100, 2)
gap_sum = 0.0
for lang in order:
res_terms.append(acc[lang])
gap_sum += acc["en"] - acc[lang]
res_terms.append(round(final_res["test_avg"]["acc"] * 100, 2))
res_terms.append(round(final_res["valid_avg"]["acc"] * 100, 2))
res_terms.append(round(gap_sum / (len(res_terms) - 3), 2))
for term in res_terms:
final_output += str(term) + ","
print(final_output[:-1])
elif args.task_name == "pawsx":
# order = ["de", "en", "es", "fr", "ja", "ko", "zh"]
order = ["en", "de", "es", "fr", "ja", "ko", "zh"]
acc = {}
final_output = ""
for k in final_res:
if k.find("-") == -1:
continue
mode, lang = k.split("-")
if mode == "test":
# final_output += lang + " " + str(final_res[k]["acc"]) + '\t'
acc[lang] = round(final_res[k]["acc"] * 100, 2)
for lang in order:
res_terms.append(acc[lang])
res_terms.append(round(final_res["test_avg"]["acc"] * 100, 2))
res_terms.append(round(final_res["valid_avg"]["acc"] * 100, 2))
for term in res_terms:
final_output += str(term) + ","
print(final_output[:-1])
elif args.task_name == "panx":
# order = ["ar", "he", "vi", "id", "jv", "ms", "tl", "eu", "ml", "ta", "te", "af", "nl", "en", "de", "el",
# "bn", "hi", "mr", "ur", "fa", "fr", "it", "pt", "es", "bg", "ru", "ja", "ka", "ko", "th", "sw",
# "yo", "my", "zh", "kk", "tr", "et", "fi", "hu"]
order = ["en", "af", "ar", "bg", "bn", "de", "el", "es", "et", "eu", "fa", "fi", "fr", "he", "hi", "hu",
"id", "it", "ja", "jv", "ka", "kk", "ko", "ml", "mr", "ms", "my", "nl", "pt", "ru", "sw", "ta",
"te", "th", "tl", "tr", "ur", "vi", "yo", "zh"]
f1 = {}
final_output = ""
for k in final_res:
if k.find("_") == -1:
continue
mode, lang = k.split("_")
if mode == "test":
f1[lang] = round(final_res[k]["f1"] * 100, 1)
for lang in order:
res_terms.append(f1[lang])
res_terms.append(round(final_res["test_avg"]["f1"] * 100, 1))
res_terms.append(round(final_res["dev_avg"]["f1"] * 100, 1))
for term in res_terms:
final_output += str(term) + ","
print(final_output[:-1])
elif args.task_name == "udpos":
order = ["af", "ar", "bg", "de", "el", "en", "es", "et", "eu", "fa", "fi", "fr", "he", "hi", "hu", "id",
"it", "ja", "kk", "ko", "mr", "nl", "pt", "ru", "ta", "te", "th", "tl", "tr", "ur", "vi", "yo",
"zh"]
f1 = {}
final_output = ""
for k in final_res:
if k.find("_") == -1:
continue
mode, lang = k.split("_")
if mode == "test":
f1[lang] = round(final_res[k]["f1"] * 100, 1)
gap_sum = 0.0
for lang in order:
if lang in f1:
res_terms.append(f1[lang])
gap_sum += f1["en"] - f1[lang]
res_terms.append(round(final_res["test_avg"]["f1"] * 100, 1))
res_terms.append(round(final_res["dev_avg"]["f1"] * 100, 1))
res_terms.append(round(gap_sum / (len(res_terms) - 3), 1))
for term in res_terms:
final_output += str(term) + ","
print(final_output[:-1])
elif args.task_name == "mlqa":
# order = ["en", "es", "de", "ar", "hi", "vi", "zh"]
order = ["en", "ar", "de", "es", "hi", "vi", "zh"]
res = {}
final_output = ""
for k in final_res:
if k.find("_") == -1:
continue
mode, lang = k.split("_")
if mode == "test":
res[lang] = (round(final_res[k]["f1"], 1), round(final_res[k]["exact_match"], 1))
gap_sum_f1, gap_sum_em = 0, 0
for lang in order:
res_terms.append((res[lang][0], res[lang][1]))
gap_sum_f1 += res["en"][0] - res[lang][0]
gap_sum_em += res["en"][1] - res[lang][1]
res_terms.append(
(round(final_res["test_avg"]["f1"], 1), round(final_res["test_avg"]["exact_match"], 1)))
res_terms.append(
(round(final_res["dev_avg"]["f1"], 1), round(final_res["dev_avg"]["exact_match"], 1)))
res_terms.append(
(round(gap_sum_f1 / (len(res_terms) - 3), 2), round(gap_sum_em / (len(res_terms) - 3), 2)))
for term in res_terms:
final_output += str(term[0]) + '/' + str(term[1]) + ','
print(final_output[:-1])
elif args.task_name == "xquad":
# order = ["ar", "de", "el", "en", "es", "hi", "ru", "th", "tr", "vi", "zh"]
order = ["en", "ar", "de", "el", "es", "hi", "ru", "th", "tr", "vi", "zh"]
res = {}
final_output = ""
for k in final_res:
if k.find("_") == -1:
continue
mode, lang = k.split("_")
if mode == "dev":
res[lang] = (round(final_res[k]["f1"], 2), round(final_res[k]["exact_match"], 2))
for lang in order:
res_terms.append((res[lang][0], res[lang][1]))
res_terms.append(
(round(final_res["dev_avg"]["f1"], 2), round(final_res["dev_avg"]["exact_match"], 2)))
for term in res_terms:
final_output += str(term[0]) + '/' + str(term[1]) + ','
print(final_output[:-1])
elif args.task_name == "tydiqa":
order = ["en", "ar", "bn", "fi", "id", "ko", "ru", "sw", "te"]
res = {}
final_output = ""
for k in final_res:
if k.find("_") == -1:
continue
mode, lang = k.split("_")
if mode == "dev":
res[lang] = (round(final_res[k]["f1"], 2), round(final_res[k]["exact_match"], 2))
for lang in order:
res_terms.append((res[lang][0], res[lang][1]))
res_terms.append(
(round(final_res["dev_avg"]["f1"], 2), round(final_res["dev_avg"]["exact_match"], 2)))
for term in res_terms:
final_output += str(term[0]) + '/' + str(term[1]) + ','
print(final_output[:-1])
elif args.task_name == "squad":
res_terms.append(
(round(final_res["dev_en"]["f1"], 2), round(final_res["dev_en"]["exact_match"], 2)))
final_output = ""
for term in res_terms:
final_output += str(term[0]) + '/' + str(term[1]) + ','
print(final_output[:-1])
pass
all_res_terms.append(res_terms)
except:
print()
pass
print(order)
print(all_final_res)
final_res_terms = []
for i in range(len(all_res_terms[0])):
s = 0.0 if args.task_name in ['xnli', "panx", "udpos", "pawsx"] else [0.0, 0.0]
for j in range(len(all_res_terms)):
if args.task_name in ['xnli', "panx", "udpos", "pawsx"]:
s += all_res_terms[j][i]
else:
s[0] += all_res_terms[j][i][0]
s[1] += all_res_terms[j][i][1]
final_res_terms.append(
round(s / len(all_res_terms), 2) if args.task_name in ['xnli', "panx", "udpos", "pawsx"] else (
round(s[0] / len(all_res_terms), 2), round(s[1] / len(all_res_terms), 2)))
final_output = '{}-average: '.format(len(all_res_terms))
if args.task_name in ['xnli', "panx", "udpos", "pawsx"]:
for term in final_res_terms:
final_output += str(term) + ","
print(final_output[:-1])
else:
for term in final_res_terms:
final_output += str(term[0]) + '/' + str(term[1]) + ','
print(final_output[:-1])