Spaces:
Runtime error
Runtime error
File size: 2,789 Bytes
ed4d993 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
"""Wrapper around Dria Retriever."""
from typing import Any, List, Optional
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_community.utilities import DriaAPIWrapper
class DriaRetriever(BaseRetriever):
"""`Dria` retriever using the DriaAPIWrapper."""
api_wrapper: DriaAPIWrapper
def __init__(self, api_key: str, contract_id: Optional[str] = None, **kwargs: Any):
"""
Initialize the DriaRetriever with a DriaAPIWrapper instance.
Args:
api_key: The API key for Dria.
contract_id: The contract ID of the knowledge base to interact with.
"""
api_wrapper = DriaAPIWrapper(api_key=api_key, contract_id=contract_id)
super().__init__(api_wrapper=api_wrapper, **kwargs) # type: ignore[call-arg]
def create_knowledge_base(
self,
name: str,
description: str,
category: str = "Unspecified",
embedding: str = "jina",
) -> str:
"""Create a new knowledge base in Dria.
Args:
name: The name of the knowledge base.
description: The description of the knowledge base.
category: The category of the knowledge base.
embedding: The embedding model to use for the knowledge base.
Returns:
The ID of the created knowledge base.
"""
response = self.api_wrapper.create_knowledge_base(
name, description, category, embedding
)
return response
def add_texts(
self,
texts: List,
) -> None:
"""Add texts to the Dria knowledge base.
Args:
texts: An iterable of texts and metadatas to add to the knowledge base.
Returns:
List of IDs representing the added texts.
"""
data = [{"text": text["text"], "metadata": text["metadata"]} for text in texts]
self.api_wrapper.insert_data(data)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Retrieve relevant documents from Dria based on a query.
Args:
query: The query string to search for in the knowledge base.
run_manager: Callback manager for the retriever run.
Returns:
A list of Documents containing the search results.
"""
results = self.api_wrapper.search(query)
docs = [
Document(
page_content=result["metadata"],
metadata={"id": result["id"], "score": result["score"]},
)
for result in results
]
return docs
|