Spaces:
Runtime error
Runtime error
File size: 5,328 Bytes
0658357 65a1209 0658357 402f092 0658357 402f092 0658357 65a1209 4c4129f 0658357 65a1209 0658357 65a1209 0658357 65a1209 4c4129f 0658357 4c4129f 65a1209 4c4129f 65a1209 4c4129f 65a1209 0658357 65a1209 0658357 65a1209 0658357 65a1209 4c4129f 402f092 65a1209 402f092 65a1209 0658357 402f092 0658357 65a1209 4c4129f 65a1209 4c4129f 65a1209 402f092 65a1209 402f092 65a1209 402f092 0658357 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import os
import gradio as gr
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.document_loaders import TextLoader
from langchain.memory import ConversationBufferMemory
from langchain.llms import HuggingFaceHub
from langchain.chains import ConversationalRetrievalChain
embeddings = None
qa_chain = None
def load_embeddings():
global embeddings
if not embeddings:
print("loading embeddings...")
model_name = os.environ['HUGGINGFACEHUB_EMBEDDINGS_MODEL_NAME']
embeddings = HuggingFaceInstructEmbeddings(model_name=model_name)
return embeddings
def split_file(file, chunk_size, chunk_overlap):
print('spliting file...', file.name, chunk_size, chunk_overlap)
loader = TextLoader(file.name)
documents = loader.load()
text_splitter = CharacterTextSplitter(
chunk_size=chunk_size, chunk_overlap=chunk_overlap)
return text_splitter.split_documents(documents)
def get_persist_directory(file_name):
return os.path.join(os.environ['CHROMADB_PERSIST_DIRECTORY'], file_name)
def process_file(file, chunk_size, chunk_overlap):
file_name, _ = os.path.splitext(os.path.basename(file.name))
persist_directory = get_persist_directory(file_name)
print("persist directory", persist_directory)
docs = split_file(file, chunk_size, chunk_overlap)
embeddings = load_embeddings()
vectordb = Chroma.from_documents(documents=docs, embedding=embeddings,
collection_name=file_name, persist_directory=persist_directory)
print(vectordb._client.list_collections())
vectordb.persist()
return 'Done!', gr.Dropdown.update(choices=get_vector_dbs(), value=file_name)
def is_dir(root, name):
path = os.path.join(root, name)
return os.path.isdir(path)
def get_vector_dbs():
root = os.environ['CHROMADB_PERSIST_DIRECTORY']
if not os.path.exists(root):
return []
print('get vector dbs...', root)
files = os.listdir(root)
dirs = list(filter(lambda x: is_dir(root, x), files))
print(dirs)
return dirs
def load_vectordb(file_name):
embeddings = load_embeddings()
persist_directory = get_persist_directory(file_name)
print(persist_directory)
vectordb = Chroma(collection_name=file_name,
embedding_function=embeddings, persist_directory=persist_directory)
print(vectordb._client.list_collections())
return vectordb
def create_qa_chain(collection_name, temperature, max_length):
print('creating qa chain...', collection_name, temperature, max_length)
if not collection_name:
return
global qa_chain
memory = ConversationBufferMemory(
memory_key="chat_history", return_messages=True)
llm = HuggingFaceHub(
repo_id=os.environ["HUGGINGFACEHUB_LLM_REPO_ID"],
model_kwargs={"temperature": temperature, "max_length": max_length}
)
vectordb = load_vectordb(collection_name)
qa_chain = ConversationalRetrievalChain.from_llm(
llm=llm, retriever=vectordb.as_retriever(), memory=memory)
def submit_message(bot_history, text):
bot_history = bot_history + [(text, None)]
return bot_history, ""
def bot(bot_history):
global qa_chain
print(qa_chain, bot_history[-1][1])
result = qa_chain.run(bot_history[-1][0])
print(result)
bot_history[-1][1] = result
return bot_history
def clear_bot():
return None
title = "QnA Chatbot"
with gr.Blocks() as demo:
gr.Markdown(f"# {title}")
with gr.Tab("File"):
upload = gr.File(file_types=["text"], label="Upload File")
chunk_size = gr.Slider(
500, 5000, value=1000, step=100, label="Chunk Size")
chunk_overlap = gr.Slider(0, 30, value=20, label="Chunk Overlap")
process = gr.Button("Process")
result = gr.Label()
with gr.Tab("Bot"):
with gr.Row():
with gr.Column(scale=0.5):
choices = get_vector_dbs()
collection = gr.Dropdown(
choices, value=choices[0] if choices else None, label="Document")
temperature = gr.Slider(
0.0, 1.0, value=0.5, step=0.05, label="Temperature")
max_length = gr.Slider(
20, 1000, value=100, step=10, label="Max Length")
with gr.Column(scale=0.5):
chatbot = gr.Chatbot([], elem_id="chatbot").style(height=550)
message = gr.Textbox(
show_label=False, placeholder="Ask me anything!")
clear = gr.Button("Clear")
process.click(
process_file,
[upload, chunk_size, chunk_overlap],
[result, collection]
)
create_qa_chain(collection.value, temperature.value, max_length.value)
collection.change(create_qa_chain, [collection, temperature, max_length])
temperature.change(create_qa_chain, [collection, temperature, max_length])
max_length.change(create_qa_chain, [collection, temperature, max_length])
message.submit(submit_message, [chatbot, message], [chatbot, message]).then(
bot, chatbot, chatbot
)
clear.click(clear_bot, None, chatbot)
demo.title = title
demo.launch()
|