File size: 4,379 Bytes
72d74b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
012caa6
72d74b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a15a3a9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import gradio as gr
import faiss
import numpy as np
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain_core.documents import Document
from PyPDF2 import PdfReader
from langchain_anthropic import ChatAnthropic

API_KEY = os.getenv('CLAUDE_API_KEY')

llm = ChatAnthropic(model="claude-3-5-sonnet-20240620", temperature=0.5, max_tokens=8192, anthropic_api_key=API_KEY)

embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")

memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)

vector_store = None


def process_file(file_path):
    _, ext = os.path.splitext(file_path)
    try:
        if ext.lower() == '.txt':
            with open(file_path, 'r', encoding='utf-8') as file:
                text = file.read()
        elif ext.lower() == '.docx':
            with open(file_path, 'rb') as file:
                content = file.read()
                text = content.decode('utf-8', errors='ignore')
        elif ext.lower() == '.pdf':
            with open(file_path, 'rb') as file:
                pdf_reader = PdfReader(file)
                text = '\n'.join([page.extract_text() for page in pdf_reader.pages if page.extract_text()])
        else:
            print(f"Unsupported file type: {ext}")
            return None

        return [Document(page_content=text, metadata={"source": file_path})]
    except Exception as e:
        print(f"Error processing file {file_path}: {str(e)}")
        return None


def process_files(file_list, progress=gr.Progress()):
    global vector_store
    documents = []
    total_files = len(file_list)

    for i, file in enumerate(file_list):
        progress((i + 1) / total_files, f"Processing file {i + 1} of {total_files}")
        if file.name.lower().endswith(('.txt', '.docx', '.pdf')):
            docs = process_file(file.name)
            if docs:
                documents.extend(docs)

    if not documents:
        return "No documents were successfully processed. Please check your files and try again."

    progress(0.5, "Splitting text")
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=10000, chunk_overlap=200)
    texts = text_splitter.split_documents(documents)

    progress(0.7, "Creating embeddings")
    vector_store = FAISS.from_documents(texts, embeddings)

    progress(0.9, "Saving vector store")
    vector_store.save_local("faiss_index")

    progress(1.0, "Completed")
    return f"Embedding process completed and database created. Processed {len(documents)} files. You can now start chatting!"


def chat(message, history):
    global vector_store
    if vector_store is None:
        return "Please load documents or an existing index first."

    qa_chain = ConversationalRetrievalChain.from_llm(
        llm,
        vector_store.as_retriever(),
        memory=memory
    )

    result = qa_chain.invoke({"question": message, "chat_history": history})
    return result['answer']


def reset_chat():
    global memory
    memory.clear()
    return []


with gr.Blocks() as demo:
    gr.Markdown("# Document-based Chatbot")

    with gr.Row():
        with gr.Column():
            file_input = gr.File(label="Select Source Knowledge Documents", file_count="multiple", file_types=[".pdf", ".docx", ".txt"])
            process_button = gr.Button("Process Files")

    output = gr.Textbox(label="Processing Output")

    chatbot = gr.Chatbot()
    msg = gr.Textbox()
    send = gr.Button("Send")
    clear = gr.Button("Clear")


    def process_selected_files(files):
        if files:
            return process_files(files)
        else:
            return "No files selected. Please select files and try again."


    process_button.click(process_selected_files, file_input, output)


    def respond(message, chat_history):
        bot_message = chat(message, chat_history)
        chat_history.append((message, bot_message))
        return "", chat_history


    msg.submit(respond, [msg, chatbot], [msg, chatbot])
    send.click(respond, [msg, chatbot], [msg, chatbot])
    clear.click(reset_chat, None, chatbot)

if __name__ == "__main__":
    demo.launch()