|
import json |
|
import logging |
|
import os |
|
from collections import defaultdict, namedtuple |
|
from typing import * |
|
|
|
from allennlp.data.dataset_readers.dataset_reader import DatasetReader |
|
from allennlp.data.instance import Instance |
|
|
|
from .span_reader import SpanReader |
|
from ..utils import Span |
|
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
SpanTuple = namedtuple('Span', ['start', 'end']) |
|
|
|
|
|
@DatasetReader.register('better') |
|
class BetterDatasetReader(SpanReader): |
|
def __init__( |
|
self, |
|
eval_type, |
|
consolidation_strategy='first', |
|
span_set_type='single', |
|
max_argument_ss_size=1, |
|
use_ref_events=False, |
|
**extra |
|
): |
|
super().__init__(**extra) |
|
self.eval_type = eval_type |
|
assert self.eval_type in ['abstract', 'basic'] |
|
|
|
self.consolidation_strategy = consolidation_strategy |
|
self.unitary_spans = span_set_type == 'single' |
|
|
|
self.max_arg_spans = max_argument_ss_size |
|
self.use_ref_events = use_ref_events |
|
|
|
self.n_overlap_arg = 0 |
|
self.n_overlap_trigger = 0 |
|
self.n_skip = 0 |
|
self.n_too_long = 0 |
|
|
|
@staticmethod |
|
def post_process_basic_span(predicted_span, basic_entry): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
start_idx = predicted_span['start_idx'] |
|
end_idx = predicted_span['end_idx'] |
|
|
|
char_start_idx = basic_entry['tok2char'][predicted_span['start_idx']][0] |
|
char_end_idx = basic_entry['tok2char'][predicted_span['end_idx']][-1] + 1 |
|
|
|
span_text = basic_entry['segment-text'][char_start_idx:char_end_idx] |
|
span_text_tok = basic_entry['segment-text-tok'][start_idx:end_idx + 1] |
|
|
|
span = {'string': span_text, |
|
'start': char_start_idx, |
|
'end': char_end_idx, |
|
'start-token': start_idx, |
|
'end-token': end_idx, |
|
'string-tok': span_text_tok, |
|
'label': predicted_span['label'], |
|
'predicted': True} |
|
return span |
|
|
|
@staticmethod |
|
def _get_shortest_span(spans): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return [s[-1] for s in sorted([(len(span['string']), ix, span) for ix, span in enumerate(spans)])] |
|
|
|
@staticmethod |
|
def _get_first_span(spans): |
|
spans = [(span['start'], -len(span['string']), ix, span) for ix, span in enumerate(spans)] |
|
try: |
|
return [s[-1] for s in sorted(spans)] |
|
except: |
|
breakpoint() |
|
|
|
@staticmethod |
|
def _get_longest_span(spans): |
|
return [s[-1] for s in sorted([(len(span['string']), ix, span) for ix, span in enumerate(spans)], reverse=True)] |
|
|
|
@staticmethod |
|
def _subfinder(text, pattern): |
|
|
|
matches = [] |
|
pattern_length = len(pattern) |
|
for i, token in enumerate(text): |
|
try: |
|
if token == pattern[0] and text[i:i + pattern_length] == pattern: |
|
matches.append(SpanTuple(start=i, end=i + pattern_length - 1)) |
|
except: |
|
continue |
|
return matches |
|
|
|
def consolidate_span_set(self, spans): |
|
if self.consolidation_strategy == 'first': |
|
spans = BetterDatasetReader._get_first_span(spans) |
|
elif self.consolidation_strategy == 'shortest': |
|
spans = BetterDatasetReader._get_shortest_span(spans) |
|
elif self.consolidation_strategy == 'longest': |
|
spans = BetterDatasetReader._get_longest_span(spans) |
|
else: |
|
raise NotImplementedError(f"{self.consolidation_strategy} does not exist") |
|
|
|
if self.unitary_spans: |
|
spans = [spans[0]] |
|
else: |
|
spans = spans[:self.max_arg_spans] |
|
|
|
|
|
|
|
return spans |
|
|
|
def get_mention_spans(self, text: List[str], span_sets: Dict): |
|
mention_spans = defaultdict(list) |
|
for span_set_id in span_sets.keys(): |
|
spans = span_sets[span_set_id]['spans'] |
|
|
|
|
|
consolidated_spans = self.consolidate_span_set(spans) |
|
|
|
|
|
|
|
if self.eval_type == 'abstract': |
|
span = consolidated_spans[0] |
|
span_tokens = span['string-tok'] |
|
|
|
span_indices = BetterDatasetReader._subfinder(text=text, pattern=span_tokens) |
|
|
|
if len(span_indices) > 1: |
|
pass |
|
|
|
if len(span_indices) == 0: |
|
continue |
|
|
|
mention_spans[span_set_id] = span_indices[0] |
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
for span in consolidated_spans: |
|
mention_spans[span_set_id].append(SpanTuple(start=span['start-token'], end=span['end-token'])) |
|
|
|
return mention_spans |
|
|
|
def _read_single_file(self, file_path): |
|
with open(file_path) as fp: |
|
json_content = json.load(fp) |
|
if 'entries' in json_content: |
|
for doc_name, entry in json_content['entries'].items(): |
|
instance = self.text_to_instance(entry, 'train' in file_path) |
|
yield instance |
|
else: |
|
for doc_name, entry in json_content.items(): |
|
instance = self.text_to_instance(entry, True) |
|
yield instance |
|
|
|
logger.warning(f'{self.n_overlap_arg} overlapped args detected!') |
|
logger.warning(f'{self.n_overlap_trigger} overlapped triggers detected!') |
|
logger.warning(f'{self.n_skip} skipped detected!') |
|
logger.warning(f'{self.n_too_long} were skipped because they are too long!') |
|
self.n_overlap_arg = self.n_skip = self.n_too_long = self.n_overlap_trigger = 0 |
|
|
|
def _read(self, file_path: str) -> Iterable[Instance]: |
|
|
|
if os.path.isdir(file_path): |
|
for fn in os.listdir(file_path): |
|
if not fn.endswith('.json'): |
|
logger.info(f'Skipping {fn}') |
|
continue |
|
logger.info(f'Loading from {fn}') |
|
yield from self._read_single_file(os.path.join(file_path, fn)) |
|
else: |
|
yield from self._read_single_file(file_path) |
|
|
|
def text_to_instance(self, entry, is_training=False): |
|
word_tokens = entry['segment-text-tok'] |
|
|
|
|
|
spans = self.get_mention_spans( |
|
word_tokens, entry['annotation-sets'][f'{self.eval_type}-events']['span-sets'] |
|
) |
|
|
|
|
|
all_trigger_idxs = set() |
|
|
|
|
|
input_spans = [] |
|
|
|
self._local_child_overlap = 0 |
|
self._local_child_total = 0 |
|
|
|
better_events = entry['annotation-sets'][f'{self.eval_type}-events']['events'] |
|
|
|
skipped_events = set() |
|
|
|
for event_id, event in better_events.items(): |
|
assert event['anchors'] in spans |
|
|
|
|
|
anchor_start, anchor_end = spans[event['anchors']][0] |
|
|
|
if any(ix in all_trigger_idxs for ix in range(anchor_start, anchor_end + 1)): |
|
logger.warning( |
|
f"Skipped {event_id} with anchor span {event['anchors']}, overlaps a previously found event trigger/anchor") |
|
self.n_overlap_trigger += 1 |
|
skipped_events.add(event_id) |
|
continue |
|
|
|
all_trigger_idxs.update(range(anchor_start, anchor_end + 1)) |
|
|
|
for event_id, event in better_events.items(): |
|
if event_id in skipped_events: |
|
continue |
|
|
|
|
|
local_arg_idxs = set() |
|
|
|
anchor_start, anchor_end = spans[event['anchors']][0] |
|
|
|
event_span = Span(anchor_start, anchor_end, event['event-type'], True) |
|
input_spans.append(event_span) |
|
|
|
def add_a_child(span_id, label): |
|
|
|
assert span_id in spans |
|
for child_span in spans[span_id]: |
|
self._local_child_total += 1 |
|
arg_start, arg_end = child_span |
|
|
|
if any(ix in local_arg_idxs for ix in range(arg_start, arg_end + 1)): |
|
|
|
|
|
self.n_overlap_arg += 1 |
|
self._local_child_overlap += 1 |
|
continue |
|
|
|
local_arg_idxs.update(range(arg_start, arg_end + 1)) |
|
event_span.add_child(Span(arg_start, arg_end, label, False)) |
|
|
|
for agent in event['agents']: |
|
add_a_child(agent, 'agent') |
|
for patient in event['patients']: |
|
add_a_child(patient, 'patient') |
|
|
|
if self.use_ref_events: |
|
for ref_event in event['ref-events']: |
|
if ref_event in skipped_events: |
|
continue |
|
ref_event_anchor_id = better_events[ref_event]['anchors'] |
|
add_a_child(ref_event_anchor_id, 'ref-event') |
|
|
|
|
|
|
|
|
|
fields = self.prepare_inputs(word_tokens, spans=input_spans) |
|
if self._local_child_overlap > 0: |
|
logging.warning( |
|
f"Skipped {self._local_child_overlap} / {self._local_child_total} argument spans due to overlaps") |
|
return Instance(fields) |
|
|
|
|