Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -14,62 +14,85 @@ from langchain.chains import RetrievalQA
|
|
14 |
from langchain_community.llms import HuggingFaceHub
|
15 |
|
16 |
# define constants
|
17 |
-
|
|
|
|
|
|
|
|
|
18 |
MISTRAL_MODEL1 = 'mistralai/Mixtral-8x7B-Instruct-v0.1'
|
19 |
HF_MODEL1 = 'HuggingFaceH4/zephyr-7b-beta'
|
20 |
# define paths
|
21 |
vector_path = 'faiss_index'
|
|
|
|
|
22 |
hf_token = os.environ["HUGGINGFACEHUB_API_TOKEN"]
|
23 |
|
24 |
-
def respond(message, history
|
25 |
-
#system_message,
|
26 |
-
#max_tokens,
|
27 |
-
#temperature,
|
28 |
-
#top_p
|
29 |
-
):
|
30 |
|
31 |
# Initialize your embedding model
|
32 |
-
|
|
|
|
|
33 |
|
34 |
# Load FAISS from relative path
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
raise FileNotFoundError("FAISS index not found in Space. Please upload it to faiss_index/")
|
39 |
|
40 |
# define retriever object
|
41 |
-
|
|
|
|
|
42 |
|
43 |
# initialse chatbot llm
|
44 |
llm = HuggingFaceHub(
|
45 |
repo_id=MISTRAL_MODEL1,
|
46 |
huggingfacehub_api_token=hf_token,
|
47 |
-
model_kwargs={"temperature": 0.
|
48 |
)
|
49 |
|
50 |
# create a RAG pipeline
|
51 |
-
|
|
|
|
|
|
|
52 |
#generate results
|
53 |
-
|
54 |
-
|
|
|
55 |
|
56 |
# remove the top instructions
|
57 |
instruction_prefix = (
|
58 |
"Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer."
|
59 |
)
|
60 |
-
if
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
63 |
# Split question, Helpful Answer and Reason
|
64 |
-
|
65 |
-
|
66 |
-
)
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
-
formatted_responce = f'Question:{
|
73 |
yield formatted_responce
|
74 |
|
75 |
|
|
|
14 |
from langchain_community.llms import HuggingFaceHub
|
15 |
|
16 |
# define constants
|
17 |
+
# Embedding models
|
18 |
+
EMB_MODEL_bge = 'BAAI/bge-base-en-v1.5'
|
19 |
+
EMB_MODEL_gtr_t5 = 'sentence-transformers/gtr-t5-base'
|
20 |
+
EMB_MODEL_e5 = 'intfloat/e5-large-v2'
|
21 |
+
# Chat app model
|
22 |
MISTRAL_MODEL1 = 'mistralai/Mixtral-8x7B-Instruct-v0.1'
|
23 |
HF_MODEL1 = 'HuggingFaceH4/zephyr-7b-beta'
|
24 |
# define paths
|
25 |
vector_path = 'faiss_index'
|
26 |
+
vector_path_2 = 'faiss_index_2'
|
27 |
+
vector_path_e5 = 'faiss_index_3'
|
28 |
hf_token = os.environ["HUGGINGFACEHUB_API_TOKEN"]
|
29 |
|
30 |
+
def respond(message, history):
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
# Initialize your embedding model
|
33 |
+
embedding_model_bge = HuggingFaceEmbeddings(model_name=EMB_MODEL_bge)
|
34 |
+
embedding_model_gtr_t5 = HuggingFaceEmbeddings(model_name=EMB_MODEL_gtr_t5)
|
35 |
+
embedding_model_e5 = HuggingFaceEmbeddings(model_name=EMB_MODEL_e5)
|
36 |
|
37 |
# Load FAISS from relative path
|
38 |
+
vectordb_bge = FAISS.load_local(vector_path_bge, embedding_model_bge, allow_dangerous_deserialization=True)
|
39 |
+
vectordb_gtr_t5 = FAISS.load_local(vector_path_gtr_t5, embedding_model_gtr_t5, allow_dangerous_deserialization=True)
|
40 |
+
vectordb_e5 = FAISS.load_local(vector_path_e5, embedding_model_e5, allow_dangerous_deserialization=True)
|
|
|
41 |
|
42 |
# define retriever object
|
43 |
+
retriever_bge = vectordb_bge.as_retriever(search_type="similarity", search_kwargs={"k": 5})
|
44 |
+
retriever_gtr_t5 = vectordb_gtr_t5.as_retriever(search_type="similarity", search_kwargs={"k": 5})
|
45 |
+
retriever_e5 = vectordb_e5.as_retriever(search_type="similarity", search_kwargs={"k": 5})
|
46 |
|
47 |
# initialse chatbot llm
|
48 |
llm = HuggingFaceHub(
|
49 |
repo_id=MISTRAL_MODEL1,
|
50 |
huggingfacehub_api_token=hf_token,
|
51 |
+
model_kwargs={"temperature": 0.7, "max_new_tokens": 512}
|
52 |
)
|
53 |
|
54 |
# create a RAG pipeline
|
55 |
+
qa_chain_bge = RetrievalQA.from_chain_type(llm=llm, retriever=retriever_bge)
|
56 |
+
qa_chain_gtr_t5 = RetrievalQA.from_chain_type(llm=llm, retriever=retriever_gtr_t5)
|
57 |
+
qa_chain_e5 = RetrievalQA.from_chain_type(llm=llm, retriever=retriever_e5)
|
58 |
+
|
59 |
#generate results
|
60 |
+
responce_bge = qa_chain_bge.invoke(message)['result']
|
61 |
+
responce_gtr_t5 = qa_chain_gtr_t5.invoke(message)['result']
|
62 |
+
responce_e5 = qa_chain_e5.invoke(message)['result']
|
63 |
|
64 |
# remove the top instructions
|
65 |
instruction_prefix = (
|
66 |
"Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer."
|
67 |
)
|
68 |
+
if responce_bge.strip().startswith(instruction_prefix):
|
69 |
+
responce_bge = responce_bge.strip()[len(instruction_prefix):].strip()
|
70 |
+
if responce_gtr_t5.strip().startswith(instruction_prefix):
|
71 |
+
responce_gtr_t5 = responce_gtr_t5.strip()[len(instruction_prefix):].strip()
|
72 |
+
if responce_e5.strip().startswith(instruction_prefix):
|
73 |
+
responce_e5 = responce_e5.strip()[len(instruction_prefix):].strip()
|
74 |
+
|
75 |
# Split question, Helpful Answer and Reason
|
76 |
+
match_bge = re.search(r"^(.*?)(?:\n+)?Question:\s*(.*?)(?:\n+)?Helpful Answer:\s*(.*)", responce_bge, re.DOTALL)
|
77 |
+
match_gtr_t5 = re.search(r"^(.*?)(?:\n+)?Question:\s*(.*?)(?:\n+)?Helpful Answer:\s*(.*)", responce_gtr_t5, re.DOTALL)
|
78 |
+
match_e5 = re.search(r"^(.*?)(?:\n+)?Question:\s*(.*?)(?:\n+)?Helpful Answer:\s*(.*)", responce_e5, re.DOTALL)
|
79 |
+
|
80 |
+
if match_bge:
|
81 |
+
#original_text_bge = match_bge.group(1).strip()
|
82 |
+
question_bge = match_bge.group(2).strip()
|
83 |
+
answer_bge = match_bge.group(3).strip()
|
84 |
+
|
85 |
+
if match_gtr_t5:
|
86 |
+
#original_text_gtr_t5 = match_gtr_t5.group(1).strip()
|
87 |
+
#question_gtr_t5 = match_gtr_t5.group(2).strip()
|
88 |
+
answer_gtr_t5 = match_gtr_t5.group(3).strip()
|
89 |
+
|
90 |
+
if match_e5:
|
91 |
+
#original_text_e5 = match_e5.group(1).strip()
|
92 |
+
#question_e5 = match_e5.group(2).strip()
|
93 |
+
answer_e5 = match_e5.group(3).strip()
|
94 |
|
95 |
+
formatted_responce = f'Question:{question_bge}\nHelpful Answer Type 1:\n{answer_bge}\nHelpful Answer Type 2:\n{answer_gtr_t5}\nHelpful Answer Type 3:\n{answer_e5}'
|
96 |
yield formatted_responce
|
97 |
|
98 |
|