File size: 5,059 Bytes
ed4d993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from __future__ import annotations

from typing import (
    TYPE_CHECKING,
    Any,
    AsyncIterator,
    Callable,
    Iterator,
    Optional,
    Sequence,
    Union,
)

from langchain_core.documents import Document

from langchain_community.document_loaders.base import BaseLoader
from langchain_community.utilities.cassandra import aexecute_cql

_NOT_SET = object()

if TYPE_CHECKING:
    from cassandra.cluster import Session
    from cassandra.pool import Host
    from cassandra.query import Statement


class CassandraLoader(BaseLoader):
    def __init__(
        self,
        table: Optional[str] = None,
        session: Optional[Session] = None,
        keyspace: Optional[str] = None,
        query: Union[str, Statement, None] = None,
        page_content_mapper: Callable[[Any], str] = str,
        metadata_mapper: Callable[[Any], dict] = lambda _: {},
        *,
        query_parameters: Union[dict, Sequence, None] = None,
        query_timeout: Optional[float] = _NOT_SET,  # type: ignore[assignment]
        query_trace: bool = False,
        query_custom_payload: Optional[dict] = None,
        query_execution_profile: Any = _NOT_SET,
        query_paging_state: Any = None,
        query_host: Optional[Host] = None,
        query_execute_as: Optional[str] = None,
    ) -> None:
        """
        Document Loader for Apache Cassandra.

        Args:
            table: The table to load the data from.
                (do not use together with the query parameter)
            session: The cassandra driver session.
                If not provided, the cassio resolved session will be used.
            keyspace: The keyspace of the table.
                If not provided, the cassio resolved keyspace will be used.
            query: The query used to load the data.
                (do not use together with the table parameter)
            page_content_mapper: a function to convert a row to string page content.
                Defaults to the str representation of the row.
            metadata_mapper: a function to convert a row to document metadata.
            query_parameters: The query parameters used when calling session.execute .
            query_timeout: The query timeout used when calling session.execute .
            query_trace: Whether to use tracing when calling session.execute .
            query_custom_payload: The query custom_payload used when calling
                session.execute .
            query_execution_profile: The query execution_profile used when calling
                session.execute .
            query_host: The query host used when calling session.execute .
            query_execute_as: The query execute_as used when calling session.execute .
        """
        if query and table:
            raise ValueError("Cannot specify both query and table.")

        if not query and not table:
            raise ValueError("Must specify query or table.")

        if not session or (table and not keyspace):
            try:
                from cassio.config import check_resolve_keyspace, check_resolve_session
            except (ImportError, ModuleNotFoundError):
                raise ImportError(
                    "Could not import a recent cassio package."
                    "Please install it with `pip install --upgrade cassio`."
                )

        if table:
            _keyspace = keyspace or check_resolve_keyspace(keyspace)
            self.query = f"SELECT * FROM {_keyspace}.{table};"
            self.metadata = {"table": table, "keyspace": _keyspace}
        else:
            self.query = query  # type: ignore[assignment]
            self.metadata = {}

        self.session = session or check_resolve_session(session)
        self.page_content_mapper = page_content_mapper
        self.metadata_mapper = metadata_mapper

        self.query_kwargs = {
            "parameters": query_parameters,
            "trace": query_trace,
            "custom_payload": query_custom_payload,
            "paging_state": query_paging_state,
            "host": query_host,
            "execute_as": query_execute_as,
        }
        if query_timeout is not _NOT_SET:
            self.query_kwargs["timeout"] = query_timeout

        if query_execution_profile is not _NOT_SET:
            self.query_kwargs["execution_profile"] = query_execution_profile

    def lazy_load(self) -> Iterator[Document]:
        for row in self.session.execute(self.query, **self.query_kwargs):
            metadata = self.metadata.copy()
            metadata.update(self.metadata_mapper(row))
            yield Document(
                page_content=self.page_content_mapper(row), metadata=metadata
            )

    async def alazy_load(self) -> AsyncIterator[Document]:
        for row in await aexecute_cql(self.session, self.query, **self.query_kwargs):
            metadata = self.metadata.copy()
            metadata.update(self.metadata_mapper(row))
            yield Document(
                page_content=self.page_content_mapper(row), metadata=metadata
            )