|
import argparse |
|
from collections import Counter |
|
|
|
def main(): |
|
|
|
args = parse_args() |
|
|
|
hyp_m2 = open(args.hyp).read().strip().split("\n\n")[args.start:args.end] if args.start is not None and args.end is not None else open(args.hyp).read().strip().split("\n\n") |
|
ref_m2 = open(args.ref).read().strip().split("\n\n")[args.start:args.end] if args.start is not None and args.end is not None else open(args.ref).read().strip().split("\n\n") |
|
|
|
assert len(hyp_m2) == len(ref_m2), print(len(hyp_m2), len(ref_m2)) |
|
|
|
|
|
best_dict = Counter({"tp":0, "fp":0, "fn":0}) |
|
best_cats = {} |
|
|
|
sents = zip(hyp_m2, ref_m2) |
|
for sent_id, sent in enumerate(sents): |
|
|
|
|
|
|
|
src = sent[0].split("\n")[0] |
|
hyp_edits = simplify_edits(sent[0], args.max_answer_num) |
|
ref_edits = simplify_edits(sent[1], args.max_answer_num) |
|
|
|
hyp_dict = process_edits(hyp_edits, args) |
|
ref_dict = process_edits(ref_edits, args) |
|
if args.reference_num is None or len(ref_dict.keys()) == args.reference_num: |
|
|
|
count_dict, cat_dict = evaluate_edits(src, |
|
hyp_dict, ref_dict, best_dict, sent_id, args) |
|
|
|
best_dict += Counter(count_dict) |
|
best_cats = merge_dict(best_cats, cat_dict) |
|
|
|
print_results(best_dict, best_cats, args) |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser( |
|
description="Calculate F-scores for error detection and/or correction.\n" |
|
"Flags let you evaluate at different levels of granularity.", |
|
formatter_class=argparse.RawTextHelpFormatter, |
|
usage="%(prog)s [options] -hyp HYP -ref REF") |
|
parser.add_argument( |
|
"-hyp", |
|
help="A hypothesis M2 file.", |
|
required=True) |
|
parser.add_argument( |
|
"-ref", |
|
help="A reference M2 file.", |
|
required=True) |
|
parser.add_argument( |
|
"--start", |
|
type=int, |
|
default=None |
|
) |
|
parser.add_argument( |
|
"--end", |
|
type=int, |
|
default=None |
|
) |
|
parser.add_argument( |
|
"--max_answer_num", |
|
type=int, |
|
default=None |
|
) |
|
parser.add_argument( |
|
"--reference_num", |
|
type=int, |
|
default=None |
|
) |
|
parser.add_argument( |
|
"-b", |
|
"--beta", |
|
help="Value of beta in F-score. (default: 0.5)", |
|
default=0.5, |
|
type=float) |
|
parser.add_argument( |
|
"-v", |
|
"--verbose", |
|
help="Print verbose output.", |
|
action="store_true") |
|
eval_type = parser.add_mutually_exclusive_group() |
|
eval_type.add_argument( |
|
"-dt", |
|
help="Evaluate Detection in terms of Tokens.", |
|
action="store_true") |
|
eval_type.add_argument( |
|
"-ds", |
|
help="Evaluate Detection in terms of Spans.", |
|
action="store_true") |
|
eval_type.add_argument( |
|
"-cs", |
|
help="Evaluate Correction in terms of Spans. (default)", |
|
action="store_true") |
|
eval_type.add_argument( |
|
"-cse", |
|
help="Evaluate Correction in terms of Spans and Error types.", |
|
action="store_true") |
|
parser.add_argument( |
|
"-single", |
|
help="Only evaluate single token edits; i.e. 0:1, 1:0 or 1:1", |
|
action="store_true") |
|
parser.add_argument( |
|
"-multi", |
|
help="Only evaluate multi token edits; i.e. 2+:n or n:2+", |
|
action="store_true") |
|
parser.add_argument( |
|
"-multi_hyp_avg", |
|
help="When get multiple hypotheses for a sentence, calculate their average F-scores for this sentence.", |
|
action="store_true") |
|
parser.add_argument( |
|
"-multi_hyp_max", |
|
help="When get multiple hypotheses for a sentence, calculate their F-scores and select the max one for this sentence.", |
|
action="store_true") |
|
parser.add_argument( |
|
"-filt", |
|
help="Do not evaluate the specified error types.", |
|
nargs="+", |
|
default=[]) |
|
parser.add_argument( |
|
"-cat", |
|
help="Show error category scores.\n" |
|
"1: Only show operation tier scores; e.g. R.\n" |
|
"2: Only show main tier scores; e.g. NOUN.\n" |
|
"3: Show all category scores; e.g. R:NOUN.", |
|
choices=[1, 2, 3], |
|
type=int) |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
|
|
def simplify_edits(sent, max_answer_num): |
|
out_edits = [] |
|
|
|
edits = sent.split("\n") |
|
|
|
for edit in edits: |
|
|
|
if edit.startswith("A "): |
|
edit = edit[2:].split("|||") |
|
span = edit[0].split() |
|
start = int(span[0]) |
|
end = int(span[1]) |
|
cat = edit[1] |
|
cor = edit[2].replace(" ", "") |
|
coder = int(edit[-1]) |
|
out_edit = [start, end, cat, cor, coder] |
|
out_edits.append(out_edit) |
|
|
|
if max_answer_num is None: |
|
return out_edits |
|
elif max_answer_num == 1: |
|
return [edit for edit in out_edits if edit[-1] == 0] |
|
elif max_answer_num == 2: |
|
return [edit for edit in out_edits if edit[-1] in [0, 1]] |
|
elif max_answer_num == 3: |
|
return [edit for edit in out_edits if edit[-1] in [0, 1, 2]] |
|
|
|
|
|
|
|
|
|
def process_edits(edits, args): |
|
coder_dict = {} |
|
|
|
if not edits: edits = [[-1, -1, "noop", "-NONE-", 0]] |
|
|
|
for edit in edits: |
|
|
|
start = edit[0] |
|
end = edit[1] |
|
cat = edit[2] |
|
cor = edit[3] |
|
coder = edit[4] |
|
|
|
if coder not in coder_dict: coder_dict[coder] = {} |
|
|
|
|
|
|
|
if not args.dt and not args.ds and cat == "UNK": continue |
|
|
|
if args.single and (end-start >= 2 or len(cor.split()) >= 2): continue |
|
|
|
if args.multi and end-start < 2 and len(cor.split()) < 2: continue |
|
|
|
if args.filt and cat in args.filt: continue |
|
|
|
|
|
if args.dt: |
|
|
|
if start == -1: |
|
if (start, start) in coder_dict[coder].keys(): |
|
coder_dict[coder][(start, start)].append(cat) |
|
else: |
|
coder_dict[coder][(start, start)] = [cat] |
|
|
|
elif start == end and start >= 0: |
|
if (start, start+1) in coder_dict[coder].keys(): |
|
coder_dict[coder][(start, start+1)].append(cat) |
|
else: |
|
coder_dict[coder][(start, start+1)] = [cat] |
|
|
|
else: |
|
for tok_id in range(start, end): |
|
if (tok_id, tok_id+1) in coder_dict[coder].keys(): |
|
coder_dict[coder][(tok_id, tok_id+1)].append(cat) |
|
else: |
|
coder_dict[coder][(tok_id, tok_id+1)] = [cat] |
|
|
|
|
|
elif args.ds: |
|
if (start, end) in coder_dict[coder].keys(): |
|
coder_dict[coder][(start, end)].append(cat) |
|
else: |
|
coder_dict[coder][(start, end)] = [cat] |
|
|
|
|
|
else: |
|
|
|
if args.cse: |
|
if (start, end, cat, cor) in coder_dict[coder].keys(): |
|
coder_dict[coder][(start, end, cat, cor)].append(cat) |
|
else: |
|
coder_dict[coder][(start, end, cat, cor)] = [cat] |
|
|
|
else: |
|
if (start, end, cor) in coder_dict[coder].keys(): |
|
coder_dict[coder][(start, end, cor)].append(cat) |
|
else: |
|
coder_dict[coder][(start, end, cor)] = [cat] |
|
return coder_dict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_edits(src, hyp_dict, ref_dict, best, sent_id, args): |
|
|
|
|
|
best_tp, best_fp, best_fn, best_f, best_hyp, best_ref = 0, 0, 0, -1, 0, 0 |
|
best_cat = {} |
|
|
|
if len(ref_dict.keys()) == 1: |
|
ref_id = list(ref_dict.keys())[0] |
|
if len(ref_dict[ref_id].keys()) == 1: |
|
cat = list(ref_dict[ref_id].values())[0][0] |
|
if cat == "NA": |
|
best_dict = {"tp":best_tp, "fp":best_fp, "fn":best_fn} |
|
return best_dict, best_cat |
|
|
|
|
|
for hyp_id in hyp_dict.keys(): |
|
for ref_id in ref_dict.keys(): |
|
|
|
tp, fp, fn, cat_dict = compareEdits(hyp_dict[hyp_id], ref_dict[ref_id]) |
|
|
|
loc_p, loc_r, loc_f = computeFScore(tp, fp, fn, args.beta) |
|
|
|
p, r, f = computeFScore( |
|
tp+best["tp"], fp+best["fp"], fn+best["fn"], args.beta) |
|
|
|
|
|
|
|
|
|
|
|
if (f > best_f) or \ |
|
(f == best_f and tp > best_tp) or \ |
|
(f == best_f and tp == best_tp and fp < best_fp) or \ |
|
(f == best_f and tp == best_tp and fp == best_fp and fn < best_fn): |
|
best_tp, best_fp, best_fn = tp, fp, fn |
|
best_f, best_hyp, best_ref = f, hyp_id, ref_id |
|
best_cat = cat_dict |
|
|
|
if args.verbose: |
|
|
|
hyp_verb = list(sorted(hyp_dict[hyp_id].keys())) |
|
ref_verb = list(sorted(ref_dict[ref_id].keys())) |
|
|
|
if not hyp_verb or hyp_verb[0][0] == -1: hyp_verb = [] |
|
if not ref_verb or ref_verb[0][0] == -1: ref_verb = [] |
|
|
|
print('{:-^40}'.format("")) |
|
print("SENTENCE "+str(sent_id)+src[1:]) |
|
print('{:-^40}'.format("")) |
|
print("SENTENCE "+str(sent_id)+" - HYP "+str(hyp_id)+" - REF "+str(ref_id)) |
|
print("HYPOTHESIS EDITS :", hyp_verb) |
|
print("REFERENCE EDITS :", ref_verb) |
|
print("Local TP/FP/FN :", str(tp), str(fp), str(fn)) |
|
print("Local P/R/F"+str(args.beta)+" :", str(loc_p), str(loc_r), str(loc_f)) |
|
print("Global TP/FP/FN :", str(tp+best["tp"]), str(fp+best["fp"]), str(fn+best["fn"])) |
|
print("Global P/R/F"+str(args.beta)+" :", str(p), str(r), str(f)) |
|
|
|
if args.verbose: |
|
print('{:-^40}'.format("")) |
|
print("^^ HYP "+str(best_hyp)+", REF "+str(best_ref)+" chosen for sentence "+str(sent_id)) |
|
|
|
best_dict = {"tp":best_tp, "fp":best_fp, "fn":best_fn} |
|
return best_dict, best_cat |
|
|
|
|
|
|
|
|
|
|
|
def compareEdits(hyp_edits, ref_edits): |
|
tp = 0 |
|
fp = 0 |
|
fn = 0 |
|
cat_dict = {} |
|
|
|
for h_edit, h_cats in hyp_edits.items(): |
|
|
|
if h_cats[0] == "noop": continue |
|
|
|
if h_edit in ref_edits.keys(): |
|
|
|
for h_cat in ref_edits[h_edit]: |
|
tp += 1 |
|
|
|
if h_cat in cat_dict.keys(): |
|
cat_dict[h_cat][0] += 1 |
|
else: |
|
cat_dict[h_cat] = [1, 0, 0] |
|
|
|
else: |
|
|
|
for h_cat in h_cats: |
|
fp += 1 |
|
|
|
if h_cat in cat_dict.keys(): |
|
cat_dict[h_cat][1] += 1 |
|
else: |
|
cat_dict[h_cat] = [0, 1, 0] |
|
for r_edit, r_cats in ref_edits.items(): |
|
|
|
if r_cats[0] == "noop": continue |
|
|
|
if r_edit not in hyp_edits.keys(): |
|
|
|
for r_cat in r_cats: |
|
fn += 1 |
|
|
|
if r_cat in cat_dict.keys(): |
|
cat_dict[r_cat][2] += 1 |
|
else: |
|
cat_dict[r_cat] = [0, 0, 1] |
|
return tp, fp, fn, cat_dict |
|
|
|
|
|
|
|
|
|
def computeFScore(tp, fp, fn, beta): |
|
p = float(tp)/(tp+fp) if fp else 1.0 |
|
r = float(tp)/(tp+fn) if fn else 1.0 |
|
f = float((1+(beta**2))*p*r)/(((beta**2)*p)+r) if p+r else 0.0 |
|
return round(p, 4), round(r, 4), round(f, 4) |
|
|
|
|
|
|
|
def merge_dict(dict1, dict2): |
|
for cat, stats in dict2.items(): |
|
if cat in dict1.keys(): |
|
dict1[cat] = [x+y for x, y in zip(dict1[cat], stats)] |
|
else: |
|
dict1[cat] = stats |
|
return dict1 |
|
|
|
|
|
|
|
|
|
|
|
def processCategories(cat_dict, setting): |
|
|
|
proc_cat_dict = {} |
|
for cat, cnt in cat_dict.items(): |
|
if cat == "UNK": |
|
proc_cat_dict[cat] = cnt |
|
continue |
|
|
|
if setting == 1: |
|
if cat[0] in proc_cat_dict.keys(): |
|
proc_cat_dict[cat[0]] = [x+y for x, y in zip(proc_cat_dict[cat[0]], cnt)] |
|
else: |
|
proc_cat_dict[cat[0]] = cnt |
|
|
|
elif setting == 2: |
|
if cat[2:] in proc_cat_dict.keys(): |
|
proc_cat_dict[cat[2:]] = [x+y for x, y in zip(proc_cat_dict[cat[2:]], cnt)] |
|
else: |
|
proc_cat_dict[cat[2:]] = cnt |
|
|
|
else: |
|
return cat_dict |
|
return proc_cat_dict |
|
|
|
|
|
|
|
|
|
def print_results(best, best_cats, args): |
|
|
|
if args.dt: title = " Token-Based Detection " |
|
elif args.ds: title = " Span-Based Detection " |
|
elif args.cse: title = " Span-Based Correction + Classification " |
|
else: title = " Span-Based Correction " |
|
|
|
|
|
if args.cat: |
|
best_cats = processCategories(best_cats, args.cat) |
|
print("") |
|
print('{:=^66}'.format(title)) |
|
print("Category".ljust(14), "TP".ljust(8), "FP".ljust(8), "FN".ljust(8), |
|
"P".ljust(8), "R".ljust(8), "F"+str(args.beta)) |
|
for cat, cnts in sorted(best_cats.items()): |
|
cat_p, cat_r, cat_f = computeFScore(cnts[0], cnts[1], cnts[2], args.beta) |
|
print(cat.ljust(14), str(cnts[0]).ljust(8), str(cnts[1]).ljust(8), |
|
str(cnts[2]).ljust(8), str(cat_p).ljust(8), str(cat_r).ljust(8), cat_f) |
|
|
|
|
|
print("") |
|
print('{:=^46}'.format(title)) |
|
print("\t".join(["TP", "FP", "FN", "Prec", "Rec", "F"+str(args.beta)])) |
|
print("\t".join(map(str, [best["tp"], best["fp"], |
|
best["fn"]]+list(computeFScore(best["tp"], best["fp"], best["fn"], args.beta))))) |
|
print('{:=^46}'.format("")) |
|
print("") |
|
|
|
if __name__ == "__main__": |
|
|
|
main() |
|
|