File size: 3,361 Bytes
2cc87ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ced4316
2cc87ec
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import abc
import logging
from copy import copy
from typing import Iterator, List, Optional, Sequence, Tuple

import pandas as pd
from langchain_core.documents import Document as LCDocument
from langchain_core.stores import BaseStore
from pytorch_ie.documents import TextBasedDocument

from .serializable_store import SerializableStore

logger = logging.getLogger(__name__)


class PieDocumentStore(SerializableStore, BaseStore[str, LCDocument], abc.ABC):
    """Abstract base class for document stores specialized in storing and retrieving pie documents."""

    METADATA_KEY_PIE_DOCUMENT: str = "pie_document"
    """Key for the pie document in the (langchain) document metadata."""

    def wrap(self, pie_document: TextBasedDocument, **metadata) -> LCDocument:
        """Wrap the pie document in an LCDocument."""
        return LCDocument(
            id=pie_document.id,
            page_content="",
            metadata={self.METADATA_KEY_PIE_DOCUMENT: pie_document, **metadata},
        )

    def unwrap(self, document: LCDocument) -> TextBasedDocument:
        """Get the pie document from the langchain document."""
        return document.metadata[self.METADATA_KEY_PIE_DOCUMENT]

    def unwrap_with_metadata(self, document: LCDocument) -> Tuple[TextBasedDocument, dict]:
        """Get the pie document and metadata from the langchain document."""
        metadata = copy(document.metadata)
        pie_document = metadata.pop(self.METADATA_KEY_PIE_DOCUMENT)
        return pie_document, metadata

    @abc.abstractmethod
    def mget(self, keys: Sequence[str]) -> List[LCDocument]:
        pass

    @abc.abstractmethod
    def mset(self, items: Sequence[Tuple[str, LCDocument]]) -> None:
        pass

    @abc.abstractmethod
    def mdelete(self, keys: Sequence[str]) -> None:
        pass

    @abc.abstractmethod
    def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
        pass

    def __len__(self):
        return len(list(self.yield_keys()))

    def overview(self, layer_captions: dict, use_predictions: bool = False) -> pd.DataFrame:
        """Get an overview of the document store, including the number of items in each layer for each document
        in the store.

        Args:
            layer_captions: A dictionary mapping layer names to captions.
            use_predictions: Whether to use predictions instead of the actual layers.

        Returns:
            DataFrame: A pandas DataFrame containing the overview.
        """
        rows = []
        for doc_id in self.yield_keys():
            document = self.mget([doc_id])[0]
            pie_document = self.unwrap(document)
            layers = {
                caption: pie_document[layer_name] for layer_name, caption in layer_captions.items()
            }
            layer_sizes = {
                f"num_{caption}": len(layer) + (len(layer.predictions) if use_predictions else 0)
                for caption, layer in layers.items()
            }
            rows.append({"doc_id": doc_id, **layer_sizes})
        df = pd.DataFrame(rows)
        return df

    def as_dict(self, document: LCDocument) -> dict:
        """Convert the langchain document to a dictionary."""
        pie_document, metadata = self.unwrap_with_metadata(document)
        return {self.METADATA_KEY_PIE_DOCUMENT: pie_document.asdict(), "metadata": metadata}