Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
from typing import Any, List | |
from langchain_core.callbacks import CallbackManagerForRetrieverRun | |
from langchain_core.documents import Document | |
from langchain_core.retrievers import BaseRetriever | |
class KayAiRetriever(BaseRetriever): | |
""" | |
Retriever for Kay.ai datasets. | |
To work properly, expects you to have KAY_API_KEY env variable set. | |
You can get one for free at https://kay.ai/. | |
""" | |
client: Any | |
num_contexts: int | |
def create( | |
cls, | |
dataset_id: str, | |
data_types: List[str], | |
num_contexts: int = 6, | |
) -> KayAiRetriever: | |
""" | |
Create a KayRetriever given a Kay dataset id and a list of datasources. | |
Args: | |
dataset_id: A dataset id category in Kay, like "company" | |
data_types: A list of datasources present within a dataset. For | |
"company" the corresponding datasources could be | |
["10-K", "10-Q", "8-K", "PressRelease"]. | |
num_contexts: The number of documents to retrieve on each query. | |
Defaults to 6. | |
""" | |
try: | |
from kay.rag.retrievers import KayRetriever | |
except ImportError: | |
raise ImportError( | |
"Could not import kay python package. Please install it with " | |
"`pip install kay`.", | |
) | |
client = KayRetriever(dataset_id, data_types) | |
return cls(client=client, num_contexts=num_contexts) | |
def _get_relevant_documents( | |
self, query: str, *, run_manager: CallbackManagerForRetrieverRun | |
) -> List[Document]: | |
ctxs = self.client.query(query=query, num_context=self.num_contexts) | |
docs = [] | |
for ctx in ctxs: | |
page_content = ctx.pop("chunk_embed_text", None) | |
if page_content is None: | |
continue | |
docs.append(Document(page_content=page_content, metadata={**ctx})) | |
return docs | |