|
import json |
|
import os |
|
import copy |
|
from collections import defaultdict |
|
from argparse import ArgumentParser |
|
from tqdm import tqdm |
|
import random |
|
from tqdm import tqdm |
|
from scripts.predict_concrete import read_kairos |
|
|
|
from sftp import SpanPredictor |
|
|
|
|
|
parser = ArgumentParser() |
|
parser.add_argument('aida', type=str) |
|
parser.add_argument('model', type=str) |
|
parser.add_argument('fn2kairos', type=str, default=None) |
|
parser.add_argument('--device', type=int, default=3) |
|
args = parser.parse_args() |
|
|
|
corpus = json.load(open(args.aida)) |
|
mapping = read_kairos(args.fn2kairos) |
|
predictor = SpanPredictor.from_path(args.model, cuda_device=args.device) |
|
random.seed(42) |
|
random.shuffle(corpus) |
|
batch_size = 128 |
|
|
|
|
|
def batchify(a_list): |
|
cur = list() |
|
for item in a_list: |
|
cur.append(item) |
|
if len(cur) == batch_size: |
|
yield cur |
|
cur = list() |
|
if len(cur) > 0: |
|
yield cur |
|
|
|
|
|
batches = list(batchify(corpus)) |
|
|
|
|
|
n_total = n_pos = n_span_match = 0 |
|
for idx, lines in tqdm(enumerate(batches)): |
|
n_total += batch_size |
|
prediction_lines = predictor.predict_batch_sentences( |
|
[line['tokens'] for line in lines], max_tokens=1024, ontology_mapping=mapping |
|
) |
|
for preds, ann in zip(prediction_lines, lines): |
|
ann = ann['annotation'] |
|
preds = preds['prediction'] |
|
for pred in preds: |
|
if pred['start_idx'] == ann['start_idx'] and pred['end_idx'] == ann['end_idx']: |
|
n_span_match += 1 |
|
if pred['label'] == ann['label']: |
|
n_pos += 1 |
|
|
|
print(f'exact match precision: {n_pos * 100 / n_total:.3f}') |
|
print(f'span only precision: {n_span_match * 100 / n_total:.3f}') |
|
|