Spaces:
Sleeping
Sleeping
from typing import Any | |
from qdrant_client import QdrantClient, models | |
from uuid import uuid4 | |
from src.config import QDRANT_COLLECTION_NAME | |
class QdrantStore: | |
def __init__(self, client: QdrantClient, collection_config=None): | |
self.client = client | |
self.collection_names = set([i.name for i in client.get_collections().collections]) | |
if collection_config is not None: | |
self.create_collection(collection_config) | |
def create_collection(self, collection_config: dict): | |
collection_name = collection_config["collection_name"] | |
if not self.client.collection_exists(collection_name): | |
self.client.create_collection(**collection_config) | |
self.collection_names.add(collection_name) | |
def _check_collection_name(self, collection_name): | |
if collection_name not in self.collection_names: | |
raise ValueError(f"Collection: {collection_name} does not exist.") | |
def upsert_points(self, | |
vectors: Any | list[Any], | |
payloads: dict | list[dict], | |
collection_name: str): | |
self._check_collection_name(collection_name) | |
ids = [str(uuid4()) for _ in payloads] | |
self.client.upsert( | |
collection_name=collection_name, | |
points=models.Batch( | |
ids=ids, | |
payloads=payloads, | |
vectors=vectors | |
) | |
) | |
def delete_points(self, | |
filters: dict[str, list[models.FieldCondition]], | |
collection_name: str): | |
self._check_collection_name(collection_name) | |
self.client.delete( | |
collection_name=collection_name, | |
points_selector=models.Filter(**filters) | |
) | |
def delete_points_by_match(self, | |
key_value: tuple[str, list[str] | str], | |
collection_name: str): | |
key, values = key_value | |
if isinstance(values, str): | |
values = [values] | |
filter = {"must": [models.FieldCondition(key=key, match=models.MatchAny(any=values))]} | |
self.delete_points(filter, collection_name) | |