File size: 9,115 Bytes
ed4d993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
import contextlib
from pathlib import Path
from typing import (
    Any,
    AsyncGenerator,
    AsyncIterator,
    Dict,
    Generator,
    Iterator,
    List,
    Optional,
    Sequence,
    Tuple,
    Union,
    cast,
)

from langchain_core.stores import BaseStore
from sqlalchemy import (
    Engine,
    LargeBinary,
    and_,
    create_engine,
    delete,
    select,
)
from sqlalchemy.ext.asyncio import (
    AsyncEngine,
    AsyncSession,
    async_sessionmaker,
    create_async_engine,
)
from sqlalchemy.orm import (
    Mapped,
    Session,
    declarative_base,
    mapped_column,
    sessionmaker,
)

Base = declarative_base()


def items_equal(x: Any, y: Any) -> bool:
    return x == y


class LangchainKeyValueStores(Base):  # type: ignore[valid-type,misc]
    """Table used to save values."""

    # ATTENTION:
    # Prior to modifying this table, please determine whether
    # we should create migrations for this table to make sure
    # users do not experience data loss.
    __tablename__ = "langchain_key_value_stores"

    namespace: Mapped[str] = mapped_column(primary_key=True, index=True, nullable=False)
    key: Mapped[str] = mapped_column(primary_key=True, index=True, nullable=False)
    value = mapped_column(LargeBinary, index=False, nullable=False)


