File size: 5,257 Bytes
3133b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import logging
from typing import Optional, Sequence, Type

from langchain_core.documents import Document as LCDocument
from pie_datasets import Dataset, IterableDataset
from pytorch_ie import Document, WithDocumentTypeMixin
from pytorch_ie.annotations import BinaryRelation, LabeledSpan
from pytorch_ie.documents import TextBasedDocument

from src.langchain_modules import DocumentAwareSpanRetriever

logger = logging.getLogger(__name__)


class DummyTaskmodule(WithDocumentTypeMixin):
    def __init__(self, document_type: Type[Document]):
        self._document_type = document_type

    @property
    def document_type(self) -> Optional[Type[Document]]:
        return self._document_type


class SpanRetrievalBasedRelationExtractionPipeline:
    """Pipeline for adding binary relations between spans based on span retrieval within the same document.

    This pipeline retrieves spans for all existing spans as query and adds binary relations between the
    query spans and the retrieved spans.

    Args:
        retriever: The span retriever to use for retrieving spans.
        relation_label: The label to use for the binary relations.
        relation_layer_name: The name of the annotation layer to add the binary relations to.
        load_store_path: If provided, the retriever store(s) will be loaded from this path before processing.
        save_store_path: If provided, the retriever store(s) will be saved to this path after processing.
        fast_dev_run: Whether to run the pipeline in fast dev mode, i.e. only processing the first 2 documents.
    """

    def __init__(
        self,
        retriever: DocumentAwareSpanRetriever,
        relation_label: str,
        relation_layer_name: str = "binary_relations",
        use_predicted_annotations: bool = False,
        load_store_path: Optional[str] = None,
        save_store_path: Optional[str] = None,
        fast_dev_run: bool = False,
    ):
        self.retriever = retriever
        if not self.retriever.retrieve_from_same_document:
            raise NotImplementedError("Retriever must retrieve from the same document")
        self.relation_label = relation_label
        self.relation_layer_name = relation_layer_name
        self.use_predicted_annotations = use_predicted_annotations
        self.load_store_path = load_store_path
        self.save_store_path = save_store_path
        if self.load_store_path is not None:
            self.retriever.load_from_directory(path=self.load_store_path)

        self.fast_dev_run = fast_dev_run

    # to make auto-conversion work: we request documents of type pipeline.taskmodule.document_type
    # from the dataset
    @property
    def taskmodule(self) -> DummyTaskmodule:
        return DummyTaskmodule(self.retriever.pie_document_type)

    def _construct_similarity_relations(
        self,
        query_results: list[LCDocument],
        query_span: LabeledSpan,
    ) -> list[BinaryRelation]:
        return [
            BinaryRelation(
                head=query_span,
                tail=lc_doc.metadata["attached_span"],
                label=self.relation_label,
                score=float(lc_doc.metadata["relevance_score"]),
            )
            for lc_doc in query_results
        ]

    def _process_single_document(
        self,
        document: Document,
    ) -> TextBasedDocument:
        if not isinstance(document, TextBasedDocument):
            raise ValueError("Document must be a TextBasedDocument")

        self.retriever.add_pie_documents(
            [document], use_predicted_annotations=self.use_predicted_annotations
        )

        all_new_rels = []
        spans = self.retriever.get_base_layer(
            document, use_predicted_annotations=self.use_predicted_annotations
        )
        span_id2idx = self.retriever.get_span_id2idx_from_doc(document.id)
        for span_id, span_idx in span_id2idx.items():
            query_span = spans[span_idx]
            query_result = self.retriever.invoke(input=span_id)
            query_rels = self._construct_similarity_relations(query_result, query_span=query_span)
            all_new_rels.extend(query_rels)

        if self.relation_layer_name not in document:
            raise ValueError(f"Document does not have a layer named {self.relation_layer_name}")
        document[self.relation_layer_name].predictions.extend(all_new_rels)

        if self.retriever.retrieve_from_same_document and self.save_store_path is None:
            self.retriever.delete_documents([document.id])

        return document

    def __call__(self, documents: Sequence[Document], inplace: bool = False) -> Sequence[Document]:
        if inplace:
            raise NotImplementedError("Inplace processing is not supported yet")

        if self.fast_dev_run:
            logger.warning("Fast dev run enabled, only processing the first 2 documents")
            documents = documents[:2]

        if not isinstance(documents, (Dataset, IterableDataset)):
            documents = Dataset.from_documents(documents)

        mapped_documents = documents.map(self._process_single_document)

        if self.save_store_path is not None:
            self.retriever.save_to_directory(path=self.save_store_path)

        return mapped_documents