File size: 2,133 Bytes
05922fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import json
from argparse import ArgumentParser
from collections import defaultdict
import numpy as np

from tqdm import tqdm
from nltk.corpus import framenet as fn

from sftp import SpanPredictor


def run(model_path, data_path, device, use_ontology=False):
    data = list(map(json.loads, open(data_path).readlines()))
    lu2frame = defaultdict(list)
    for lu in fn.lus():
        lu2frame[lu.name].append(lu.frame.name)
    predictor = SpanPredictor.from_path(model_path, cuda_device=device)
    frame2idx = predictor._model.vocab.get_token_to_index_vocabulary('span_label')
    all_frames = [fr.name for fr in fn.frames()]
    n_positive = n_total = 0
    with tqdm(total=len(data)) as bar:
        for sent in data:
            bar.update()
            for point in sent['annotations']:
                model_output = predictor.force_decode(
                    sent['tokens'], child_spans=[(point['span'][0], point['span'][-1])]
                ).distribution[0]
                if use_ontology:
                    candidate_frames = lu2frame[point['lu']]
                else:
                    candidate_frames = all_frames
                candidate_prob = [-1.0 for _ in candidate_frames]
                for idx_can, fr in enumerate(candidate_frames):
                    if fr in frame2idx:
                        candidate_prob[idx_can] = model_output[frame2idx[fr]]
                if len(candidate_prob) > 0:
                    pred_frame = candidate_frames[int(np.argmax(candidate_prob))]
                    if pred_frame == point['label']:
                        n_positive += 1
                n_total += 1
            bar.set_description(f'acc={n_positive/n_total*100:.3f}')
    print(f'acc={n_positive/n_total*100:.3f}')


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('model', metavar="MODEL")
    parser.add_argument('data', metavar="DATA")
    parser.add_argument('-d', default=-1, type=int, help='Device')
    parser.add_argument('-o', action='store_true', help='Flag to use ontology.')
    args = parser.parse_args()
    run(args.model, args.data, args.d, args.o)