File size: 18,332 Bytes
256a159 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 |
import argparse
from collections import Counter
def main():
# Parse command line args
args = parse_args()
# Open hypothesis and reference m2 files and split into chunks
hyp_m2 = open(args.hyp).read().strip().split("\n\n")[args.start:args.end] if args.start is not None and args.end is not None else open(args.hyp).read().strip().split("\n\n")
ref_m2 = open(args.ref).read().strip().split("\n\n")[args.start:args.end] if args.start is not None and args.end is not None else open(args.ref).read().strip().split("\n\n")
# Make sure they have the same number of sentences
assert len(hyp_m2) == len(ref_m2), print(len(hyp_m2), len(ref_m2))
# Store global corpus level best counts here
best_dict = Counter({"tp":0, "fp":0, "fn":0})
best_cats = {}
# Process each sentence
sents = zip(hyp_m2, ref_m2)
for sent_id, sent in enumerate(sents):
# Simplify the edits into lists of lists
# if "A1" in sent[0] or "A1" in sent[1] or sent_id in sent_id_cons:
# sent_id_cons.append(sent_id)
src = sent[0].split("\n")[0]
hyp_edits = simplify_edits(sent[0], args.max_answer_num)
ref_edits = simplify_edits(sent[1], args.max_answer_num)
# Process the edits for detection/correction based on args
hyp_dict = process_edits(hyp_edits, args)
ref_dict = process_edits(ref_edits, args)
if args.reference_num is None or len(ref_dict.keys()) == args.reference_num:
# Evaluate edits and get best TP, FP, FN hyp+ref combo.
count_dict, cat_dict = evaluate_edits(src,
hyp_dict, ref_dict, best_dict, sent_id, args)
# Merge these dicts with best_dict and best_cats
best_dict += Counter(count_dict)
best_cats = merge_dict(best_cats, cat_dict)
# Print results
print_results(best_dict, best_cats, args)
# Parse command line args
def parse_args():
parser = argparse.ArgumentParser(
description="Calculate F-scores for error detection and/or correction.\n"
"Flags let you evaluate at different levels of granularity.",
formatter_class=argparse.RawTextHelpFormatter,
usage="%(prog)s [options] -hyp HYP -ref REF")
parser.add_argument(
"-hyp",
help="A hypothesis M2 file.",
required=True)
parser.add_argument(
"-ref",
help="A reference M2 file.",
required=True)
parser.add_argument(
"--start",
type=int,
default=None
)
parser.add_argument(
"--end",
type=int,
default=None
)
parser.add_argument(
"--max_answer_num",
type=int,
default=None
)
parser.add_argument(
"--reference_num",
type=int,
default=None
)
parser.add_argument(
"-b",
"--beta",
help="Value of beta in F-score. (default: 0.5)",
default=0.5,
type=float)
parser.add_argument(
"-v",
"--verbose",
help="Print verbose output.",
action="store_true")
eval_type = parser.add_mutually_exclusive_group()
eval_type.add_argument(
"-dt",
help="Evaluate Detection in terms of Tokens.",
action="store_true")
eval_type.add_argument(
"-ds",
help="Evaluate Detection in terms of Spans.",
action="store_true")
eval_type.add_argument(
"-cs",
help="Evaluate Correction in terms of Spans. (default)",
action="store_true")
eval_type.add_argument(
"-cse",
help="Evaluate Correction in terms of Spans and Error types.",
action="store_true")
parser.add_argument(
"-single",
help="Only evaluate single token edits; i.e. 0:1, 1:0 or 1:1",
action="store_true")
parser.add_argument(
"-multi",
help="Only evaluate multi token edits; i.e. 2+:n or n:2+",
action="store_true")
parser.add_argument(
"-multi_hyp_avg",
help="When get multiple hypotheses for a sentence, calculate their average F-scores for this sentence.",
action="store_true") # For IAA calculation
parser.add_argument(
"-multi_hyp_max",
help="When get multiple hypotheses for a sentence, calculate their F-scores and select the max one for this sentence.",
action="store_true") # For multiple hypotheses system evaluation
parser.add_argument(
"-filt",
help="Do not evaluate the specified error types.",
nargs="+",
default=[])
parser.add_argument(
"-cat",
help="Show error category scores.\n"
"1: Only show operation tier scores; e.g. R.\n"
"2: Only show main tier scores; e.g. NOUN.\n"
"3: Show all category scores; e.g. R:NOUN.",
choices=[1, 2, 3],
type=int)
args = parser.parse_args()
return args
# Input: An m2 format sentence with edits.
# Output: A list of lists. Each edit: [start, end, cat, cor, coder]
def simplify_edits(sent, max_answer_num):
out_edits = []
# Get the edit lines from an m2 block.
edits = sent.split("\n")
# Loop through the edits
for edit in edits:
# Preprocessing
if edit.startswith("A "):
edit = edit[2:].split("|||") # Ignore "A " then split.
span = edit[0].split()
start = int(span[0])
end = int(span[1])
cat = edit[1]
cor = edit[2].replace(" ", "")
coder = int(edit[-1])
out_edit = [start, end, cat, cor, coder]
out_edits.append(out_edit)
# return [edit for edit in out_edits if edit[-1] in [0,1]]
if max_answer_num is None:
return out_edits
elif max_answer_num == 1:
return [edit for edit in out_edits if edit[-1] == 0]
elif max_answer_num == 2:
return [edit for edit in out_edits if edit[-1] in [0, 1]]
elif max_answer_num == 3:
return [edit for edit in out_edits if edit[-1] in [0, 1, 2]]
# Input 1: A list of edits. Each edit: [start, end, cat, cor, coder]
# Input 2: Command line args
# Output: A dict; key is coder, value is edit dict.
def process_edits(edits, args):
coder_dict = {}
# Add an explicit noop edit if there are no edits.
if not edits: edits = [[-1, -1, "noop", "-NONE-", 0]]
# Loop through the edits
for edit in edits:
# Name the edit elements for clarity
start = edit[0]
end = edit[1]
cat = edit[2]
cor = edit[3]
coder = edit[4]
# Add the coder to the coder_dict if necessary
if coder not in coder_dict: coder_dict[coder] = {}
# Optionally apply filters based on args
# 1. UNK type edits are only useful for detection, not correction.
if not args.dt and not args.ds and cat == "UNK": continue
# 2. Only evaluate single token edits; i.e. 0:1, 1:0 or 1:1
if args.single and (end-start >= 2 or len(cor.split()) >= 2): continue
# 3. Only evaluate multi token edits; i.e. 2+:n or n:2+
if args.multi and end-start < 2 and len(cor.split()) < 2: continue
# 4. If there is a filter, ignore the specified error types
if args.filt and cat in args.filt: continue
# Token Based Detection
if args.dt:
# Preserve noop edits.
if start == -1:
if (start, start) in coder_dict[coder].keys():
coder_dict[coder][(start, start)].append(cat)
else:
coder_dict[coder][(start, start)] = [cat]
# Insertions defined as affecting the token on the right
elif start == end and start >= 0:
if (start, start+1) in coder_dict[coder].keys():
coder_dict[coder][(start, start+1)].append(cat)
else:
coder_dict[coder][(start, start+1)] = [cat]
# Edit spans are split for each token in the range.
else:
for tok_id in range(start, end):
if (tok_id, tok_id+1) in coder_dict[coder].keys():
coder_dict[coder][(tok_id, tok_id+1)].append(cat)
else:
coder_dict[coder][(tok_id, tok_id+1)] = [cat]
# Span Based Detection
elif args.ds:
if (start, end) in coder_dict[coder].keys():
coder_dict[coder][(start, end)].append(cat)
else:
coder_dict[coder][(start, end)] = [cat]
# Span Based Correction
else:
# With error type classification
if args.cse:
if (start, end, cat, cor) in coder_dict[coder].keys():
coder_dict[coder][(start, end, cat, cor)].append(cat)
else:
coder_dict[coder][(start, end, cat, cor)] = [cat]
# Without error type classification
else:
if (start, end, cor) in coder_dict[coder].keys():
coder_dict[coder][(start, end, cor)].append(cat)
else:
coder_dict[coder][(start, end, cor)] = [cat]
return coder_dict
# Input 1: A hyp dict; key is coder_id, value is dict of processed hyp edits.
# Input 2: A ref dict; key is coder_id, value is dict of processed ref edits.
# Input 3: A dictionary of the best corpus level TP, FP and FN counts so far.
# Input 4: Sentence ID (for verbose output only)
# Input 5: Command line args
# Output 1: A dict of the best corpus level TP, FP and FN for the input sentence.
# Output 2: The corresponding error type dict for the above dict.
def evaluate_edits(src, hyp_dict, ref_dict, best, sent_id, args):
# Store the best sentence level scores and hyp+ref combination IDs
# best_f is initialised as -1 cause 0 is a valid result.
best_tp, best_fp, best_fn, best_f, best_hyp, best_ref = 0, 0, 0, -1, 0, 0
best_cat = {}
# skip not annotatable sentence
if len(ref_dict.keys()) == 1:
ref_id = list(ref_dict.keys())[0]
if len(ref_dict[ref_id].keys()) == 1:
cat = list(ref_dict[ref_id].values())[0][0]
if cat == "NA":
best_dict = {"tp":best_tp, "fp":best_fp, "fn":best_fn}
return best_dict, best_cat
# Compare each hyp and ref combination
for hyp_id in hyp_dict.keys():
for ref_id in ref_dict.keys():
# Get the local counts for the current combination.
tp, fp, fn, cat_dict = compareEdits(hyp_dict[hyp_id], ref_dict[ref_id])
# Compute the local sentence scores (for verbose output only)
loc_p, loc_r, loc_f = computeFScore(tp, fp, fn, args.beta)
# Compute the global sentence scores
p, r, f = computeFScore(
tp+best["tp"], fp+best["fp"], fn+best["fn"], args.beta)
# Save the scores if they are better in terms of:
# 1. Higher F-score
# 2. Same F-score, higher TP
# 3. Same F-score and TP, lower FP
# 4. Same F-score, TP and FP, lower FN
if (f > best_f) or \
(f == best_f and tp > best_tp) or \
(f == best_f and tp == best_tp and fp < best_fp) or \
(f == best_f and tp == best_tp and fp == best_fp and fn < best_fn):
best_tp, best_fp, best_fn = tp, fp, fn
best_f, best_hyp, best_ref = f, hyp_id, ref_id
best_cat = cat_dict
# Verbose output
if args.verbose:
# Prepare verbose output edits.
hyp_verb = list(sorted(hyp_dict[hyp_id].keys()))
ref_verb = list(sorted(ref_dict[ref_id].keys()))
# Ignore noop edits
if not hyp_verb or hyp_verb[0][0] == -1: hyp_verb = []
if not ref_verb or ref_verb[0][0] == -1: ref_verb = []
# Print verbose info
print('{:-^40}'.format(""))
print("SENTENCE "+str(sent_id)+src[1:])
print('{:-^40}'.format(""))
print("SENTENCE "+str(sent_id)+" - HYP "+str(hyp_id)+" - REF "+str(ref_id))
print("HYPOTHESIS EDITS :", hyp_verb)
print("REFERENCE EDITS :", ref_verb)
print("Local TP/FP/FN :", str(tp), str(fp), str(fn))
print("Local P/R/F"+str(args.beta)+" :", str(loc_p), str(loc_r), str(loc_f))
print("Global TP/FP/FN :", str(tp+best["tp"]), str(fp+best["fp"]), str(fn+best["fn"]))
print("Global P/R/F"+str(args.beta)+" :", str(p), str(r), str(f))
# Verbose output: display the best hyp+ref combination
if args.verbose:
print('{:-^40}'.format(""))
print("^^ HYP "+str(best_hyp)+", REF "+str(best_ref)+" chosen for sentence "+str(sent_id))
# Save the best TP, FP and FNs as a dict, and return this and the best_cat dict
best_dict = {"tp":best_tp, "fp":best_fp, "fn":best_fn}
return best_dict, best_cat
# Input 1: A dictionary of hypothesis edits for a single system.
# Input 2: A dictionary of reference edits for a single annotator.
# Output 1-3: The TP, FP and FN for the hyp vs the given ref annotator.
# Output 4: A dictionary of the error type counts.
def compareEdits(hyp_edits, ref_edits):
tp = 0 # True Positives
fp = 0 # False Positives
fn = 0 # False Negatives
cat_dict = {} # {cat: [tp, fp, fn], ...}
for h_edit, h_cats in hyp_edits.items():
# noop hyp edits cannot be TP or FP
if h_cats[0] == "noop": continue
# TRUE POSITIVES
if h_edit in ref_edits.keys():
# On occasion, multiple tokens at same span.
for h_cat in ref_edits[h_edit]: # Use ref dict for TP
tp += 1
# Each dict value [TP, FP, FN]
if h_cat in cat_dict.keys():
cat_dict[h_cat][0] += 1
else:
cat_dict[h_cat] = [1, 0, 0]
# FALSE POSITIVES
else:
# On occasion, multiple tokens at same span.
for h_cat in h_cats:
fp += 1
# Each dict value [TP, FP, FN]
if h_cat in cat_dict.keys():
cat_dict[h_cat][1] += 1
else:
cat_dict[h_cat] = [0, 1, 0]
for r_edit, r_cats in ref_edits.items():
# noop ref edits cannot be FN
if r_cats[0] == "noop": continue
# FALSE NEGATIVES
if r_edit not in hyp_edits.keys():
# On occasion, multiple tokens at same span.
for r_cat in r_cats:
fn += 1
# Each dict value [TP, FP, FN]
if r_cat in cat_dict.keys():
cat_dict[r_cat][2] += 1
else:
cat_dict[r_cat] = [0, 0, 1]
return tp, fp, fn, cat_dict
# Input 1-3: True positives, false positives, false negatives
# Input 4: Value of beta in F-score.
# Output 1-3: Precision, Recall and F-score rounded to 4dp.
def computeFScore(tp, fp, fn, beta):
p = float(tp)/(tp+fp) if fp else 1.0
r = float(tp)/(tp+fn) if fn else 1.0
f = float((1+(beta**2))*p*r)/(((beta**2)*p)+r) if p+r else 0.0
return round(p, 4), round(r, 4), round(f, 4)
# Input 1-2: Two error category dicts. Key is cat, value is list of TP, FP, FN.
# Output: The dictionaries combined with cumulative TP, FP, FN.
def merge_dict(dict1, dict2):
for cat, stats in dict2.items():
if cat in dict1.keys():
dict1[cat] = [x+y for x, y in zip(dict1[cat], stats)]
else:
dict1[cat] = stats
return dict1
# Input 1: A dict; key is error cat, value is counts for [tp, fp, fn]
# Input 2: Integer value denoting level of error category granularity.
# 1: Operation tier; e.g. M, R, U. 2: Main tier; e.g. NOUN, VERB 3: Everything.
# Output: A dictionary of category TP, FP and FN based on Input 2.
def processCategories(cat_dict, setting):
# Otherwise, do some processing.
proc_cat_dict = {}
for cat, cnt in cat_dict.items():
if cat == "UNK":
proc_cat_dict[cat] = cnt
continue
# M, U, R or UNK combined only.
if setting == 1:
if cat[0] in proc_cat_dict.keys():
proc_cat_dict[cat[0]] = [x+y for x, y in zip(proc_cat_dict[cat[0]], cnt)]
else:
proc_cat_dict[cat[0]] = cnt
# Everything without M, U or R.
elif setting == 2:
if cat[2:] in proc_cat_dict.keys():
proc_cat_dict[cat[2:]] = [x+y for x, y in zip(proc_cat_dict[cat[2:]], cnt)]
else:
proc_cat_dict[cat[2:]] = cnt
# All error category combinations
else:
return cat_dict
return proc_cat_dict
# Input 1: A dict of global best TP, FP and FNs
# Input 2: A dict of error types and counts for those TP, FP and FNs
# Input 3: Command line args
def print_results(best, best_cats, args):
# Prepare output title.
if args.dt: title = " Token-Based Detection "
elif args.ds: title = " Span-Based Detection "
elif args.cse: title = " Span-Based Correction + Classification "
else: title = " Span-Based Correction "
# Category Scores
if args.cat:
best_cats = processCategories(best_cats, args.cat)
print("")
print('{:=^66}'.format(title))
print("Category".ljust(14), "TP".ljust(8), "FP".ljust(8), "FN".ljust(8),
"P".ljust(8), "R".ljust(8), "F"+str(args.beta))
for cat, cnts in sorted(best_cats.items()):
cat_p, cat_r, cat_f = computeFScore(cnts[0], cnts[1], cnts[2], args.beta)
print(cat.ljust(14), str(cnts[0]).ljust(8), str(cnts[1]).ljust(8),
str(cnts[2]).ljust(8), str(cat_p).ljust(8), str(cat_r).ljust(8), cat_f)
# Print the overall results.
print("")
print('{:=^46}'.format(title))
print("\t".join(["TP", "FP", "FN", "Prec", "Rec", "F"+str(args.beta)]))
print("\t".join(map(str, [best["tp"], best["fp"],
best["fn"]]+list(computeFScore(best["tp"], best["fp"], best["fn"], args.beta)))))
print('{:=^46}'.format(""))
print("")
if __name__ == "__main__":
# Run the program
main()
|