File size: 2,537 Bytes
b5deaf1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""Module for retrievers that fetch documents from various sources."""
from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_core.documents import Document
from models.chroma import vectorstore

class DocRetriever(BaseRetriever):
    """
    DocRetriever is a class that retrieves documents using a VectorStoreRetriever.
    Attributes:
        retriever (VectorStoreRetriever): An instance used to retrieve documents.
        k (int): The number of documents to retrieve. Default is 10.
    Methods:
        __init__(k: int = 10) -> None:
            Initializes the DocRetriever with a specified number of documents to retrieve.
        _get_relevant_documents(query: str, *, run_manager) -> list:
            Retrieves relevant documents based on the given query.
            Args:
                query (str): The query string to search for relevant documents.
                run_manager: An object to manage the run (not used in the method).
            Returns:
                list: A list of Document objects with relevant metadata.
    """
    retriever: VectorStoreRetriever = None
    k: int = 10

    def __init__(self, req, k: int = 10) -> None:
        super().__init__()
        # _filter={}
        # if req.site != []:
        #     _filter.update({"site": {"$in": req.site}})
        # if req.id != []:
        #     _filter.update({"id": {"$in": req.id}})
        self.retriever = vectorstore.as_retriever(
            search_type='similarity_score_threshold',
            search_kwargs={
                "k": k,
                # "filter": _filter,
                "score_threshold": .1
            }
        )

    def _get_relevant_documents(self, query: str, *, run_manager) -> list:
        retrieved_docs = self.retriever.invoke(query)
        doc_lst = []
        for doc in retrieved_docs:
            # date = str(doc.metadata['publishDate'])
            doc_lst.append(Document(
                page_content = doc.page_content,
                metadata = {
                    "content": doc.page_content,
                    # "id": doc.metadata['id'],
                    # "title": doc.metadata['title'],
                    # "site": doc.metadata['site'],
                    # "link": doc.metadata['link'],
                    # "publishDate": doc.metadata['publishDate'].strftime('%Y-%m-%d'),
                    # 'web': False,
                    # "source": "Finfast"
                }
            ))
        return doc_lst