# This is a fix of original SQLStore.
# This can will be removed when a PR will be merged.
class SQLStore(BaseStore[str, bytes]):
    """BaseStore interface that works on an SQL database.

    Examples:
        Create a SQLStore instance and perform operations on it:

        .. code-block:: python

            from langchain_rag.storage import SQLStore

            # Instantiate the SQLStore with the root path
            sql_store = SQLStore(namespace="test", db_url="sqllite://:memory:")

            # Set values for keys
            sql_store.mset([("key1", b"value1"), ("key2", b"value2")])

            # Get values for keys
            values = sql_store.mget(["key1", "key2"])  # Returns [b"value1", b"value2"]

            # Delete keys
            sql_store.mdelete(["key1"])

            # Iterate over keys
            for key in sql_store.yield_keys():
                print(key)

    """

    def __init__(
        self,
        *,
        namespace: str,
        db_url: Optional[Union[str, Path]] = None,
        engine: Optional[Union[Engine, AsyncEngine]] = None,
        engine_kwargs: Optional[Dict[str, Any]] = None,
        async_mode: Optional[bool] = None,
    ):
        if db_url is None and engine is None:
            raise ValueError("Must specify either db_url or engine")

        if db_url is not None and engine is not None:
            raise ValueError("Must specify either db_url or engine, not both")

        _engine: Union[Engine, AsyncEngine]
        if db_url:
            if async_mode is None:
                async_mode = False
            if async_mode:
                _engine = create_async_engine(
                    url=str(db_url),
                    **(engine_kwargs or {}),
                )
            else:
                _engine = create_engine(url=str(db_url), **(engine_kwargs or {}))
        elif engine:
            _engine = engine

        else:
            raise AssertionError("Something went wrong with configuration of engine.")

        _session_maker: Union[sessionmaker[Session], async_sessionmaker[AsyncSession]]
        if isinstance(_engine, AsyncEngine):
            self.async_mode = True
            _session_maker = async_sessionmaker(bind=_engine)
        else:
            self.async_mode = False
            _session_maker = sessionmaker(bind=_engine)

        self.engine = _engine
        self.dialect = _engine.dialect.name
        self.session_maker = _session_maker
        self.namespace = namespace

    def create_schema(self) -> None:
        Base.metadata.create_all(self.engine)

    async def acreate_schema(self) -> None:
        assert isinstance(self.engine, AsyncEngine)
        async with self.engine.begin() as session:
            await session.run_sync(Base.metadata.create_all)

    def drop(self) -> None:
        Base.metadata.drop_all(bind=self.engine.connect())

    async def amget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
        assert isinstance(self.engine, AsyncEngine)
        result: Dict[str, bytes] = {}
        async with self._make_async_session() as session:
            stmt = select(LangchainKeyValueStores).filter(
                and_(
                    LangchainKeyValueStores.key.in_(keys),
                    LangchainKeyValueStores.namespace == self.namespace,
                )
            )
            for v in await session.scalars(stmt):
                result[v.key] = v.value
        return [result.get(key) for key in keys]

    def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
        result = {}

        with self._make_sync_session() as session:
            stmt = select(LangchainKeyValueStores).filter(
                and_(
                    LangchainKeyValueStores.key.in_(keys),
                    LangchainKeyValueStores.namespace == self.namespace,
                )
            )
            for v in session.scalars(stmt):
                result[v.key] = v.value
        return [result.get(key) for key in keys]

    async def amset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
        async with self._make_async_session() as session:
            await self._amdelete([key for key, _ in key_value_pairs], session)
            session.add_all(
                [
                    LangchainKeyValueStores(namespace=self.namespace, key=k, value=v)
                    for k, v in key_value_pairs
                ]
            )
            await session.commit()

    def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
        values: Dict[str, bytes] = dict(key_value_pairs)
        with self._make_sync_session() as session:
            self._mdelete(list(values.keys()), session)
            session.add_all(
                [
                    LangchainKeyValueStores(namespace=self.namespace, key=k, value=v)
                    for k, v in values.items()
                ]
            )
            session.commit()

    def _mdelete(self, keys: Sequence[str], session: Session) -> None:
        stmt = delete(LangchainKeyValueStores).filter(
            and_(
                LangchainKeyValueStores.key.in_(keys),
                LangchainKeyValueStores.namespace == self.namespace,
            )
        )
        session.execute(stmt)

    async def _amdelete(self, keys: Sequence[str], session: AsyncSession) -> None:
        stmt = delete(LangchainKeyValueStores).filter(
            and_(
                LangchainKeyValueStores.key.in_(keys),
                LangchainKeyValueStores.namespace == self.namespace,
            )
        )
        await session.execute(stmt)

    def mdelete(self, keys: Sequence[str]) -> None:
        with self._make_sync_session() as session:
            self._mdelete(keys, session)
            session.commit()

    async def amdelete(self, keys: Sequence[str]) -> None:
        async with self._make_async_session() as session:
            await self._amdelete(keys, session)
            await session.commit()

    def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
        with self._make_sync_session() as session:
            for v in session.query(LangchainKeyValueStores).filter(  # type: ignore
                LangchainKeyValueStores.namespace == self.namespace
            ):
                if str(v.key).startswith(prefix or ""):
                    yield str(v.key)
            session.close()

    async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]:
        async with self._make_async_session() as session:
            stmt = select(LangchainKeyValueStores).filter(
                LangchainKeyValueStores.namespace == self.namespace
            )
            for v in await session.scalars(stmt):
                if str(v.key).startswith(prefix or ""):
                    yield str(v.key)
            await session.close()

    @contextlib.contextmanager
    def _make_sync_session(self) -> Generator[Session, None, None]:
        """Make an async session."""
        if self.async_mode:
            raise ValueError(
                "Attempting to use a sync method in when async mode is turned on. "
                "Please use the corresponding async method instead."
            )
        with cast(Session, self.session_maker()) as session:
            yield cast(Session, session)

    @contextlib.asynccontextmanager
    async def _make_async_session(self) -> AsyncGenerator[AsyncSession, None]:
        """Make an async session."""
        if not self.async_mode:
            raise ValueError(
                "Attempting to use an async method in when sync mode is turned on. "
                "Please use the corresponding async method instead."
            )
        async with cast(AsyncSession, self.session_maker()) as session:
            yield cast(AsyncSession, session)