|
import json |
|
import os |
|
from collections import defaultdict |
|
from typing import * |
|
|
|
import numpy as np |
|
from allennlp.data import Vocabulary |
|
from tqdm import tqdm |
|
|
|
from sftp import SpanPredictor, Span |
|
from sftp.utils import VIRTUAL_ROOT |
|
|
|
|
|
def read_framenet(path: str): |
|
ret = list() |
|
for line in map(json.loads, open(path).readlines()): |
|
ret.append((line['tokens'], Span.from_json(line['annotations']))) |
|
return ret |
|
|
|
|
|
def co_occur( |
|
predictor: SpanPredictor, |
|
sentences: List[Tuple[List[str], Span]], |
|
event_list: List[str], |
|
arg_list: List[str], |
|
): |
|
idx2label = predictor.vocab.get_index_to_token_vocabulary('span_label') |
|
event_count = np.zeros([len(event_list), len(idx2label)], np.float64) |
|
arg_count = np.zeros([len(arg_list), len(idx2label)], np.float64) |
|
for sent, vr in tqdm(sentences): |
|
|
|
_, _, event_dist = predictor.force_decode(sent, child_spans=[event.boundary for event in vr]) |
|
for event, dist in zip(vr, event_dist): |
|
event_count[event_list.index(event.label)] += dist |
|
|
|
for event, one_event_dist in zip(vr, event_dist): |
|
parent_label = idx2label[int(one_event_dist.argmax())] |
|
arg_spans = [child.boundary for child in event] |
|
_, _, arg_dist = predictor.force_decode( |
|
sent, event.boundary, parent_label, arg_spans |
|
) |
|
for arg, dist in zip(event, arg_dist): |
|
arg_count[arg_list.index(arg.label)] += dist |
|
return event_count, arg_count |
|
|
|
|
|
def create_vocab(events, args): |
|
vocab = Vocabulary() |
|
vocab.add_token_to_namespace(VIRTUAL_ROOT, 'span_label') |
|
for event in events: |
|
vocab.add_token_to_namespace(event, 'span_label') |
|
for arg in args: |
|
vocab.add_token_to_namespace(arg, 'span_label') |
|
return vocab |
|
|
|
|
|
def count_data(annotations: Iterable[Span]): |
|
event_cnt, arg_cnt = defaultdict(int), defaultdict(int) |
|
for sent in annotations: |
|
for event in sent: |
|
event_cnt[event.label] += 1 |
|
for arg in event: |
|
arg_cnt[arg.label] += 1 |
|
return dict(event_cnt), dict(arg_cnt) |
|
|
|
|
|
def gen_mapping( |
|
src_label: List[str], src_count: Dict[str, int], |
|
tgt_onto: List[str], tgt_label: List[str], |
|
cooccur_count: np.ndarray |
|
): |
|
""" |
|
:param src_label: Src label list, including events and args. |
|
:param src_count: Src label count, event or arg. |
|
:param tgt_onto: Target label list, only event or arg. |
|
:param tgt_label: Target label count, event or arg. |
|
:param cooccur_count: Co-occurrence counting table. |
|
:return: Mapping dict. |
|
""" |
|
onto2label = np.zeros([len(tgt_onto), len(tgt_label)], dtype=np.float) |
|
for onto_idx, onto_tag in enumerate(tgt_onto): |
|
onto2label[onto_idx, tgt_label.index(onto_tag)] = 1.0 |
|
ret = dict() |
|
for src_tag, src_freq in src_count.items(): |
|
if src_tag in src_label: |
|
src_idx = src_label.index(src_tag) |
|
ret[src_tag] = list((cooccur_count[:, src_idx] / src_freq) @ onto2label) |
|
return ret |
|
|
|
|
|
def ontology_map( |
|
model_path, |
|
src_data: List[Tuple[List[str], Span]], |
|
tgt_data: List[Tuple[List[str], Span]], |
|
device: int, |
|
dst_path: str, |
|
meta: Optional[dict] = None, |
|
) -> None: |
|
ret = {'meta': meta or {}} |
|
data = {'src': {}, 'tgt': {}} |
|
for name, datasets in [['src', src_data], ['tgt', tgt_data]]: |
|
d = data[name] |
|
d['sentences'], d['annotations'] = zip(*datasets) |
|
d['event_cnt'], d['arg_cnt'] = count_data(d['annotations']) |
|
d['event'], d['arg'] = list(d['event_cnt']), list(d['arg_cnt']) |
|
|
|
predictor = SpanPredictor.from_path(model_path, cuda_device=device) |
|
tgt_vocab = create_vocab(data['tgt']['event'], data['tgt']['arg']) |
|
for name, vocab in [['src', predictor.vocab], ['tgt', tgt_vocab]]: |
|
data[name]['label'] = [ |
|
vocab.get_index_to_token_vocabulary('span_label')[i] for i in range(vocab.get_vocab_size('span_label')) |
|
] |
|
|
|
data['event'], data['arg'] = co_occur( |
|
predictor, tgt_data, data['tgt']['event'], data['tgt']['arg'] |
|
) |
|
mapping = {} |
|
for layer in ['event', 'arg']: |
|
mapping[layer] = gen_mapping( |
|
data['src']['label'], data['src'][layer+'_cnt'], data['tgt'][layer], data['tgt']['label'], data[layer] |
|
) |
|
|
|
for key, name in [['source', 'src'], ['target', 'tgt']]: |
|
ret[key] = { |
|
'label': data[name]['label'], |
|
'event': data[name]['event'], |
|
'argument': data[name]['arg'] |
|
} |
|
ret['mapping'] = { |
|
'event': mapping['event'], |
|
'argument': mapping['arg'] |
|
} |
|
|
|
os.makedirs(dst_path, exist_ok=True) |
|
with open(os.path.join(dst_path, 'ontology_mapping.json'), 'w') as fp: |
|
json.dump(ret, fp) |
|
with open(os.path.join(dst_path, 'ontology.tsv'), 'w') as fp: |
|
to_dump = list() |
|
to_dump.append('\t'.join([VIRTUAL_ROOT] + ret['target']['event'])) |
|
for event in ret['target']['event']: |
|
to_dump.append('\t'.join([event] + ret['target']['argument'])) |
|
fp.write('\n'.join(to_dump)) |
|
tgt_vocab.save_to_files(os.path.join(dst_path, 'vocabulary')) |
|
|