File size: 2,133 Bytes
05922fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import json
from argparse import ArgumentParser
from collections import defaultdict
import numpy as np
from tqdm import tqdm
from nltk.corpus import framenet as fn
from sftp import SpanPredictor
def run(model_path, data_path, device, use_ontology=False):
data = list(map(json.loads, open(data_path).readlines()))
lu2frame = defaultdict(list)
for lu in fn.lus():
lu2frame[lu.name].append(lu.frame.name)
predictor = SpanPredictor.from_path(model_path, cuda_device=device)
frame2idx = predictor._model.vocab.get_token_to_index_vocabulary('span_label')
all_frames = [fr.name for fr in fn.frames()]
n_positive = n_total = 0
with tqdm(total=len(data)) as bar:
for sent in data:
bar.update()
for point in sent['annotations']:
model_output = predictor.force_decode(
sent['tokens'], child_spans=[(point['span'][0], point['span'][-1])]
).distribution[0]
if use_ontology:
candidate_frames = lu2frame[point['lu']]
else:
candidate_frames = all_frames
candidate_prob = [-1.0 for _ in candidate_frames]
for idx_can, fr in enumerate(candidate_frames):
if fr in frame2idx:
candidate_prob[idx_can] = model_output[frame2idx[fr]]
if len(candidate_prob) > 0:
pred_frame = candidate_frames[int(np.argmax(candidate_prob))]
if pred_frame == point['label']:
n_positive += 1
n_total += 1
bar.set_description(f'acc={n_positive/n_total*100:.3f}')
print(f'acc={n_positive/n_total*100:.3f}')
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('model', metavar="MODEL")
parser.add_argument('data', metavar="DATA")
parser.add_argument('-d', default=-1, type=int, help='Device')
parser.add_argument('-o', action='store_true', help='Flag to use ontology.')
args = parser.parse_args()
run(args.model, args.data, args.d, args.o)
|