RAG-PDF-AI / app.py
DHEIVER's picture
Update app.py
5435b87 verified
import gradio as gr
import os
from functools import partial
api_token = os.getenv("HF_TOKEN")
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import ConversationalRetrievalChain
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.memory import ConversationBufferMemory
from langchain_community.llms import HuggingFaceEndpoint
list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2", "deepseek-ai/DeepSeek-R1"]
list_llm_simple = [os.path.basename(llm) for llm in list_llm]
# Load and split PDF document
def load_doc(list_file_path):
loaders = [PyPDFLoader(x) for x in list_file_path]
pages = []
for loader in loaders:
pages.extend(loader.load())
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
doc_splits = text_splitter.split_documents(pages)
return doc_splits
# Create vector database
def create_db(splits):
embeddings = HuggingFaceEmbeddings()
vectordb = FAISS.from_documents(splits, embeddings)
return vectordb
# Initialize database
def initialize_database(list_file_obj, progress=gr.Progress()):
list_file_path = [x.name for x in list_file_obj if x is not None]
doc_splits = load_doc(list_file_path)
vector_db = create_db(doc_splits)
return vector_db, "Database created successfully! ✅"
# Initialize langchain LLM chain
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
if llm_model == "meta-llama/Meta-Llama-3-8B-Instruct":
llm = HuggingFaceEndpoint(
repo_id=llm_model,
huggingfacehub_api_token=api_token,
temperature=temperature,
max_new_tokens=max_tokens,
top_k=top_k,
timeout=120,
max_retries=3
)
else:
llm = HuggingFaceEndpoint(
huggingfacehub_api_token=api_token,
repo_id=llm_model,
temperature=temperature,
max_new_tokens=max_tokens,
top_k=top_k,
timeout=120,
max_retries=3
)
memory = ConversationBufferMemory(memory_key="chat_history", output_key='answer', return_messages=True)
retriever = vector_db.as_retriever()
qa_chain = ConversationalRetrievalChain.from_llm(
llm,
retriever=retriever,
chain_type="stuff",
memory=memory,
return_source_documents=True,
verbose=False,
)
return qa_chain
# Initialize LLM
def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
llm_name = list_llm[llm_option]
qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
return qa_chain, "QA chain initialized. Chatbot is ready! 🚀"
def format_chat_history(message, chat_history):
formatted_chat_history = []
for user_message, bot_message in chat_history:
formatted_chat_history.append(f"User: {user_message}")
formatted_chat_history.append(f"Assistant: {bot_message}")
return formatted_chat_history
def conversation(qa_chain, message, history, language):
formatted_chat_history = format_chat_history(message, history)
if language == "Português":
prompt = f"Responda em português: {message}"
else:
prompt = f"Answer in English: {message}"
try:
response = qa_chain.invoke({"question": prompt, "chat_history": formatted_chat_history})
response_answer = response["answer"]
if response_answer.find("Helpful Answer:") != -1:
response_answer = response_answer.split("Helpful Answer:")[-1]
except Exception as e:
if language == "Português":
response_answer = f"Erro: Não foi possível obter resposta do modelo devido a problemas no servidor. Tente novamente mais tarde. ({str(e)})"
else:
response_answer = f"Error: Could not get a response from the model due to server issues. Please try again later. ({str(e)})"
try:
response_sources = response["source_documents"]
response_source1 = response_sources[0].page_content.strip()
response_source1_page = response_sources[0].metadata["page"] + 1
response_source2 = response_sources[1].page_content.strip()
response_source2_page = response_sources[1].metadata["page"] + 1
response_source3 = response_sources[2].page_content.strip()
response_source3_page = response_sources[2].metadata["page"] + 1
except:
response_source1 = response_source2 = response_source3 = "N/A"
response_source1_page = response_source2_page = response_source3_page = 0
new_history = history + [(message, response_answer)]
return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
# Main demo with enhanced UI
def demo():
# Custom CSS
custom_css = """
/* Global styles */
body {
font-family: 'Inter', sans-serif;
color: #333333; /* Dark Gray Text */
background-color: #f7f7f7; /* Light Gray Background */
}
.container {
max-width: 1200px;
margin: 0 auto;
}
/* Header styles */
.header {
text-align: center;
padding: 20px 0;
margin-bottom: 20px;
background: linear-gradient(90deg, #3171c7, #24599b); /* Primary & Secondary Blue */
color: white;
border-radius: 10px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
.header h1 {
font-size: 2.5rem;
margin: 0;
padding: 0;
}
.header p {
font-size: 1.1rem;
margin: 10px 0 0;
opacity: 0.9;
}
/* Card styles */
.card {
background-color: white;
border-radius: 10px;
padding: 20px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05);
margin-bottom: 20px;
}
/* Section titles */
.section-title {
font-size: 1.25rem;
font-weight: 600;
margin-bottom: 15px;
color: #3171c7; /* Primary Blue */
display: flex;
align-items: center;
}
.section-title svg {
margin-right: 8px;
}
/* Buttons */
.primary-button {
background: linear-gradient(90deg, #3171c7, #24599b); /* Primary & Secondary Blue */
color: white !important;
border: none !important;
padding: 10px 20px !important;
border-radius: 8px !important;
font-weight: 500 !important;
cursor: pointer;
transition: all 0.2s ease;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1) !important;
}
.primary-button:hover {
background: linear-gradient(90deg, #24599b, #1d4a83); /* Darker Blue on Hover */
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.15) !important;
transform: translateY(-1px);
}
/* Status indicators */
.status {
padding: 8px 12px !important;
border-radius: 6px !important;
font-size: 0.9rem !important;
font-weight: 500 !important;
}
.status-success {
background-color: #d1fae5 !important;
color: #065f46 !important; /* Teal */
}
.status-waiting {
background-color: #fef3c7 !important;
color: #92400e !important;
}
.status-error {
background-color: #fee2e2 !important;
color: #d9534f !important; /* Red */
}
/* Chat container */
.chat-container {
border-radius: 10px !important;
border: 1px solid #e0e0e0 !important; /* Medium Gray Border */
overflow: hidden !important;
}
/* Document upload area */
.upload-area {
border: 2px dashed #d1d5db !important;
border-radius: 8px !important;
padding: 20px !important;
text-align: center !important;
background-color: #f9fafb !important;
transition: all 0.2s ease;
}
.upload-area:hover {
border-color: #3171c7 !important; /* Primary Blue on Hover */
background-color: #eff6ff !important;
}
/* Parameter sliders */
.parameter-slider {
margin-bottom: 15px !important;
}
/* Reference boxes */
.reference-box {
background-color: #f3f4f6 !important;
border-left: 4px solid #3171c7 !important; /* Primary Blue */
padding: 10px 15px !important;
margin-bottom: 10px !important;
border-radius: 4px !important;
}
.reference-box-title {
font-weight: 600 !important;
color: #3171c7 !important; /* Primary Blue */
margin-bottom: 5px !important;
display: flex !important;
justify-content: space-between !important;
}
.page-number {
background-color: #dbeafe !important;
color: #3171c7 !important; /* Primary Blue */
padding: 2px 8px !important;
border-radius: 12px !important;
font-size: 0.8rem !important;
}
/* Responsive adjustments */
@media (max-width: 768px) {
.header h1 {
font-size: 1.8rem;
}
}
"""
# HTML Components
header_html = """
<div class="header">
<h1>📚 RAG PDF Chatbot</h1>
<p>Query your documents with AI-powered search and generation</p>
</div>
"""
upload_html = """
<div class="section-title">
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<path d="M21 15v4a2 2 0 0 1-2 2H5a2 2 0 0 1-2-2v-4"></path>
<polyline points="17 8 12 3 7 8"></polyline>
<line x1="12" y1="3" x2="12" y2="15"></line>
</svg>
Upload your PDF documents
</div>
<p>Select one or more PDF files to analyze and chat with.</p>
"""
model_html = """
<div class="section-title">
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<path d="M12 2L2 7l10 5 10-5-10-5z"></path>
<path d="M2 17l10 5 10-5"></path>
<path d="M2 12l10 5 10-5"></path>
</svg>
Select AI Model
</div>
<p>Choose the language model that will process your questions.</p>
"""
chat_html = """
<div class="section-title">
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<path d="M21 15a2 2 0 0 1-2 2H7l-4 4V5a2 2 0 0 1 2-2h14a2 2 0 0 1 2 2z"></path>
</svg>
Chat with your Documents
</div>
<p>Ask questions about your uploaded documents to get AI-powered answers.</p>
"""
reference_html = """
<div class="section-title">
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<path d="M2 3h6a4 4 0 0 1 4 4v14a3 3 0 0 0-3-3H2z"></path>
<path d="M22 3h-6a4 4 0 0 0-4 4v14a3 3 0 0 1 3-3h7z"></path>
</svg>
Document References
</div>
<p>These are the relevant sections from your documents that the AI used to generate its response.</p>
"""
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="blue", neutral_hue="slate"), css=custom_css) as demo:
# State variables
vector_db = gr.State()
qa_chain = gr.State()
# Header
gr.HTML(header_html)
with gr.Row():
# Left column - Setup
with gr.Column(scale=1):
with gr.Group(elem_classes="card"):
gr.HTML(upload_html)
document = gr.Files(height=200, file_count="multiple", file_types=["pdf"], interactive=True)
db_btn = gr.Button("Create Vector Database", elem_classes="primary-button")
db_progress = gr.Textbox(value="Not initialized", show_label=False, elem_classes="status status-waiting")
with gr.Group(elem_classes="card"):
gr.HTML(model_html)
llm_btn = gr.Radio(list_llm_simple, label="", value=list_llm_simple[0], type="index")
with gr.Accordion("Advanced Parameters", open=False):
slider_temperature = gr.Slider(minimum=0.01, maximum=1.0, value=0.5, step=0.1, label="Temperature", interactive=True, elem_classes="parameter-slider")
slider_maxtokens = gr.Slider(minimum=128, maximum=9192, value=4096, step=128, label="Max Tokens", interactive=True, elem_classes="parameter-slider")
slider_topk = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Top-K", interactive=True, elem_classes="parameter-slider")
qachain_btn = gr.Button("Initialize Chatbot", elem_classes="primary-button")
llm_progress = gr.Textbox(value="Not initialized", show_label=False, elem_classes="status status-waiting")
with gr.Group(elem_classes="card"):
gr.Markdown("### Usage Instructions")
gr.Markdown("""
1. Upload one or more PDF documents
2. Click "Create Vector Database"
3. Select your preferred AI model
4. Click "Initialize Chatbot"
5. Start asking questions about your documents
**Note:** The system will analyze your documents and use AI to answer questions based on their content.
""")
# Right column - Chat
with gr.Column(scale=1.5):
with gr.Group(elem_classes="card"):
gr.HTML(chat_html)
language_selector = gr.Radio(["English", "Português"], label="Response Language", value="English")
chatbot = gr.Chatbot(height=400, elem_classes="chat-container")
with gr.Row():
with gr.Column(scale=4):
msg = gr.Textbox(placeholder="Ask a question about your documents...", show_label=False)
with gr.Column(scale=1):
submit_btn = gr.Button("Send", elem_classes="primary-button")
with gr.Row():
clear_btn = gr.Button("Clear Chat", scale=1)
with gr.Group(elem_classes="card"):
gr.HTML(reference_html)
with gr.Accordion("Document References", open=True):
# Reference 1
gr.Markdown("**Reference 1**", elem_classes="reference-box-title")
with gr.Row():
doc_source1 = gr.Textbox(show_label=False, lines=2, elem_classes="reference-box")
source1_page = gr.Number(label="Page", show_label=True, elem_classes="page-number")
# Reference 2
gr.Markdown("**Reference 2**", elem_classes="reference-box-title")
with gr.Row():
doc_source2 = gr.Textbox(show_label=False, lines=2, elem_classes="reference-box")
source2_page = gr.Number(label="Page", show_label=True, elem_classes="page-number")
# Reference 3
gr.Markdown("**Reference 3**", elem_classes="reference-box-title")
with gr.Row():
doc_source3 = gr.Textbox(show_label=False, lines=2, elem_classes="reference-box")
source3_page = gr.Number(label="Page", show_label=True, elem_classes="page-number")
# Preprocessing events
db_btn.click(initialize_database, inputs=[document], outputs=[vector_db, db_progress])
qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], outputs=[qa_chain, llm_progress]).then(
lambda: [None, "", 0, "", 0, "", 0], inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False
)
# Chatbot events
msg.submit(conversation, inputs=[qa_chain, msg, chatbot, language_selector], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False)
submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot, language_selector], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False)
clear_btn.click(lambda: [None, "", 0, "", 0, "", 0], inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False)
demo.queue().launch(debug=True)
if __name__ == "__main__":
demo()