TwT-6's picture
Upload 2667 files
256a159 verified
import argparse
from collections import Counter
def main():
# Parse command line args
args = parse_args()
# Open hypothesis and reference m2 files and split into chunks
hyp_m2 = open(args.hyp).read().strip().split("\n\n")[args.start:args.end] if args.start is not None and args.end is not None else open(args.hyp).read().strip().split("\n\n")
ref_m2 = open(args.ref).read().strip().split("\n\n")[args.start:args.end] if args.start is not None and args.end is not None else open(args.ref).read().strip().split("\n\n")
# Make sure they have the same number of sentences
assert len(hyp_m2) == len(ref_m2), print(len(hyp_m2), len(ref_m2))
# Store global corpus level best counts here
best_dict = Counter({"tp":0, "fp":0, "fn":0})
best_cats = {}
# Process each sentence
sents = zip(hyp_m2, ref_m2)
for sent_id, sent in enumerate(sents):
# Simplify the edits into lists of lists
# if "A1" in sent[0] or "A1" in sent[1] or sent_id in sent_id_cons:
# sent_id_cons.append(sent_id)
src = sent[0].split("\n")[0]
hyp_edits = simplify_edits(sent[0], args.max_answer_num)
ref_edits = simplify_edits(sent[1], args.max_answer_num)
# Process the edits for detection/correction based on args
hyp_dict = process_edits(hyp_edits, args)
ref_dict = process_edits(ref_edits, args)
if args.reference_num is None or len(ref_dict.keys()) == args.reference_num:
# Evaluate edits and get best TP, FP, FN hyp+ref combo.
count_dict, cat_dict = evaluate_edits(src,
hyp_dict, ref_dict, best_dict, sent_id, args)
# Merge these dicts with best_dict and best_cats
best_dict += Counter(count_dict)
best_cats = merge_dict(best_cats, cat_dict)
# Print results
print_results(best_dict, best_cats, args)
# Parse command line args
def parse_args():
parser = argparse.ArgumentParser(
description="Calculate F-scores for error detection and/or correction.\n"
"Flags let you evaluate at different levels of granularity.",
formatter_class=argparse.RawTextHelpFormatter,
usage="%(prog)s [options] -hyp HYP -ref REF")
parser.add_argument(
"-hyp",
help="A hypothesis M2 file.",
required=True)
parser.add_argument(
"-ref",
help="A reference M2 file.",
required=True)
parser.add_argument(
"--start",
type=int,
default=None
)
parser.add_argument(
"--end",
type=int,
default=None
)
parser.add_argument(
"--max_answer_num",
type=int,
default=None
)
parser.add_argument(
"--reference_num",
type=int,
default=None
)
parser.add_argument(
"-b",
"--beta",
help="Value of beta in F-score. (default: 0.5)",
default=0.5,
type=float)
parser.add_argument(
"-v",
"--verbose",
help="Print verbose output.",
action="store_true")
eval_type = parser.add_mutually_exclusive_group()
eval_type.add_argument(
"-dt",
help="Evaluate Detection in terms of Tokens.",
action="store_true")
eval_type.add_argument(
"-ds",
help="Evaluate Detection in terms of Spans.",
action="store_true")
eval_type.add_argument(
"-cs",
help="Evaluate Correction in terms of Spans. (default)",
action="store_true")
eval_type.add_argument(
"-cse",
help="Evaluate Correction in terms of Spans and Error types.",
action="store_true")
parser.add_argument(
"-single",
help="Only evaluate single token edits; i.e. 0:1, 1:0 or 1:1",
action="store_true")
parser.add_argument(
"-multi",
help="Only evaluate multi token edits; i.e. 2+:n or n:2+",
action="store_true")
parser.add_argument(
"-multi_hyp_avg",
help="When get multiple hypotheses for a sentence, calculate their average F-scores for this sentence.",
action="store_true") # For IAA calculation
parser.add_argument(
"-multi_hyp_max",
help="When get multiple hypotheses for a sentence, calculate their F-scores and select the max one for this sentence.",
action="store_true") # For multiple hypotheses system evaluation
parser.add_argument(
"-filt",
help="Do not evaluate the specified error types.",
nargs="+",
default=[])
parser.add_argument(
"-cat",
help="Show error category scores.\n"
"1: Only show operation tier scores; e.g. R.\n"
"2: Only show main tier scores; e.g. NOUN.\n"
"3: Show all category scores; e.g. R:NOUN.",
choices=[1, 2, 3],
type=int)
args = parser.parse_args()
return args
# Input: An m2 format sentence with edits.
# Output: A list of lists. Each edit: [start, end, cat, cor, coder]
def simplify_edits(sent, max_answer_num):
out_edits = []
# Get the edit lines from an m2 block.
edits = sent.split("\n")
# Loop through the edits
for edit in edits:
# Preprocessing
if edit.startswith("A "):
edit = edit[2:].split("|||") # Ignore "A " then split.
span = edit[0].split()
start = int(span[0])
end = int(span[1])
cat = edit[1]
cor = edit[2].replace(" ", "")
coder = int(edit[-1])
out_edit = [start, end, cat, cor, coder]
out_edits.append(out_edit)
# return [edit for edit in out_edits if edit[-1] in [0,1]]
if max_answer_num is None:
return out_edits
elif max_answer_num == 1:
return [edit for edit in out_edits if edit[-1] == 0]
elif max_answer_num == 2:
return [edit for edit in out_edits if edit[-1] in [0, 1]]
elif max_answer_num == 3:
return [edit for edit in out_edits if edit[-1] in [0, 1, 2]]
# Input 1: A list of edits. Each edit: [start, end, cat, cor, coder]
# Input 2: Command line args
# Output: A dict; key is coder, value is edit dict.
def process_edits(edits, args):
coder_dict = {}
# Add an explicit noop edit if there are no edits.
if not edits: edits = [[-1, -1, "noop", "-NONE-", 0]]
# Loop through the edits
for edit in edits:
# Name the edit elements for clarity
start = edit[0]
end = edit[1]
cat = edit[2]
cor = edit[3]
coder = edit[4]
# Add the coder to the coder_dict if necessary
if coder not in coder_dict: coder_dict[coder] = {}
# Optionally apply filters based on args
# 1. UNK type edits are only useful for detection, not correction.
if not args.dt and not args.ds and cat == "UNK": continue
# 2. Only evaluate single token edits; i.e. 0:1, 1:0 or 1:1
if args.single and (end-start >= 2 or len(cor.split()) >= 2): continue
# 3. Only evaluate multi token edits; i.e. 2+:n or n:2+
if args.multi and end-start < 2 and len(cor.split()) < 2: continue
# 4. If there is a filter, ignore the specified error types
if args.filt and cat in args.filt: continue
# Token Based Detection
if args.dt:
# Preserve noop edits.
if start == -1:
if (start, start) in coder_dict[coder].keys():
coder_dict[coder][(start, start)].append(cat)
else:
coder_dict[coder][(start, start)] = [cat]
# Insertions defined as affecting the token on the right
elif start == end and start >= 0:
if (start, start+1) in coder_dict[coder].keys():
coder_dict[coder][(start, start+1)].append(cat)
else:
coder_dict[coder][(start, start+1)] = [cat]
# Edit spans are split for each token in the range.
else:
for tok_id in range(start, end):
if (tok_id, tok_id+1) in coder_dict[coder].keys():
coder_dict[coder][(tok_id, tok_id+1)].append(cat)
else:
coder_dict[coder][(tok_id, tok_id+1)] = [cat]
# Span Based Detection
elif args.ds:
if (start, end) in coder_dict[coder].keys():
coder_dict[coder][(start, end)].append(cat)
else:
coder_dict[coder][(start, end)] = [cat]
# Span Based Correction
else:
# With error type classification
if args.cse:
if (start, end, cat, cor) in coder_dict[coder].keys():
coder_dict[coder][(start, end, cat, cor)].append(cat)
else:
coder_dict[coder][(start, end, cat, cor)] = [cat]
# Without error type classification
else:
if (start, end, cor) in coder_dict[coder].keys():
coder_dict[coder][(start, end, cor)].append(cat)
else:
coder_dict[coder][(start, end, cor)] = [cat]
return coder_dict
# Input 1: A hyp dict; key is coder_id, value is dict of processed hyp edits.
# Input 2: A ref dict; key is coder_id, value is dict of processed ref edits.
# Input 3: A dictionary of the best corpus level TP, FP and FN counts so far.
# Input 4: Sentence ID (for verbose output only)
# Input 5: Command line args
# Output 1: A dict of the best corpus level TP, FP and FN for the input sentence.
# Output 2: The corresponding error type dict for the above dict.
def evaluate_edits(src, hyp_dict, ref_dict, best, sent_id, args):
# Store the best sentence level scores and hyp+ref combination IDs
# best_f is initialised as -1 cause 0 is a valid result.
best_tp, best_fp, best_fn, best_f, best_hyp, best_ref = 0, 0, 0, -1, 0, 0
best_cat = {}
# skip not annotatable sentence
if len(ref_dict.keys()) == 1:
ref_id = list(ref_dict.keys())[0]
if len(ref_dict[ref_id].keys()) == 1:
cat = list(ref_dict[ref_id].values())[0][0]
if cat == "NA":
best_dict = {"tp":best_tp, "fp":best_fp, "fn":best_fn}
return best_dict, best_cat
# Compare each hyp and ref combination
for hyp_id in hyp_dict.keys():
for ref_id in ref_dict.keys():
# Get the local counts for the current combination.
tp, fp, fn, cat_dict = compareEdits(hyp_dict[hyp_id], ref_dict[ref_id])
# Compute the local sentence scores (for verbose output only)
loc_p, loc_r, loc_f = computeFScore(tp, fp, fn, args.beta)
# Compute the global sentence scores
p, r, f = computeFScore(
tp+best["tp"], fp+best["fp"], fn+best["fn"], args.beta)
# Save the scores if they are better in terms of:
# 1. Higher F-score
# 2. Same F-score, higher TP
# 3. Same F-score and TP, lower FP
# 4. Same F-score, TP and FP, lower FN
if (f > best_f) or \
(f == best_f and tp > best_tp) or \
(f == best_f and tp == best_tp and fp < best_fp) or \
(f == best_f and tp == best_tp and fp == best_fp and fn < best_fn):
best_tp, best_fp, best_fn = tp, fp, fn
best_f, best_hyp, best_ref = f, hyp_id, ref_id
best_cat = cat_dict
# Verbose output
if args.verbose:
# Prepare verbose output edits.
hyp_verb = list(sorted(hyp_dict[hyp_id].keys()))
ref_verb = list(sorted(ref_dict[ref_id].keys()))
# Ignore noop edits
if not hyp_verb or hyp_verb[0][0] == -1: hyp_verb = []
if not ref_verb or ref_verb[0][0] == -1: ref_verb = []
# Print verbose info
print('{:-^40}'.format(""))
print("SENTENCE "+str(sent_id)+src[1:])
print('{:-^40}'.format(""))
print("SENTENCE "+str(sent_id)+" - HYP "+str(hyp_id)+" - REF "+str(ref_id))
print("HYPOTHESIS EDITS :", hyp_verb)
print("REFERENCE EDITS :", ref_verb)
print("Local TP/FP/FN :", str(tp), str(fp), str(fn))
print("Local P/R/F"+str(args.beta)+" :", str(loc_p), str(loc_r), str(loc_f))
print("Global TP/FP/FN :", str(tp+best["tp"]), str(fp+best["fp"]), str(fn+best["fn"]))
print("Global P/R/F"+str(args.beta)+" :", str(p), str(r), str(f))
# Verbose output: display the best hyp+ref combination
if args.verbose:
print('{:-^40}'.format(""))
print("^^ HYP "+str(best_hyp)+", REF "+str(best_ref)+" chosen for sentence "+str(sent_id))
# Save the best TP, FP and FNs as a dict, and return this and the best_cat dict
best_dict = {"tp":best_tp, "fp":best_fp, "fn":best_fn}
return best_dict, best_cat
# Input 1: A dictionary of hypothesis edits for a single system.
# Input 2: A dictionary of reference edits for a single annotator.
# Output 1-3: The TP, FP and FN for the hyp vs the given ref annotator.
# Output 4: A dictionary of the error type counts.
def compareEdits(hyp_edits, ref_edits):
tp = 0 # True Positives
fp = 0 # False Positives
fn = 0 # False Negatives
cat_dict = {} # {cat: [tp, fp, fn], ...}
for h_edit, h_cats in hyp_edits.items():
# noop hyp edits cannot be TP or FP
if h_cats[0] == "noop": continue
# TRUE POSITIVES
if h_edit in ref_edits.keys():
# On occasion, multiple tokens at same span.
for h_cat in ref_edits[h_edit]: # Use ref dict for TP
tp += 1
# Each dict value [TP, FP, FN]
if h_cat in cat_dict.keys():
cat_dict[h_cat][0] += 1
else:
cat_dict[h_cat] = [1, 0, 0]
# FALSE POSITIVES
else:
# On occasion, multiple tokens at same span.
for h_cat in h_cats:
fp += 1
# Each dict value [TP, FP, FN]
if h_cat in cat_dict.keys():
cat_dict[h_cat][1] += 1
else:
cat_dict[h_cat] = [0, 1, 0]
for r_edit, r_cats in ref_edits.items():
# noop ref edits cannot be FN
if r_cats[0] == "noop": continue
# FALSE NEGATIVES
if r_edit not in hyp_edits.keys():
# On occasion, multiple tokens at same span.
for r_cat in r_cats:
fn += 1
# Each dict value [TP, FP, FN]
if r_cat in cat_dict.keys():
cat_dict[r_cat][2] += 1
else:
cat_dict[r_cat] = [0, 0, 1]
return tp, fp, fn, cat_dict
# Input 1-3: True positives, false positives, false negatives
# Input 4: Value of beta in F-score.
# Output 1-3: Precision, Recall and F-score rounded to 4dp.
def computeFScore(tp, fp, fn, beta):
p = float(tp)/(tp+fp) if fp else 1.0
r = float(tp)/(tp+fn) if fn else 1.0
f = float((1+(beta**2))*p*r)/(((beta**2)*p)+r) if p+r else 0.0
return round(p, 4), round(r, 4), round(f, 4)
# Input 1-2: Two error category dicts. Key is cat, value is list of TP, FP, FN.
# Output: The dictionaries combined with cumulative TP, FP, FN.
def merge_dict(dict1, dict2):
for cat, stats in dict2.items():
if cat in dict1.keys():
dict1[cat] = [x+y for x, y in zip(dict1[cat], stats)]
else:
dict1[cat] = stats
return dict1
# Input 1: A dict; key is error cat, value is counts for [tp, fp, fn]
# Input 2: Integer value denoting level of error category granularity.
# 1: Operation tier; e.g. M, R, U. 2: Main tier; e.g. NOUN, VERB 3: Everything.
# Output: A dictionary of category TP, FP and FN based on Input 2.
def processCategories(cat_dict, setting):
# Otherwise, do some processing.
proc_cat_dict = {}
for cat, cnt in cat_dict.items():
if cat == "UNK":
proc_cat_dict[cat] = cnt
continue
# M, U, R or UNK combined only.
if setting == 1:
if cat[0] in proc_cat_dict.keys():
proc_cat_dict[cat[0]] = [x+y for x, y in zip(proc_cat_dict[cat[0]], cnt)]
else:
proc_cat_dict[cat[0]] = cnt
# Everything without M, U or R.
elif setting == 2:
if cat[2:] in proc_cat_dict.keys():
proc_cat_dict[cat[2:]] = [x+y for x, y in zip(proc_cat_dict[cat[2:]], cnt)]
else:
proc_cat_dict[cat[2:]] = cnt
# All error category combinations
else:
return cat_dict
return proc_cat_dict
# Input 1: A dict of global best TP, FP and FNs
# Input 2: A dict of error types and counts for those TP, FP and FNs
# Input 3: Command line args
def print_results(best, best_cats, args):
# Prepare output title.
if args.dt: title = " Token-Based Detection "
elif args.ds: title = " Span-Based Detection "
elif args.cse: title = " Span-Based Correction + Classification "
else: title = " Span-Based Correction "
# Category Scores
if args.cat:
best_cats = processCategories(best_cats, args.cat)
print("")
print('{:=^66}'.format(title))
print("Category".ljust(14), "TP".ljust(8), "FP".ljust(8), "FN".ljust(8),
"P".ljust(8), "R".ljust(8), "F"+str(args.beta))
for cat, cnts in sorted(best_cats.items()):
cat_p, cat_r, cat_f = computeFScore(cnts[0], cnts[1], cnts[2], args.beta)
print(cat.ljust(14), str(cnts[0]).ljust(8), str(cnts[1]).ljust(8),
str(cnts[2]).ljust(8), str(cat_p).ljust(8), str(cat_r).ljust(8), cat_f)
# Print the overall results.
print("")
print('{:=^46}'.format(title))
print("\t".join(["TP", "FP", "FN", "Prec", "Rec", "F"+str(args.beta)]))
print("\t".join(map(str, [best["tp"], best["fp"],
best["fn"]]+list(computeFScore(best["tp"], best["fp"], best["fn"], args.beta)))))
print('{:=^46}'.format(""))
print("")
if __name__ == "__main__":
# Run the program
main()