import gradio as gr import os import time from datetime import datetime from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader from langchain_community.vectorstores import Chroma from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_huggingface import HuggingFaceEndpoint, HuggingFaceEmbeddings from langchain.chains import ConversationalRetrievalChain from langchain.memory import ConversationBufferMemory from pptx import Presentation from io import BytesIO import shutil import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Environment setup for Hugging Face token os.environ["HUGGINGFACEHUB_API_TOKEN"] = os.getenv("HUGGINGFACEHUB_API_TOKEN", "default-token") if os.environ["HUGGINGFACEHUB_API_TOKEN"] == "default-token": logger.warning("HUGGINGFACEHUB_API_TOKEN not set. Some models may not work.") # Model and embedding options LLM_MODELS = { "Lightweight (Gemma-2B)": "google/gemma-2b-it", "Balanced (Mixtral-8x7B)": "mistralai/Mixtral-8x7B-Instruct-v0.1", "High Accuracy (Llama-3-8B)": "meta-llama/Llama-3-8b-hf" } EMBEDDING_MODELS = { "Lightweight (MiniLM-L6)": "sentence-transformers/all-MiniLM-L6-v2", "Balanced (MPNet-Base)": "sentence-transformers/all-mpnet-base-v2", "High Accuracy (BGE-Large)": "BAAI/bge-large-en-v1.5" } # Global state vector_store = None qa_chain = None chat_history = [] memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) PERSIST_DIRECTORY = "./chroma_db" # Custom PPTX loader class PPTXLoader: def __init__(self, file_path): self.file_path = file_path def load(self): docs = [] try: with open(self.file_path, "rb") as f: prs = Presentation(BytesIO(f.read())) for slide_num, slide in enumerate(prs.slides, 1): text = "" for shape in slide.shapes: if hasattr(shape, "text") and shape.text: text += shape.text + "\n" if text.strip(): docs.append({"page_content": text, "metadata": {"source": self.file_path, "slide": slide_num}}) except Exception as e: logger.error(f"Error loading PPTX {self.file_path}: {str(e)}") return [] return docs # Function to load documents def load_documents(files): documents = [] for file in files: try: file_path = file.name logger.info(f"Loading file: {file_path}") if file_path.endswith(".pdf"): loader = PyPDFLoader(file_path) documents.extend(loader.load()) elif file_path.endswith(".txt"): loader = TextLoader(file_path) documents.extend(loader.load()) elif file_path.endswith(".docx"): loader = Docx2txtLoader(file_path) documents.extend(loader.load()) elif file_path.endswith(".pptx"): loader = PPTXLoader(file_path) documents.extend([{"page_content": doc["page_content"], "metadata": doc["metadata"]} for doc in loader.load()]) except Exception as e: logger.error(f"Error loading file {file_path}: {str(e)}") continue return documents # Function to process documents and create vector store def process_documents(files, chunk_size, chunk_overlap, embedding_model): global vector_store if not files: return "Please upload at least one document.", None # Clear existing vector store to avoid dimensionality mismatch if os.path.exists(PERSIST_DIRECTORY): try: shutil.rmtree(PERSIST_DIRECTORY) logger.info("Cleared existing ChromaDB directory.") except Exception as e: logger.error(f"Error clearing ChromaDB directory: {str(e)}") return f"Error clearing vector store: {str(e)}", None # Load documents documents = load_documents(files) if not documents: return "No valid documents loaded. Check file formats or content.", None # Split documents try: text_splitter = RecursiveCharacterTextSplitter( chunk_size=int(chunk_size), chunk_overlap=int(chunk_overlap), length_function=len ) doc_splits = text_splitter.split_documents(documents) logger.info(f"Split {len(documents)} documents into {len(doc_splits)} chunks.") except Exception as e: logger.error(f"Error splitting documents: {str(e)}") return f"Error splitting documents: {str(e)}", None # Create embeddings try: embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODELS[embedding_model]) except Exception as e: logger.error(f"Error initializing embeddings for {embedding_model}: {str(e)}") return f"Error initializing embeddings: {str(e)}", None # Create vector store try: vector_store = Chroma.from_documents(doc_splits, embeddings, persist_directory=PERSIST_DIRECTORY) return f"Processed {len(documents)} documents into {len(doc_splits)} chunks.", None except Exception as e: logger.error(f"Error creating vector store: {str(e)}") return f"Error creating vector store: {str(e)}", None # Function to initialize QA chain def initialize_qa_chain(llm_model, temperature): global qa_chain if not vector_store: return "Please process documents first.", None try: llm = HuggingFaceEndpoint( repo_id=LLM_MODELS[llm_model], task="text-generation", temperature=float(temperature), max_new_tokens=512, huggingfacehub_api_token=os.environ["HUGGINGFACEHUB_API_TOKEN"] ) qa_chain = ConversationalRetrievalChain.from_llm( llm=llm, retriever=vector_store.as_retriever(search_kwargs={"k": 3}), memory=memory ) logger.info(f"Initialized QA chain with {llm_model}.") return "QA chain initialized successfully.", None except Exception as e: logger.error(f"Error initializing QA chain for {llm_model}: {str(e)}") return f"Error initializing QA chain: {str(e)}. Ensure your HF token has access to {llm_model}.", None # Function to handle user query def answer_question(question, llm_model, embedding_model, temperature, chunk_size, chunk_overlap): global chat_history if not vector_store: return "Please process documents first.", chat_history if not qa_chain: return "Please initialize the QA chain.", chat_history if not question.strip(): return "Please enter a valid question.", chat_history try: response = qa_chain({"question": question})["answer"] chat_history.append(("User", question)) chat_history.append(("Bot", response)) logger.info(f"Answered question: {question}") return response, chat_history except Exception as e: logger.error(f"Error answering question: {str(e)}") return f"Error answering question: {str(e)}", chat_history # Function to export chat history def export_chat(): if not chat_history: return "No chat history to export.", None try: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"chat_history_{timestamp}.txt" with open(filename, "w") as f: for role, message in chat_history: f.write(f"{role}: {message}\n\n") logger.info(f"Exported chat history to {filename}.") return f"Chat history exported to {filename}.", filename except Exception as e: logger.error(f"Error exporting chat history: {str(e)}") return f"Error exporting chat history: {str(e)}", None # Function to reset the app def reset_app(): global vector_store, qa_chain, chat_history, memory try: vector_store = None qa_chain = None chat_history = [] memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) if os.path.exists(PERSIST_DIRECTORY): shutil.rmtree(PERSIST_DIRECTORY) logger.info("Cleared ChromaDB directory on reset.") logger.info("App reset successfully.") return "App reset successfully.", None except Exception as e: logger.error(f"Error resetting app: {str(e)}") return f"Error resetting app: {str(e)}", None # Gradio interface with gr.Blocks(theme=gr.themes.Soft(), title="DocTalk: Document Q&A Chatbot") as demo: gr.Markdown("# DocTalk: Document Q&A Chatbot") gr.Markdown("Upload documents (PDF, TXT, DOCX, PPTX), select models, tune parameters, and ask questions!") with gr.Row(): with gr.Column(scale=2): file_upload = gr.Files(label="Upload Documents", file_types=[".pdf", ".txt", ".docx", ".pptx"]) with gr.Row(): process_button = gr.Button("Process Documents") reset_button = gr.Button("Reset App") status = gr.Textbox(label="Status", interactive=False) with gr.Column(scale=1): llm_model = gr.Dropdown(choices=list(LLM_MODELS.keys()), label="Select LLM Model", value="Lightweight (Gemma-2B)") embedding_model = gr.Dropdown(choices=list(EMBEDDING_MODELS.keys()), label="Select Embedding Model", value="Lightweight (MiniLM-L6)") temperature = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.7, label="Temperature") chunk_size = gr.Slider(minimum=500, maximum=2000, step=100, value=1000, label="Chunk Size") chunk_overlap = gr.Slider(minimum=0, maximum=500, step=50, value=100, label="Chunk Overlap") init_button = gr.Button("Initialize QA Chain") gr.Markdown("## Chat Interface") question = gr.Textbox(label="Ask a Question", placeholder="Type your question here...") answer = gr.Textbox(label="Answer", interactive=False) chat_display = gr.Chatbot(label="Chat History") export_button = gr.Button("Export Chat History") export_file = gr.File(label="Exported Chat File") # Event handlers process_button.click( fn=process_documents, inputs=[file_upload, chunk_size, chunk_overlap, embedding_model], outputs=[status, chat_display] ) init_button.click( fn=initialize_qa_chain, inputs=[llm_model, temperature], outputs=[status, chat_display] ) question.submit( fn=answer_question, inputs=[question, llm_model, embedding_model, temperature, chunk_size, chunk_overlap], outputs=[answer, chat_display] ) export_button.click( fn=export_chat, outputs=[status, export_file] ) reset_button.click( fn=reset_app, outputs=[status, chat_display] ) demo.launch()