File size: 4,555 Bytes
ed4d993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from __future__ import annotations

import json
from typing import Any, Dict, List, Literal, Optional, Sequence, Union

from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever


class VespaRetriever(BaseRetriever):
    """`Vespa` retriever."""

    app: Any
    """Vespa application to query."""
    body: Dict
    """Body of the query."""
    content_field: str
    """Name of the content field."""
    metadata_fields: Sequence[str]
    """Names of the metadata fields."""

    def _query(self, body: Dict) -> List[Document]:
        response = self.app.query(body)

        if not str(response.status_code).startswith("2"):
            raise RuntimeError(
                "Could not retrieve data from Vespa. Error code: {}".format(
                    response.status_code
                )
            )

        root = response.json["root"]
        if "errors" in root:
            raise RuntimeError(json.dumps(root["errors"]))

        docs = []
        for child in response.hits:
            page_content = child["fields"].pop(self.content_field, "")
            if self.metadata_fields == "*":
                metadata = child["fields"]
            else:
                metadata = {mf: child["fields"].get(mf) for mf in self.metadata_fields}
            metadata["id"] = child["id"]
            docs.append(Document(page_content=page_content, metadata=metadata))
        return docs

    def _get_relevant_documents(
        self, query: str, *, run_manager: CallbackManagerForRetrieverRun
    ) -> List[Document]:
        body = self.body.copy()
        body["query"] = query
        return self._query(body)

    def get_relevant_documents_with_filter(
        self, query: str, *, _filter: Optional[str] = None
    ) -> List[Document]:
        body = self.body.copy()
        _filter = f" and {_filter}" if _filter else ""
        body["yql"] = body["yql"] + _filter
        body["query"] = query
        return self._query(body)

    @classmethod
    def from_params(
        cls,
        url: str,
        content_field: str,
        *,
        k: Optional[int] = None,
        metadata_fields: Union[Sequence[str], Literal["*"]] = (),
        sources: Union[Sequence[str], Literal["*"], None] = None,
        _filter: Optional[str] = None,
        yql: Optional[str] = None,
        **kwargs: Any,
    ) -> VespaRetriever:
        """Instantiate retriever from params.

        Args:
            url (str): Vespa app URL.
            content_field (str): Field in results to return as Document page_content.
            k (Optional[int]): Number of Documents to return. Defaults to None.
            metadata_fields(Sequence[str] or "*"): Fields in results to include in
                document metadata. Defaults to empty tuple ().
            sources (Sequence[str] or "*" or None): Sources to retrieve
                from. Defaults to None.
            _filter (Optional[str]): Document filter condition expressed in YQL.
                Defaults to None.
            yql (Optional[str]): Full YQL query to be used. Should not be specified
                if _filter or sources are specified. Defaults to None.
            kwargs (Any): Keyword arguments added to query body.

        Returns:
            VespaRetriever: Instantiated VespaRetriever.
        """
        try:
            from vespa.application import Vespa
        except ImportError:
            raise ImportError(
                "pyvespa is not installed, please install with `pip install pyvespa`"
            )
        app = Vespa(url)
        body = kwargs.copy()
        if yql and (sources or _filter):
            raise ValueError(
                "yql should only be specified if both sources and _filter are not "
                "specified."
            )
        else:
            if metadata_fields == "*":
                _fields = "*"
                body["summary"] = "short"
            else:
                _fields = ", ".join([content_field] + list(metadata_fields or []))
            _sources = ", ".join(sources) if isinstance(sources, Sequence) else "*"
            _filter = f" and {_filter}" if _filter else ""
            yql = f"select {_fields} from sources {_sources} where userQuery(){_filter}"
        body["yql"] = yql
        if k:
            body["hits"] = k
        return cls(
            app=app,
            body=body,
            content_field=content_field,
            metadata_fields=metadata_fields,
        )