|
"""Module for retrievers that fetch documents from various sources.""" |
|
from importlib import metadata |
|
from venv import logger |
|
from langchain_core.retrievers import BaseRetriever |
|
from langchain_core.vectorstores import VectorStoreRetriever |
|
from langchain_core.documents import Document |
|
from models.chroma import vectorstore |
|
|
|
class DocRetriever(BaseRetriever): |
|
""" |
|
DocRetriever is a class that retrieves documents using a VectorStoreRetriever. |
|
Attributes: |
|
retriever (VectorStoreRetriever): An instance used to retrieve documents. |
|
k (int): The number of documents to retrieve. Default is 10. |
|
Methods: |
|
__init__(k: int = 10) -> None: |
|
Initializes the DocRetriever with a specified number of documents to retrieve. |
|
_get_relevant_documents(query: str, *, run_manager) -> list: |
|
Retrieves relevant documents based on the given query. |
|
Args: |
|
query (str): The query string to search for relevant documents. |
|
run_manager: An object to manage the run (not used in the method). |
|
Returns: |
|
list: A list of Document objects with relevant metadata. |
|
""" |
|
retriever: VectorStoreRetriever = None |
|
k: int = 10 |
|
|
|
def __init__(self, req, k: int = 10) -> None: |
|
super().__init__() |
|
|
|
|
|
|
|
|
|
|
|
self.retriever = vectorstore.as_retriever( |
|
search_type='similarity', |
|
search_kwargs={ |
|
"k": k, |
|
|
|
|
|
} |
|
) |
|
|
|
def _get_relevant_documents(self, query: str, *, run_manager) -> list: |
|
try: |
|
retrieved_docs = self.retriever.invoke(query) |
|
|
|
for doc in retrieved_docs: |
|
|
|
doc.metadata['content'] = doc.page_content |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return retrieved_docs |
|
except RuntimeError as e: |
|
logger.error("Error retrieving documents: %s", e) |
|
return [] |
|
|