import gradio as gr import os import torch from langchain_community.vectorstores import FAISS, Chroma from langchain_community.document_loaders import PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.chains import ConversationalRetrievalChain from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.llms import HuggingFaceEndpoint from langchain.memory import ConversationBufferMemory from langchain_community.retrievers import BM25Retriever, EnsembleRetriever # Corrected import #from langchain.chains.query_constructor.base import AttributeInfo # Removed deprecated code #from langchain.chains import create_query_chain # Removed deprecated code #from langchain.retrievers.self_query.base import SelfQueryRetriever # Removed deprecated code #from langchain.chains.query_constructor.schema import FieldInfo # Removed deprecated code from langchain.retrievers.multi_query import MultiQueryRetriever api_token = os.getenv("FirstToken") # Available LLM models list_llm = [ "meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2", "deepseek-ai/deepseek-llm-7b-chat" ] list_llm_simple = [os.path.basename(llm) for llm in list_llm] # ----------------------------------------------------------------------------- # Document Loading and Splitting # ----------------------------------------------------------------------------- def load_doc(list_file_path): """Load and split PDF documents into chunks.""" loaders = [PyPDFLoader(x) for x in list_file_path] pages = [] for loader in loaders: pages.extend(loader.load()) text_splitter = RecursiveCharacterTextSplitter( chunk_size=1024, chunk_overlap=64 ) doc_splits = text_splitter.split_documents(pages) return doc_splits # ----------------------------------------------------------------------------- # Vector Database Creation (ChromaDB and FAISS) # ----------------------------------------------------------------------------- def create_chromadb(splits, persist_directory="chroma_db"): """Create ChromaDB vector database from document splits.""" embeddings = HuggingFaceEmbeddings() chromadb = Chroma.from_documents( documents=splits, embedding=embeddings, persist_directory=persist_directory ) chromadb.persist() # Ensure data is written to disk return chromadb def create_faissdb(splits): """Create FAISS vector database from document splits.""" embeddings = HuggingFaceEmbeddings() faissdb = FAISS.from_documents(splits, embeddings) return faissdb # ----------------------------------------------------------------------------- # BM25 Retriever # ----------------------------------------------------------------------------- def create_bm25_retriever(splits): """Create BM25 retriever from document splits.""" bm25_retriever = BM25Retriever.from_documents(splits) bm25_retriever.k = 3 # Number of documents to retrieve return bm25_retriever # ----------------------------------------------------------------------------- # MultiQueryRetriever # ----------------------------------------------------------------------------- def create_multi_query_retriever(llm, vector_db, num_queries=3): """ Create a MultiQueryRetriever. Args: llm: The language model to use for query generation. vector_db: The vector database to retrieve from. num_queries: The number of diverse queries to generate. Returns: A MultiQueryRetriever instance. """ retriever = MultiQueryRetriever.from_llm( llm=llm, retriever=vector_db.as_retriever(), output_key="answer", memory_key="chat_history", return_messages=True, verbose=False ) return retriever # ----------------------------------------------------------------------------- # Ensemble Retriever (Combine VectorDB and BM25) # ----------------------------------------------------------------------------- def create_ensemble_retriever(vector_db, bm25_retriever): """Create an ensemble retriever combining ChromaDB and BM25.""" ensemble_retriever = EnsembleRetriever( retrievers=[vector_db.as_retriever(), bm25_retriever], weights=[0.7, 0.3] # Adjust weights as needed ) return ensemble_retriever # ----------------------------------------------------------------------------- # Initialize Database # ----------------------------------------------------------------------------- def initialize_database(list_file_obj, progress=gr.Progress()): """Initialize the document database.""" list_file_path = [x.name for x in list_file_obj if x is not None] doc_splits = load_doc(list_file_path) # Create vector databases and retrievers chromadb = create_chromadb(doc_splits) bm25_retriever = create_bm25_retriever(doc_splits) # Create ensemble retriever ensemble_retriever = create_ensemble_retriever(chromadb, bm25_retriever) return ensemble_retriever, "Database created successfully!" # ----------------------------------------------------------------------------- # Initialize LLM Chain # ----------------------------------------------------------------------------- def initialize_llmchain(llm_model, temperature, max_tokens, top_k, retriever, progress=gr.Progress()): """Initialize the language model chain.""" llm = HuggingFaceEndpoint( repo_id=llm_model, huggingfacehub_api_token=api_token, temperature=temperature, max_new_tokens=max_tokens, top_k=top_k, task="text-generation" ) memory = ConversationBufferMemory( memory_key="chat_history", output_key='answer', return_messages=True ) qa_chain = ConversationalRetrievalChain.from_llm( llm, retriever=retriever, chain_type="stuff", memory=memory, return_source_documents=True, verbose=False, ) return qa_chain # ----------------------------------------------------------------------------- # Initialize LLM # ----------------------------------------------------------------------------- def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, retriever, progress=gr.Progress()): """Initialize the Language Model.""" llm_name = list_llm[llm_option] print("Selected LLM model:", llm_name) qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, retriever, progress) return qa_chain, "Analysis Assistant initialized and ready!" # ----------------------------------------------------------------------------- # Chat History Formatting # ----------------------------------------------------------------------------- def format_chat_history(message, chat_history): """Format chat history for the model.""" formatted_chat_history = [] for user_message, bot_message in chat_history: formatted_chat_history.append(f"User: {user_message}") formatted_chat_history.append(f"Assistant: {bot_message}") return formatted_chat_history # ----------------------------------------------------------------------------- # Conversation Function # ----------------------------------------------------------------------------- def conversation(qa_chain, message, history, lang): """Handle conversation and document analysis.""" # Add language instruction to the message if lang == "pt": message += " (Responda em Português)" else: message += " (Respond in English)" formatted_chat_history = format_chat_history(message, history) response = qa_chain.invoke({"question": message, "chat_history": formatted_chat_history}) response_answer = response["answer"] # Remove the language instruction from the chat history if "(Respond" in message: message = message.split(" (Respond")[0] if response_answer.find("Helpful Answer:") != -1: response_answer = response_answer.split("Helpful Answer:")[-1] response_sources = response["source_documents"] response_source1 = response_sources[0].page_content.strip() response_source2 = response_sources[1].page_content.strip() response_source3 = response_sources[2].page_content.strip() response_source1_page = response_sources[0].metadata["page"] + 1 response_source2_page = response_sources[1].metadata["page"] + 1 response_source3_page = response_sources[2].metadata["page"] + 1 new_history = history + [(message, response_answer)] return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page # ----------------------------------------------------------------------------- # Gradio Demo # ----------------------------------------------------------------------------- def demo(): """Main demo application with enhanced layout.""" theme = gr.themes.Default( primary_hue="indigo", secondary_hue="blue", neutral_hue="slate", ) # Custom CSS for advanced layout custom_css = """ .container {background: #ffffff; padding: 1rem; border-radius: 8px; box-shadow: 0 1px 3px rgba(0,0,0,0.1);} .header {text-align: center; margin-bottom: 2rem;} .header h1 {color: #1a365d; font-size: 2.5rem; margin-bottom: 0.5rem;} .header p {color: #4a5568; font-size: 1.2rem;} .section {margin-bottom: 1.5rem; padding: 1rem; background: #f8fafc; border-radius: 8px;} .control-panel {margin-bottom: 1rem;} .chat-area {background: white; padding: 1rem; border-radius: 8px;} """ with gr.Blocks(theme=theme, css=custom_css) as demo: retriever = gr.State() qa_chain = gr.State() language = gr.State(value="en") # State for language control # Header gr.HTML( """

