File size: 5,540 Bytes
30eced7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from typing import List

from langchain_community.llms import Replicate    # importing from langchain depricated; use langchain_community for several modules here
from langchain_community.document_loaders import OnlinePDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import CohereEmbeddings
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CohereRerank
from langchain.prompts import ChatPromptTemplate
from langchain.schema import StrOutputParser
from langchain_core.runnables import RunnableParallel, RunnablePassthrough


class ElevatedRagChain:
    '''
    Class ElevatedRagChain integrates various components from the langchain library to build
    an advanced retrieval-augmented generation (RAG) system designed to process documents
    by reading in, chunking, embedding, and adding their chunk embeddings to FAISS vector store
    for efficient retrieval. It uses the embeddings to retrieve relevant document chunks
    in response to user queries.
    The chunks are retrieved using an ensemble retriever (BM25 retriever + FAISS retriver)
    and passed through a Cohere reranker before being used as context
    for generating answers using a Llama 2 large language model (LLM). 
    '''
    def __init__(self) -> None:
        '''
        Initialize the class with predefined model, embedding function, weights, and top_k value
        '''
        self.llama2_70b   = 'meta/llama-2-70b-chat:2d19859030ff705a87c746f7e96eea03aefb71f166725aee39692f1476566d48'
        self.embed_func   = CohereEmbeddings(model="embed-english-light-v3.0")
        self.bm25_weight  = 0.6
        self.faiss_weight = 0.4
        self.top_k        = 5


    def add_pdfs_to_vectore_store(
            self,
            pdf_links: List,
            chunk_size: int=1500,
            ) -> None:
        '''
        Processes PDF documents by loading, chunking, embedding, and adding them to a FAISS vector store.
        Build an advanced RAG system  
        Args:
            pdf_links (List): list of URLs pointing to the PDF documents to be processed
            chunk_size (int, optional): size of text chunks to split the documents into, defaults to 1500
        '''        
        # load pdfs
        self.raw_data = [ OnlinePDFLoader(doc).load()[0] for doc in pdf_links ]

        # chunk text
        self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=100)
        self.split_data    = self.text_splitter.split_documents(self.raw_data)

        # add chunks to BM25 retriever
        self.bm25_retriever   = BM25Retriever.from_documents(self.split_data)
        self.bm25_retriever.k = self.top_k

        # embed and add chunks to vectore store
        self.vector_store     = FAISS.from_documents(self.split_data, self.embed_func)
        self.faiss_retriever  = self.vector_store.as_retriever(search_kwargs={"k": self.top_k})
        print("All PDFs processed and added to vectore store.")
        
        # build advanced RAG system
        self.build_elevated_rag_system()
        print("RAG system is built successfully.")


    def build_elevated_rag_system(self) -> None:
        '''
        Build an advanced RAG system from different components:
        * BM25 retriever
        * FAISS vector store retriever
        * Llama 2 model
        '''
        # combine BM25 and FAISS retrievers into an ensemble retriever
        self.ensemble_retriever = EnsembleRetriever(
            retrievers=[self.bm25_retriever, self.faiss_retriever],
            weights=[self.bm25_weight, self.faiss_weight]
        )

        # use reranker to improve retrieval quality
        self.reranker = CohereRerank(top_n=5)
        self.rerank_retriever = ContextualCompressionRetriever(    # combine ensemble retriever and reranker
            base_retriever=self.ensemble_retriever,
            base_compressor=self.reranker,
        )

        # define prompt template for the language model
        RAG_PROMPT_TEMPLATE = """\
        Use the following context to provide a detailed technical answer the user's question.
        Do not use an introduction similar to "Based on the provided documents, ...", just answer the question.
        If you don't know the answer, please respond with "I don't know".

        Context:
        {context}

        User's question:
        {question}
        """
        self.rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT_TEMPLATE)
        self.str_output_parser = StrOutputParser()

        # parallel execution of context retrieval and question passing
        self.entry_point_and_elevated_retriever = RunnableParallel(
            {
                "context" : self.rerank_retriever,
                "question" : RunnablePassthrough()
            }
        )

        # initialize Llama 2 model with specific parameters
        self.llm = Replicate(
            model=self.llama2_70b,
            model_kwargs={"temperature": 0.5,"top_p": 1, "max_new_tokens":1000}
        )

        # chain components to form final elevated RAG system using LangChain Expression Language (LCEL)
        self.elevated_rag_chain = self.entry_point_and_elevated_retriever | self.rag_prompt | self.llm #| self.str_output_parser