Spaces:
Runtime error
Runtime error
File size: 1,985 Bytes
ed4d993 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
from __future__ import annotations
from typing import Any, List
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
class KayAiRetriever(BaseRetriever):
"""
Retriever for Kay.ai datasets.
To work properly, expects you to have KAY_API_KEY env variable set.
You can get one for free at https://kay.ai/.
"""
client: Any
num_contexts: int
@classmethod
def create(
cls,
dataset_id: str,
data_types: List[str],
num_contexts: int = 6,
) -> KayAiRetriever:
"""
Create a KayRetriever given a Kay dataset id and a list of datasources.
Args:
dataset_id: A dataset id category in Kay, like "company"
data_types: A list of datasources present within a dataset. For
"company" the corresponding datasources could be
["10-K", "10-Q", "8-K", "PressRelease"].
num_contexts: The number of documents to retrieve on each query.
Defaults to 6.
"""
try:
from kay.rag.retrievers import KayRetriever
except ImportError:
raise ImportError(
"Could not import kay python package. Please install it with "
"`pip install kay`.",
)
client = KayRetriever(dataset_id, data_types)
return cls(client=client, num_contexts=num_contexts)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
ctxs = self.client.query(query=query, num_context=self.num_contexts)
docs = []
for ctx in ctxs:
page_content = ctx.pop("chunk_embed_text", None)
if page_content is None:
continue
docs.append(Document(page_content=page_content, metadata={**ctx}))
return docs
|