Andrew commited on
Commit
30eced7
·
1 Parent(s): e8fc33c

Initial commit

Browse files
Files changed (6) hide show
  1. .gitignore +1 -0
  2. README.md +42 -13
  3. advanced_rag.py +124 -0
  4. app.py +92 -0
  5. packages.txt +2 -0
  6. requirements.txt +20 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ **/.DS_Store
README.md CHANGED
@@ -1,13 +1,42 @@
1
- ---
2
- title: Rag Demo With Gradio
3
- emoji: 👀
4
- colorFrom: yellow
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 4.19.2
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Advanced RAG System
3
+
4
+ This repository contains the code for a Gradio web app that demoes a Retrieval-Augmented Generation (RAG) system. This app is designed to allow users to load multiple documents of their choice into a vector database, submit queries, and receive answers generated by a sophisticated RAG system that leverages the latest advancements in natural language processing and information retrieval technologies.
5
+
6
+ ## Features
7
+
8
+ #### 1. Dynamic Processing
9
+ - Users can load multiple source documents of their choice into a vector store in real-time.
10
+ - Users can submit queries which are processed in real-time for enhanced retrieval and generation.
11
+
12
+ #### 2. PDF Integration
13
+ - The system allows for the loading of multiple PDF documents into a vector store, enabling the RAG system to retrieve information from a vast corpus.
14
+
15
+ #### 3. Advanced RAG System
16
+ Integrates various components, including:
17
+ - **UI**: Allows users to input URLs for documents and then input user queries; displays the LLM response.
18
+ - **Document Loader**: Loads documents from URLs.
19
+ - **Text Splitter**: Chunks loaded documents.
20
+ - **Vector Store**: Embeds text chunks and adds them to a FAISS vector store; embeds user queries.
21
+ - **Retrievers**: Uses an ensemble of BM25 and FAISS retrievers, along with a Cohere reranker, to retrieve relevant document chunks based on user queries.
22
+ - **Language Model**: Utilizes a Llama 2 large language model for generating responses based on the user query and retrieved context.
23
+
24
+ #### 4. PDF and Query Error Handling
25
+ - Validates PDF URLs and queries to ensure that they are not empty and that they are valid.
26
+ - Displays error messages for empty queries or issues with the RAG system.
27
+
28
+ #### 5. Refresh Mechanism
29
+ - Instructs users to refresh the page to clear / reset the RAG system.
30
+
31
+ ## Installation
32
+
33
+ To run this application, you need to have Python and Gradio installed. Follow these steps:
34
+
35
+ 1. Clone this repository to your local machine.
36
+ 2. Create and activate a virtual environment of your choice (venv, conda, etc.).
37
+ 3. Install dependencies from the requirements.txt file by running `pip install -r requirements.txt`.
38
+ 4. Set up environment variables REPLICATE_API_TOKEN (for a Llama 2 model hosted on replicate.com) and COHERE_API_KEY (for embeddings and reranking service on cohere.com)
39
+ 4. Start the Gradio app by running `python rag_gradio_app.py`.
40
+
41
+ ## Licence
42
+ MIT license
advanced_rag.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
3
+ from typing import List
4
+
5
+ from langchain_community.llms import Replicate # importing from langchain depricated; use langchain_community for several modules here
6
+ from langchain_community.document_loaders import OnlinePDFLoader
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from langchain_community.vectorstores import FAISS
9
+ from langchain_community.embeddings import CohereEmbeddings
10
+ from langchain_community.retrievers import BM25Retriever
11
+ from langchain.retrievers import EnsembleRetriever
12
+ from langchain.retrievers import ContextualCompressionRetriever
13
+ from langchain.retrievers.document_compressors import CohereRerank
14
+ from langchain.prompts import ChatPromptTemplate
15
+ from langchain.schema import StrOutputParser
16
+ from langchain_core.runnables import RunnableParallel, RunnablePassthrough
17
+
18
+
19
+ class ElevatedRagChain:
20
+ '''
21
+ Class ElevatedRagChain integrates various components from the langchain library to build
22
+ an advanced retrieval-augmented generation (RAG) system designed to process documents
23
+ by reading in, chunking, embedding, and adding their chunk embeddings to FAISS vector store
24
+ for efficient retrieval. It uses the embeddings to retrieve relevant document chunks
25
+ in response to user queries.
26
+ The chunks are retrieved using an ensemble retriever (BM25 retriever + FAISS retriver)
27
+ and passed through a Cohere reranker before being used as context
28
+ for generating answers using a Llama 2 large language model (LLM).
29
+ '''
30
+ def __init__(self) -> None:
31
+ '''
32
+ Initialize the class with predefined model, embedding function, weights, and top_k value
33
+ '''
34
+ self.llama2_70b = 'meta/llama-2-70b-chat:2d19859030ff705a87c746f7e96eea03aefb71f166725aee39692f1476566d48'
35
+ self.embed_func = CohereEmbeddings(model="embed-english-light-v3.0")
36
+ self.bm25_weight = 0.6
37
+ self.faiss_weight = 0.4
38
+ self.top_k = 5
39
+
40
+
41
+ def add_pdfs_to_vectore_store(
42
+ self,
43
+ pdf_links: List,
44
+ chunk_size: int=1500,
45
+ ) -> None:
46
+ '''
47
+ Processes PDF documents by loading, chunking, embedding, and adding them to a FAISS vector store.
48
+ Build an advanced RAG system
49
+ Args:
50
+ pdf_links (List): list of URLs pointing to the PDF documents to be processed
51
+ chunk_size (int, optional): size of text chunks to split the documents into, defaults to 1500
52
+ '''
53
+ # load pdfs
54
+ self.raw_data = [ OnlinePDFLoader(doc).load()[0] for doc in pdf_links ]
55
+
56
+ # chunk text
57
+ self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=100)
58
+ self.split_data = self.text_splitter.split_documents(self.raw_data)
59
+
60
+ # add chunks to BM25 retriever
61
+ self.bm25_retriever = BM25Retriever.from_documents(self.split_data)
62
+ self.bm25_retriever.k = self.top_k
63
+
64
+ # embed and add chunks to vectore store
65
+ self.vector_store = FAISS.from_documents(self.split_data, self.embed_func)
66
+ self.faiss_retriever = self.vector_store.as_retriever(search_kwargs={"k": self.top_k})
67
+ print("All PDFs processed and added to vectore store.")
68
+
69
+ # build advanced RAG system
70
+ self.build_elevated_rag_system()
71
+ print("RAG system is built successfully.")
72
+
73
+
74
+ def build_elevated_rag_system(self) -> None:
75
+ '''
76
+ Build an advanced RAG system from different components:
77
+ * BM25 retriever
78
+ * FAISS vector store retriever
79
+ * Llama 2 model
80
+ '''
81
+ # combine BM25 and FAISS retrievers into an ensemble retriever
82
+ self.ensemble_retriever = EnsembleRetriever(
83
+ retrievers=[self.bm25_retriever, self.faiss_retriever],
84
+ weights=[self.bm25_weight, self.faiss_weight]
85
+ )
86
+
87
+ # use reranker to improve retrieval quality
88
+ self.reranker = CohereRerank(top_n=5)
89
+ self.rerank_retriever = ContextualCompressionRetriever( # combine ensemble retriever and reranker
90
+ base_retriever=self.ensemble_retriever,
91
+ base_compressor=self.reranker,
92
+ )
93
+
94
+ # define prompt template for the language model
95
+ RAG_PROMPT_TEMPLATE = """\
96
+ Use the following context to provide a detailed technical answer the user's question.
97
+ Do not use an introduction similar to "Based on the provided documents, ...", just answer the question.
98
+ If you don't know the answer, please respond with "I don't know".
99
+
100
+ Context:
101
+ {context}
102
+
103
+ User's question:
104
+ {question}
105
+ """
106
+ self.rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT_TEMPLATE)
107
+ self.str_output_parser = StrOutputParser()
108
+
109
+ # parallel execution of context retrieval and question passing
110
+ self.entry_point_and_elevated_retriever = RunnableParallel(
111
+ {
112
+ "context" : self.rerank_retriever,
113
+ "question" : RunnablePassthrough()
114
+ }
115
+ )
116
+
117
+ # initialize Llama 2 model with specific parameters
118
+ self.llm = Replicate(
119
+ model=self.llama2_70b,
120
+ model_kwargs={"temperature": 0.5,"top_p": 1, "max_new_tokens":1000}
121
+ )
122
+
123
+ # chain components to form final elevated RAG system using LangChain Expression Language (LCEL)
124
+ self.elevated_rag_chain = self.entry_point_and_elevated_retriever | self.rag_prompt | self.llm #| self.str_output_parser
app.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from advanced_rag import ElevatedRagChain
3
+
4
+
5
+ rag_chain = ElevatedRagChain()
6
+
7
+
8
+ def load_pdfs(pdf_links):
9
+ if not pdf_links:
10
+ gr.Warning("Please enter non-empty URLs")
11
+ return "Please enter non-empty URLs"
12
+ try:
13
+ pdf_links = pdf_links.split("\n") # get individual PDF links
14
+ rag_chain.add_pdfs_to_vectore_store(pdf_links)
15
+ gr.Info("PDFs loaded successfully into a new vector store. If you had an old one, it was overwritten.")
16
+ return "PDFs loaded successfully into a new vector store. If you had an old one, it was overwritten."
17
+ except Exception as e:
18
+ gr.Warning("Could not load PDFs. Are URLs valid?")
19
+ print(e)
20
+ return "Could not load PDFs. Are URLs valid?"
21
+
22
+
23
+ def submit_query(query):
24
+ if not query:
25
+ gr.Warning("Please enter a non-empty query")
26
+ return "Please enter a non-empty query"
27
+ if hasattr(rag_chain, 'elevated_rag_chain'):
28
+ try:
29
+ response = rag_chain.elevated_rag_chain.invoke(query)
30
+ return response
31
+ except Exception as e:
32
+ gr.Warning("LLM error. Please re-submit your query")
33
+ print(e)
34
+ return "LLM error. Please re-submit your query"
35
+
36
+ else:
37
+ gr.Warning("Please load PDFs before submitting a query")
38
+ return "Please load PDFs before submitting a query"
39
+
40
+
41
+ def reset_app():
42
+ global rag_chain
43
+ rag_chain = ElevatedRagChain() # Re-initialize the ElevatedRagChain object
44
+ gr.Info("App reset successfully. You can now load new PDFs")
45
+ return "App reset successfully. You can now load new PDFs"
46
+
47
+
48
+ # custom css for different age elements
49
+ custom_css = """
50
+ // customize button
51
+ button {
52
+ background-color: grey !important;
53
+ font-family: Arial !important;
54
+ font-weight: bold !important;
55
+ color: blue !important;
56
+ }
57
+
58
+
59
+
60
+ // customize background color and use it as "app = gr.Blocks(css=custom_css)"
61
+ //.gradio-container {background-color: #E0F7FA}
62
+ """
63
+
64
+ # Define the Gradio app using Blocks for a flexible layout
65
+ app = gr.Blocks(css=custom_css) # theme=gr.themes.Base(), Soft(), Default(), Glass(), Monochrome(): https://www.gradio.app/guides/theming-guide
66
+
67
+ with app:
68
+ gr.Markdown('''# Query your own data
69
+ ## Llama 2 RAG
70
+ - Type in one or more URLs for PDF files - one per line and click on Load PDFs. Wait until the RAG system is built.
71
+ - Type your query and click on Submit Query. Once the LLM sends back a reponse, it will be displayed in the Reponse field.
72
+ - The system "remembers" the source documents, but has no memory of past user queries.
73
+ - Click on Reset App to clear / reset the RAG system
74
+ ''')
75
+ with gr.Row():
76
+ with gr.Column():
77
+ pdf_input = gr.Textbox(label="Enter your PDF URLs (one per line)", placeholder="Enter one URL per line", lines=4)
78
+ load_button = gr.Button("Load PDF")
79
+ with gr.Column():
80
+ query_input = gr.Textbox(label="Enter your query here", placeholder="Type your query", lines=4)
81
+ submit_button = gr.Button("Submit")
82
+
83
+ response_output = gr.Textbox(label="Response", placeholder="Response will appear here", lines=4)
84
+ reset_button = gr.Button("Reset App")
85
+
86
+ load_button.click(load_pdfs, inputs=pdf_input, outputs=response_output)
87
+ submit_button.click(submit_query, inputs=query_input, outputs=response_output)
88
+ reset_button.click(reset_app, inputs=None, outputs=response_output)
89
+
90
+
91
+ # Run the app
92
+ app.launch()
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ libgl1
2
+ poppler-utils
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ langchain==0.1.6
2
+ langchain-community==0.0.19
3
+ langchain_core==0.1.22
4
+ langchain-openai==0.0.5
5
+ faiss-cpu==1.7.3
6
+ huggingface-hub==0.20.1
7
+ google-generativeai==0.3.2
8
+ cohere==4.46
9
+ openai==1.11.1
10
+ opencv-python==4.9.0.80
11
+ pdf2image==1.17.0
12
+ pdfminer-six==20221105
13
+ pikepdf==8.12.0
14
+ pypdf==4.0.1
15
+ rank-bm25==0.2.2
16
+ replicate==0.23.1
17
+ tiktoken==0.5.2
18
+ unstructured==0.12.3
19
+ unstructured-pytesseract==0.3.12
20
+ unstructured-inference==0.7.23