File size: 1,774 Bytes
05922fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from typing import *
import torch
import json
import argparse
import os
from tqdm import tqdm

from sftp.predictor import SpanPredictor
from sftp.models import SpanModel
from sftp.data_reader import BetterDatasetReader


def predict_doc(predictor, json_path: str):
    src = json.load(open(json_path))
    for doc_name, entry in tqdm(list(src['entries'].items())):
        pred = predictor.predict_json(entry)
        triggers = list()
        for trigger in pred['prediction']:
            children = list()
            for child in trigger['children']:
                children.append([child['start_idx'], child['end_idx']])
            triggers.append({
                "span": [trigger['start_idx'], trigger['end_idx']],
                "argument": children
            })
        entry['trigger span'] = triggers
    return src


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-a', type=str, help='archive path')
    parser.add_argument('-s', type=str, help='source path')
    parser.add_argument('-d', type=str, help='destination path')
    parser.add_argument('-c', type=int, default=0, help='cuda device')
    args = parser.parse_args()
    predictor_ = SpanPredictor.from_path(os.path.join(args.a, 'model.tar.gz'), 'span', cuda_device=args.c)
    model_name = os.path.basename(args.a)
    tgt_path = os.path.join(args.d, model_name)
    os.makedirs(tgt_path, exist_ok=True)
    for root, _, files in os.walk(args.s):
        for fn in files:
            if not fn.endswith('json') and not fn.endswith('valid'):
                continue
            processed_json = predict_doc(predictor_, os.path.join(root, fn))
            with open(os.path.join(tgt_path, fn), 'w') as fp:
                json.dump(processed_json, fp)