Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
import faiss | |
import numpy as np | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores import FAISS | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain.memory import ConversationBufferMemory | |
from langchain_core.documents import Document | |
from PyPDF2 import PdfReader | |
from langchain_anthropic import ChatAnthropic | |
API_KEY = os.getenv('CLAUDE_API_KEY') | |
llm = ChatAnthropic(model="claude-3-5-sonnet-20240620", temperature=0.5, max_tokens=8192, anthropic_api_key=API_KEY) | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") | |
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) | |
vector_store = None | |
def process_file(file_path): | |
_, ext = os.path.splitext(file_path) | |
try: | |
if ext.lower() == '.txt': | |
with open(file_path, 'r', encoding='utf-8') as file: | |
text = file.read() | |
elif ext.lower() == '.docx': | |
with open(file_path, 'rb') as file: | |
content = file.read() | |
text = content.decode('utf-8', errors='ignore') | |
elif ext.lower() == '.pdf': | |
with open(file_path, 'rb') as file: | |
pdf_reader = PdfReader(file) | |
text = '\n'.join([page.extract_text() for page in pdf_reader.pages if page.extract_text()]) | |
else: | |
print(f"Unsupported file type: {ext}") | |
return None | |
return [Document(page_content=text, metadata={"source": file_path})] | |
except Exception as e: | |
print(f"Error processing file {file_path}: {str(e)}") | |
return None | |
def process_files(file_list, progress=gr.Progress()): | |
global vector_store | |
documents = [] | |
total_files = len(file_list) | |
for i, file in enumerate(file_list): | |
progress((i + 1) / total_files, f"Processing file {i + 1} of {total_files}") | |
if file.name.lower().endswith(('.txt', '.docx', '.pdf')): | |
docs = process_file(file.name) | |
if docs: | |
documents.extend(docs) | |
if not documents: | |
return "No documents were successfully processed. Please check your files and try again." | |
progress(0.5, "Splitting text") | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=10000, chunk_overlap=200) | |
texts = text_splitter.split_documents(documents) | |
progress(0.7, "Creating embeddings") | |
vector_store = FAISS.from_documents(texts, embeddings) | |
progress(0.9, "Saving vector store") | |
vector_store.save_local("faiss_index") | |
progress(1.0, "Completed") | |
return f"Embedding process completed and database created. Processed {len(documents)} files. You can now start chatting!" | |
def chat(message, history): | |
global vector_store | |
if vector_store is None: | |
return "Please load documents or an existing index first." | |
qa_chain = ConversationalRetrievalChain.from_llm( | |
llm, | |
vector_store.as_retriever(), | |
memory=memory | |
) | |
result = qa_chain.invoke({"question": message, "chat_history": history}) | |
return result['answer'] | |
def reset_chat(): | |
global memory | |
memory.clear() | |
return [] | |
with gr.Blocks() as demo: | |
gr.Markdown("# Document-based Chatbot") | |
with gr.Row(): | |
with gr.Column(): | |
file_input = gr.File(label="Select Source Knowledge Documents", file_count="multiple", file_types=[".pdf", ".docx", ".txt"]) | |
process_button = gr.Button("Process Files") | |
output = gr.Textbox(label="Processing Output") | |
chatbot = gr.Chatbot() | |
msg = gr.Textbox() | |
send = gr.Button("Send") | |
clear = gr.Button("Clear") | |
def process_selected_files(files): | |
if files: | |
return process_files(files) | |
else: | |
return "No files selected. Please select files and try again." | |
process_button.click(process_selected_files, file_input, output) | |
def respond(message, chat_history): | |
bot_message = chat(message, chat_history) | |
chat_history.append((message, bot_message)) | |
return "", chat_history | |
msg.submit(respond, [msg, chatbot], [msg, chatbot]) | |
send.click(respond, [msg, chatbot], [msg, chatbot]) | |
clear.click(reset_chat, None, chatbot) | |
if __name__ == "__main__": | |
demo.launch() | |