|
import json |
|
import os |
|
import copy |
|
from collections import defaultdict |
|
from argparse import ArgumentParser |
|
from tqdm import tqdm |
|
import random |
|
from tqdm import tqdm |
|
from scripts.predict_concrete import read_kairos |
|
|
|
from sftp import SpanPredictor |
|
|
|
|
|
parser = ArgumentParser() |
|
parser.add_argument('aida', type=str) |
|
parser.add_argument('model', type=str) |
|
parser.add_argument('dst', type=str) |
|
parser.add_argument('--topk', type=int, default=10) |
|
parser.add_argument('--device', type=int, default=0) |
|
args = parser.parse_args() |
|
|
|
k = args.topk |
|
corpus = json.load(open(args.aida)) |
|
predictor = SpanPredictor.from_path(args.model, cuda_device=args.device) |
|
idx2fn = predictor._model.vocab.get_index_to_token_vocabulary('span_label') |
|
random.seed(42) |
|
random.shuffle(corpus) |
|
|
|
|
|
output_fp = open(args.dst, 'a') |
|
for line in tqdm(corpus): |
|
tokens, ann = line['tokens'], line['annotation'] |
|
start, end, kairos_label = ann['start_idx'], ann['end_idx'], ann['label'] |
|
prob_dist = predictor.force_decode(tokens, [(start, end)])[0] |
|
topk_indices = prob_dist.argsort(descending=True)[:k] |
|
prob = prob_dist[topk_indices].tolist() |
|
frames = [(idx2fn[int(idx)], p) for idx, p in zip(topk_indices, prob)] |
|
output_fp.write(json.dumps({ |
|
'tokens': tokens, |
|
'frames': frames, |
|
'kairos': kairos_label |
|
}) + '\n') |
|
|