Spaces:
Sleeping
Sleeping
File size: 8,187 Bytes
ca55784 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
import gradio as gr
import os
import time
from datetime import datetime
from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader
from langchain_community.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEndpoint, HuggingFaceEmbeddings
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from pptx import Presentation
from io import BytesIO
# Environment setup for Hugging Face token
os.environ["HUGGINGFACEHUB_API_TOKEN"] = os.getenv("HUGGINGFACEHUB_API_TOKEN", "your-hf-token-here")
# Model and embedding options
LLM_MODELS = {
"Lightweight (Gemma-2B)": "google/gemma-2b-it",
"Balanced (Mixtral-8x7B)": "mistralai/Mixtral-8x7B-Instruct-v0.1",
"High Accuracy (Llama-3-8B)": "meta-llama/Llama-3-8b-hf"
}
EMBEDDING_MODELS = {
"Lightweight (MiniLM-L6)": "sentence-transformers/all-MiniLM-L6-v2",
"Balanced (MPNet-Base)": "sentence-transformers/all-mpnet-base-v2",
"High Accuracy (BGE-Large)": "BAAI/bge-large-en-v1.5"
}
# Global state
vector_store = None
qa_chain = None
chat_history = []
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
# Custom PPTX loader
class PPTXLoader:
def __init__(self, file_path):
self.file_path = file_path
def load(self):
docs = []
with open(self.file_path, "rb") as f:
prs = Presentation(BytesIO(f.read()))
for slide_num, slide in enumerate(prs.slides, 1):
text = ""
for shape in slide.shapes:
if hasattr(shape, "text"):
text += shape.text + "\n"
if text.strip():
docs.append({"page_content": text, "metadata": {"source": self.file_path, "slide": slide_num}})
return docs
# Function to load documents
def load_documents(files):
documents = []
for file in files:
file_path = file.name
if file_path.endswith(".pdf"):
loader = PyPDFLoader(file_path)
documents.extend(loader.load())
elif file_path.endswith(".txt"):
loader = TextLoader(file_path)
documents.extend(loader.load())
elif file_path.endswith(".docx"):
loader = Docx2txtLoader(file_path)
documents.extend(loader.load())
elif file_path.endswith(".pptx"):
loader = PPTXLoader(file_path)
documents.extend([{"page_content": doc["page_content"], "metadata": doc["metadata"]} for doc in loader.load()])
return documents
# Function to process documents and create vector store
def process_documents(files, chunk_size, chunk_overlap, embedding_model):
global vector_store, qa_chain
if not files:
return "Please upload at least one document.", None
# Load documents
documents = load_documents(files)
if not documents:
return "No valid documents loaded.", None
# Split documents
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=int(chunk_size),
chunk_overlap=int(chunk_overlap),
length_function=len
)
doc_splits = text_splitter.split_documents(documents)
# Create embeddings
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODELS[embedding_model])
# Create vector store
try:
vector_store = Chroma.from_documents(doc_splits, embeddings, persist_directory="./chroma_db")
return f"Processed {len(documents)} documents into {len(doc_splits)} chunks.", None
except Exception as e:
return f"Error processing documents: {str(e)}", None
# Function to initialize QA chain
def initialize_qa_chain(llm_model, temperature):
global qa_chain
try:
llm = HuggingFaceEndpoint(
repo_id=LLM_MODELS[llm_model],
temperature=float(temperature),
max_length=512,
huggingfacehub_api_token=os.environ["HUGGINGFACEHUB_API_TOKEN"]
)
qa_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=vector_store.as_retriever(search_kwargs={"k": 3}),
memory=memory
)
return "QA chain initialized successfully.", None
except Exception as e:
return f"Error initializing QA chain: {str(e)}", None
# Function to handle user query
def answer_question(question, llm_model, embedding_model, temperature, chunk_size, chunk_overlap):
global chat_history
if not vector_store or not qa_chain:
return "Please upload documents and initialize the QA chain.", chat_history
try:
response = qa_chain({"question": question})["answer"]
chat_history.append(("User", question))
chat_history.append(("Bot", response))
return response, chat_history
except Exception as e:
return f"Error answering question: {str(e)}", chat_history
# Function to export chat history
def export_chat():
if not chat_history:
return "No chat history to export.", None
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"chat_history_{timestamp}.txt"
with open(filename, "w") as f:
for role, message in chat_history:
f.write(f"{role}: {message}\n\n")
return f"Chat history exported to {filename}.", filename
# Function to reset the app
def reset_app():
global vector_store, qa_chain, chat_history, memory
vector_store = None
qa_chain = None
chat_history = []
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
if os.path.exists("./chroma_db"):
import shutil
shutil.rmtree("./chroma_db")
return "App reset successfully.", None
# Gradio interface
with gr.Blocks(theme=gr.themes.Soft(), title="DocTalk: Document Q&A Chatbot") as demo:
gr.Markdown("# DocTalk: Document Q&A Chatbot")
gr.Markdown("Upload documents (PDF, TXT, DOCX, PPTX), select models, tune parameters, and ask questions!")
with gr.Row():
with gr.Column(scale=2):
file_upload = gr.Files(label="Upload Documents", file_types=[".pdf", ".txt", ".docx", ".pptx"])
with gr.Row():
process_button = gr.Button("Process Documents")
reset_button = gr.Button("Reset App")
status = gr.Textbox(label="Status", interactive=False)
with gr.Column(scale=1):
llm_model = gr.Dropdown(choices=list(LLM_MODELS.keys()), label="Select LLM Model", value="Lightweight (Gemma-2B)")
embedding_model = gr.Dropdown(choices=list(EMBEDDING_MODELS.keys()), label="Select Embedding Model", value="Lightweight (MiniLM-L6)")
temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.7, label="Temperature")
chunk_size = gr.Slider(minimum=500, maximum=2000, step=100, value=1000, label="Chunk Size")
chunk_overlap = gr.Slider(minimum=0, maximum=500, step=50, value=100, label="Chunk Overlap")
init_button = gr.Button("Initialize QA Chain")
gr.Markdown("## Chat Interface")
question = gr.Textbox(label="Ask a Question", placeholder="Type your question here...")
answer = gr.Textbox(label="Answer", interactive=False)
chat_display = gr.Chatbot(label="Chat History")
export_button = gr.Button("Export Chat History")
export_file = gr.File(label="Exported Chat File")
# Event handlers
process_button.click(
fn=process_documents,
inputs=[file_upload, chunk_size, chunk_overlap, embedding_model],
outputs=[status, chat_display]
)
init_button.click(
fn=initialize_qa_chain,
inputs=[llm_model, temperature],
outputs=[status, chat_display]
)
question.submit(
fn=answer_question,
inputs=[question, llm_model, embedding_model, temperature, chunk_size, chunk_overlap],
outputs=[answer, chat_display]
)
export_button.click(
fn=export_chat,
outputs=[status, export_file]
)
reset_button.click(
fn=reset_app,
outputs=[status, chat_display]
)
demo.launch() |