Spaces:
Sleeping
Sleeping
Andrew
commited on
Commit
·
30eced7
1
Parent(s):
e8fc33c
Initial commit
Browse files- .gitignore +1 -0
- README.md +42 -13
- advanced_rag.py +124 -0
- app.py +92 -0
- packages.txt +2 -0
- requirements.txt +20 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
**/.DS_Store
|
README.md
CHANGED
@@ -1,13 +1,42 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Advanced RAG System
|
3 |
+
|
4 |
+
This repository contains the code for a Gradio web app that demoes a Retrieval-Augmented Generation (RAG) system. This app is designed to allow users to load multiple documents of their choice into a vector database, submit queries, and receive answers generated by a sophisticated RAG system that leverages the latest advancements in natural language processing and information retrieval technologies.
|
5 |
+
|
6 |
+
## Features
|
7 |
+
|
8 |
+
#### 1. Dynamic Processing
|
9 |
+
- Users can load multiple source documents of their choice into a vector store in real-time.
|
10 |
+
- Users can submit queries which are processed in real-time for enhanced retrieval and generation.
|
11 |
+
|
12 |
+
#### 2. PDF Integration
|
13 |
+
- The system allows for the loading of multiple PDF documents into a vector store, enabling the RAG system to retrieve information from a vast corpus.
|
14 |
+
|
15 |
+
#### 3. Advanced RAG System
|
16 |
+
Integrates various components, including:
|
17 |
+
- **UI**: Allows users to input URLs for documents and then input user queries; displays the LLM response.
|
18 |
+
- **Document Loader**: Loads documents from URLs.
|
19 |
+
- **Text Splitter**: Chunks loaded documents.
|
20 |
+
- **Vector Store**: Embeds text chunks and adds them to a FAISS vector store; embeds user queries.
|
21 |
+
- **Retrievers**: Uses an ensemble of BM25 and FAISS retrievers, along with a Cohere reranker, to retrieve relevant document chunks based on user queries.
|
22 |
+
- **Language Model**: Utilizes a Llama 2 large language model for generating responses based on the user query and retrieved context.
|
23 |
+
|
24 |
+
#### 4. PDF and Query Error Handling
|
25 |
+
- Validates PDF URLs and queries to ensure that they are not empty and that they are valid.
|
26 |
+
- Displays error messages for empty queries or issues with the RAG system.
|
27 |
+
|
28 |
+
#### 5. Refresh Mechanism
|
29 |
+
- Instructs users to refresh the page to clear / reset the RAG system.
|
30 |
+
|
31 |
+
## Installation
|
32 |
+
|
33 |
+
To run this application, you need to have Python and Gradio installed. Follow these steps:
|
34 |
+
|
35 |
+
1. Clone this repository to your local machine.
|
36 |
+
2. Create and activate a virtual environment of your choice (venv, conda, etc.).
|
37 |
+
3. Install dependencies from the requirements.txt file by running `pip install -r requirements.txt`.
|
38 |
+
4. Set up environment variables REPLICATE_API_TOKEN (for a Llama 2 model hosted on replicate.com) and COHERE_API_KEY (for embeddings and reranking service on cohere.com)
|
39 |
+
4. Start the Gradio app by running `python rag_gradio_app.py`.
|
40 |
+
|
41 |
+
## Licence
|
42 |
+
MIT license
|
advanced_rag.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
from langchain_community.llms import Replicate # importing from langchain depricated; use langchain_community for several modules here
|
6 |
+
from langchain_community.document_loaders import OnlinePDFLoader
|
7 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
8 |
+
from langchain_community.vectorstores import FAISS
|
9 |
+
from langchain_community.embeddings import CohereEmbeddings
|
10 |
+
from langchain_community.retrievers import BM25Retriever
|
11 |
+
from langchain.retrievers import EnsembleRetriever
|
12 |
+
from langchain.retrievers import ContextualCompressionRetriever
|
13 |
+
from langchain.retrievers.document_compressors import CohereRerank
|
14 |
+
from langchain.prompts import ChatPromptTemplate
|
15 |
+
from langchain.schema import StrOutputParser
|
16 |
+
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
|
17 |
+
|
18 |
+
|
19 |
+
class ElevatedRagChain:
|
20 |
+
'''
|
21 |
+
Class ElevatedRagChain integrates various components from the langchain library to build
|
22 |
+
an advanced retrieval-augmented generation (RAG) system designed to process documents
|
23 |
+
by reading in, chunking, embedding, and adding their chunk embeddings to FAISS vector store
|
24 |
+
for efficient retrieval. It uses the embeddings to retrieve relevant document chunks
|
25 |
+
in response to user queries.
|
26 |
+
The chunks are retrieved using an ensemble retriever (BM25 retriever + FAISS retriver)
|
27 |
+
and passed through a Cohere reranker before being used as context
|
28 |
+
for generating answers using a Llama 2 large language model (LLM).
|
29 |
+
'''
|
30 |
+
def __init__(self) -> None:
|
31 |
+
'''
|
32 |
+
Initialize the class with predefined model, embedding function, weights, and top_k value
|
33 |
+
'''
|
34 |
+
self.llama2_70b = 'meta/llama-2-70b-chat:2d19859030ff705a87c746f7e96eea03aefb71f166725aee39692f1476566d48'
|
35 |
+
self.embed_func = CohereEmbeddings(model="embed-english-light-v3.0")
|
36 |
+
self.bm25_weight = 0.6
|
37 |
+
self.faiss_weight = 0.4
|
38 |
+
self.top_k = 5
|
39 |
+
|
40 |
+
|
41 |
+
def add_pdfs_to_vectore_store(
|
42 |
+
self,
|
43 |
+
pdf_links: List,
|
44 |
+
chunk_size: int=1500,
|
45 |
+
) -> None:
|
46 |
+
'''
|
47 |
+
Processes PDF documents by loading, chunking, embedding, and adding them to a FAISS vector store.
|
48 |
+
Build an advanced RAG system
|
49 |
+
Args:
|
50 |
+
pdf_links (List): list of URLs pointing to the PDF documents to be processed
|
51 |
+
chunk_size (int, optional): size of text chunks to split the documents into, defaults to 1500
|
52 |
+
'''
|
53 |
+
# load pdfs
|
54 |
+
self.raw_data = [ OnlinePDFLoader(doc).load()[0] for doc in pdf_links ]
|
55 |
+
|
56 |
+
# chunk text
|
57 |
+
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=100)
|
58 |
+
self.split_data = self.text_splitter.split_documents(self.raw_data)
|
59 |
+
|
60 |
+
# add chunks to BM25 retriever
|
61 |
+
self.bm25_retriever = BM25Retriever.from_documents(self.split_data)
|
62 |
+
self.bm25_retriever.k = self.top_k
|
63 |
+
|
64 |
+
# embed and add chunks to vectore store
|
65 |
+
self.vector_store = FAISS.from_documents(self.split_data, self.embed_func)
|
66 |
+
self.faiss_retriever = self.vector_store.as_retriever(search_kwargs={"k": self.top_k})
|
67 |
+
print("All PDFs processed and added to vectore store.")
|
68 |
+
|
69 |
+
# build advanced RAG system
|
70 |
+
self.build_elevated_rag_system()
|
71 |
+
print("RAG system is built successfully.")
|
72 |
+
|
73 |
+
|
74 |
+
def build_elevated_rag_system(self) -> None:
|
75 |
+
'''
|
76 |
+
Build an advanced RAG system from different components:
|
77 |
+
* BM25 retriever
|
78 |
+
* FAISS vector store retriever
|
79 |
+
* Llama 2 model
|
80 |
+
'''
|
81 |
+
# combine BM25 and FAISS retrievers into an ensemble retriever
|
82 |
+
self.ensemble_retriever = EnsembleRetriever(
|
83 |
+
retrievers=[self.bm25_retriever, self.faiss_retriever],
|
84 |
+
weights=[self.bm25_weight, self.faiss_weight]
|
85 |
+
)
|
86 |
+
|
87 |
+
# use reranker to improve retrieval quality
|
88 |
+
self.reranker = CohereRerank(top_n=5)
|
89 |
+
self.rerank_retriever = ContextualCompressionRetriever( # combine ensemble retriever and reranker
|
90 |
+
base_retriever=self.ensemble_retriever,
|
91 |
+
base_compressor=self.reranker,
|
92 |
+
)
|
93 |
+
|
94 |
+
# define prompt template for the language model
|
95 |
+
RAG_PROMPT_TEMPLATE = """\
|
96 |
+
Use the following context to provide a detailed technical answer the user's question.
|
97 |
+
Do not use an introduction similar to "Based on the provided documents, ...", just answer the question.
|
98 |
+
If you don't know the answer, please respond with "I don't know".
|
99 |
+
|
100 |
+
Context:
|
101 |
+
{context}
|
102 |
+
|
103 |
+
User's question:
|
104 |
+
{question}
|
105 |
+
"""
|
106 |
+
self.rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT_TEMPLATE)
|
107 |
+
self.str_output_parser = StrOutputParser()
|
108 |
+
|
109 |
+
# parallel execution of context retrieval and question passing
|
110 |
+
self.entry_point_and_elevated_retriever = RunnableParallel(
|
111 |
+
{
|
112 |
+
"context" : self.rerank_retriever,
|
113 |
+
"question" : RunnablePassthrough()
|
114 |
+
}
|
115 |
+
)
|
116 |
+
|
117 |
+
# initialize Llama 2 model with specific parameters
|
118 |
+
self.llm = Replicate(
|
119 |
+
model=self.llama2_70b,
|
120 |
+
model_kwargs={"temperature": 0.5,"top_p": 1, "max_new_tokens":1000}
|
121 |
+
)
|
122 |
+
|
123 |
+
# chain components to form final elevated RAG system using LangChain Expression Language (LCEL)
|
124 |
+
self.elevated_rag_chain = self.entry_point_and_elevated_retriever | self.rag_prompt | self.llm #| self.str_output_parser
|
app.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from advanced_rag import ElevatedRagChain
|
3 |
+
|
4 |
+
|
5 |
+
rag_chain = ElevatedRagChain()
|
6 |
+
|
7 |
+
|
8 |
+
def load_pdfs(pdf_links):
|
9 |
+
if not pdf_links:
|
10 |
+
gr.Warning("Please enter non-empty URLs")
|
11 |
+
return "Please enter non-empty URLs"
|
12 |
+
try:
|
13 |
+
pdf_links = pdf_links.split("\n") # get individual PDF links
|
14 |
+
rag_chain.add_pdfs_to_vectore_store(pdf_links)
|
15 |
+
gr.Info("PDFs loaded successfully into a new vector store. If you had an old one, it was overwritten.")
|
16 |
+
return "PDFs loaded successfully into a new vector store. If you had an old one, it was overwritten."
|
17 |
+
except Exception as e:
|
18 |
+
gr.Warning("Could not load PDFs. Are URLs valid?")
|
19 |
+
print(e)
|
20 |
+
return "Could not load PDFs. Are URLs valid?"
|
21 |
+
|
22 |
+
|
23 |
+
def submit_query(query):
|
24 |
+
if not query:
|
25 |
+
gr.Warning("Please enter a non-empty query")
|
26 |
+
return "Please enter a non-empty query"
|
27 |
+
if hasattr(rag_chain, 'elevated_rag_chain'):
|
28 |
+
try:
|
29 |
+
response = rag_chain.elevated_rag_chain.invoke(query)
|
30 |
+
return response
|
31 |
+
except Exception as e:
|
32 |
+
gr.Warning("LLM error. Please re-submit your query")
|
33 |
+
print(e)
|
34 |
+
return "LLM error. Please re-submit your query"
|
35 |
+
|
36 |
+
else:
|
37 |
+
gr.Warning("Please load PDFs before submitting a query")
|
38 |
+
return "Please load PDFs before submitting a query"
|
39 |
+
|
40 |
+
|
41 |
+
def reset_app():
|
42 |
+
global rag_chain
|
43 |
+
rag_chain = ElevatedRagChain() # Re-initialize the ElevatedRagChain object
|
44 |
+
gr.Info("App reset successfully. You can now load new PDFs")
|
45 |
+
return "App reset successfully. You can now load new PDFs"
|
46 |
+
|
47 |
+
|
48 |
+
# custom css for different age elements
|
49 |
+
custom_css = """
|
50 |
+
// customize button
|
51 |
+
button {
|
52 |
+
background-color: grey !important;
|
53 |
+
font-family: Arial !important;
|
54 |
+
font-weight: bold !important;
|
55 |
+
color: blue !important;
|
56 |
+
}
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
// customize background color and use it as "app = gr.Blocks(css=custom_css)"
|
61 |
+
//.gradio-container {background-color: #E0F7FA}
|
62 |
+
"""
|
63 |
+
|
64 |
+
# Define the Gradio app using Blocks for a flexible layout
|
65 |
+
app = gr.Blocks(css=custom_css) # theme=gr.themes.Base(), Soft(), Default(), Glass(), Monochrome(): https://www.gradio.app/guides/theming-guide
|
66 |
+
|
67 |
+
with app:
|
68 |
+
gr.Markdown('''# Query your own data
|
69 |
+
## Llama 2 RAG
|
70 |
+
- Type in one or more URLs for PDF files - one per line and click on Load PDFs. Wait until the RAG system is built.
|
71 |
+
- Type your query and click on Submit Query. Once the LLM sends back a reponse, it will be displayed in the Reponse field.
|
72 |
+
- The system "remembers" the source documents, but has no memory of past user queries.
|
73 |
+
- Click on Reset App to clear / reset the RAG system
|
74 |
+
''')
|
75 |
+
with gr.Row():
|
76 |
+
with gr.Column():
|
77 |
+
pdf_input = gr.Textbox(label="Enter your PDF URLs (one per line)", placeholder="Enter one URL per line", lines=4)
|
78 |
+
load_button = gr.Button("Load PDF")
|
79 |
+
with gr.Column():
|
80 |
+
query_input = gr.Textbox(label="Enter your query here", placeholder="Type your query", lines=4)
|
81 |
+
submit_button = gr.Button("Submit")
|
82 |
+
|
83 |
+
response_output = gr.Textbox(label="Response", placeholder="Response will appear here", lines=4)
|
84 |
+
reset_button = gr.Button("Reset App")
|
85 |
+
|
86 |
+
load_button.click(load_pdfs, inputs=pdf_input, outputs=response_output)
|
87 |
+
submit_button.click(submit_query, inputs=query_input, outputs=response_output)
|
88 |
+
reset_button.click(reset_app, inputs=None, outputs=response_output)
|
89 |
+
|
90 |
+
|
91 |
+
# Run the app
|
92 |
+
app.launch()
|
packages.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
libgl1
|
2 |
+
poppler-utils
|
requirements.txt
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
langchain==0.1.6
|
2 |
+
langchain-community==0.0.19
|
3 |
+
langchain_core==0.1.22
|
4 |
+
langchain-openai==0.0.5
|
5 |
+
faiss-cpu==1.7.3
|
6 |
+
huggingface-hub==0.20.1
|
7 |
+
google-generativeai==0.3.2
|
8 |
+
cohere==4.46
|
9 |
+
openai==1.11.1
|
10 |
+
opencv-python==4.9.0.80
|
11 |
+
pdf2image==1.17.0
|
12 |
+
pdfminer-six==20221105
|
13 |
+
pikepdf==8.12.0
|
14 |
+
pypdf==4.0.1
|
15 |
+
rank-bm25==0.2.2
|
16 |
+
replicate==0.23.1
|
17 |
+
tiktoken==0.5.2
|
18 |
+
unstructured==0.12.3
|
19 |
+
unstructured-pytesseract==0.3.12
|
20 |
+
unstructured-inference==0.7.23
|