from __future__ import annotations import json import logging import os import uuid from collections import defaultdict from itertools import islice from typing import ( # type: ignore[import-not-found] Any, Dict, Generator, Iterable, Iterator, List, Optional, Sequence, Tuple, Union, ) import numpy as np from langchain_core.documents import Document as LCDocument from langchain_qdrant import QdrantVectorStore, RetrievalMode from pytorch_ie.annotations import LabeledMultiSpan, LabeledSpan, MultiSpan, Span from qdrant_client import QdrantClient, models from qdrant_client.http.models import Record from .span_embeddings import SpanEmbeddings from .span_vectorstore import SpanVectorStore logger = logging.getLogger(__name__) class QdrantSpanVectorStore(SpanVectorStore, QdrantVectorStore): """An implementation of the SpanVectorStore interface that uses Qdrant as backend for storing and retrieving span embeddings.""" EMBEDDINGS_FILE = "embeddings.npy" PAYLOADS_FILE = "payloads.json" INDEX_FILE = "index.json" def __init__( self, client: QdrantClient, collection_name: str, embedding: SpanEmbeddings, vector_params: Optional[Dict[str, Any]] = None, **kwargs, ): if not client.collection_exists(collection_name): logger.info(f'Collection "{collection_name}" does not exist. Creating it now.') client.create_collection( collection_name=collection_name, vectors_config=models.VectorParams(size=embedding.embedding_dim, **vector_params), ) else: logger.info(f'Collection "{collection_name}" already exists.') super().__init__( client=client, collection_name=collection_name, embedding=embedding, **kwargs ) def __len__(self): return self.client.get_collection(collection_name=self.collection_name).points_count def get_by_ids_with_vectors(self, ids: Sequence[str | int], /) -> List[LCDocument]: results = self.client.retrieve( self.collection_name, ids, with_payload=True, with_vectors=True ) return [ self._document_from_point( result, self.collection_name, self.content_payload_key, self.metadata_payload_key, ) for result in results ] def construct_filter( self, query_span: Union[Span, MultiSpan], metadata_doc_id_key: str, doc_id_whitelist: Optional[Sequence[str]] = None, doc_id_blacklist: Optional[Sequence[str]] = None, ) -> Optional[models.Filter]: """Construct a filter for the retrieval. It should enforce that: - if the span is labeled, the retrieved span has the same label, or - if, in addition, a label mapping is provided, the retrieved span has a label that is in the mapping for the query span's label - if `doc_id_whitelist` is provided, the retrieved span is from a document in the whitelist - if `doc_id_blacklist` is provided, the retrieved span is not from a document in the blacklist Args: query_span: The query span. metadata_doc_id_key: The key in the metadata that holds the document id. doc_id_whitelist: A list of document ids to restrict the retrieval to. doc_id_blacklist: A list of document ids to exclude from the retrieval. Returns: A filter object. """ filter_kwargs = defaultdict(list) # if the span is labeled, enforce that the retrieved span has the same label if isinstance(query_span, (LabeledSpan, LabeledMultiSpan)): if self.label_mapping is not None: target_labels = self.label_mapping.get(query_span.label, []) else: target_labels = [query_span.label] filter_kwargs["must"].append( models.FieldCondition( key=f"metadata.{self.METADATA_SPAN_KEY}.label", match=models.MatchAny(any=target_labels), ) ) elif self.label_mapping is not None: raise TypeError("Label mapping is only supported for labeled spans") if doc_id_blacklist is not None and doc_id_whitelist is not None: overlap = set(doc_id_whitelist) & set(doc_id_blacklist) if len(overlap) > 0: raise ValueError( f"Overlap between doc_id_whitelist and doc_id_blacklist: {overlap}" ) if doc_id_whitelist is not None: filter_kwargs["must"].append( models.FieldCondition( key=f"metadata.{metadata_doc_id_key}", match=( models.MatchValue(value=doc_id_whitelist[0]) if len(doc_id_whitelist) == 1 else models.MatchAny(any=doc_id_whitelist) ), ) ) if doc_id_blacklist is not None: filter_kwargs["must_not"].append( models.FieldCondition( key=f"metadata.{metadata_doc_id_key}", match=( models.MatchValue(value=doc_id_blacklist[0]) if len(doc_id_blacklist) == 1 else models.MatchAny(any=doc_id_blacklist) ), ) ) if len(filter_kwargs) > 0: return models.Filter(**filter_kwargs) return None @classmethod def _document_from_point( cls, scored_point: Any, collection_name: str, content_payload_key: str, metadata_payload_key: str, ) -> LCDocument: metadata = scored_point.payload.get(metadata_payload_key) or {} metadata["_collection_name"] = collection_name if hasattr(scored_point, "score"): metadata[cls.RELEVANCE_SCORE_KEY] = scored_point.score if hasattr(scored_point, "vector"): metadata[cls.METADATA_VECTOR_KEY] = scored_point.vector return LCDocument( id=scored_point.id, page_content=scored_point.payload.get(content_payload_key, ""), metadata=metadata, ) def _build_vectors_with_metadata( self, texts: Iterable[str], metadatas: List[dict], ) -> List[models.VectorStruct]: starts = [metadata[self.METADATA_SPAN_KEY][self.SPAN_START_KEY] for metadata in metadatas] ends = [metadata[self.METADATA_SPAN_KEY][self.SPAN_END_KEY] for metadata in metadatas] if self.retrieval_mode == RetrievalMode.DENSE: batch_embeddings = self.embeddings.embed_document_spans(list(texts), starts, ends) return [ { self.vector_name: vector, } for vector in batch_embeddings ] elif self.retrieval_mode == RetrievalMode.SPARSE: raise ValueError("Sparse retrieval mode is not yet implemented.") elif self.retrieval_mode == RetrievalMode.HYBRID: raise NotImplementedError("Hybrid retrieval mode is not yet implemented.") else: raise ValueError(f"Unknown retrieval mode. {self.retrieval_mode} to build vectors.") def _build_payloads_from_metadata( self, metadatas: Iterable[dict], metadata_payload_key: str, ) -> List[dict]: payloads = [{metadata_payload_key: metadata} for metadata in metadatas] return payloads def _generate_batches( self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, ids: Optional[Sequence[str | int]] = None, batch_size: int = 64, ) -> Generator[tuple[list[str | int], list[models.PointStruct]], Any, None]: """Generate batches of points to index. Same as in `QdrantVectorStore` but metadata is used to build vectors and payloads.""" texts_iterator = iter(texts) if metadatas is None: raise ValueError("Metadata must be provided to generate batches.") metadatas_iterator = iter(metadatas) ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)]) while batch_texts := list(islice(texts_iterator, batch_size)): batch_metadatas = list(islice(metadatas_iterator, batch_size)) batch_ids = list(islice(ids_iterator, batch_size)) points = [ models.PointStruct( id=point_id, vector=vector, payload=payload, ) for point_id, vector, payload in zip( batch_ids, self._build_vectors_with_metadata(batch_texts, metadatas=batch_metadatas), # we do not save the text in the payload because the text is the full # document which is usually already saved in the docstore self._build_payloads_from_metadata( metadatas=batch_metadatas, metadata_payload_key=self.metadata_payload_key, ), ) if vector[self.vector_name] is not None ] yield [point.id for point in points], points def similarity_search_with_score_by_vector( self, embedding: List[float], k: int = 4, filter: Optional[models.Filter] = None, search_params: Optional[models.SearchParams] = None, offset: int = 0, score_threshold: Optional[float] = None, consistency: Optional[models.ReadConsistency] = None, **kwargs: Any, ) -> List[Tuple[LCDocument, float]]: """Return docs most similar to query vector. Returns: List of documents most similar to the query text and distance for each. """ query_options = { "collection_name": self.collection_name, "query_filter": filter, "search_params": search_params, "limit": k, "offset": offset, "with_payload": True, "with_vectors": False, "score_threshold": score_threshold, "consistency": consistency, **kwargs, } results = self.client.query_points( query=embedding, using=self.vector_name, **query_options, ).points return [ ( self._document_from_point( result, self.collection_name, self.content_payload_key, self.metadata_payload_key, ), result.score, ) for result in results ] def _as_indices_vectors_payloads(self) -> Tuple[List[str], np.ndarray, List[Any]]: data, _ = self.client.scroll( collection_name=self.collection_name, with_vectors=True, limit=len(self) ) vectors_np = np.array([point.vector for point in data]) payloads = [point.payload for point in data] emb_ids = [point.id for point in data] return emb_ids, vectors_np, payloads # TODO: or use create_snapshot and restore_snapshot? def _save_to_directory(self, path: str, **kwargs) -> None: indices, vectors, payloads = self._as_indices_vectors_payloads() np.save(os.path.join(path, self.EMBEDDINGS_FILE), vectors) with open(os.path.join(path, self.PAYLOADS_FILE), "w") as f: json.dump(payloads, f, indent=2) with open(os.path.join(path, self.INDEX_FILE), "w") as f: json.dump(indices, f) def _load_from_directory(self, path: str, **kwargs) -> None: with open(os.path.join(path, self.INDEX_FILE), "r") as f: index = json.load(f) embeddings_np: np.ndarray = np.load(os.path.join(path, self.EMBEDDINGS_FILE)) with open(os.path.join(path, self.PAYLOADS_FILE), "r") as f: payloads = json.load(f) points = models.Batch(ids=index, vectors=embeddings_np.tolist(), payloads=payloads) self.client.upsert( collection_name=self.collection_name, points=points, ) logger.info(f"Loaded {len(index)} points into collection {self.collection_name}.") def mget(self, keys: Sequence[str]) -> list[Optional[Record]]: return self.client.retrieve( self.collection_name, ids=keys, with_payload=True, with_vectors=True ) def mset(self, key_value_pairs: Sequence[tuple[str, Record]]) -> None: self.client.upsert( collection_name=self.collection_name, points=models.Batch( ids=[key for key, _ in key_value_pairs], vectors=[value.vector for _, value in key_value_pairs], payloads=[value.payload for _, value in key_value_pairs], ), ) def mdelete(self, keys: Sequence[str]) -> None: self.client.delete(collection_name=self.collection_name, points_selector=keys) def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: for point in self.client.scroll( collection_name=self.collection_name, with_vectors=False, with_payload=False, limit=len(self), )[0]: yield point.id