import gradio as gr import os from functools import partial api_token = os.getenv("HF_TOKEN") from langchain_community.vectorstores import FAISS from langchain_community.document_loaders import PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.chains import ConversationalRetrievalChain from langchain_community.embeddings import HuggingFaceEmbeddings from langchain.memory import ConversationBufferMemory from langchain_community.llms import HuggingFaceEndpoint list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2", "deepseek-ai/DeepSeek-R1"] list_llm_simple = [os.path.basename(llm) for llm in list_llm] # Load and split PDF document def load_doc(list_file_path): loaders = [PyPDFLoader(x) for x in list_file_path] pages = [] for loader in loaders: pages.extend(loader.load()) text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64) doc_splits = text_splitter.split_documents(pages) return doc_splits # Create vector database def create_db(splits): embeddings = HuggingFaceEmbeddings() vectordb = FAISS.from_documents(splits, embeddings) return vectordb # Initialize database def initialize_database(list_file_obj, progress=gr.Progress()): list_file_path = [x.name for x in list_file_obj if x is not None] doc_splits = load_doc(list_file_path) vector_db = create_db(doc_splits) return vector_db, "Database created successfully! ✅" # Initialize langchain LLM chain def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()): if llm_model == "meta-llama/Meta-Llama-3-8B-Instruct": llm = HuggingFaceEndpoint( repo_id=llm_model, huggingfacehub_api_token=api_token, temperature=temperature, max_new_tokens=max_tokens, top_k=top_k, timeout=120, max_retries=3 ) else: llm = HuggingFaceEndpoint( huggingfacehub_api_token=api_token, repo_id=llm_model, temperature=temperature, max_new_tokens=max_tokens, top_k=top_k, timeout=120, max_retries=3 ) memory = ConversationBufferMemory(memory_key="chat_history", output_key='answer', return_messages=True) retriever = vector_db.as_retriever() qa_chain = ConversationalRetrievalChain.from_llm( llm, retriever=retriever, chain_type="stuff", memory=memory, return_source_documents=True, verbose=False, ) return qa_chain # Initialize LLM def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()): llm_name = list_llm[llm_option] qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress) return qa_chain, "QA chain initialized. Chatbot is ready! 🚀" def format_chat_history(message, chat_history): formatted_chat_history = [] for user_message, bot_message in chat_history: formatted_chat_history.append(f"User: {user_message}") formatted_chat_history.append(f"Assistant: {bot_message}") return formatted_chat_history def conversation(qa_chain, message, history, language): formatted_chat_history = format_chat_history(message, history) if language == "Português": prompt = f"Responda em português: {message}" else: prompt = f"Answer in English: {message}" try: response = qa_chain.invoke({"question": prompt, "chat_history": formatted_chat_history}) response_answer = response["answer"] if response_answer.find("Helpful Answer:") != -1: response_answer = response_answer.split("Helpful Answer:")[-1] except Exception as e: if language == "Português": response_answer = f"Erro: Não foi possível obter resposta do modelo devido a problemas no servidor. Tente novamente mais tarde. ({str(e)})" else: response_answer = f"Error: Could not get a response from the model due to server issues. Please try again later. ({str(e)})" try: response_sources = response["source_documents"] response_source1 = response_sources[0].page_content.strip() response_source1_page = response_sources[0].metadata["page"] + 1 response_source2 = response_sources[1].page_content.strip() response_source2_page = response_sources[1].metadata["page"] + 1 response_source3 = response_sources[2].page_content.strip() response_source3_page = response_sources[2].metadata["page"] + 1 except: response_source1 = response_source2 = response_source3 = "N/A" response_source1_page = response_source2_page = response_source3_page = 0 new_history = history + [(message, response_answer)] return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page # Main demo with enhanced UI def demo(): # Custom CSS custom_css = """ /* Global styles */ body { font-family: 'Inter', sans-serif; color: #333333; /* Dark Gray Text */ background-color: #f7f7f7; /* Light Gray Background */ } .container { max-width: 1200px; margin: 0 auto; } /* Header styles */ .header { text-align: center; padding: 20px 0; margin-bottom: 20px; background: linear-gradient(90deg, #3171c7, #24599b); /* Primary & Secondary Blue */ color: white; border-radius: 10px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); } .header h1 { font-size: 2.5rem; margin: 0; padding: 0; } .header p { font-size: 1.1rem; margin: 10px 0 0; opacity: 0.9; } /* Card styles */ .card { background-color: white; border-radius: 10px; padding: 20px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05); margin-bottom: 20px; } /* Section titles */ .section-title { font-size: 1.25rem; font-weight: 600; margin-bottom: 15px; color: #3171c7; /* Primary Blue */ display: flex; align-items: center; } .section-title svg { margin-right: 8px; } /* Buttons */ .primary-button { background: linear-gradient(90deg, #3171c7, #24599b); /* Primary & Secondary Blue */ color: white !important; border: none !important; padding: 10px 20px !important; border-radius: 8px !important; font-weight: 500 !important; cursor: pointer; transition: all 0.2s ease; box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1) !important; } .primary-button:hover { background: linear-gradient(90deg, #24599b, #1d4a83); /* Darker Blue on Hover */ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.15) !important; transform: translateY(-1px); } /* Status indicators */ .status { padding: 8px 12px !important; border-radius: 6px !important; font-size: 0.9rem !important; font-weight: 500 !important; } .status-success { background-color: #d1fae5 !important; color: #065f46 !important; /* Teal */ } .status-waiting { background-color: #fef3c7 !important; color: #92400e !important; } .status-error { background-color: #fee2e2 !important; color: #d9534f !important; /* Red */ } /* Chat container */ .chat-container { border-radius: 10px !important; border: 1px solid #e0e0e0 !important; /* Medium Gray Border */ overflow: hidden !important; } /* Document upload area */ .upload-area { border: 2px dashed #d1d5db !important; border-radius: 8px !important; padding: 20px !important; text-align: center !important; background-color: #f9fafb !important; transition: all 0.2s ease; } .upload-area:hover { border-color: #3171c7 !important; /* Primary Blue on Hover */ background-color: #eff6ff !important; } /* Parameter sliders */ .parameter-slider { margin-bottom: 15px !important; } /* Reference boxes */ .reference-box { background-color: #f3f4f6 !important; border-left: 4px solid #3171c7 !important; /* Primary Blue */ padding: 10px 15px !important; margin-bottom: 10px !important; border-radius: 4px !important; } .reference-box-title { font-weight: 600 !important; color: #3171c7 !important; /* Primary Blue */ margin-bottom: 5px !important; display: flex !important; justify-content: space-between !important; } .page-number { background-color: #dbeafe !important; color: #3171c7 !important; /* Primary Blue */ padding: 2px 8px !important; border-radius: 12px !important; font-size: 0.8rem !important; } /* Responsive adjustments */ @media (max-width: 768px) { .header h1 { font-size: 1.8rem; } } """ # HTML Components header_html = """

📚 RAG PDF Chatbot

Query your documents with AI-powered search and generation

""" upload_html = """
Upload your PDF documents

Select one or more PDF files to analyze and chat with.

""" model_html = """
Select AI Model

Choose the language model that will process your questions.

""" chat_html = """
Chat with your Documents

Ask questions about your uploaded documents to get AI-powered answers.

""" reference_html = """
Document References

These are the relevant sections from your documents that the AI used to generate its response.

""" with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="blue", neutral_hue="slate"), css=custom_css) as demo: # State variables vector_db = gr.State() qa_chain = gr.State() # Header gr.HTML(header_html) with gr.Row(): # Left column - Setup with gr.Column(scale=1): with gr.Group(elem_classes="card"): gr.HTML(upload_html) document = gr.Files(height=200, file_count="multiple", file_types=["pdf"], interactive=True) db_btn = gr.Button("Create Vector Database", elem_classes="primary-button") db_progress = gr.Textbox(value="Not initialized", show_label=False, elem_classes="status status-waiting") with gr.Group(elem_classes="card"): gr.HTML(model_html) llm_btn = gr.Radio(list_llm_simple, label="", value=list_llm_simple[0], type="index") with gr.Accordion("Advanced Parameters", open=False): slider_temperature = gr.Slider(minimum=0.01, maximum=1.0, value=0.5, step=0.1, label="Temperature", interactive=True, elem_classes="parameter-slider") slider_maxtokens = gr.Slider(minimum=128, maximum=9192, value=4096, step=128, label="Max Tokens", interactive=True, elem_classes="parameter-slider") slider_topk = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Top-K", interactive=True, elem_classes="parameter-slider") qachain_btn = gr.Button("Initialize Chatbot", elem_classes="primary-button") llm_progress = gr.Textbox(value="Not initialized", show_label=False, elem_classes="status status-waiting") with gr.Group(elem_classes="card"): gr.Markdown("### Usage Instructions") gr.Markdown(""" 1. Upload one or more PDF documents 2. Click "Create Vector Database" 3. Select your preferred AI model 4. Click "Initialize Chatbot" 5. Start asking questions about your documents **Note:** The system will analyze your documents and use AI to answer questions based on their content. """) # Right column - Chat with gr.Column(scale=1.5): with gr.Group(elem_classes="card"): gr.HTML(chat_html) language_selector = gr.Radio(["English", "Português"], label="Response Language", value="English") chatbot = gr.Chatbot(height=400, elem_classes="chat-container") with gr.Row(): with gr.Column(scale=4): msg = gr.Textbox(placeholder="Ask a question about your documents...", show_label=False) with gr.Column(scale=1): submit_btn = gr.Button("Send", elem_classes="primary-button") with gr.Row(): clear_btn = gr.Button("Clear Chat", scale=1) with gr.Group(elem_classes="card"): gr.HTML(reference_html) with gr.Accordion("Document References", open=True): # Reference 1 gr.Markdown("**Reference 1**", elem_classes="reference-box-title") with gr.Row(): doc_source1 = gr.Textbox(show_label=False, lines=2, elem_classes="reference-box") source1_page = gr.Number(label="Page", show_label=True, elem_classes="page-number") # Reference 2 gr.Markdown("**Reference 2**", elem_classes="reference-box-title") with gr.Row(): doc_source2 = gr.Textbox(show_label=False, lines=2, elem_classes="reference-box") source2_page = gr.Number(label="Page", show_label=True, elem_classes="page-number") # Reference 3 gr.Markdown("**Reference 3**", elem_classes="reference-box-title") with gr.Row(): doc_source3 = gr.Textbox(show_label=False, lines=2, elem_classes="reference-box") source3_page = gr.Number(label="Page", show_label=True, elem_classes="page-number") # Preprocessing events db_btn.click(initialize_database, inputs=[document], outputs=[vector_db, db_progress]) qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], outputs=[qa_chain, llm_progress]).then( lambda: [None, "", 0, "", 0, "", 0], inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False ) # Chatbot events msg.submit(conversation, inputs=[qa_chain, msg, chatbot, language_selector], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False) submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot, language_selector], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False) clear_btn.click(lambda: [None, "", 0, "", 0, "", 0], inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False) demo.queue().launch(debug=True) if __name__ == "__main__": demo()