File size: 2,970 Bytes
b5deaf1
1031c5b
e83b975
b5deaf1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1031c5b
b5deaf1
1031c5b
b5deaf1
 
 
 
 
 
 
e83b975
b5deaf1
 
 
e83b975
b5deaf1
 
 
 
e83b975
 
1031c5b
e83b975
 
1031c5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e83b975
1031c5b
e83b975
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""Module for retrievers that fetch documents from various sources."""
from importlib import metadata
from venv import logger
from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_core.documents import Document
from models.chroma import vectorstore

class DocRetriever(BaseRetriever):
    """
    DocRetriever is a class that retrieves documents using a VectorStoreRetriever.
    Attributes:
        retriever (VectorStoreRetriever): An instance used to retrieve documents.
        k (int): The number of documents to retrieve. Default is 10.
    Methods:
        __init__(k: int = 10) -> None:
            Initializes the DocRetriever with a specified number of documents to retrieve.
        _get_relevant_documents(query: str, *, run_manager) -> list:
            Retrieves relevant documents based on the given query.
            Args:
                query (str): The query string to search for relevant documents.
                run_manager: An object to manage the run (not used in the method).
            Returns:
                list: A list of Document objects with relevant metadata.
    """
    retriever: VectorStoreRetriever = None
    k: int = 10

    def __init__(self, req, k: int = 10) -> None:
        super().__init__()
        # _filter={}
        # if req.site != []:
        #     _filter.update({"site": {"$in": req.site}})
        # if req.id != []:
        #     _filter.update({"id": {"$in": req.id}})
        self.retriever = vectorstore.as_retriever(
            search_type='similarity',
            search_kwargs={
                "k": k,
                # "filter": _filter,
                # "score_threshold": .1
            }
        )

    def _get_relevant_documents(self, query: str, *, run_manager) -> list:
        try:
            retrieved_docs = self.retriever.invoke(query)
            # doc_lst = []
            for doc in retrieved_docs:
                # date = str(doc.metadata['publishDate'])
                doc.metadata['content'] = doc.page_content
                # doc_lst.append(Document(
                #     page_content = doc.page_content,
                #     metadata = doc.metadata
                #     # metadata = {
                #     #     "content": doc.page_content,
                #     #     # "id": doc.metadata['id'],
                #     #     "title": doc.metadata['subject'],
                #     #     # "site": doc.metadata['site'],
                #     #     # "link": doc.metadata['link'],
                #     #     # "publishDate": doc.metadata['publishDate'].strftime('%Y-%m-%d'),
                #     #     # 'web': False,
                #     #     # "source": "Finfast"
                #     # }
                # ))
            # print(doc_lst)
            return retrieved_docs
        except RuntimeError as e:
            logger.error("Error retrieving documents: %s", e)
            return []