Spaces:
Runtime error
Runtime error
File size: 6,657 Bytes
ed4d993 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
from __future__ import annotations
import asyncio
from asyncio import InvalidStateError, Task
from typing import (
TYPE_CHECKING,
AsyncIterator,
Iterator,
List,
Optional,
Sequence,
Tuple,
)
from langchain_core.stores import ByteStore
from langchain_community.utilities.cassandra import SetupMode, aexecute_cql
if TYPE_CHECKING:
from cassandra.cluster import Session
from cassandra.query import PreparedStatement
CREATE_TABLE_CQL_TEMPLATE = """
CREATE TABLE IF NOT EXISTS {keyspace}.{table}
(row_id TEXT, body_blob BLOB, PRIMARY KEY (row_id));
"""
SELECT_TABLE_CQL_TEMPLATE = (
"""SELECT row_id, body_blob FROM {keyspace}.{table} WHERE row_id IN ?;"""
)
SELECT_ALL_TABLE_CQL_TEMPLATE = """SELECT row_id, body_blob FROM {keyspace}.{table};"""
INSERT_TABLE_CQL_TEMPLATE = (
"""INSERT INTO {keyspace}.{table} (row_id, body_blob) VALUES (?, ?);"""
)
DELETE_TABLE_CQL_TEMPLATE = """DELETE FROM {keyspace}.{table} WHERE row_id IN ?;"""
class CassandraByteStore(ByteStore):
def __init__(
self,
table: str,
*,
session: Optional[Session] = None,
keyspace: Optional[str] = None,
setup_mode: SetupMode = SetupMode.SYNC,
) -> None:
if not session or not keyspace:
try:
from cassio.config import check_resolve_keyspace, check_resolve_session
self.keyspace = keyspace or check_resolve_keyspace(keyspace)
self.session = session or check_resolve_session()
except (ImportError, ModuleNotFoundError):
raise ImportError(
"Could not import a recent cassio package."
"Please install it with `pip install --upgrade cassio`."
)
else:
self.keyspace = keyspace
self.session = session
self.table = table
self.select_statement = None
self.insert_statement = None
self.delete_statement = None
create_cql = CREATE_TABLE_CQL_TEMPLATE.format(
keyspace=self.keyspace,
table=self.table,
)
self.db_setup_task: Optional[Task[None]] = None
if setup_mode == SetupMode.ASYNC:
self.db_setup_task = asyncio.create_task(
aexecute_cql(self.session, create_cql)
)
else:
self.session.execute(create_cql)
def ensure_db_setup(self) -> None:
if self.db_setup_task:
try:
self.db_setup_task.result()
except InvalidStateError:
raise ValueError(
"Asynchronous setup of the DB not finished. "
"NB: AstraDB components sync methods shouldn't be called from the "
"event loop. Consider using their async equivalents."
)
async def aensure_db_setup(self) -> None:
if self.db_setup_task:
await self.db_setup_task
def get_select_statement(self) -> PreparedStatement:
if not self.select_statement:
self.select_statement = self.session.prepare(
SELECT_TABLE_CQL_TEMPLATE.format(
keyspace=self.keyspace, table=self.table
)
)
return self.select_statement
def get_insert_statement(self) -> PreparedStatement:
if not self.insert_statement:
self.insert_statement = self.session.prepare(
INSERT_TABLE_CQL_TEMPLATE.format(
keyspace=self.keyspace, table=self.table
)
)
return self.insert_statement
def get_delete_statement(self) -> PreparedStatement:
if not self.delete_statement:
self.delete_statement = self.session.prepare(
DELETE_TABLE_CQL_TEMPLATE.format(
keyspace=self.keyspace, table=self.table
)
)
return self.delete_statement
def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
from cassandra.query import ValueSequence
self.ensure_db_setup()
docs_dict = {}
for row in self.session.execute(
self.get_select_statement(), [ValueSequence(keys)]
):
docs_dict[row.row_id] = row.body_blob
return [docs_dict.get(key) for key in keys]
async def amget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
from cassandra.query import ValueSequence
await self.aensure_db_setup()
docs_dict = {}
for row in await aexecute_cql(
self.session, self.get_select_statement(), parameters=[ValueSequence(keys)]
):
docs_dict[row.row_id] = row.body_blob
return [docs_dict.get(key) for key in keys]
def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
self.ensure_db_setup()
insert_statement = self.get_insert_statement()
for k, v in key_value_pairs:
self.session.execute(insert_statement, (k, v))
async def amset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
await self.aensure_db_setup()
insert_statement = self.get_insert_statement()
for k, v in key_value_pairs:
await aexecute_cql(self.session, insert_statement, parameters=(k, v))
def mdelete(self, keys: Sequence[str]) -> None:
from cassandra.query import ValueSequence
self.ensure_db_setup()
self.session.execute(self.get_delete_statement(), [ValueSequence(keys)])
async def amdelete(self, keys: Sequence[str]) -> None:
from cassandra.query import ValueSequence
await self.aensure_db_setup()
await aexecute_cql(
self.session, self.get_delete_statement(), parameters=[ValueSequence(keys)]
)
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
self.ensure_db_setup()
for row in self.session.execute(
SELECT_ALL_TABLE_CQL_TEMPLATE.format(
keyspace=self.keyspace, table=self.table
)
):
key = row.row_id
if not prefix or key.startswith(prefix):
yield key
async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]:
await self.aensure_db_setup()
for row in await aexecute_cql(
self.session,
SELECT_ALL_TABLE_CQL_TEMPLATE.format(
keyspace=self.keyspace, table=self.table
),
):
key = row.row_id
if not prefix or key.startswith(prefix):
yield key
|