Spaces:
Sleeping
Sleeping
acumplid
commited on
Commit
路
e27df4e
1
Parent(s):
af12306
modified ui and rag
Browse files
app.py
CHANGED
@@ -8,10 +8,17 @@ from rag import RAG
|
|
8 |
from utils import setup
|
9 |
|
10 |
MAX_NEW_TOKENS = 700
|
11 |
-
SHOW_MODEL_PARAMETERS_IN_UI = os.environ.get("SHOW_MODEL_PARAMETERS_IN_UI", default="
|
|
|
|
|
|
|
12 |
|
13 |
setup()
|
14 |
|
|
|
|
|
|
|
|
|
15 |
rag = RAG(
|
16 |
vs_hf_repo_path=os.getenv("VS_REPO_NAME"),
|
17 |
vectorstore_path=os.getenv("VECTORSTORE_PATH"),
|
@@ -40,6 +47,9 @@ def generate(prompt, model_parameters):
|
|
40 |
|
41 |
|
42 |
def submit_input(input_, num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature):
|
|
|
|
|
|
|
43 |
if input_.strip() == "":
|
44 |
gr.Warning("Not possible to inference an empty input")
|
45 |
return None
|
@@ -89,41 +99,53 @@ def clear():
|
|
89 |
|
90 |
def gradio_app():
|
91 |
with gr.Blocks(theme=theme) as demo:
|
|
|
|
|
92 |
with gr.Row():
|
93 |
with gr.Column():
|
|
|
94 |
gr.Markdown(
|
95 |
-
"""# Demo de Retrieval-Augmented Generation per la Viquip猫dia
|
96 |
-
馃攳 **Retrieval-Augmented Generation** (RAG) 茅s una tecnologia d'IA que permet interrogar un repositori de documents amb preguntes
|
97 |
-
en llenguatge natural, i combina t猫cniques de recuperaci贸 d'informaci贸 avan莽ades amb models generatius per redactar una resposta
|
98 |
-
fent servir nom茅s la informaci贸 existent en els documents del repositori.
|
99 |
|
100 |
-
馃幆 **Objectiu:** Aquest 茅s un demostrador amb Viquip猫dia i genera la resposta fent servir el model salamandra-7b-instruct.
|
101 |
|
102 |
-
鈿狅笍 **Advertencies**: Aquesta versi贸 茅s experimental. El contingut generat per aquest model no est脿 supervisat i pot ser incorrecte.
|
103 |
-
Si us plau, tingueu-ho en compte quan exploreu aquest recurs. El model en inferencia asociat a aquesta demo de desenvolupament no funciona continuament. Si vol fer proves,
|
104 |
-
contacteu amb nosaltres a Langtech.
|
105 |
-
"""
|
106 |
)
|
107 |
-
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
input_ = Textbox(
|
110 |
-
lines=
|
111 |
label="Input",
|
112 |
placeholder="Qui va crear la guerra de les Galaxies ?",
|
113 |
)
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
|
|
118 |
submit_btn = Button("Submit", variant="primary", interactive=False)
|
119 |
|
120 |
-
with gr.Row(variant="panel"):
|
121 |
-
|
|
|
122 |
num_chunks = Slider(
|
123 |
minimum=1,
|
124 |
maximum=6,
|
125 |
step=1,
|
126 |
-
value=
|
127 |
label="Number of chunks"
|
128 |
)
|
129 |
max_new_tokens = Slider(
|
@@ -166,14 +188,29 @@ def gradio_app():
|
|
166 |
|
167 |
parameters_compontents = [num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature]
|
168 |
|
169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
output = Textbox(
|
171 |
lines=10,
|
|
|
172 |
label="Output",
|
173 |
interactive=False,
|
174 |
show_copy_button=True
|
175 |
)
|
176 |
-
|
|
|
177 |
source_context = gr.Markdown(
|
178 |
label="Sources",
|
179 |
show_label=False,
|
@@ -186,8 +223,9 @@ def gradio_app():
|
|
186 |
# autoscroll=False,
|
187 |
# show_copy_button=True
|
188 |
)
|
189 |
-
|
190 |
|
|
|
|
|
191 |
input_.change(
|
192 |
fn=change_interactive,
|
193 |
inputs=[input_],
|
@@ -219,20 +257,20 @@ def gradio_app():
|
|
219 |
outputs=[output, source_context, context_evaluation],
|
220 |
api_name="get-results"
|
221 |
)
|
|
|
222 |
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
)
|
233 |
|
234 |
demo.launch(show_api=True)
|
235 |
|
236 |
-
|
237 |
if __name__ == "__main__":
|
238 |
gradio_app()
|
|
|
8 |
from utils import setup
|
9 |
|
10 |
MAX_NEW_TOKENS = 700
|
11 |
+
SHOW_MODEL_PARAMETERS_IN_UI = os.environ.get("SHOW_MODEL_PARAMETERS_IN_UI", default="False") == "True"
|
12 |
+
import logging
|
13 |
+
|
14 |
+
logging.basicConfig(level=logging.INFO, format='[%(asctime)s][%(levelname)s] - %(message)s')
|
15 |
|
16 |
setup()
|
17 |
|
18 |
+
print("Loading RAG model...")
|
19 |
+
print("Show model parameters in UI: ", SHOW_MODEL_PARAMETERS_IN_UI)
|
20 |
+
|
21 |
+
# Load the RAG model
|
22 |
rag = RAG(
|
23 |
vs_hf_repo_path=os.getenv("VS_REPO_NAME"),
|
24 |
vectorstore_path=os.getenv("VECTORSTORE_PATH"),
|
|
|
47 |
|
48 |
|
49 |
def submit_input(input_, num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature):
|
50 |
+
"""
|
51 |
+
Function to handle the input and call the RAG model for inference.
|
52 |
+
"""
|
53 |
if input_.strip() == "":
|
54 |
gr.Warning("Not possible to inference an empty input")
|
55 |
return None
|
|
|
99 |
|
100 |
def gradio_app():
|
101 |
with gr.Blocks(theme=theme) as demo:
|
102 |
+
# App Description
|
103 |
+
# =====================================================================================================================================
|
104 |
with gr.Row():
|
105 |
with gr.Column():
|
106 |
+
|
107 |
gr.Markdown(
|
108 |
+
# """# Demo de Retrieval-Augmented Generation per la Viquip猫dia
|
109 |
+
# 馃攳 **Retrieval-Augmented Generation** (RAG) 茅s una tecnologia d'IA que permet interrogar un repositori de documents amb preguntes
|
110 |
+
# en llenguatge natural, i combina t猫cniques de recuperaci贸 d'informaci贸 avan莽ades amb models generatius per redactar una resposta
|
111 |
+
# fent servir nom茅s la informaci贸 existent en els documents del repositori.
|
112 |
|
113 |
+
# 馃幆 **Objectiu:** Aquest 茅s un demostrador amb Viquip猫dia i genera la resposta fent servir el model salamandra-7b-instruct.
|
114 |
|
115 |
+
# 鈿狅笍 **Advertencies**: Aquesta versi贸 茅s experimental. El contingut generat per aquest model no est脿 supervisat i pot ser incorrecte.
|
116 |
+
# Si us plau, tingueu-ho en compte quan exploreu aquest recurs. El model en inferencia asociat a aquesta demo de desenvolupament no funciona continuament. Si vol fer proves,
|
117 |
+
# contacteu amb nosaltres a Langtech.
|
118 |
+
# """
|
119 |
)
|
120 |
+
|
121 |
+
|
122 |
+
# with gr.Row(equal_height=True):
|
123 |
+
with gr.Row(equal_height=False):
|
124 |
+
# User Input
|
125 |
+
# =====================================================================================================================================
|
126 |
+
with gr.Column(scale=2, variant="panel"):
|
127 |
+
|
128 |
input_ = Textbox(
|
129 |
+
lines=5,
|
130 |
label="Input",
|
131 |
placeholder="Qui va crear la guerra de les Galaxies ?",
|
132 |
)
|
133 |
+
|
134 |
+
|
135 |
+
# with gr.Column(variant="panel"):
|
136 |
+
with gr.Row(variant="default"):
|
137 |
+
# with gr.Row(variant="panel"):
|
138 |
+
clear_btn = Button("Clear",)
|
139 |
submit_btn = Button("Submit", variant="primary", interactive=False)
|
140 |
|
141 |
+
# with gr.Row(variant="panel"):
|
142 |
+
with gr.Row(variant="default"):
|
143 |
+
with gr.Accordion("Model parameters (not used)", open=False, visible=SHOW_MODEL_PARAMETERS_IN_UI):
|
144 |
num_chunks = Slider(
|
145 |
minimum=1,
|
146 |
maximum=6,
|
147 |
step=1,
|
148 |
+
value=5,
|
149 |
label="Number of chunks"
|
150 |
)
|
151 |
max_new_tokens = Slider(
|
|
|
188 |
|
189 |
parameters_compontents = [num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature]
|
190 |
|
191 |
+
# Add Examples manually
|
192 |
+
gr.Examples(
|
193 |
+
examples=[
|
194 |
+
["Qui va crear la guerra de les Galaxies?"],
|
195 |
+
["Quin era el nom real de Voltaire?"],
|
196 |
+
["Qu猫 fan al BSC?"]
|
197 |
+
],
|
198 |
+
inputs=[input_], # only inputs
|
199 |
+
)
|
200 |
+
|
201 |
+
# Output
|
202 |
+
# =====================================================================================================================================
|
203 |
+
with gr.Column(scale=10, variant="panel"):
|
204 |
+
|
205 |
output = Textbox(
|
206 |
lines=10,
|
207 |
+
max_lines=25,
|
208 |
label="Output",
|
209 |
interactive=False,
|
210 |
show_copy_button=True
|
211 |
)
|
212 |
+
|
213 |
+
with gr.Accordion("Sources and context:", open=False, visible=False):
|
214 |
source_context = gr.Markdown(
|
215 |
label="Sources",
|
216 |
show_label=False,
|
|
|
223 |
# autoscroll=False,
|
224 |
# show_copy_button=True
|
225 |
)
|
|
|
226 |
|
227 |
+
# Event Handlers
|
228 |
+
# =====================================================================================================================================
|
229 |
input_.change(
|
230 |
fn=change_interactive,
|
231 |
inputs=[input_],
|
|
|
257 |
outputs=[output, source_context, context_evaluation],
|
258 |
api_name="get-results"
|
259 |
)
|
260 |
+
# =====================================================================================================================================
|
261 |
|
262 |
+
# # Output
|
263 |
+
# with gr.Row():
|
264 |
+
# with gr.Column(scale=0.5):
|
265 |
+
# gr.Examples(
|
266 |
+
# examples=[["""Qui va crear la guerra de les Galaxies ?"""],],
|
267 |
+
# inputs=input_,
|
268 |
+
# outputs=[output, source_context, context_evaluation],
|
269 |
+
# fn=submit_input,
|
270 |
+
# )
|
|
|
271 |
|
272 |
demo.launch(show_api=True)
|
273 |
|
274 |
+
|
275 |
if __name__ == "__main__":
|
276 |
gradio_app()
|
rag.py
CHANGED
@@ -10,6 +10,10 @@ from langchain_community.vectorstores import FAISS
|
|
10 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
11 |
|
12 |
|
|
|
|
|
|
|
|
|
13 |
class RAG:
|
14 |
NO_ANSWER_MESSAGE: str = "Ho sento, no he pogut respondre la teva pregunta."
|
15 |
|
@@ -26,11 +30,15 @@ class RAG:
|
|
26 |
self.rerank_number_contexts = rerank_number_contexts
|
27 |
|
28 |
# load vectore store
|
29 |
-
hf_vectorstore = snapshot_download(repo_id=vs_hf_repo_path)
|
30 |
-
|
31 |
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model, model_kwargs={'device': 'cpu'})
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
logging.info("RAG loaded!")
|
35 |
logging.info( self.vectore_store)
|
36 |
|
@@ -44,44 +52,52 @@ class RAG:
|
|
44 |
|
45 |
tokenizer = AutoTokenizer.from_pretrained(rerank_model)
|
46 |
model = AutoModelForSequenceClassification.from_pretrained(rerank_model)
|
|
|
47 |
|
48 |
def get_score(query, passage):
|
49 |
"""Calculate the relevance score of a passage with respect to a query."""
|
50 |
|
51 |
-
|
52 |
inputs = tokenizer(query, passage, return_tensors='pt', truncation=True, padding=True, max_length=512)
|
53 |
-
|
54 |
|
55 |
with torch.no_grad():
|
56 |
outputs = model(**inputs)
|
57 |
|
58 |
-
|
59 |
logits = outputs.logits
|
60 |
-
|
61 |
-
|
62 |
score = logits.view(-1, ).float()
|
63 |
|
|
|
64 |
|
65 |
return score
|
66 |
|
67 |
scores = [get_score(instruction, c[0].page_content) for c in contexts]
|
|
|
|
|
|
|
68 |
combined = list(zip(contexts, scores))
|
69 |
sorted_combined = sorted(combined, key=lambda x: x[1], reverse=True)
|
70 |
sorted_texts, _ = zip(*sorted_combined)
|
71 |
|
72 |
return sorted_texts[:number_of_contexts]
|
73 |
|
74 |
-
|
|
|
75 |
"""Retrieve the most relevant contexts for a given instruction."""
|
|
|
76 |
logging.info("RETRIEVE DOCUMENTS")
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
|
|
|
|
|
|
|
|
84 |
|
|
|
85 |
def predict_dolly(self, instruction, context, model_parameters):
|
86 |
|
87 |
api_key = os.getenv("HF_TOKEN")
|
@@ -155,26 +171,35 @@ class RAG:
|
|
155 |
def get_response(self, prompt: str, model_parameters: dict) -> str:
|
156 |
try:
|
157 |
docs = self.get_context(prompt, model_parameters["NUM_CHUNKS"])
|
158 |
-
text_context, full_context, source = self.beautiful_context(docs)
|
159 |
-
print("#"*100)
|
160 |
-
logging.info("text_context")
|
161 |
-
logging.info(text_context)
|
162 |
-
|
163 |
-
print("#"*100)
|
164 |
-
logging.info("full context")
|
165 |
-
logging.info(full_context)
|
166 |
-
|
167 |
-
print("#"*100)
|
168 |
-
logging.info("source")
|
169 |
-
logging.info(source)
|
170 |
-
|
171 |
-
del model_parameters["NUM_CHUNKS"]
|
172 |
-
|
173 |
-
response = self.predict_completion(prompt, text_context, model_parameters)
|
174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
if not response:
|
176 |
return self.NO_ANSWER_MESSAGE
|
177 |
|
178 |
return response, full_context, source
|
|
|
179 |
except Exception as err:
|
180 |
print(err)
|
|
|
10 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
11 |
|
12 |
|
13 |
+
logging.basicConfig(level=logging.INFO, format='[%(asctime)s][%(levelname)s] - %(message)s')
|
14 |
+
# logging.getLogger().setLevel(logging.INFO)
|
15 |
+
|
16 |
+
|
17 |
class RAG:
|
18 |
NO_ANSWER_MESSAGE: str = "Ho sento, no he pogut respondre la teva pregunta."
|
19 |
|
|
|
30 |
self.rerank_number_contexts = rerank_number_contexts
|
31 |
|
32 |
# load vectore store
|
|
|
|
|
33 |
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model, model_kwargs={'device': 'cpu'})
|
34 |
+
|
35 |
+
if vs_hf_repo_path:
|
36 |
+
hf_vectorstore = snapshot_download(repo_id=vs_hf_repo_path)
|
37 |
+
self.vectore_store = FAISS.load_local(hf_vectorstore, embeddings, allow_dangerous_deserialization=True)
|
38 |
+
else:
|
39 |
+
self.vectore_store = FAISS.load_local(self.vectorstore_path, embeddings, allow_dangerous_deserialization=True)
|
40 |
+
|
41 |
+
|
42 |
logging.info("RAG loaded!")
|
43 |
logging.info( self.vectore_store)
|
44 |
|
|
|
52 |
|
53 |
tokenizer = AutoTokenizer.from_pretrained(rerank_model)
|
54 |
model = AutoModelForSequenceClassification.from_pretrained(rerank_model)
|
55 |
+
logging.info("Rerank model loaded!")
|
56 |
|
57 |
def get_score(query, passage):
|
58 |
"""Calculate the relevance score of a passage with respect to a query."""
|
59 |
|
|
|
60 |
inputs = tokenizer(query, passage, return_tensors='pt', truncation=True, padding=True, max_length=512)
|
61 |
+
print("Inputs: ", inputs)
|
62 |
|
63 |
with torch.no_grad():
|
64 |
outputs = model(**inputs)
|
65 |
|
|
|
66 |
logits = outputs.logits
|
|
|
|
|
67 |
score = logits.view(-1, ).float()
|
68 |
|
69 |
+
print("Score: ", score)
|
70 |
|
71 |
return score
|
72 |
|
73 |
scores = [get_score(instruction, c[0].page_content) for c in contexts]
|
74 |
+
|
75 |
+
print("Scores: ", scores)
|
76 |
+
|
77 |
combined = list(zip(contexts, scores))
|
78 |
sorted_combined = sorted(combined, key=lambda x: x[1], reverse=True)
|
79 |
sorted_texts, _ = zip(*sorted_combined)
|
80 |
|
81 |
return sorted_texts[:number_of_contexts]
|
82 |
|
83 |
+
|
84 |
+
def get_context(self, instruction, number_of_contexts=3):
|
85 |
"""Retrieve the most relevant contexts for a given instruction."""
|
86 |
+
|
87 |
logging.info("RETRIEVE DOCUMENTS")
|
88 |
+
documents_retrieved = self.vectore_store.similarity_search_with_score(instruction, k=self.rerank_number_contexts)
|
89 |
+
logging.info(f"Documents retrieved: {len(documents_retrieved)}")
|
90 |
+
|
91 |
+
if self.rerank_model:
|
92 |
+
logging.info("RERANK DOCUMENTS")
|
93 |
+
documents_reranked = self.rerank_contexts(instruction, documents_retrieved, number_of_contexts=number_of_contexts)
|
94 |
+
else:
|
95 |
+
logging.info("NO RERANKING")
|
96 |
+
documents_reranked = documents_retrieved[:number_of_contexts]
|
97 |
+
|
98 |
+
return documents_reranked
|
99 |
|
100 |
+
|
101 |
def predict_dolly(self, instruction, context, model_parameters):
|
102 |
|
103 |
api_key = os.getenv("HF_TOKEN")
|
|
|
171 |
def get_response(self, prompt: str, model_parameters: dict) -> str:
|
172 |
try:
|
173 |
docs = self.get_context(prompt, model_parameters["NUM_CHUNKS"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
|
175 |
+
response = ""
|
176 |
+
|
177 |
+
for i, (doc, score) in enumerate(docs):
|
178 |
+
|
179 |
+
response += "\n\n" + "="*100
|
180 |
+
response += f"\nDocument {i+1}"
|
181 |
+
response += "\n" + "="*100
|
182 |
+
response += f"\nScore: {score:.5f}"
|
183 |
+
response += f"\nTitle: {doc.metadata['title']}"
|
184 |
+
response += f"\nURL: {doc.metadata['url']}"
|
185 |
+
response += f"\nID: {doc.metadata['id']}"
|
186 |
+
response += f"\nStart index: {doc.metadata['start_index']}"
|
187 |
+
# response += f"\nSource: {doc.metadata['src']}"
|
188 |
+
# response += f"\nRedirected: {doc.metadata['redirected']}"
|
189 |
+
# url = doc.metadata['url']
|
190 |
+
# response += f"\nRevision ID: {url}"
|
191 |
+
# response += f'\nURL: <a href="{url}" target="_blank">{url}</a><br>'
|
192 |
+
response += "\n" + "-"*100 + "\n"
|
193 |
+
response += f"\nContent:\n"
|
194 |
+
response += doc.page_content
|
195 |
+
|
196 |
+
full_context = ""
|
197 |
+
source = []
|
198 |
+
|
199 |
if not response:
|
200 |
return self.NO_ANSWER_MESSAGE
|
201 |
|
202 |
return response, full_context, source
|
203 |
+
|
204 |
except Exception as err:
|
205 |
print(err)
|