File size: 6,657 Bytes
ed4d993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
from __future__ import annotations

import asyncio
from asyncio import InvalidStateError, Task
from typing import (
    TYPE_CHECKING,
    AsyncIterator,
    Iterator,
    List,
    Optional,
    Sequence,
    Tuple,
)

from langchain_core.stores import ByteStore

from langchain_community.utilities.cassandra import SetupMode, aexecute_cql

if TYPE_CHECKING:
    from cassandra.cluster import Session
    from cassandra.query import PreparedStatement

CREATE_TABLE_CQL_TEMPLATE = """
    CREATE TABLE IF NOT EXISTS {keyspace}.{table} 
    (row_id TEXT, body_blob BLOB, PRIMARY KEY (row_id));
"""
SELECT_TABLE_CQL_TEMPLATE = (
    """SELECT row_id, body_blob FROM  {keyspace}.{table} WHERE row_id IN ?;"""
)
SELECT_ALL_TABLE_CQL_TEMPLATE = """SELECT row_id, body_blob FROM  {keyspace}.{table};"""
INSERT_TABLE_CQL_TEMPLATE = (
    """INSERT INTO {keyspace}.{table} (row_id, body_blob) VALUES (?, ?);"""
)
DELETE_TABLE_CQL_TEMPLATE = """DELETE FROM {keyspace}.{table} WHERE row_id IN ?;"""


class CassandraByteStore(ByteStore):
    def __init__(
        self,
        table: str,
        *,
        session: Optional[Session] = None,
        keyspace: Optional[str] = None,
        setup_mode: SetupMode = SetupMode.SYNC,
    ) -> None:
        if not session or not keyspace:
            try:
                from cassio.config import check_resolve_keyspace, check_resolve_session

                self.keyspace = keyspace or check_resolve_keyspace(keyspace)
                self.session = session or check_resolve_session()
            except (ImportError, ModuleNotFoundError):
                raise ImportError(
                    "Could not import a recent cassio package."
                    "Please install it with `pip install --upgrade cassio`."
                )
        else:
            self.keyspace = keyspace
            self.session = session
        self.table = table
        self.select_statement = None
        self.insert_statement = None
        self.delete_statement = None

        create_cql = CREATE_TABLE_CQL_TEMPLATE.format(
            keyspace=self.keyspace,
            table=self.table,
        )
        self.db_setup_task: Optional[Task[None]] = None
        if setup_mode == SetupMode.ASYNC:
            self.db_setup_task = asyncio.create_task(
                aexecute_cql(self.session, create_cql)
            )
        else:
            self.session.execute(create_cql)

    def ensure_db_setup(self) -> None:
        if self.db_setup_task:
            try:
                self.db_setup_task.result()
            except InvalidStateError:
                raise ValueError(
                    "Asynchronous setup of the DB not finished. "
                    "NB: AstraDB components sync methods shouldn't be called from the "
                    "event loop. Consider using their async equivalents."
                )

    async def aensure_db_setup(self) -> None:
        if self.db_setup_task:
            await self.db_setup_task

    def get_select_statement(self) -> PreparedStatement:
        if not self.select_statement:
            self.select_statement = self.session.prepare(
                SELECT_TABLE_CQL_TEMPLATE.format(
                    keyspace=self.keyspace, table=self.table
                )
            )
        return self.select_statement

    def get_insert_statement(self) -> PreparedStatement:
        if not self.insert_statement:
            self.insert_statement = self.session.prepare(
                INSERT_TABLE_CQL_TEMPLATE.format(
                    keyspace=self.keyspace, table=self.table
                )
            )
        return self.insert_statement

    def get_delete_statement(self) -> PreparedStatement:
        if not self.delete_statement:
            self.delete_statement = self.session.prepare(
                DELETE_TABLE_CQL_TEMPLATE.format(
                    keyspace=self.keyspace, table=self.table
                )
            )
        return self.delete_statement

    def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
        from cassandra.query import ValueSequence

        self.ensure_db_setup()
        docs_dict = {}
        for row in self.session.execute(
            self.get_select_statement(), [ValueSequence(keys)]
        ):
            docs_dict[row.row_id] = row.body_blob
        return [docs_dict.get(key) for key in keys]

    async def amget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
        from cassandra.query import ValueSequence

        await self.aensure_db_setup()
        docs_dict = {}
        for row in await aexecute_cql(
            self.session, self.get_select_statement(), parameters=[ValueSequence(keys)]
        ):
            docs_dict[row.row_id] = row.body_blob
        return [docs_dict.get(key) for key in keys]

    def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
        self.ensure_db_setup()
        insert_statement = self.get_insert_statement()
        for k, v in key_value_pairs:
            self.session.execute(insert_statement, (k, v))

    async def amset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
        await self.aensure_db_setup()
        insert_statement = self.get_insert_statement()
        for k, v in key_value_pairs:
            await aexecute_cql(self.session, insert_statement, parameters=(k, v))

    def mdelete(self, keys: Sequence[str]) -> None:
        from cassandra.query import ValueSequence

        self.ensure_db_setup()
        self.session.execute(self.get_delete_statement(), [ValueSequence(keys)])

    async def amdelete(self, keys: Sequence[str]) -> None:
        from cassandra.query import ValueSequence

        await self.aensure_db_setup()
        await aexecute_cql(
            self.session, self.get_delete_statement(), parameters=[ValueSequence(keys)]
        )

    def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
        self.ensure_db_setup()
        for row in self.session.execute(
            SELECT_ALL_TABLE_CQL_TEMPLATE.format(
                keyspace=self.keyspace, table=self.table
            )
        ):
            key = row.row_id
            if not prefix or key.startswith(prefix):
                yield key

    async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]:
        await self.aensure_db_setup()
        for row in await aexecute_cql(
            self.session,
            SELECT_ALL_TABLE_CQL_TEMPLATE.format(
                keyspace=self.keyspace, table=self.table
            ),
        ):
            key = row.row_id
            if not prefix or key.startswith(prefix):
                yield key