Spaces:
Runtime error
Runtime error
from enum import Enum | |
from typing import Any, Dict, List, Optional, Union | |
import numpy as np | |
from langchain_core.callbacks import CallbackManagerForRetrieverRun | |
from langchain_core.documents import Document | |
from langchain_core.embeddings import Embeddings | |
from langchain_core.retrievers import BaseRetriever | |
from langchain_community.vectorstores.utils import maximal_marginal_relevance | |
class SearchType(str, Enum): | |
"""Enumerator of the types of search to perform.""" | |
similarity = "similarity" | |
mmr = "mmr" | |
class DocArrayRetriever(BaseRetriever): | |
"""`DocArray Document Indices` retriever. | |
Currently, it supports 5 backends: | |
InMemoryExactNNIndex, HnswDocumentIndex, QdrantDocumentIndex, | |
ElasticDocIndex, and WeaviateDocumentIndex. | |
Args: | |
index: One of the above-mentioned index instances | |
embeddings: Embedding model to represent text as vectors | |
search_field: Field to consider for searching in the documents. | |
Should be an embedding/vector/tensor. | |
content_field: Field that represents the main content in your document schema. | |
Will be used as a `page_content`. Everything else will go into `metadata`. | |
search_type: Type of search to perform (similarity / mmr) | |
filters: Filters applied for document retrieval. | |
top_k: Number of documents to return | |
""" | |
index: Any | |
embeddings: Embeddings | |
search_field: str | |
content_field: str | |
search_type: SearchType = SearchType.similarity | |
top_k: int = 1 | |
filters: Optional[Any] = None | |
class Config: | |
"""Configuration for this pydantic object.""" | |
arbitrary_types_allowed = True | |
def _get_relevant_documents( | |
self, | |
query: str, | |
*, | |
run_manager: CallbackManagerForRetrieverRun, | |
) -> List[Document]: | |
"""Get documents relevant for a query. | |
Args: | |
query: string to find relevant documents for | |
Returns: | |
List of relevant documents | |
""" | |
query_emb = np.array(self.embeddings.embed_query(query)) | |
if self.search_type == SearchType.similarity: | |
results = self._similarity_search(query_emb) | |
elif self.search_type == SearchType.mmr: | |
results = self._mmr_search(query_emb) | |
else: | |
raise ValueError( | |
f"Search type {self.search_type} does not exist. " | |
f"Choose either 'similarity' or 'mmr'." | |
) | |
return results | |
def _search( | |
self, query_emb: np.ndarray, top_k: int | |
) -> List[Union[Dict[str, Any], Any]]: | |
""" | |
Perform a search using the query embedding and return top_k documents. | |
Args: | |
query_emb: Query represented as an embedding | |
top_k: Number of documents to return | |
Returns: | |
A list of top_k documents matching the query | |
""" | |
from docarray.index import ElasticDocIndex, WeaviateDocumentIndex | |
filter_args = {} | |
search_field = self.search_field | |
if isinstance(self.index, WeaviateDocumentIndex): | |
filter_args["where_filter"] = self.filters | |
search_field = "" | |
elif isinstance(self.index, ElasticDocIndex): | |
filter_args["query"] = self.filters | |
else: | |
filter_args["filter_query"] = self.filters | |
if self.filters: | |
query = ( | |
self.index.build_query() # get empty query object | |
.find( | |
query=query_emb, search_field=search_field | |
) # add vector similarity search | |
.filter(**filter_args) # add filter search | |
.build(limit=top_k) # build the query | |
) | |
# execute the combined query and return the results | |
docs = self.index.execute_query(query) | |
if hasattr(docs, "documents"): | |
docs = docs.documents | |
docs = docs[:top_k] | |
else: | |
docs = self.index.find( | |
query=query_emb, search_field=search_field, limit=top_k | |
).documents | |
return docs | |
def _similarity_search(self, query_emb: np.ndarray) -> List[Document]: | |
""" | |
Perform a similarity search. | |
Args: | |
query_emb: Query represented as an embedding | |
Returns: | |
A list of documents most similar to the query | |
""" | |
docs = self._search(query_emb=query_emb, top_k=self.top_k) | |
results = [self._docarray_to_langchain_doc(doc) for doc in docs] | |
return results | |
def _mmr_search(self, query_emb: np.ndarray) -> List[Document]: | |
""" | |
Perform a maximal marginal relevance (mmr) search. | |
Args: | |
query_emb: Query represented as an embedding | |
Returns: | |
A list of diverse documents related to the query | |
""" | |
docs = self._search(query_emb=query_emb, top_k=20) | |
mmr_selected = maximal_marginal_relevance( | |
query_emb, | |
[ | |
doc[self.search_field] | |
if isinstance(doc, dict) | |
else getattr(doc, self.search_field) | |
for doc in docs | |
], | |
k=self.top_k, | |
) | |
results = [self._docarray_to_langchain_doc(docs[idx]) for idx in mmr_selected] | |
return results | |
def _docarray_to_langchain_doc(self, doc: Union[Dict[str, Any], Any]) -> Document: | |
""" | |
Convert a DocArray document (which also might be a dict) | |
to a langchain document format. | |
DocArray document can contain arbitrary fields, so the mapping is done | |
in the following way: | |
page_content <-> content_field | |
metadata <-> all other fields excluding | |
tensors and embeddings (so float, int, string) | |
Args: | |
doc: DocArray document | |
Returns: | |
Document in langchain format | |
Raises: | |
ValueError: If the document doesn't contain the content field | |
""" | |
fields = doc.keys() if isinstance(doc, dict) else doc.__fields__ | |
if self.content_field not in fields: | |
raise ValueError( | |
f"Document does not contain the content field - {self.content_field}." | |
) | |
lc_doc = Document( | |
page_content=doc[self.content_field] | |
if isinstance(doc, dict) | |
else getattr(doc, self.content_field) | |
) | |
for name in fields: | |
value = doc[name] if isinstance(doc, dict) else getattr(doc, name) | |
if ( | |
isinstance(value, (str, int, float, bool)) | |
and name != self.content_field | |
): | |
lc_doc.metadata[name] = value | |
return lc_doc | |