File size: 4,593 Bytes
0658357
 
 
 
 
 
 
65a1209
 
 
0658357
 
 
65a1209
0658357
 
 
 
65a1209
 
0658357
 
65a1209
 
0658357
 
 
 
 
 
 
65a1209
 
0658357
 
 
 
65a1209
0658357
 
 
 
65a1209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0658357
 
 
 
 
 
65a1209
0658357
 
65a1209
0658357
 
 
65a1209
 
 
 
 
 
 
 
 
 
 
 
 
0658357
 
 
 
65a1209
 
 
 
 
0658357
 
 
 
 
 
 
 
 
 
 
 
 
65a1209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0658357
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os
import gradio as gr

from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.document_loaders import TextLoader
from langchain.memory import ConversationBufferMemory
from langchain.llms import HuggingFaceHub
from langchain.chains import ConversationalRetrievalChain


def load_embeddings():
    print("Loading embeddings...")
    model_name = os.environ['HUGGINGFACEHUB_EMBEDDINGS_MODEL_NAME']
    return HuggingFaceInstructEmbeddings(model_name=model_name)


def split_file(file, chunk_size, chunk_overlap):
    print('spliting file', file.name)
    loader = TextLoader(file.name)
    documents = loader.load()
    text_splitter = CharacterTextSplitter(
        chunk_size=chunk_size, chunk_overlap=chunk_overlap)
    return text_splitter.split_documents(documents)


def get_persist_directory(file_name):
    return os.path.join(os.environ['CHROMADB_PERSIST_DIRECTORY'], file_name)


def process_file(file, chunk_size, chunk_overlap):
    docs = split_file(file, chunk_size, chunk_overlap)
    embeddings = load_embeddings()

    file_name, _ = os.path.splitext(os.path.basename(file.name))
    persist_directory = get_persist_directory(file_name)
    print("persist directory", persist_directory)
    vectordb = Chroma.from_documents(documents=docs, embedding=embeddings,
                                     collection_name=file_name, persist_directory=persist_directory)
    print(vectordb._client.list_collections())
    vectordb.persist()
    return 'Done!'


def is_dir(root, name):
    path = os.path.join(root, name)
    return os.path.isdir(path)


def get_vector_dbs():
    root = os.environ['CHROMADB_PERSIST_DIRECTORY']
    if not os.path.exists(root):
        return []

    files = os.listdir(root)
    dirs = filter(lambda x: is_dir(root, x), files)
    print(dirs)
    return dirs


def load_vectordb(file_name):
    embeddings = load_embeddings()

    persist_directory = get_persist_directory(file_name)
    print(persist_directory)
    vectordb = Chroma(collection_name=file_name,
                      embedding_function=embeddings, persist_directory=persist_directory)
    print(vectordb._client.list_collections())
    return vectordb


def create_qa_chain(collection_name, temperature, max_length):
    print('creating qa chain...')
    memory = ConversationBufferMemory(
        memory_key="chat_history", return_messages=True)
    llm = HuggingFaceHub(
        repo_id=os.environ["HUGGINGFACEHUB_LLM_REPO_ID"],
        model_kwargs={"temperature": temperature, "max_length": max_length}
    )
    vectordb = load_vectordb(collection_name)
    return ConversationalRetrievalChain.from_llm(llm=llm, retriever=vectordb.as_retriever(), memory=memory)


def submit_message(bot_history, text):
    bot_history = bot_history + [(text, None)]
    return bot_history, ""


def bot(bot_history, collection_name, temperature, max_length):
    qa = create_qa_chain(collection_name, temperature, max_length)
    print(qa, bot_history[-1][1])
    qa.run(bot_history[-1][0])

    bot_history[-1][1] = 'so cool!'
    return bot_history


def clear_bot():
    return None


title = "QnA Chatbot"

with gr.Blocks() as demo:
    gr.Markdown(f"# {title}")

    with gr.Tab("File"):
        upload = gr.File(file_types=["text"], label="Upload File")
        chunk_size = gr.Slider(
            500, 5000, value=1000, step=100, label="Chunk Size")
        chunk_overlap = gr.Slider(0, 30, value=20, label="Chunk Overlap")
        process = gr.Button("Process")
        result = gr.Label()

    with gr.Tab("Bot"):
        with gr.Row():
            with gr.Column(scale=0.5):
                collection = gr.Dropdown(
                    choices=get_vector_dbs(), label="Document")
                temperature = gr.Slider(
                    0.0, 1.0, value=0.5, step=0.05, label="Temperature")
                max_length = gr.Slider(20, 1000, value=64, label="Max Length")

            with gr.Column(scale=0.5):
                chatbot = gr.Chatbot([], elem_id="chatbot").style(height=550)
                message = gr.Textbox(
                    show_label=False, placeholder="Ask me anything!")
                clear = gr.Button("Clear")

    process.click(process_file, [upload, chunk_size, chunk_overlap], result)

    message.submit(submit_message, [chatbot, message], [chatbot, message]).then(
        bot, [chatbot, collection, temperature, max_length], chatbot
    )
    clear.click(clear_bot, None, chatbot)

demo.title = title

demo.launch()