Spaces:
Running
Running
import gradio as gr | |
import os | |
from functools import partial | |
api_token = os.getenv("HF_TOKEN") | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain.memory import ConversationBufferMemory | |
from langchain_community.llms import HuggingFaceEndpoint | |
list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2", "deepseek-ai/DeepSeek-R1"] | |
list_llm_simple = [os.path.basename(llm) for llm in list_llm] | |
# Load and split PDF document | |
def load_doc(list_file_path): | |
loaders = [PyPDFLoader(x) for x in list_file_path] | |
pages = [] | |
for loader in loaders: | |
pages.extend(loader.load()) | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64) | |
doc_splits = text_splitter.split_documents(pages) | |
return doc_splits | |
# Create vector database | |
def create_db(splits): | |
embeddings = HuggingFaceEmbeddings() | |
vectordb = FAISS.from_documents(splits, embeddings) | |
return vectordb | |
# Initialize database | |
def initialize_database(list_file_obj, progress=gr.Progress()): | |
list_file_path = [x.name for x in list_file_obj if x is not None] | |
doc_splits = load_doc(list_file_path) | |
vector_db = create_db(doc_splits) | |
return vector_db, "Database created successfully! ✅" | |
# Initialize langchain LLM chain | |
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()): | |
if llm_model == "meta-llama/Meta-Llama-3-8B-Instruct": | |
llm = HuggingFaceEndpoint( | |
repo_id=llm_model, | |
huggingfacehub_api_token=api_token, | |
temperature=temperature, | |
max_new_tokens=max_tokens, | |
top_k=top_k, | |
timeout=120, | |
max_retries=3 | |
) | |
else: | |
llm = HuggingFaceEndpoint( | |
huggingfacehub_api_token=api_token, | |
repo_id=llm_model, | |
temperature=temperature, | |
max_new_tokens=max_tokens, | |
top_k=top_k, | |
timeout=120, | |
max_retries=3 | |
) | |
memory = ConversationBufferMemory(memory_key="chat_history", output_key='answer', return_messages=True) | |
retriever = vector_db.as_retriever() | |
qa_chain = ConversationalRetrievalChain.from_llm( | |
llm, | |
retriever=retriever, | |
chain_type="stuff", | |
memory=memory, | |
return_source_documents=True, | |
verbose=False, | |
) | |
return qa_chain | |
# Initialize LLM | |
def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()): | |
llm_name = list_llm[llm_option] | |
qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress) | |
return qa_chain, "QA chain initialized. Chatbot is ready! 🚀" | |
def format_chat_history(message, chat_history): | |
formatted_chat_history = [] | |
for user_message, bot_message in chat_history: | |
formatted_chat_history.append(f"User: {user_message}") | |
formatted_chat_history.append(f"Assistant: {bot_message}") | |
return formatted_chat_history | |
def conversation(qa_chain, message, history, language): | |
formatted_chat_history = format_chat_history(message, history) | |
if language == "Português": | |
prompt = f"Responda em português: {message}" | |
else: | |
prompt = f"Answer in English: {message}" | |
try: | |
response = qa_chain.invoke({"question": prompt, "chat_history": formatted_chat_history}) | |
response_answer = response["answer"] | |
if response_answer.find("Helpful Answer:") != -1: | |
response_answer = response_answer.split("Helpful Answer:")[-1] | |
except Exception as e: | |
if language == "Português": | |
response_answer = f"Erro: Não foi possível obter resposta do modelo devido a problemas no servidor. Tente novamente mais tarde. ({str(e)})" | |
else: | |
response_answer = f"Error: Could not get a response from the model due to server issues. Please try again later. ({str(e)})" | |
try: | |
response_sources = response["source_documents"] | |
response_source1 = response_sources[0].page_content.strip() | |
response_source1_page = response_sources[0].metadata["page"] + 1 | |
response_source2 = response_sources[1].page_content.strip() | |
response_source2_page = response_sources[1].metadata["page"] + 1 | |
response_source3 = response_sources[2].page_content.strip() | |
response_source3_page = response_sources[2].metadata["page"] + 1 | |
except: | |
response_source1 = response_source2 = response_source3 = "N/A" | |
response_source1_page = response_source2_page = response_source3_page = 0 | |
new_history = history + [(message, response_answer)] | |
return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page | |
# Main demo with enhanced UI | |
def demo(): | |
# Custom CSS | |
custom_css = """ | |
/* Global styles */ | |
body { | |
font-family: 'Inter', sans-serif; | |
color: #333333; /* Dark Gray Text */ | |
background-color: #f7f7f7; /* Light Gray Background */ | |
} | |
.container { | |
max-width: 1200px; | |
margin: 0 auto; | |
} | |
/* Header styles */ | |
.header { | |
text-align: center; | |
padding: 20px 0; | |
margin-bottom: 20px; | |
background: linear-gradient(90deg, #3171c7, #24599b); /* Primary & Secondary Blue */ | |
color: white; | |
border-radius: 10px; | |
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); | |
} | |
.header h1 { | |
font-size: 2.5rem; | |
margin: 0; | |
padding: 0; | |
} | |
.header p { | |
font-size: 1.1rem; | |
margin: 10px 0 0; | |
opacity: 0.9; | |
} | |
/* Card styles */ | |
.card { | |
background-color: white; | |
border-radius: 10px; | |
padding: 20px; | |
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05); | |
margin-bottom: 20px; | |
} | |
/* Section titles */ | |
.section-title { | |
font-size: 1.25rem; | |
font-weight: 600; | |
margin-bottom: 15px; | |
color: #3171c7; /* Primary Blue */ | |
display: flex; | |
align-items: center; | |
} | |
.section-title svg { | |
margin-right: 8px; | |
} | |
/* Buttons */ | |
.primary-button { | |
background: linear-gradient(90deg, #3171c7, #24599b); /* Primary & Secondary Blue */ | |
color: white !important; | |
border: none !important; | |
padding: 10px 20px !important; | |
border-radius: 8px !important; | |
font-weight: 500 !important; | |
cursor: pointer; | |
transition: all 0.2s ease; | |
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1) !important; | |
} | |
.primary-button:hover { | |
background: linear-gradient(90deg, #24599b, #1d4a83); /* Darker Blue on Hover */ | |
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.15) !important; | |
transform: translateY(-1px); | |
} | |
/* Status indicators */ | |
.status { | |
padding: 8px 12px !important; | |
border-radius: 6px !important; | |
font-size: 0.9rem !important; | |
font-weight: 500 !important; | |
} | |
.status-success { | |
background-color: #d1fae5 !important; | |
color: #065f46 !important; /* Teal */ | |
} | |
.status-waiting { | |
background-color: #fef3c7 !important; | |
color: #92400e !important; | |
} | |
.status-error { | |
background-color: #fee2e2 !important; | |
color: #d9534f !important; /* Red */ | |
} | |
/* Chat container */ | |
.chat-container { | |
border-radius: 10px !important; | |
border: 1px solid #e0e0e0 !important; /* Medium Gray Border */ | |
overflow: hidden !important; | |
} | |
/* Document upload area */ | |
.upload-area { | |
border: 2px dashed #d1d5db !important; | |
border-radius: 8px !important; | |
padding: 20px !important; | |
text-align: center !important; | |
background-color: #f9fafb !important; | |
transition: all 0.2s ease; | |
} | |
.upload-area:hover { | |
border-color: #3171c7 !important; /* Primary Blue on Hover */ | |
background-color: #eff6ff !important; | |
} | |
/* Parameter sliders */ | |
.parameter-slider { | |
margin-bottom: 15px !important; | |
} | |
/* Reference boxes */ | |
.reference-box { | |
background-color: #f3f4f6 !important; | |
border-left: 4px solid #3171c7 !important; /* Primary Blue */ | |
padding: 10px 15px !important; | |
margin-bottom: 10px !important; | |
border-radius: 4px !important; | |
} | |
.reference-box-title { | |
font-weight: 600 !important; | |
color: #3171c7 !important; /* Primary Blue */ | |
margin-bottom: 5px !important; | |
display: flex !important; | |
justify-content: space-between !important; | |
} | |
.page-number { | |
background-color: #dbeafe !important; | |
color: #3171c7 !important; /* Primary Blue */ | |
padding: 2px 8px !important; | |
border-radius: 12px !important; | |
font-size: 0.8rem !important; | |
} | |
/* Responsive adjustments */ | |
@media (max-width: 768px) { | |
.header h1 { | |
font-size: 1.8rem; | |
} | |
} | |
""" | |
# HTML Components | |
header_html = """ | |
<div class="header"> | |
<h1>📚 RAG PDF Chatbot</h1> | |
<p>Query your documents with AI-powered search and generation</p> | |
</div> | |
""" | |
upload_html = """ | |
<div class="section-title"> | |
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"> | |
<path d="M21 15v4a2 2 0 0 1-2 2H5a2 2 0 0 1-2-2v-4"></path> | |
<polyline points="17 8 12 3 7 8"></polyline> | |
<line x1="12" y1="3" x2="12" y2="15"></line> | |
</svg> | |
Upload your PDF documents | |
</div> | |
<p>Select one or more PDF files to analyze and chat with.</p> | |
""" | |
model_html = """ | |
<div class="section-title"> | |
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"> | |
<path d="M12 2L2 7l10 5 10-5-10-5z"></path> | |
<path d="M2 17l10 5 10-5"></path> | |
<path d="M2 12l10 5 10-5"></path> | |
</svg> | |
Select AI Model | |
</div> | |
<p>Choose the language model that will process your questions.</p> | |
""" | |
chat_html = """ | |
<div class="section-title"> | |
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"> | |
<path d="M21 15a2 2 0 0 1-2 2H7l-4 4V5a2 2 0 0 1 2-2h14a2 2 0 0 1 2 2z"></path> | |
</svg> | |
Chat with your Documents | |
</div> | |
<p>Ask questions about your uploaded documents to get AI-powered answers.</p> | |
""" | |
reference_html = """ | |
<div class="section-title"> | |
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"> | |
<path d="M2 3h6a4 4 0 0 1 4 4v14a3 3 0 0 0-3-3H2z"></path> | |
<path d="M22 3h-6a4 4 0 0 0-4 4v14a3 3 0 0 1 3-3h7z"></path> | |
</svg> | |
Document References | |
</div> | |
<p>These are the relevant sections from your documents that the AI used to generate its response.</p> | |
""" | |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="blue", neutral_hue="slate"), css=custom_css) as demo: | |
# State variables | |
vector_db = gr.State() | |
qa_chain = gr.State() | |
# Header | |
gr.HTML(header_html) | |
with gr.Row(): | |
# Left column - Setup | |
with gr.Column(scale=1): | |
with gr.Group(elem_classes="card"): | |
gr.HTML(upload_html) | |
document = gr.Files(height=200, file_count="multiple", file_types=["pdf"], interactive=True) | |
db_btn = gr.Button("Create Vector Database", elem_classes="primary-button") | |
db_progress = gr.Textbox(value="Not initialized", show_label=False, elem_classes="status status-waiting") | |
with gr.Group(elem_classes="card"): | |
gr.HTML(model_html) | |
llm_btn = gr.Radio(list_llm_simple, label="", value=list_llm_simple[0], type="index") | |
with gr.Accordion("Advanced Parameters", open=False): | |
slider_temperature = gr.Slider(minimum=0.01, maximum=1.0, value=0.5, step=0.1, label="Temperature", interactive=True, elem_classes="parameter-slider") | |
slider_maxtokens = gr.Slider(minimum=128, maximum=9192, value=4096, step=128, label="Max Tokens", interactive=True, elem_classes="parameter-slider") | |
slider_topk = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Top-K", interactive=True, elem_classes="parameter-slider") | |
qachain_btn = gr.Button("Initialize Chatbot", elem_classes="primary-button") | |
llm_progress = gr.Textbox(value="Not initialized", show_label=False, elem_classes="status status-waiting") | |
with gr.Group(elem_classes="card"): | |
gr.Markdown("### Usage Instructions") | |
gr.Markdown(""" | |
1. Upload one or more PDF documents | |
2. Click "Create Vector Database" | |
3. Select your preferred AI model | |
4. Click "Initialize Chatbot" | |
5. Start asking questions about your documents | |
**Note:** The system will analyze your documents and use AI to answer questions based on their content. | |
""") | |
# Right column - Chat | |
with gr.Column(scale=1.5): | |
with gr.Group(elem_classes="card"): | |
gr.HTML(chat_html) | |
language_selector = gr.Radio(["English", "Português"], label="Response Language", value="English") | |
chatbot = gr.Chatbot(height=400, elem_classes="chat-container") | |
with gr.Row(): | |
with gr.Column(scale=4): | |
msg = gr.Textbox(placeholder="Ask a question about your documents...", show_label=False) | |
with gr.Column(scale=1): | |
submit_btn = gr.Button("Send", elem_classes="primary-button") | |
with gr.Row(): | |
clear_btn = gr.Button("Clear Chat", scale=1) | |
with gr.Group(elem_classes="card"): | |
gr.HTML(reference_html) | |
with gr.Accordion("Document References", open=True): | |
# Reference 1 | |
gr.Markdown("**Reference 1**", elem_classes="reference-box-title") | |
with gr.Row(): | |
doc_source1 = gr.Textbox(show_label=False, lines=2, elem_classes="reference-box") | |
source1_page = gr.Number(label="Page", show_label=True, elem_classes="page-number") | |
# Reference 2 | |
gr.Markdown("**Reference 2**", elem_classes="reference-box-title") | |
with gr.Row(): | |
doc_source2 = gr.Textbox(show_label=False, lines=2, elem_classes="reference-box") | |
source2_page = gr.Number(label="Page", show_label=True, elem_classes="page-number") | |
# Reference 3 | |
gr.Markdown("**Reference 3**", elem_classes="reference-box-title") | |
with gr.Row(): | |
doc_source3 = gr.Textbox(show_label=False, lines=2, elem_classes="reference-box") | |
source3_page = gr.Number(label="Page", show_label=True, elem_classes="page-number") | |
# Preprocessing events | |
db_btn.click(initialize_database, inputs=[document], outputs=[vector_db, db_progress]) | |
qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], outputs=[qa_chain, llm_progress]).then( | |
lambda: [None, "", 0, "", 0, "", 0], inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False | |
) | |
# Chatbot events | |
msg.submit(conversation, inputs=[qa_chain, msg, chatbot, language_selector], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False) | |
submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot, language_selector], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False) | |
clear_btn.click(lambda: [None, "", 0, "", 0, "", 0], inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False) | |
demo.queue().launch(debug=True) | |
if __name__ == "__main__": | |
demo() |