Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import time | |
from datetime import datetime | |
from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader | |
from langchain_community.vectorstores import Chroma | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_huggingface import HuggingFaceEndpoint, HuggingFaceEmbeddings | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain.memory import ConversationBufferMemory | |
from pptx import Presentation | |
from io import BytesIO | |
# Environment setup for Hugging Face token | |
os.environ["HUGGINGFACEHUB_API_TOKEN"] = os.getenv("HUGGINGFACEHUB_API_TOKEN", "your-hf-token-here") | |
# Model and embedding options | |
LLM_MODELS = { | |
"Lightweight (Gemma-2B)": "google/gemma-2b-it", | |
"Balanced (Mixtral-8x7B)": "mistralai/Mixtral-8x7B-Instruct-v0.1", | |
"High Accuracy (Llama-3-8B)": "meta-llama/Llama-3-8b-hf" | |
} | |
EMBEDDING_MODELS = { | |
"Lightweight (MiniLM-L6)": "sentence-transformers/all-MiniLM-L6-v2", | |
"Balanced (MPNet-Base)": "sentence-transformers/all-mpnet-base-v2", | |
"High Accuracy (BGE-Large)": "BAAI/bge-large-en-v1.5" | |
} | |
# Global state | |
vector_store = None | |
qa_chain = None | |
chat_history = [] | |
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) | |
# Custom PPTX loader | |
class PPTXLoader: | |
def __init__(self, file_path): | |
self.file_path = file_path | |
def load(self): | |
docs = [] | |
with open(self.file_path, "rb") as f: | |
prs = Presentation(BytesIO(f.read())) | |
for slide_num, slide in enumerate(prs.slides, 1): | |
text = "" | |
for shape in slide.shapes: | |
if hasattr(shape, "text"): | |
text += shape.text + "\n" | |
if text.strip(): | |
docs.append({"page_content": text, "metadata": {"source": self.file_path, "slide": slide_num}}) | |
return docs | |
# Function to load documents | |
def load_documents(files): | |
documents = [] | |
for file in files: | |
file_path = file.name | |
if file_path.endswith(".pdf"): | |
loader = PyPDFLoader(file_path) | |
documents.extend(loader.load()) | |
elif file_path.endswith(".txt"): | |
loader = TextLoader(file_path) | |
documents.extend(loader.load()) | |
elif file_path.endswith(".docx"): | |
loader = Docx2txtLoader(file_path) | |
documents.extend(loader.load()) | |
elif file_path.endswith(".pptx"): | |
loader = PPTXLoader(file_path) | |
documents.extend([{"page_content": doc["page_content"], "metadata": doc["metadata"]} for doc in loader.load()]) | |
return documents | |
# Function to process documents and create vector store | |
def process_documents(files, chunk_size, chunk_overlap, embedding_model): | |
global vector_store, qa_chain | |
if not files: | |
return "Please upload at least one document.", None | |
# Load documents | |
documents = load_documents(files) | |
if not documents: | |
return "No valid documents loaded.", None | |
# Split documents | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=int(chunk_size), | |
chunk_overlap=int(chunk_overlap), | |
length_function=len | |
) | |
doc_splits = text_splitter.split_documents(documents) | |
# Create embeddings | |
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODELS[embedding_model]) | |
# Create vector store | |
try: | |
vector_store = Chroma.from_documents(doc_splits, embeddings, persist_directory="./chroma_db") | |
return f"Processed {len(documents)} documents into {len(doc_splits)} chunks.", None | |
except Exception as e: | |
return f"Error processing documents: {str(e)}", None | |
# Function to initialize QA chain | |
def initialize_qa_chain(llm_model, temperature): | |
global qa_chain | |
try: | |
llm = HuggingFaceEndpoint( | |
repo_id=LLM_MODELS[llm_model], | |
temperature=float(temperature), | |
max_length=512, | |
huggingfacehub_api_token=os.environ["HUGGINGFACEHUB_API_TOKEN"] | |
) | |
qa_chain = ConversationalRetrievalChain.from_llm( | |
llm=llm, | |
retriever=vector_store.as_retriever(search_kwargs={"k": 3}), | |
memory=memory | |
) | |
return "QA chain initialized successfully.", None | |
except Exception as e: | |
return f"Error initializing QA chain: {str(e)}", None | |
# Function to handle user query | |
def answer_question(question, llm_model, embedding_model, temperature, chunk_size, chunk_overlap): | |
global chat_history | |
if not vector_store or not qa_chain: | |
return "Please upload documents and initialize the QA chain.", chat_history | |
try: | |
response = qa_chain({"question": question})["answer"] | |
chat_history.append(("User", question)) | |
chat_history.append(("Bot", response)) | |
return response, chat_history | |
except Exception as e: | |
return f"Error answering question: {str(e)}", chat_history | |
# Function to export chat history | |
def export_chat(): | |
if not chat_history: | |
return "No chat history to export.", None | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
filename = f"chat_history_{timestamp}.txt" | |
with open(filename, "w") as f: | |
for role, message in chat_history: | |
f.write(f"{role}: {message}\n\n") | |
return f"Chat history exported to {filename}.", filename | |
# Function to reset the app | |
def reset_app(): | |
global vector_store, qa_chain, chat_history, memory | |
vector_store = None | |
qa_chain = None | |
chat_history = [] | |
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) | |
if os.path.exists("./chroma_db"): | |
import shutil | |
shutil.rmtree("./chroma_db") | |
return "App reset successfully.", None | |
# Gradio interface | |
with gr.Blocks(theme=gr.themes.Soft(), title="DocTalk: Document Q&A Chatbot") as demo: | |
gr.Markdown("# DocTalk: Document Q&A Chatbot") | |
gr.Markdown("Upload documents (PDF, TXT, DOCX, PPTX), select models, tune parameters, and ask questions!") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
file_upload = gr.Files(label="Upload Documents", file_types=[".pdf", ".txt", ".docx", ".pptx"]) | |
with gr.Row(): | |
process_button = gr.Button("Process Documents") | |
reset_button = gr.Button("Reset App") | |
status = gr.Textbox(label="Status", interactive=False) | |
with gr.Column(scale=1): | |
llm_model = gr.Dropdown(choices=list(LLM_MODELS.keys()), label="Select LLM Model", value="Lightweight (Gemma-2B)") | |
embedding_model = gr.Dropdown(choices=list(EMBEDDING_MODELS.keys()), label="Select Embedding Model", value="Lightweight (MiniLM-L6)") | |
temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.7, label="Temperature") | |
chunk_size = gr.Slider(minimum=500, maximum=2000, step=100, value=1000, label="Chunk Size") | |
chunk_overlap = gr.Slider(minimum=0, maximum=500, step=50, value=100, label="Chunk Overlap") | |
init_button = gr.Button("Initialize QA Chain") | |
gr.Markdown("## Chat Interface") | |
question = gr.Textbox(label="Ask a Question", placeholder="Type your question here...") | |
answer = gr.Textbox(label="Answer", interactive=False) | |
chat_display = gr.Chatbot(label="Chat History") | |
export_button = gr.Button("Export Chat History") | |
export_file = gr.File(label="Exported Chat File") | |
# Event handlers | |
process_button.click( | |
fn=process_documents, | |
inputs=[file_upload, chunk_size, chunk_overlap, embedding_model], | |
outputs=[status, chat_display] | |
) | |
init_button.click( | |
fn=initialize_qa_chain, | |
inputs=[llm_model, temperature], | |
outputs=[status, chat_display] | |
) | |
question.submit( | |
fn=answer_question, | |
inputs=[question, llm_model, embedding_model, temperature, chunk_size, chunk_overlap], | |
outputs=[answer, chat_display] | |
) | |
export_button.click( | |
fn=export_chat, | |
outputs=[status, export_file] | |
) | |
reset_button.click( | |
fn=reset_app, | |
outputs=[status, chat_display] | |
) | |
demo.launch() |