File size: 5,531 Bytes
7ba5af2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import os
import gradio as gr
import faiss
import numpy as np
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain_core.documents import Document
from PyPDF2 import PdfReader
from langchain_anthropic import ChatAnthropic

API_KEY = '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'
llm = ChatAnthropic(model="claude-3-5-sonnet-20240620", temperature=0.5, max_tokens=8192, anthropic_api_key=API_KEY)

embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")

memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)

vector_store = None


def process_file(file_path):
    _, ext = os.path.splitext(file_path)
    try:
        if ext.lower() == '.txt':
            with open(file_path, 'r', encoding='utf-8') as file:
                text = file.read()
        elif ext.lower() == '.docx':
            with open(file_path, 'rb') as file:
                content = file.read()
                text = content.decode('utf-8', errors='ignore')
        elif ext.lower() == '.pdf':
            with open(file_path, 'rb') as file:
                pdf_reader = PdfReader(file)
                text = '\n'.join([page.extract_text() for page in pdf_reader.pages if page.extract_text()])
        else:
            print(f"Unsupported file type: {ext}")
            return None

        return [Document(page_content=text, metadata={"source": file_path})]
    except Exception as e:
        print(f"Error processing file {file_path}: {str(e)}")
        return None


def process_files(file_list, progress=gr.Progress()):
    global vector_store
    documents = []
    total_files = len(file_list)

    for i, file in enumerate(file_list):
        progress((i + 1) / total_files, f"Processing file {i + 1} of {total_files}")
        if file.name.lower().endswith(('.txt', '.docx', '.pdf')):
            docs = process_file(file.name)
            if docs:
                documents.extend(docs)

    if not documents:
        return "No documents were successfully processed. Please check your files and try again."

    progress(0.5, "Splitting text")
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=10000, chunk_overlap=200)
    texts = text_splitter.split_documents(documents)

    progress(0.7, "Creating embeddings")
    vector_store = FAISS.from_documents(texts, embeddings)

    progress(0.9, "Saving vector store")
    vector_store.save_local("faiss_index")

    progress(1.0, "Completed")
    return f"Embedding process completed and database created. Processed {len(documents)} files. You can now start chatting!"


def load_existing_index(folder_path):
    global vector_store
    try:
        index_file = os.path.join(folder_path, "index.faiss")
        pkl_file = os.path.join(folder_path, "index.pkl")

        if not os.path.exists(index_file) or not os.path.exists(pkl_file):
            return f"Error: FAISS index files not found in {folder_path}. Please ensure both 'index.faiss' and 'index.pkl' are present."

        vector_store = FAISS.load_local(folder_path, embeddings, allow_dangerous_deserialization=True)
        return f"Successfully loaded existing index from {folder_path}."
    except Exception as e:
        return f"Error loading index: {str(e)}"


def chat(message, history):
    global vector_store
    if vector_store is None:
        return "Please load documents or an existing index first."

    qa_chain = ConversationalRetrievalChain.from_llm(
        llm,
        vector_store.as_retriever(),
        memory=memory
    )

    result = qa_chain.invoke({"question": message, "chat_history": history})
    return result['answer']


def reset_chat():
    global memory
    memory.clear()
    return []


with gr.Blocks() as demo:
    gr.Markdown("# Document-based Chatbot")

    with gr.Row():
        with gr.Column():
            file_input = gr.File(label="Select Files", file_count="multiple", file_types=[".pdf", ".docx", ".txt"])
            process_button = gr.Button("Process Files")
        with gr.Column():
            index_folder = gr.Textbox(label="Existing Index Folder Path",
                                      value="C:\\Works\\Data\\projects\\Python\\QA_Chatbot\\faiss_index")
            load_index_button = gr.Button("Load Existing Index")

    output = gr.Textbox(label="Processing Output")

    chatbot = gr.Chatbot()
    msg = gr.Textbox()
    send = gr.Button("Send")
    clear = gr.Button("Clear")


    def process_selected_files(files):
        if files:
            return process_files(files)
        else:
            return "No files selected. Please select files and try again."


    def load_selected_index(folder_path):
        return load_existing_index(folder_path)


    process_button.click(process_selected_files, file_input, output)
    load_index_button.click(load_selected_index, index_folder, output)


    def respond(message, chat_history):
        bot_message = chat(message, chat_history)
        chat_history.append((message, bot_message))
        return "", chat_history


    msg.submit(respond, [msg, chatbot], [msg, chatbot])
    send.click(respond, [msg, chatbot], [msg, chatbot])
    clear.click(reset_chat, None, chatbot)

if __name__ == "__main__":
    demo.launch()