Spaces:
Sleeping
Sleeping
from typing import Any | |
from qdrant_client import QdrantClient, models | |
from uuid import uuid4 | |
from transformers import PreTrainedModel | |
from src.config import QDRANT_COLLECTION_NAME, QDRANT_URL, QDRANT_API_KEY, EMBEDDING_MODEL | |
from src.embeddings import TextEmbedder | |
class QdrantStore: | |
def __init__(self, client: QdrantClient, collection_config=None): | |
self.client = client | |
self.collection_names = set([i.name for i in client.get_collections().collections]) | |
if collection_config is not None: | |
self.create_collection(collection_config) | |
def create_collection(self, collection_config: dict): | |
collection_name = collection_config["collection_name"] | |
if not self.client.collection_exists(collection_name): | |
self.client.create_collection(**collection_config) | |
self.collection_names.add(collection_name) | |
def _check_collection_name(self, collection_name): | |
if collection_name not in self.collection_names: | |
raise ValueError(f"Collection: {collection_name} does not exist.") | |
def upsert_points(self, | |
vectors: Any | list[Any], | |
payloads: dict | list[dict], | |
collection_name: str): | |
self._check_collection_name(collection_name) | |
ids = [str(uuid4()) for _ in payloads] | |
self.client.upsert( | |
collection_name=collection_name, | |
points=models.Batch( | |
ids=ids, | |
payloads=payloads, | |
vectors=vectors | |
) | |
) | |
def delete_points(self, | |
filters: dict[str, list[models.FieldCondition]], | |
collection_name: str): | |
self._check_collection_name(collection_name) | |
self.client.delete( | |
collection_name=collection_name, | |
points_selector=models.Filter(**filters) | |
) | |
def delete_points_by_match(self, | |
key_value: tuple[str, list[str] | str], | |
collection_name: str): | |
key, values = key_value | |
if isinstance(values, str): | |
values = [values] | |
filter = {"must": [models.FieldCondition(key=key, match=models.MatchAny(any=values))]} | |
self.delete_points(filter, collection_name) | |
def get_topk_points_single(self, | |
query: Any | str, | |
collection_name: str, | |
k=5): | |
responses = self.client.query_points(collection_name=collection_name, | |
query=query, | |
limit=k) | |
return [i.payload["text"] for i in responses.points] | |
if __name__ == '__main__': | |
client = QdrantClient(QDRANT_URL, api_key=QDRANT_API_KEY) | |
qdrant_store = QdrantStore(client) | |
embedding_model = TextEmbedder(modelname=EMBEDDING_MODEL) | |
query = "How to filter a dataframe" | |
query_emb = embedding_model.embed_text(query) | |
responses = qdrant_store.get_topk_points_single(query_emb[0], collection_name=QDRANT_COLLECTION_NAME) | |
print(responses) | |