File size: 6,778 Bytes
ed4d993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
from enum import Enum
from typing import Any, Dict, List, Optional, Union

import numpy as np
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever

from langchain_community.vectorstores.utils import maximal_marginal_relevance


class SearchType(str, Enum):
    """Enumerator of the types of search to perform."""

    similarity = "similarity"
    mmr = "mmr"


class DocArrayRetriever(BaseRetriever):
    """`DocArray Document Indices` retriever.

    Currently, it supports 5 backends:
    InMemoryExactNNIndex, HnswDocumentIndex, QdrantDocumentIndex,
    ElasticDocIndex, and WeaviateDocumentIndex.

    Args:
        index: One of the above-mentioned index instances
        embeddings: Embedding model to represent text as vectors
        search_field: Field to consider for searching in the documents.
            Should be an embedding/vector/tensor.
        content_field: Field that represents the main content in your document schema.
            Will be used as a `page_content`. Everything else will go into `metadata`.
        search_type: Type of search to perform (similarity / mmr)
        filters: Filters applied for document retrieval.
        top_k: Number of documents to return
    """

    index: Any
    embeddings: Embeddings
    search_field: str
    content_field: str
    search_type: SearchType = SearchType.similarity
    top_k: int = 1
    filters: Optional[Any] = None

    class Config:
        """Configuration for this pydantic object."""

        arbitrary_types_allowed = True

    def _get_relevant_documents(
        self,
        query: str,
        *,
        run_manager: CallbackManagerForRetrieverRun,
    ) -> List[Document]:
        """Get documents relevant for a query.

        Args:
            query: string to find relevant documents for

        Returns:
            List of relevant documents
        """
        query_emb = np.array(self.embeddings.embed_query(query))

        if self.search_type == SearchType.similarity:
            results = self._similarity_search(query_emb)
        elif self.search_type == SearchType.mmr:
            results = self._mmr_search(query_emb)
        else:
            raise ValueError(
                f"Search type {self.search_type} does not exist. "
                f"Choose either 'similarity' or 'mmr'."
            )

        return results

    def _search(
        self, query_emb: np.ndarray, top_k: int
    ) -> List[Union[Dict[str, Any], Any]]:
        """
        Perform a search using the query embedding and return top_k documents.

        Args:
            query_emb: Query represented as an embedding
            top_k: Number of documents to return

        Returns:
            A list of top_k documents matching the query
        """

        from docarray.index import ElasticDocIndex, WeaviateDocumentIndex

        filter_args = {}
        search_field = self.search_field
        if isinstance(self.index, WeaviateDocumentIndex):
            filter_args["where_filter"] = self.filters
            search_field = ""
        elif isinstance(self.index, ElasticDocIndex):
            filter_args["query"] = self.filters
        else:
            filter_args["filter_query"] = self.filters

        if self.filters:
            query = (
                self.index.build_query()  # get empty query object
                .find(
                    query=query_emb, search_field=search_field
                )  # add vector similarity search
                .filter(**filter_args)  # add filter search
                .build(limit=top_k)  # build the query
            )
            # execute the combined query and return the results
            docs = self.index.execute_query(query)
            if hasattr(docs, "documents"):
                docs = docs.documents
            docs = docs[:top_k]
        else:
            docs = self.index.find(
                query=query_emb, search_field=search_field, limit=top_k
            ).documents
        return docs

    def _similarity_search(self, query_emb: np.ndarray) -> List[Document]:
        """
        Perform a similarity search.

        Args:
            query_emb: Query represented as an embedding

        Returns:
            A list of documents most similar to the query
        """
        docs = self._search(query_emb=query_emb, top_k=self.top_k)
        results = [self._docarray_to_langchain_doc(doc) for doc in docs]
        return results

    def _mmr_search(self, query_emb: np.ndarray) -> List[Document]:
        """
        Perform a maximal marginal relevance (mmr) search.

        Args:
            query_emb: Query represented as an embedding

        Returns:
            A list of diverse documents related to the query
        """
        docs = self._search(query_emb=query_emb, top_k=20)

        mmr_selected = maximal_marginal_relevance(
            query_emb,
            [
                doc[self.search_field]
                if isinstance(doc, dict)
                else getattr(doc, self.search_field)
                for doc in docs
            ],
            k=self.top_k,
        )
        results = [self._docarray_to_langchain_doc(docs[idx]) for idx in mmr_selected]
        return results

    def _docarray_to_langchain_doc(self, doc: Union[Dict[str, Any], Any]) -> Document:
        """
        Convert a DocArray document (which also might be a dict)
        to a langchain document format.

        DocArray document can contain arbitrary fields, so the mapping is done
        in the following way:

        page_content <-> content_field
        metadata <-> all other fields excluding
            tensors and embeddings (so float, int, string)

        Args:
            doc: DocArray document

        Returns:
            Document in langchain format

        Raises:
            ValueError: If the document doesn't contain the content field
        """

        fields = doc.keys() if isinstance(doc, dict) else doc.__fields__

        if self.content_field not in fields:
            raise ValueError(
                f"Document does not contain the content field - {self.content_field}."
            )
        lc_doc = Document(
            page_content=doc[self.content_field]
            if isinstance(doc, dict)
            else getattr(doc, self.content_field)
        )

        for name in fields:
            value = doc[name] if isinstance(doc, dict) else getattr(doc, name)
            if (
                isinstance(value, (str, int, float, bool))
                and name != self.content_field
            ):
                lc_doc.metadata[name] = value

        return lc_doc