Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import time | |
from datetime import datetime | |
from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader | |
from langchain_community.vectorstores import Chroma | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_huggingface import HuggingFaceEndpoint, HuggingFaceEmbeddings | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain.memory import ConversationBufferMemory | |
from pptx import Presentation | |
from io import BytesIO | |
import shutil | |
import logging | |
import chromadb | |
import tempfile | |
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type | |
import requests | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Environment setup for Hugging Face token | |
os.environ["HUGGINGFACEHUB_API_TOKEN"] = os.getenv("HUGGINGFACEHUB_API_TOKEN", "default-token") | |
if os.environ["HUGGINGFACEHUB_API_TOKEN"] == "default-token": | |
logger.warning("HUGGINGFACEHUB_API_TOKEN not set. Some models may not work.") | |
# Model and embedding options | |
LLM_MODELS = { | |
"Balanced (Mixtral-8x7B)": "mistralai/Mixtral-8x7B-Instruct-v0.1", | |
"Lightweight (Gemma-2B)": "google/gemma-2b-it", | |
"High Accuracy (Llama-3-8B)": "meta-llama/Llama-3-8b-hf" | |
} | |
EMBEDDING_MODELS = { | |
"Lightweight (MiniLM-L6)": "sentence-transformers/all-MiniLM-L6-v2", | |
"Balanced (MPNet-Base)": "sentence-transformers/all-mpnet-base-v2", | |
"High Accuracy (BGE-Large)": "BAAI/bge-large-en-v1.5" | |
} | |
# Global state | |
vector_store = None | |
qa_chain = None | |
chat_history = [] | |
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) | |
PERSIST_DIRECTORY = tempfile.mkdtemp() # Use temporary directory for ChromaDB | |
# Custom PPTX loader | |
class PPTXLoader: | |
def __init__(self, file_path): | |
self.file_path = file_path | |
def load(self): | |
docs = [] | |
try: | |
with open(self.file_path, "rb") as f: | |
prs = Presentation(BytesIO(f.read())) | |
for slide_num, slide in enumerate(prs.slides, 1): | |
text = "" | |
for shape in slide.shapes: | |
if hasattr(shape, "text") and shape.text: | |
text += shape.text + "\n" | |
if text.strip(): | |
docs.append({"page_content": text, "metadata": {"source": self.file_path, "slide": slide_num}}) | |
except Exception as e: | |
logger.error(f"Error loading PPTX {self.file_path}: {str(e)}") | |
return [] | |
return docs | |
# Function to load documents | |
def load_documents(files): | |
documents = [] | |
for file in files: | |
try: | |
file_path = file.name | |
logger.info(f"Loading file: {file_path}") | |
if file_path.endswith(".pdf"): | |
loader = PyPDFLoader(file_path) | |
documents.extend(loader.load()) | |
elif file_path.endswith(".txt"): | |
loader = TextLoader(file_path) | |
documents.extend(loader.load()) | |
elif file_path.endswith(".docx"): | |
loader = Docx2txtLoader(file_path) | |
documents.extend(loader.load()) | |
elif file_path.endswith(".pptx"): | |
loader = PPTXLoader(file_path) | |
documents.extend([{"page_content": doc["page_content"], "metadata": doc["metadata"]} for doc in loader.load()]) | |
except Exception as e: | |
logger.error(f"Error loading file {file_path}: {str(e)}") | |
continue | |
return documents | |
# Function to process documents and create vector store | |
def process_documents(files, chunk_size, chunk_overlap, embedding_model): | |
global vector_store | |
if not files: | |
return "Please upload at least one document.", None | |
# Clear existing vector store | |
if os.path.exists(PERSIST_DIRECTORY): | |
try: | |
shutil.rmtree(PERSIST_DIRECTORY) | |
logger.info("Cleared existing ChromaDB directory.") | |
except Exception as e: | |
logger.error(f"Error clearing ChromaDB directory: {str(e)}") | |
return f"Error clearing vector store: {str(e)}", None | |
os.makedirs(PERSIST_DIRECTORY, exist_ok=True) | |
# Load documents | |
documents = load_documents(files) | |
if not documents: | |
return "No valid documents loaded. Check file formats or content.", None | |
# Split documents | |
try: | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=int(chunk_size), | |
chunk_overlap=int(chunk_overlap), | |
length_function=len | |
) | |
doc_splits = text_splitter.split_documents(documents) | |
logger.info(f"Split {len(documents)} documents into {len(doc_splits)} chunks.") | |
except Exception as e: | |
logger.error(f"Error splitting documents: {str(e)}") | |
return f"Error splitting documents: {str(e)}", None | |
# Create embeddings | |
try: | |
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODELS[embedding_model]) | |
except Exception as e: | |
logger.error(f"Error initializing embeddings for {embedding_model}: {str(e)}") | |
return f"Error initializing embeddings: {str(e)}", None | |
# Create vector store | |
try: | |
# Use in-memory Chroma client to avoid filesystem issues | |
collection_name = f"doctalk_collection_{int(time.time())}" | |
client = chromadb.Client() | |
vector_store = Chroma.from_documents( | |
documents=doc_splits, | |
embedding=embeddings, | |
collection_name=collection_name | |
) | |
return f"Processed {len(documents)} documents into {len(doc_splits)} chunks.", None | |
except Exception as e: | |
logger.error(f"Error creating vector store: {str(e)}") | |
return f"Error creating vector store: {str(e)}", None | |
# Function to initialize QA chain with retry logic | |
def initialize_qa_chain(llm_model, temperature): | |
global qa_chain | |
if not vector_store: | |
return "Please process documents first.", None | |
try: | |
llm = HuggingFaceEndpoint( | |
repo_id=LLM_MODELS[llm_model], | |
task="text-generation", | |
temperature=float(temperature), | |
max_new_tokens=512, | |
huggingfacehub_api_token=os.environ["HUGGINGFACEHUB_API_TOKEN"], | |
timeout=30 | |
) | |
# Dynamically set k based on vector store size | |
collection = vector_store._collection | |
doc_count = collection.count() | |
k = min(3, doc_count) if doc_count > 0 else 1 | |
qa_chain = ConversationalRetrievalChain.from_llm( | |
llm=llm, | |
retriever=vector_store.as_retriever(search_kwargs={"k": k}), | |
memory=memory | |
) | |
logger.info(f"Initialized QA chain with {llm_model} and k={k}.") | |
return "QA Doctor: QA chain initialized successfully.", None | |
except requests.exceptions.HTTPError as e: | |
logger.error(f"HTTP error initializing QA chain for {llm_model}: {str(e)}") | |
if "503" in str(e): | |
return f"Error: Hugging Face API temporarily unavailable for {llm_model}. Try 'Balanced (Mixtral-8x7B)' or wait and retry.", None | |
elif "403" in str(e): | |
return f"Error: Access denied for {llm_model}. Ensure your HF token has access.", None | |
return f"Error initializing QA chain: {str(e)}.", None | |
except Exception as e: | |
logger.error(f"Error initializing QA chain for {llm_model}: {str(e)}") | |
return f"Error initializing QA chain: {str(e)}. Ensure your HF token has access to {llm_model}.", None | |
# Function to handle user query with retry logic | |
def answer_question(question, llm_model, embedding_model, temperature, chunk_size, chunk_overlap): | |
global chat_history | |
if not vector_store: | |
return "Please process documents first.", chat_history | |
if not qa_chain: | |
return "Please initialize the QA chain.", chat_history | |
if not question.strip(): | |
return "Please enter a valid question.", chat_history | |
try: | |
response = qa_chain.invoke({"question": question})["answer"] | |
chat_history.append({"role": "user", "content": question}) | |
chat_history.append({"role": "assistant", "content": response}) | |
logger.info(f"Answered question: {question}") | |
return response, chat_history | |
except requests.exceptions.HTTPError as e: | |
logger.error(f"HTTP error answering question: {str(e)}") | |
if "503" in str(e): | |
return f"Error: Hugging Face API temporarily unavailable for {llm_model}. Try 'Balanced (Mixtral-8x7B)' or wait and retry.", chat_history | |
elif "403" in str(e): | |
return f"Error: Access denied for {llm_model}. Ensure your HF token has access.", chat_history | |
return f"Error answering question: {str(e)}", chat_history | |
except Exception as e: | |
logger.error(f"Error answering question: {str(e)}") | |
return f"Error answering question: {str(e)}", chat_history | |
# Function to export chat history | |
def export_chat(): | |
if not chat_history: | |
return "No chat history to export.", None | |
try: | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
filename = f"chat_history_{timestamp}.txt" | |
with open(filename, "w") as f: | |
for message in chat_history: | |
role = message["role"].capitalize() | |
content = message["content"] | |
f.write(f"{role}: {content}\n\n") | |
logger.info(f"Exported chat history to {filename}.") | |
return f"Chat history exported to {filename}.", filename | |
except Exception as e: | |
logger.error(f"Error exporting chat history: {str(e)}") | |
return f"Error exporting chat history: {str(e)}", None | |
# Function to reset the app | |
def reset_app(): | |
global vector_store, qa_chain, chat_history, memory | |
try: | |
vector_store = None | |
qa_chain = None | |
chat_history = [] | |
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) | |
if os.path.exists(PERSIST_DIRECTORY): | |
shutil.rmtree(PERSIST_DIRECTORY) | |
os.makedirs(PERSIST_DIRECTORY, exist_ok=True) | |
logger.info("Cleared ChromaDB directory on reset.") | |
logger.info("App reset successfully.") | |
return "App reset successfully.", None | |
except Exception as e: | |
logger.error(f"Error resetting app: {str(e)}") | |
return f"Error resetting app: {str(e)}", None | |
# Gradio interface | |
with gr.Blocks(theme=gr.themes.Soft(), title="DocTalk: Document Q&A Chatbot") as demo: | |
gr.Markdown("# DocTalk: Document Q&A Chatbot") | |
gr.Markdown("Upload documents (PDF, TXT, DOCX, PPTX), select models, tune parameters, and ask questions!") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
file_upload = gr.Files(label="Upload Documents", file_types=[".pdf", ".txt", ".docx", ".pptx"]) | |
with gr.Row(): | |
process_button = gr.Button("Process Documents") | |
reset_button = gr.Button("Reset App") | |
status = gr.Textbox(label="Status", interactive=False) | |
with gr.Column(scale=1): | |
llm_model = gr.Dropdown(choices=list(LLM_MODELS.keys()), label="Select LLM Model", value="Balanced (Mixtral-8x7B)") | |
embedding_model = gr.Dropdown(choices=list(EMBEDDING_MODELS.keys()), label="Select Embedding Model", value="Lightweight (MiniLM-L6)") | |
temperature = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.7, label="Temperature") | |
chunk_size = gr.Slider(minimum=500, maximum=2000, step=100, value=1000, label="Chunk Size") | |
chunk_overlap = gr.Slider(minimum=0, maximum=500, step=50, value=100, label="Chunk Overlap") | |
init_button = gr.Button("Initialize QA Chain") | |
gr.Markdown("## Chat Interface") | |
question = gr.Textbox(label="Ask a Question", placeholder="Type your question here...") | |
answer = gr.Textbox(label="Answer", interactive=False) | |
chat_display = gr.Chatbot(label="Chat History", type="messages") | |
export_button = gr.Button("Export Chat History") | |
export_file = gr.File(label="Exported Chat File") | |
# Event handlers | |
process_button.click( | |
fn=process_documents, | |
inputs=[file_upload, chunk_size, chunk_overlap, embedding_model], | |
outputs=[status, chat_display] | |
) | |
init_button.click( | |
fn=initialize_qa_chain, | |
inputs=[llm_model, temperature], | |
outputs=[status, chat_display] | |
) | |
question.submit( | |
fn=answer_question, | |
inputs=[question, llm_model, embedding_model, temperature, chunk_size, chunk_overlap], | |
outputs=[answer, chat_display] | |
) | |
export_button.click( | |
fn=export_chat, | |
outputs=[status, export_file] | |
) | |
reset_button.click( | |
fn=reset_app, | |
outputs=[status, chat_display] | |
) | |
demo.launch() |