""" Retrievers for text chunks. """ import os from langchain.text_splitter import ( RecursiveCharacterTextSplitter, SpacyTextSplitter, ) from rerank import BgeRerank from langchain.retrievers import ContextualCompressionRetriever def get_parent_doc_retriever( documents, vectorstore, add_documents=True, docstore="in_memory", save_path_root="./", docstore_file="store_location", save_vectorstore=False, save_docstore=False, k=10, ): """Parent document (small-to-big) retriever.""" # TODO need better design # Ref: explain how it works: https://clusteredbytes.pages.dev/posts/2023/langchain-parent-document-retriever/ from langchain.storage.file_system import LocalFileStore from langchain.storage import InMemoryStore from langchain.storage._lc_store import create_kv_docstore from langchain.retrievers import ParentDocumentRetriever # Document store for parent, different from (child) docs in vectorestore if docstore == "in_memory": docstore = InMemoryStore() elif docstore == "local_storage": # Ref: https://stackoverflow.com/questions/77385587/persist-parentdocumentretriever-of-langchain fs = LocalFileStore(docstore_file) docstore = create_kv_docstore(fs) elif docstore == "sql": from langchain_rag.storage import SQLStore # Instantiate the SQLStore with the root path docstore = SQLStore( namespace="test", db_url="sqlite:///parent_retrieval_db.db" ) # TODO: WIP else: docstore = docstore # TODO: add check # raise # TODO implement other docstores # TODO: how to better set these values? # parent_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n"], chunk_size=1024, chunk_overlap=256) # child_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n"], chunk_size=256, chunk_overlap=64) parent_splitter = SpacyTextSplitter.from_tiktoken_encoder( chunk_size=512, chunk_overlap=128, ) child_splitter = SpacyTextSplitter.from_tiktoken_encoder( chunk_size=256, chunk_overlap=64, ) retriever = ParentDocumentRetriever( vectorstore=vectorstore, docstore=docstore, child_splitter=child_splitter, parent_splitter=parent_splitter, search_kwargs={"k": k}, ) if add_documents: retriever.add_documents(documents) if save_vectorstore: vectorstore.save_local(os.path.join(save_path_root, "faiss_index")) if save_docstore: import pickle def save_to_pickle(obj, filename): with open(filename, "wb") as file: pickle.dump(obj, file, pickle.HIGHEST_PROTOCOL) save_to_pickle(docstore, os.path.join(save_path_root, "docstore.pkl")) return retriever def get_rerank_retriever(base_retriever, reranker=None): """Return rerank retriever.""" # Use default BgeRerank or user defined reranker if reranker is None: compressor = BgeRerank() else: # TODO : add check compressor = reranker compression_retriever = ContextualCompressionRetriever( base_compressor=compressor, base_retriever=base_retriever ) return compression_retriever