import dataclasses from pytorch_ie import AnnotationLayer, annotation_field from pytorch_ie.annotations import BinaryRelation from pytorch_ie.documents import ( TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, ) @dataclasses.dataclass(eq=True, frozen=True) class RelatedRelation(BinaryRelation): link_relation: BinaryRelation = dataclasses.field(default=None, compare=False) relation: BinaryRelation = dataclasses.field(default=None, compare=False) def __post_init__(self): super().__post_init__() # check if the reference_span is correct self.reference_span @property def reference_span(self): if self.link_relation is None: raise ValueError( "No semantically_same_relation available, cannot return reference_span" ) if self.link_relation.head == self.head: return self.link_relation.tail elif self.link_relation.tail == self.head: return self.link_relation.head elif self.link_relation.head == self.tail: return self.link_relation.tail elif self.link_relation.tail == self.tail: return self.link_relation.head else: raise ValueError( "The semantically_same_relation is neither linked to head nor tail of the current relation" ) @dataclasses.dataclass class TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations( TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, ): related_relations: AnnotationLayer[RelatedRelation] = annotation_field( targets=["labeled_multi_spans", "binary_relations"] )