File size: 2,185 Bytes
c48903e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
from typing import Any
from qdrant_client import QdrantClient, models
from uuid import uuid4
from src.config import QDRANT_COLLECTION_NAME
class QdrantStore:
def __init__(self, client: QdrantClient, collection_config=None):
self.client = client
self.collection_names = set([i.name for i in client.get_collections().collections])
if collection_config is not None:
self.create_collection(collection_config)
def create_collection(self, collection_config: dict):
collection_name = collection_config["collection_name"]
if not self.client.collection_exists(collection_name):
self.client.create_collection(**collection_config)
self.collection_names.add(collection_name)
def _check_collection_name(self, collection_name):
if collection_name not in self.collection_names:
raise ValueError(f"Collection: {collection_name} does not exist.")
def upsert_points(self,
vectors: Any | list[Any],
payloads: dict | list[dict],
collection_name: str):
self._check_collection_name(collection_name)
ids = [str(uuid4()) for _ in payloads]
self.client.upsert(
collection_name=collection_name,
points=models.Batch(
ids=ids,
payloads=payloads,
vectors=vectors
)
)
def delete_points(self,
filters: dict[str, list[models.FieldCondition]],
collection_name: str):
self._check_collection_name(collection_name)
self.client.delete(
collection_name=collection_name,
points_selector=models.Filter(**filters)
)
def delete_points_by_match(self,
key_value: tuple[str, list[str] | str],
collection_name: str):
key, values = key_value
if isinstance(values, str):
values = [values]
filter = {"must": [models.FieldCondition(key=key, match=models.MatchAny(any=values))]}
self.delete_points(filter, collection_name)
|