import logging from typing import Optional, Sequence, Type from langchain_core.documents import Document as LCDocument from pie_datasets import Dataset, IterableDataset from pytorch_ie import Document, WithDocumentTypeMixin from pytorch_ie.annotations import BinaryRelation, LabeledSpan from pytorch_ie.documents import TextBasedDocument from src.langchain_modules import DocumentAwareSpanRetriever logger = logging.getLogger(__name__) class DummyTaskmodule(WithDocumentTypeMixin): def __init__(self, document_type: Type[Document]): self._document_type = document_type @property def document_type(self) -> Optional[Type[Document]]: return self._document_type class SpanRetrievalBasedRelationExtractionPipeline: """Pipeline for adding binary relations between spans based on span retrieval within the same document. This pipeline retrieves spans for all existing spans as query and adds binary relations between the query spans and the retrieved spans. Args: retriever: The span retriever to use for retrieving spans. relation_label: The label to use for the binary relations. relation_layer_name: The name of the annotation layer to add the binary relations to. load_store_path: If provided, the retriever store(s) will be loaded from this path before processing. save_store_path: If provided, the retriever store(s) will be saved to this path after processing. fast_dev_run: Whether to run the pipeline in fast dev mode, i.e. only processing the first 2 documents. """ def __init__( self, retriever: DocumentAwareSpanRetriever, relation_label: str, relation_layer_name: str = "binary_relations", use_predicted_annotations: bool = False, load_store_path: Optional[str] = None, save_store_path: Optional[str] = None, fast_dev_run: bool = False, ): self.retriever = retriever if not self.retriever.retrieve_from_same_document: raise NotImplementedError("Retriever must retrieve from the same document") self.relation_label = relation_label self.relation_layer_name = relation_layer_name self.use_predicted_annotations = use_predicted_annotations self.load_store_path = load_store_path self.save_store_path = save_store_path if self.load_store_path is not None: self.retriever.load_from_directory(path=self.load_store_path) self.fast_dev_run = fast_dev_run # to make auto-conversion work: we request documents of type pipeline.taskmodule.document_type # from the dataset @property def taskmodule(self) -> DummyTaskmodule: return DummyTaskmodule(self.retriever.pie_document_type) def _construct_similarity_relations( self, query_results: list[LCDocument], query_span: LabeledSpan, ) -> list[BinaryRelation]: return [ BinaryRelation( head=query_span, tail=lc_doc.metadata["attached_span"], label=self.relation_label, score=float(lc_doc.metadata["relevance_score"]), ) for lc_doc in query_results ] def _process_single_document( self, document: Document, ) -> TextBasedDocument: if not isinstance(document, TextBasedDocument): raise ValueError("Document must be a TextBasedDocument") self.retriever.add_pie_documents( [document], use_predicted_annotations=self.use_predicted_annotations ) all_new_rels = [] spans = self.retriever.get_base_layer( document, use_predicted_annotations=self.use_predicted_annotations ) span_id2idx = self.retriever.get_span_id2idx_from_doc(document.id) for span_id, span_idx in span_id2idx.items(): query_span = spans[span_idx] query_result = self.retriever.invoke(input=span_id) query_rels = self._construct_similarity_relations(query_result, query_span=query_span) all_new_rels.extend(query_rels) if self.relation_layer_name not in document: raise ValueError(f"Document does not have a layer named {self.relation_layer_name}") document[self.relation_layer_name].predictions.extend(all_new_rels) if self.retriever.retrieve_from_same_document and self.save_store_path is None: self.retriever.delete_documents([document.id]) return document def __call__(self, documents: Sequence[Document], inplace: bool = False) -> Sequence[Document]: if inplace: raise NotImplementedError("Inplace processing is not supported yet") if self.fast_dev_run: logger.warning("Fast dev run enabled, only processing the first 2 documents") documents = documents[:2] if not isinstance(documents, (Dataset, IterableDataset)): documents = Dataset.from_documents(documents) mapped_documents = documents.map(self._process_single_document) if self.save_store_path is not None: self.retriever.save_to_directory(path=self.save_store_path) return mapped_documents