DocuQA-Chat / app.py
reab5555's picture
Update app.py
012caa6 verified
import os
import gradio as gr
import faiss
import numpy as np
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain_core.documents import Document
from PyPDF2 import PdfReader
from langchain_anthropic import ChatAnthropic
API_KEY = os.getenv('CLAUDE_API_KEY')
llm = ChatAnthropic(model="claude-3-5-sonnet-20240620", temperature=0.5, max_tokens=8192, anthropic_api_key=API_KEY)
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
vector_store = None
def process_file(file_path):
_, ext = os.path.splitext(file_path)
try:
if ext.lower() == '.txt':
with open(file_path, 'r', encoding='utf-8') as file:
text = file.read()
elif ext.lower() == '.docx':
with open(file_path, 'rb') as file:
content = file.read()
text = content.decode('utf-8', errors='ignore')
elif ext.lower() == '.pdf':
with open(file_path, 'rb') as file:
pdf_reader = PdfReader(file)
text = '\n'.join([page.extract_text() for page in pdf_reader.pages if page.extract_text()])
else:
print(f"Unsupported file type: {ext}")
return None
return [Document(page_content=text, metadata={"source": file_path})]
except Exception as e:
print(f"Error processing file {file_path}: {str(e)}")
return None
def process_files(file_list, progress=gr.Progress()):
global vector_store
documents = []
total_files = len(file_list)
for i, file in enumerate(file_list):
progress((i + 1) / total_files, f"Processing file {i + 1} of {total_files}")
if file.name.lower().endswith(('.txt', '.docx', '.pdf')):
docs = process_file(file.name)
if docs:
documents.extend(docs)
if not documents:
return "No documents were successfully processed. Please check your files and try again."
progress(0.5, "Splitting text")
text_splitter = RecursiveCharacterTextSplitter(chunk_size=10000, chunk_overlap=200)
texts = text_splitter.split_documents(documents)
progress(0.7, "Creating embeddings")
vector_store = FAISS.from_documents(texts, embeddings)
progress(0.9, "Saving vector store")
vector_store.save_local("faiss_index")
progress(1.0, "Completed")
return f"Embedding process completed and database created. Processed {len(documents)} files. You can now start chatting!"
def chat(message, history):
global vector_store
if vector_store is None:
return "Please load documents or an existing index first."
qa_chain = ConversationalRetrievalChain.from_llm(
llm,
vector_store.as_retriever(),
memory=memory
)
result = qa_chain.invoke({"question": message, "chat_history": history})
return result['answer']
def reset_chat():
global memory
memory.clear()
return []
with gr.Blocks() as demo:
gr.Markdown("# Document-based Chatbot")
with gr.Row():
with gr.Column():
file_input = gr.File(label="Select Source Knowledge Documents", file_count="multiple", file_types=[".pdf", ".docx", ".txt"])
process_button = gr.Button("Process Files")
output = gr.Textbox(label="Processing Output")
chatbot = gr.Chatbot()
msg = gr.Textbox()
send = gr.Button("Send")
clear = gr.Button("Clear")
def process_selected_files(files):
if files:
return process_files(files)
else:
return "No files selected. Please select files and try again."
process_button.click(process_selected_files, file_input, output)
def respond(message, chat_history):
bot_message = chat(message, chat_history)
chat_history.append((message, bot_message))
return "", chat_history
msg.submit(respond, [msg, chatbot], [msg, chatbot])
send.click(respond, [msg, chatbot], [msg, chatbot])
clear.click(reset_chat, None, chatbot)
if __name__ == "__main__":
demo.launch()