from typing import Optional import numpy from grpc.aio import AioRpcError from httpx import HTTPError from loguru import logger from qdrant_client import AsyncQdrantClient from qdrant_client.http import models from qdrant_client.models import RecommendStrategy from app.Models.api_models.search_api_model import SearchModelEnum, SearchBasisEnum from app.Models.img_data import ImageData from app.Models.query_params import FilterParams from app.Models.search_result import SearchResult from app.Services.lifespan_service import LifespanService from app.config import config, QdrantMode from app.util.retry_deco_async import wrap_object, retry_async class PointNotFoundError(ValueError): def __init__(self, point_id: str): self.point_id = point_id super().__init__(f"Point {point_id} not found.") class VectorDbContext(LifespanService): IMG_VECTOR = "image_vector" TEXT_VECTOR = "text_contain_vector" AVAILABLE_POINT_TYPES = models.Record | models.ScoredPoint | models.PointStruct def __init__(self): match config.qdrant.mode: case QdrantMode.SERVER: self._client = AsyncQdrantClient(host=config.qdrant.host, port=config.qdrant.port, grpc_port=config.qdrant.grpc_port, api_key=config.qdrant.api_key, prefer_grpc=config.qdrant.prefer_grpc) wrap_object(self._client, retry_async((AioRpcError, HTTPError))) case QdrantMode.LOCAL: self._client = AsyncQdrantClient(path=config.qdrant.local_path) case QdrantMode.MEMORY: logger.warning("Using in-memory Qdrant client. Data will be lost after application restart. " "This should only be used for testing and debugging.") self._client = AsyncQdrantClient(":memory:") case _: raise ValueError("Invalid Qdrant mode.") self.collection_name = config.qdrant.coll async def on_load(self): if not await self.check_collection(): logger.warning("Collection not found. Initializing...") await self.initialize_collection() async def retrieve_by_id(self, image_id: str, with_vectors=False) -> ImageData: """ Retrieve an item from database by id. Will raise PointNotFoundError if the given ID doesn't exist. :param image_id: The ID to retrieve. :param with_vectors: Whether to retrieve vectors. :return: The retrieved item. """ logger.info("Retrieving item {} from database...", image_id) result = await self._client.retrieve(collection_name=self.collection_name, ids=[image_id], with_payload=True, with_vectors=with_vectors) if len(result) != 1: logger.error("Point not exist.") raise PointNotFoundError(image_id) return self._get_img_data_from_point(result[0]) async def retrieve_by_ids(self, image_id: list[str], with_vectors=False) -> list[ImageData]: """ Retrieve items from the database by IDs. An exception is thrown if there are items in the IDs that do not exist in the database. :param image_id: The list of IDs to retrieve. :param with_vectors: Whether to retrieve vectors. :return: The list of retrieved items. """ logger.info("Retrieving {} items from database...", len(image_id)) result = await self._client.retrieve(collection_name=self.collection_name, ids=image_id, with_payload=True, with_vectors=with_vectors) result_point_ids = {t.id for t in result} missing_point_ids = set(image_id) - result_point_ids if len(missing_point_ids) > 0: logger.error("{} points not exist.", len(missing_point_ids)) raise PointNotFoundError(str(missing_point_ids)) return self._get_img_data_from_points(result) async def validate_ids(self, image_id: list[str]) -> list[str]: """ Validate a list of IDs. Will return a list of valid IDs. :param image_id: The list of IDs to validate. :return: The list of valid IDs. """ logger.info("Validating {} items from database...", len(image_id)) result = await self._client.retrieve(collection_name=self.collection_name, ids=image_id, with_payload=False, with_vectors=False) return [t.id for t in result] async def querySearch(self, query_vector, query_vector_name: str = IMG_VECTOR, top_k=10, skip=0, filter_param: FilterParams | None = None) -> list[SearchResult]: logger.info("Querying Qdrant... top_k = {}", top_k) result = await self._client.search(collection_name=self.collection_name, query_vector=(query_vector_name, query_vector), query_filter=self._get_filters_by_filter_param(filter_param), limit=top_k, offset=skip, with_payload=True) logger.success("Query completed!") return [self._get_search_result_from_scored_point(t) for t in result] async def querySimilar(self, query_vector_name: str = IMG_VECTOR, search_id: Optional[str] = None, positive_vectors: Optional[list[numpy.ndarray]] = None, negative_vectors: Optional[list[numpy.ndarray]] = None, mode: Optional[SearchModelEnum] = None, with_vectors: bool = False, filter_param: FilterParams | None = None, top_k: int = 10, skip: int = 0) -> list[SearchResult]: _positive_vectors = [t.tolist() for t in positive_vectors] if positive_vectors is not None else [search_id] _negative_vectors = [t.tolist() for t in negative_vectors] if negative_vectors is not None else None _strategy = None if mode is None else (RecommendStrategy.AVERAGE_VECTOR if mode == SearchModelEnum.average else RecommendStrategy.BEST_SCORE) # since only combined_search need return vectors, We can define _combined_search_need_vectors like below _combined_search_need_vectors = [ self.IMG_VECTOR if query_vector_name == self.TEXT_VECTOR else self.TEXT_VECTOR] if with_vectors else None logger.info("Querying Qdrant... top_k = {}", top_k) result = await self._client.recommend(collection_name=self.collection_name, using=query_vector_name, positive=_positive_vectors, negative=_negative_vectors, strategy=_strategy, with_vectors=_combined_search_need_vectors, query_filter=self._get_filters_by_filter_param(filter_param), limit=top_k, offset=skip, with_payload=True) logger.success("Query completed!") return [self._get_search_result_from_scored_point(t) for t in result] async def insertItems(self, items: list[ImageData]): logger.info("Inserting {} items into Qdrant...", len(items)) points = [self._get_point_from_img_data(t) for t in items] response = await self._client.upsert(collection_name=self.collection_name, wait=True, points=points) logger.success("Insert completed! Status: {}", response.status) async def deleteItems(self, ids: list[str]): logger.info("Deleting {} items from Qdrant...", len(ids)) response = await self._client.delete(collection_name=self.collection_name, points_selector=models.PointIdsList( points=ids ), ) logger.success("Delete completed! Status: {}", response.status) async def updatePayload(self, new_data: ImageData): """ Update the payload of an existing item in the database. Warning: This method will not update the vector of the item. :param new_data: The new data to update. """ response = await self._client.set_payload(collection_name=self.collection_name, payload=new_data.payload, points=[str(new_data.id)], wait=True) logger.success("Update completed! Status: {}", response.status) async def updateVectors(self, new_points: list[ImageData]): resp = await self._client.update_vectors(collection_name=self.collection_name, points=[self._get_vector_from_img_data(t) for t in new_points], ) logger.success("Update vectors completed! Status: {}", resp.status) async def scroll_points(self, from_id: str | None = None, count=50, with_vectors=False, filter_param: FilterParams | None = None, ) -> tuple[list[ImageData], str]: resp, next_id = await self._client.scroll(collection_name=self.collection_name, limit=count, offset=from_id, with_vectors=with_vectors, scroll_filter=self._get_filters_by_filter_param(filter_param) ) return [self._get_img_data_from_point(t) for t in resp], next_id async def get_counts(self, exact: bool) -> int: resp = await self._client.count(collection_name=self.collection_name, exact=exact) return resp.count async def check_collection(self) -> bool: resp = await self._client.get_collections() resp = [t.name for t in resp.collections] return self.collection_name in resp async def initialize_collection(self): if await self.check_collection(): logger.warning("Collection already exists. Skip initialization.") return logger.info("Initializing database, collection name: {}", self.collection_name) vectors_config = { self.IMG_VECTOR: models.VectorParams(size=768, distance=models.Distance.COSINE), self.TEXT_VECTOR: models.VectorParams(size=768, distance=models.Distance.COSINE) } await self._client.create_collection(collection_name=self.collection_name, vectors_config=vectors_config) logger.success("Collection created!") @classmethod def _get_vector_from_img_data(cls, img_data: ImageData) -> models.PointVectors: vector = {} if img_data.image_vector is not None: vector[cls.IMG_VECTOR] = img_data.image_vector.tolist() if img_data.text_contain_vector is not None: vector[cls.TEXT_VECTOR] = img_data.text_contain_vector.tolist() return models.PointVectors( id=str(img_data.id), vector=vector ) @classmethod def _get_point_from_img_data(cls, img_data: ImageData) -> models.PointStruct: return models.PointStruct( id=str(img_data.id), payload=img_data.payload, vector=cls._get_vector_from_img_data(img_data).vector ) def _get_img_data_from_point(self, point: AVAILABLE_POINT_TYPES) -> ImageData: return (ImageData .from_payload(point.id, point.payload, image_vector=numpy.array(point.vector[self.IMG_VECTOR], dtype=numpy.float32) if point.vector and self.IMG_VECTOR in point.vector else None, text_contain_vector=numpy.array(point.vector[self.TEXT_VECTOR], dtype=numpy.float32) if point.vector and self.TEXT_VECTOR in point.vector else None )) def _get_img_data_from_points(self, points: list[AVAILABLE_POINT_TYPES]) -> list[ImageData]: return [self._get_img_data_from_point(t) for t in points] def _get_search_result_from_scored_point(self, point: models.ScoredPoint) -> SearchResult: return SearchResult(img=self._get_img_data_from_point(point), score=point.score) @classmethod def vector_name_for_basis(cls, basis: SearchBasisEnum) -> str: match basis: case SearchBasisEnum.vision: return cls.IMG_VECTOR case SearchBasisEnum.ocr: return cls.TEXT_VECTOR case _: raise ValueError("Invalid basis") @staticmethod def _get_filters_by_filter_param(filter_param: FilterParams | None) -> models.Filter | None: if filter_param is None: return None filters = [] neg_filter = [] if filter_param.min_width is not None and filter_param.min_width > 0: filters.append(models.FieldCondition( key="width", range=models.Range( gte=filter_param.min_width ) )) if filter_param.min_height is not None and filter_param.min_height > 0: filters.append(models.FieldCondition( key="height", range=models.Range( gte=filter_param.min_height ) )) if filter_param.min_ratio is not None: filters.append(models.FieldCondition( key="aspect_ratio", range=models.Range( gte=filter_param.min_ratio, lte=filter_param.max_ratio ) )) if filter_param.starred is not None: filters.append(models.FieldCondition( key="starred", match=models.MatchValue( value=filter_param.starred ) )) if filter_param.ocr_text is not None: filters.append(models.FieldCondition( key="ocr_text_lower", match=models.MatchText( text=filter_param.ocr_text.lower() ) )) if filter_param.categories is not None: filters.append(models.FieldCondition( key="categories", match=models.MatchAny( any=filter_param.categories ) )) if filter_param.categories_negative is not None: neg_filter.append(models.FieldCondition( key="categories", match=models.MatchAny(any=filter_param.categories_negative), )) if not filters and not neg_filter: return None return models.Filter( must=filters, must_not=neg_filter )