|
import logging |
|
from typing import Optional, Sequence, Type |
|
|
|
from langchain_core.documents import Document as LCDocument |
|
from pie_datasets import Dataset, IterableDataset |
|
from pytorch_ie import Document, WithDocumentTypeMixin |
|
from pytorch_ie.annotations import BinaryRelation, LabeledSpan |
|
from pytorch_ie.documents import TextBasedDocument |
|
|
|
from src.langchain_modules import DocumentAwareSpanRetriever |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class DummyTaskmodule(WithDocumentTypeMixin): |
|
def __init__(self, document_type: Type[Document]): |
|
self._document_type = document_type |
|
|
|
@property |
|
def document_type(self) -> Optional[Type[Document]]: |
|
return self._document_type |
|
|
|
|
|
class SpanRetrievalBasedRelationExtractionPipeline: |
|
"""Pipeline for adding binary relations between spans based on span retrieval within the same document. |
|
|
|
This pipeline retrieves spans for all existing spans as query and adds binary relations between the |
|
query spans and the retrieved spans. |
|
|
|
Args: |
|
retriever: The span retriever to use for retrieving spans. |
|
relation_label: The label to use for the binary relations. |
|
relation_layer_name: The name of the annotation layer to add the binary relations to. |
|
load_store_path: If provided, the retriever store(s) will be loaded from this path before processing. |
|
save_store_path: If provided, the retriever store(s) will be saved to this path after processing. |
|
fast_dev_run: Whether to run the pipeline in fast dev mode, i.e. only processing the first 2 documents. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
retriever: DocumentAwareSpanRetriever, |
|
relation_label: str, |
|
relation_layer_name: str = "binary_relations", |
|
use_predicted_annotations: bool = False, |
|
load_store_path: Optional[str] = None, |
|
save_store_path: Optional[str] = None, |
|
fast_dev_run: bool = False, |
|
): |
|
self.retriever = retriever |
|
if not self.retriever.retrieve_from_same_document: |
|
raise NotImplementedError("Retriever must retrieve from the same document") |
|
self.relation_label = relation_label |
|
self.relation_layer_name = relation_layer_name |
|
self.use_predicted_annotations = use_predicted_annotations |
|
self.load_store_path = load_store_path |
|
self.save_store_path = save_store_path |
|
if self.load_store_path is not None: |
|
self.retriever.load_from_directory(path=self.load_store_path) |
|
|
|
self.fast_dev_run = fast_dev_run |
|
|
|
|
|
|
|
@property |
|
def taskmodule(self) -> DummyTaskmodule: |
|
return DummyTaskmodule(self.retriever.pie_document_type) |
|
|
|
def _construct_similarity_relations( |
|
self, |
|
query_results: list[LCDocument], |
|
query_span: LabeledSpan, |
|
) -> list[BinaryRelation]: |
|
return [ |
|
BinaryRelation( |
|
head=query_span, |
|
tail=lc_doc.metadata["attached_span"], |
|
label=self.relation_label, |
|
score=float(lc_doc.metadata["relevance_score"]), |
|
) |
|
for lc_doc in query_results |
|
] |
|
|
|
def _process_single_document( |
|
self, |
|
document: Document, |
|
) -> TextBasedDocument: |
|
if not isinstance(document, TextBasedDocument): |
|
raise ValueError("Document must be a TextBasedDocument") |
|
|
|
self.retriever.add_pie_documents( |
|
[document], use_predicted_annotations=self.use_predicted_annotations |
|
) |
|
|
|
all_new_rels = [] |
|
spans = self.retriever.get_base_layer( |
|
document, use_predicted_annotations=self.use_predicted_annotations |
|
) |
|
span_id2idx = self.retriever.get_span_id2idx_from_doc(document.id) |
|
for span_id, span_idx in span_id2idx.items(): |
|
query_span = spans[span_idx] |
|
query_result = self.retriever.invoke(input=span_id) |
|
query_rels = self._construct_similarity_relations(query_result, query_span=query_span) |
|
all_new_rels.extend(query_rels) |
|
|
|
if self.relation_layer_name not in document: |
|
raise ValueError(f"Document does not have a layer named {self.relation_layer_name}") |
|
document[self.relation_layer_name].predictions.extend(all_new_rels) |
|
|
|
if self.retriever.retrieve_from_same_document and self.save_store_path is None: |
|
self.retriever.delete_documents([document.id]) |
|
|
|
return document |
|
|
|
def __call__(self, documents: Sequence[Document], inplace: bool = False) -> Sequence[Document]: |
|
if inplace: |
|
raise NotImplementedError("Inplace processing is not supported yet") |
|
|
|
if self.fast_dev_run: |
|
logger.warning("Fast dev run enabled, only processing the first 2 documents") |
|
documents = documents[:2] |
|
|
|
if not isinstance(documents, (Dataset, IterableDataset)): |
|
documents = Dataset.from_documents(documents) |
|
|
|
mapped_documents = documents.map(self._process_single_document) |
|
|
|
if self.save_store_path is not None: |
|
self.retriever.save_to_directory(path=self.save_store_path) |
|
|
|
return mapped_documents |
|
|