File size: 1,486 Bytes
ed4d993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from typing import Any, List, Optional

from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import root_validator
from langchain_core.retrievers import BaseRetriever


class MetalRetriever(BaseRetriever):
    """`Metal API` retriever."""

    client: Any
    """The Metal client to use."""
    params: Optional[dict] = None
    """The parameters to pass to the Metal client."""

    @root_validator(pre=True)
    def validate_client(cls, values: dict) -> dict:
        """Validate that the client is of the correct type."""
        from metal_sdk.metal import Metal

        if "client" in values:
            client = values["client"]
            if not isinstance(client, Metal):
                raise ValueError(
                    "Got unexpected client, should be of type metal_sdk.metal.Metal. "
                    f"Instead, got {type(client)}"
                )

        values["params"] = values.get("params", {})

        return values

    def _get_relevant_documents(
        self, query: str, *, run_manager: CallbackManagerForRetrieverRun
    ) -> List[Document]:
        results = self.client.search({"text": query}, **self.params)
        final_results = []
        for r in results["data"]:
            metadata = {k: v for k, v in r.items() if k != "text"}
            final_results.append(Document(page_content=r["text"], metadata=metadata))
        return final_results