Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import asyncio | |
import inspect | |
from asyncio import InvalidStateError, Task | |
from enum import Enum | |
from typing import TYPE_CHECKING, Awaitable, Optional, Union | |
if TYPE_CHECKING: | |
from astrapy.db import ( | |
AstraDB, | |
AsyncAstraDB, | |
) | |
class SetupMode(Enum): | |
"""Setup mode for AstraDBEnvironment as enumerator.""" | |
SYNC = 1 | |
ASYNC = 2 | |
OFF = 3 | |
class _AstraDBEnvironment: | |
def __init__( | |
self, | |
token: Optional[str] = None, | |
api_endpoint: Optional[str] = None, | |
astra_db_client: Optional[AstraDB] = None, | |
async_astra_db_client: Optional[AsyncAstraDB] = None, | |
namespace: Optional[str] = None, | |
) -> None: | |
self.token = token | |
self.api_endpoint = api_endpoint | |
astra_db = astra_db_client | |
async_astra_db = async_astra_db_client | |
self.namespace = namespace | |
try: | |
from astrapy.db import ( | |
AstraDB, | |
AsyncAstraDB, | |
) | |
except (ImportError, ModuleNotFoundError): | |
raise ImportError( | |
"Could not import a recent astrapy python package. " | |
"Please install it with `pip install --upgrade astrapy`." | |
) | |
# Conflicting-arg checks: | |
if astra_db_client is not None or async_astra_db_client is not None: | |
if token is not None or api_endpoint is not None: | |
raise ValueError( | |
"You cannot pass 'astra_db_client' or 'async_astra_db_client' to " | |
"AstraDBEnvironment if passing 'token' and 'api_endpoint'." | |
) | |
if token and api_endpoint: | |
astra_db = AstraDB( | |
token=token, | |
api_endpoint=api_endpoint, | |
namespace=self.namespace, | |
) | |
async_astra_db = AsyncAstraDB( | |
token=token, | |
api_endpoint=api_endpoint, | |
namespace=self.namespace, | |
) | |
if astra_db: | |
self.astra_db = astra_db | |
if async_astra_db: | |
self.async_astra_db = async_astra_db | |
else: | |
self.async_astra_db = AsyncAstraDB( | |
token=self.astra_db.token, | |
api_endpoint=self.astra_db.base_url, | |
api_path=self.astra_db.api_path, | |
api_version=self.astra_db.api_version, | |
namespace=self.astra_db.namespace, | |
) | |
elif async_astra_db: | |
self.async_astra_db = async_astra_db | |
self.astra_db = AstraDB( | |
token=self.async_astra_db.token, | |
api_endpoint=self.async_astra_db.base_url, | |
api_path=self.async_astra_db.api_path, | |
api_version=self.async_astra_db.api_version, | |
namespace=self.async_astra_db.namespace, | |
) | |
else: | |
raise ValueError( | |
"Must provide 'astra_db_client' or 'async_astra_db_client' or " | |
"'token' and 'api_endpoint'" | |
) | |
class _AstraDBCollectionEnvironment(_AstraDBEnvironment): | |
def __init__( | |
self, | |
collection_name: str, | |
token: Optional[str] = None, | |
api_endpoint: Optional[str] = None, | |
astra_db_client: Optional[AstraDB] = None, | |
async_astra_db_client: Optional[AsyncAstraDB] = None, | |
namespace: Optional[str] = None, | |
setup_mode: SetupMode = SetupMode.SYNC, | |
pre_delete_collection: bool = False, | |
embedding_dimension: Union[int, Awaitable[int], None] = None, | |
metric: Optional[str] = None, | |
) -> None: | |
from astrapy.db import AstraDBCollection, AsyncAstraDBCollection | |
super().__init__( | |
token, api_endpoint, astra_db_client, async_astra_db_client, namespace | |
) | |
self.collection_name = collection_name | |
self.collection = AstraDBCollection( | |
collection_name=collection_name, | |
astra_db=self.astra_db, | |
) | |
self.async_collection = AsyncAstraDBCollection( | |
collection_name=collection_name, | |
astra_db=self.async_astra_db, | |
) | |
self.async_setup_db_task: Optional[Task] = None | |
if setup_mode == SetupMode.ASYNC: | |
async_astra_db = self.async_astra_db | |
async def _setup_db() -> None: | |
if pre_delete_collection: | |
await async_astra_db.delete_collection(collection_name) | |
if inspect.isawaitable(embedding_dimension): | |
dimension = await embedding_dimension | |
else: | |
dimension = embedding_dimension | |
await async_astra_db.create_collection( | |
collection_name, dimension=dimension, metric=metric | |
) | |
self.async_setup_db_task = asyncio.create_task(_setup_db()) | |
elif setup_mode == SetupMode.SYNC: | |
if pre_delete_collection: | |
self.astra_db.delete_collection(collection_name) | |
if inspect.isawaitable(embedding_dimension): | |
raise ValueError( | |
"Cannot use an awaitable embedding_dimension with async_setup " | |
"set to False" | |
) | |
self.astra_db.create_collection( | |
collection_name, | |
dimension=embedding_dimension, # type: ignore[arg-type] | |
metric=metric, | |
) | |
def ensure_db_setup(self) -> None: | |
if self.async_setup_db_task: | |
try: | |
self.async_setup_db_task.result() | |
except InvalidStateError: | |
raise ValueError( | |
"Asynchronous setup of the DB not finished. " | |
"NB: AstraDB components sync methods shouldn't be called from the " | |
"event loop. Consider using their async equivalents." | |
) | |
async def aensure_db_setup(self) -> None: | |
if self.async_setup_db_task: | |
await self.async_setup_db_task | |