MetroAssist AI

Expert System for Metrology Report Analysis

""" ) with gr.Row(): # Left Column - Controls with gr.Column(scale=1): gr.Markdown("## Document Processing") # File Upload Section with gr.Column(elem_classes="section"): gr.Markdown("### 📄 Upload Documents") document = gr.Files( label="Metrology Reports (PDF)", file_count="multiple", file_types=["pdf"] ) db_btn = gr.Button("Process Documents") db_progress = gr.Textbox( value="Ready for documents", label="Processing Status" ) # Model Selection Section with gr.Column(elem_classes="section"): gr.Markdown("### 🤖 Model Configuration") llm_btn = gr.Radio( choices=list_llm_simple, label="Select AI Model", value=list_llm_simple[0], type="index" ) # Language selection button language_btn = gr.Radio( choices=["English", "Português"], label="Response Language", value="English", type="value" ) with gr.Accordion("Advanced Settings", open=False): slider_temperature = gr.Slider( minimum=0.01, maximum=1.0, value=0.5, step=0.1, label="Analysis Precision" ) slider_maxtokens = gr.Slider( minimum=128, maximum=9192, value=4096, step=128, label="Response Length" ) slider_topk = gr.Slider( minimum=1, maximum=10, value=3, step=1, label="Analysis Diversity" ) qachain_btn = gr.Button("Initialize Assistant") llm_progress = gr.Textbox( value="Not initialized", label="Assistant Status" ) # Right Column - Chat Interface with gr.Column(scale=2): gr.Markdown("## Interactive Analysis") # Features Section with gr.Row(): with gr.Column(): gr.Markdown( """ ### 📊 Capabilities - Calibration Analysis - Standards Compliance - Uncertainty Evaluation """ ) with gr.Column(): gr.Markdown( """ ### 💡 Best Practices - Ask specific questions - Include measurement context - Specify standards """ ) # Chat Interface with gr.Column(elem_classes="chat-area"): chatbot = gr.Chatbot( height=400, label="Analysis Conversation" ) with gr.Row(): msg = gr.Textbox( placeholder="Ask about your metrology report...", label="Query" ) submit_btn = gr.Button("Send") clear_btn = gr.ClearButton( [msg, chatbot], value="Clear" ) # References Section with gr.Accordion("Document References", open=False): with gr.Row(): with gr.Column(): doc_source1 = gr.Textbox(label="Reference 1", lines=2) source1_page = gr.Number(label="Page") with gr.Column(): doc_source2 = gr.Textbox(label="Reference 2", lines=2) source2_page = gr.Number(label="Page") with gr.Column(): doc_source3 = gr.Textbox(label="Reference 3", lines=2) source3_page = gr.Number(label="Page") # Footer gr.Markdown( """ --- ### About MetroAssist AI A specialized tool for metrology professionals, providing advanced analysis of calibration certificates, measurement data, and technical standards compliance. **Version 1.0** | © 2024 MetroAssist AI """ ) # Event Handlers language_btn.change( lambda x: "en" if x == "English" else "pt", inputs=language_btn, outputs=language ) db_btn.click( initialize_database, inputs=[document], outputs=[retriever, db_progress] ) qachain_btn.click( initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, retriever], outputs=[qa_chain, llm_progress] ).then( lambda: [None, "", 0, "", 0, "", 0], inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False ) msg.submit( conversation, inputs=[qa_chain, msg, chatbot, language], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False ) submit_btn.click( conversation, inputs=[qa_chain, msg, chatbot, language], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False ) clear_btn.click( lambda: [None, "", 0, "", 0, "", 0], inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False ) demo.queue().launch(debug=True) if __name__ == "__main__": demo()