import logging from typing import Optional, Sequence, TypeVar, Union from pie_modules.taskmodules import CrossTextBinaryCorefTaskModule from pie_modules.taskmodules.cross_text_binary_coref import ( DocumentType, SpanDoesNotFitIntoAvailableWindow, TaskEncodingType, ) from pie_modules.utils.tokenization import SpanNotAlignedWithTokenException from pytorch_ie.annotations import Span from pytorch_ie.core import TaskEncoding, TaskModule logger = logging.getLogger(__name__) S = TypeVar("S", bound=Span) def shift_span(span: S, offset: int) -> S: return span.copy(start=span.start + offset, end=span.end + offset) @TaskModule.register() class CrossTextBinaryCorefTaskModuleWithOptionalContext(CrossTextBinaryCorefTaskModule): """Same as CrossTextBinaryCorefTaskModule, but: - optionally without context. """ def __init__( self, without_context: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) self.without_context = without_context def encode_input( self, document: DocumentType, is_training: bool = False, ) -> Optional[Union[TaskEncodingType, Sequence[TaskEncodingType]]]: if self.without_context: return self.encode_input_without_context(document) else: return super().encode_input(document) def encode_input_without_context( self, document: DocumentType ) -> Optional[Union[TaskEncodingType, Sequence[TaskEncodingType]]]: self.collect_all_relations(kind="available", relations=document.binary_coref_relations) tokenizer_kwargs = dict( padding=False, truncation=False, add_special_tokens=False, ) task_encodings = [] for coref_rel in document.binary_coref_relations: # TODO: This can miss instances if both texts are the same. We could check that # coref_rel.head is in document.labeled_spans (same for the tail), but would this # slow down the encoding? if not ( coref_rel.head.target == document.text or coref_rel.tail.target == document.text_pair ): raise ValueError( f"It is expected that coref relations go from (head) spans over 'text' " f"to (tail) spans over 'text_pair', but this is not the case for this " f"relation (i.e. it points into the other direction): {coref_rel.resolve()}" ) encoding = self.tokenizer(text=str(coref_rel.head), **tokenizer_kwargs) encoding_pair = self.tokenizer(text=str(coref_rel.tail), **tokenizer_kwargs) try: current_encoding, token_span = self.truncate_encoding_around_span( encoding=encoding, char_span=shift_span(coref_rel.head, -coref_rel.head.start) ) current_encoding_pair, token_span_pair = self.truncate_encoding_around_span( encoding=encoding_pair, char_span=shift_span(coref_rel.tail, -coref_rel.tail.start), ) except SpanNotAlignedWithTokenException as e: logger.warning( f"Could not get token offsets for argument ({e.span}) of coref relation: " f"{coref_rel.resolve()}. Skip it." ) self.collect_relation(kind="skipped_args_not_aligned", relation=coref_rel) continue except SpanDoesNotFitIntoAvailableWindow as e: logger.warning( f"Argument span [{e.span}] does not fit into available token window " f"({self.available_window}). Skip it." ) self.collect_relation( kind="skipped_span_does_not_fit_into_window", relation=coref_rel ) continue task_encodings.append( TaskEncoding( document=document, inputs={ "encoding": current_encoding, "encoding_pair": current_encoding_pair, "pooler_start_indices": token_span.start, "pooler_end_indices": token_span.end, "pooler_pair_start_indices": token_span_pair.start, "pooler_pair_end_indices": token_span_pair.end, }, metadata={"candidate_annotation": coref_rel}, ) ) self.collect_relation("used", coref_rel) return task_encodings