acumplid commited on
Commit
e27df4e
1 Parent(s): af12306

modified ui and rag

Browse files
Files changed (2) hide show
  1. app.py +72 -34
  2. rag.py +58 -33
app.py CHANGED
@@ -8,10 +8,17 @@ from rag import RAG
8
  from utils import setup
9
 
10
  MAX_NEW_TOKENS = 700
11
- SHOW_MODEL_PARAMETERS_IN_UI = os.environ.get("SHOW_MODEL_PARAMETERS_IN_UI", default="True") == "True"
 
 
 
12
 
13
  setup()
14
 
 
 
 
 
15
  rag = RAG(
16
  vs_hf_repo_path=os.getenv("VS_REPO_NAME"),
17
  vectorstore_path=os.getenv("VECTORSTORE_PATH"),
@@ -40,6 +47,9 @@ def generate(prompt, model_parameters):
40
 
41
 
42
  def submit_input(input_, num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature):
 
 
 
43
  if input_.strip() == "":
44
  gr.Warning("Not possible to inference an empty input")
45
  return None
@@ -89,41 +99,53 @@ def clear():
89
 
90
  def gradio_app():
91
  with gr.Blocks(theme=theme) as demo:
 
 
92
  with gr.Row():
93
  with gr.Column():
 
94
  gr.Markdown(
95
- """# Demo de Retrieval-Augmented Generation per la Viquip猫dia
96
- 馃攳 **Retrieval-Augmented Generation** (RAG) 茅s una tecnologia d'IA que permet interrogar un repositori de documents amb preguntes
97
- en llenguatge natural, i combina t猫cniques de recuperaci贸 d'informaci贸 avan莽ades amb models generatius per redactar una resposta
98
- fent servir nom茅s la informaci贸 existent en els documents del repositori.
99
 
100
- 馃幆 **Objectiu:** Aquest 茅s un demostrador amb Viquip猫dia i genera la resposta fent servir el model salamandra-7b-instruct.
101
 
102
- 鈿狅笍 **Advertencies**: Aquesta versi贸 茅s experimental. El contingut generat per aquest model no est脿 supervisat i pot ser incorrecte.
103
- Si us plau, tingueu-ho en compte quan exploreu aquest recurs. El model en inferencia asociat a aquesta demo de desenvolupament no funciona continuament. Si vol fer proves,
104
- contacteu amb nosaltres a Langtech.
105
- """
106
  )
107
- with gr.Row(equal_height=True):
108
- with gr.Column(variant="panel"):
 
 
 
 
 
 
109
  input_ = Textbox(
110
- lines=11,
111
  label="Input",
112
  placeholder="Qui va crear la guerra de les Galaxies ?",
113
  )
114
- with gr.Row(variant="panel"):
115
- clear_btn = Button(
116
- "Clear",
117
- )
 
 
118
  submit_btn = Button("Submit", variant="primary", interactive=False)
119
 
120
- with gr.Row(variant="panel"):
121
- with gr.Accordion("Model parameters", open=False, visible=SHOW_MODEL_PARAMETERS_IN_UI):
 
122
  num_chunks = Slider(
123
  minimum=1,
124
  maximum=6,
125
  step=1,
126
- value=2,
127
  label="Number of chunks"
128
  )
129
  max_new_tokens = Slider(
@@ -166,14 +188,29 @@ def gradio_app():
166
 
167
  parameters_compontents = [num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature]
168
 
169
- with gr.Column(variant="panel"):
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  output = Textbox(
171
  lines=10,
 
172
  label="Output",
173
  interactive=False,
174
  show_copy_button=True
175
  )
176
- with gr.Accordion("Sources and context:", open=False):
 
177
  source_context = gr.Markdown(
178
  label="Sources",
179
  show_label=False,
@@ -186,8 +223,9 @@ def gradio_app():
186
  # autoscroll=False,
187
  # show_copy_button=True
188
  )
189
-
190
 
 
 
191
  input_.change(
192
  fn=change_interactive,
193
  inputs=[input_],
@@ -219,20 +257,20 @@ def gradio_app():
219
  outputs=[output, source_context, context_evaluation],
220
  api_name="get-results"
221
  )
 
222
 
223
- with gr.Row():
224
- with gr.Column(scale=0.5):
225
- gr.Examples(
226
- examples=[
227
- ["""Qui va crear la guerra de les Galaxies ?"""],
228
- ],
229
- inputs=input_,
230
- outputs=[output, source_context, context_evaluation],
231
- fn=submit_input,
232
- )
233
 
234
  demo.launch(show_api=True)
235
 
236
-
237
  if __name__ == "__main__":
238
  gradio_app()
 
8
  from utils import setup
9
 
10
  MAX_NEW_TOKENS = 700
11
+ SHOW_MODEL_PARAMETERS_IN_UI = os.environ.get("SHOW_MODEL_PARAMETERS_IN_UI", default="False") == "True"
12
+ import logging
13
+
14
+ logging.basicConfig(level=logging.INFO, format='[%(asctime)s][%(levelname)s] - %(message)s')
15
 
16
  setup()
17
 
18
+ print("Loading RAG model...")
19
+ print("Show model parameters in UI: ", SHOW_MODEL_PARAMETERS_IN_UI)
20
+
21
+ # Load the RAG model
22
  rag = RAG(
23
  vs_hf_repo_path=os.getenv("VS_REPO_NAME"),
24
  vectorstore_path=os.getenv("VECTORSTORE_PATH"),
 
47
 
48
 
49
  def submit_input(input_, num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature):
50
+ """
51
+ Function to handle the input and call the RAG model for inference.
52
+ """
53
  if input_.strip() == "":
54
  gr.Warning("Not possible to inference an empty input")
55
  return None
 
99
 
100
  def gradio_app():
101
  with gr.Blocks(theme=theme) as demo:
102
+ # App Description
103
+ # =====================================================================================================================================
104
  with gr.Row():
105
  with gr.Column():
106
+
107
  gr.Markdown(
108
+ # """# Demo de Retrieval-Augmented Generation per la Viquip猫dia
109
+ # 馃攳 **Retrieval-Augmented Generation** (RAG) 茅s una tecnologia d'IA que permet interrogar un repositori de documents amb preguntes
110
+ # en llenguatge natural, i combina t猫cniques de recuperaci贸 d'informaci贸 avan莽ades amb models generatius per redactar una resposta
111
+ # fent servir nom茅s la informaci贸 existent en els documents del repositori.
112
 
113
+ # 馃幆 **Objectiu:** Aquest 茅s un demostrador amb Viquip猫dia i genera la resposta fent servir el model salamandra-7b-instruct.
114
 
115
+ # 鈿狅笍 **Advertencies**: Aquesta versi贸 茅s experimental. El contingut generat per aquest model no est脿 supervisat i pot ser incorrecte.
116
+ # Si us plau, tingueu-ho en compte quan exploreu aquest recurs. El model en inferencia asociat a aquesta demo de desenvolupament no funciona continuament. Si vol fer proves,
117
+ # contacteu amb nosaltres a Langtech.
118
+ # """
119
  )
120
+
121
+
122
+ # with gr.Row(equal_height=True):
123
+ with gr.Row(equal_height=False):
124
+ # User Input
125
+ # =====================================================================================================================================
126
+ with gr.Column(scale=2, variant="panel"):
127
+
128
  input_ = Textbox(
129
+ lines=5,
130
  label="Input",
131
  placeholder="Qui va crear la guerra de les Galaxies ?",
132
  )
133
+
134
+
135
+ # with gr.Column(variant="panel"):
136
+ with gr.Row(variant="default"):
137
+ # with gr.Row(variant="panel"):
138
+ clear_btn = Button("Clear",)
139
  submit_btn = Button("Submit", variant="primary", interactive=False)
140
 
141
+ # with gr.Row(variant="panel"):
142
+ with gr.Row(variant="default"):
143
+ with gr.Accordion("Model parameters (not used)", open=False, visible=SHOW_MODEL_PARAMETERS_IN_UI):
144
  num_chunks = Slider(
145
  minimum=1,
146
  maximum=6,
147
  step=1,
148
+ value=5,
149
  label="Number of chunks"
150
  )
151
  max_new_tokens = Slider(
 
188
 
189
  parameters_compontents = [num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature]
190
 
191
+ # Add Examples manually
192
+ gr.Examples(
193
+ examples=[
194
+ ["Qui va crear la guerra de les Galaxies?"],
195
+ ["Quin era el nom real de Voltaire?"],
196
+ ["Qu猫 fan al BSC?"]
197
+ ],
198
+ inputs=[input_], # only inputs
199
+ )
200
+
201
+ # Output
202
+ # =====================================================================================================================================
203
+ with gr.Column(scale=10, variant="panel"):
204
+
205
  output = Textbox(
206
  lines=10,
207
+ max_lines=25,
208
  label="Output",
209
  interactive=False,
210
  show_copy_button=True
211
  )
212
+
213
+ with gr.Accordion("Sources and context:", open=False, visible=False):
214
  source_context = gr.Markdown(
215
  label="Sources",
216
  show_label=False,
 
223
  # autoscroll=False,
224
  # show_copy_button=True
225
  )
 
226
 
227
+ # Event Handlers
228
+ # =====================================================================================================================================
229
  input_.change(
230
  fn=change_interactive,
231
  inputs=[input_],
 
257
  outputs=[output, source_context, context_evaluation],
258
  api_name="get-results"
259
  )
260
+ # =====================================================================================================================================
261
 
262
+ # # Output
263
+ # with gr.Row():
264
+ # with gr.Column(scale=0.5):
265
+ # gr.Examples(
266
+ # examples=[["""Qui va crear la guerra de les Galaxies ?"""],],
267
+ # inputs=input_,
268
+ # outputs=[output, source_context, context_evaluation],
269
+ # fn=submit_input,
270
+ # )
 
271
 
272
  demo.launch(show_api=True)
273
 
274
+
275
  if __name__ == "__main__":
276
  gradio_app()
rag.py CHANGED
@@ -10,6 +10,10 @@ from langchain_community.vectorstores import FAISS
10
  from langchain_community.embeddings import HuggingFaceEmbeddings
11
 
12
 
 
 
 
 
13
  class RAG:
14
  NO_ANSWER_MESSAGE: str = "Ho sento, no he pogut respondre la teva pregunta."
15
 
@@ -26,11 +30,15 @@ class RAG:
26
  self.rerank_number_contexts = rerank_number_contexts
27
 
28
  # load vectore store
29
- hf_vectorstore = snapshot_download(repo_id=vs_hf_repo_path)
30
-
31
  embeddings = HuggingFaceEmbeddings(model_name=embeddings_model, model_kwargs={'device': 'cpu'})
32
- self.vectore_store = FAISS.load_local(hf_vectorstore, embeddings, allow_dangerous_deserialization=True)
33
- # self.vectore_store = FAISS.load_local(self.vectorstore_path, embeddings, allow_dangerous_deserialization=True)#, allow_dangerous_deserialization=True)
 
 
 
 
 
 
34
  logging.info("RAG loaded!")
35
  logging.info( self.vectore_store)
36
 
@@ -44,44 +52,52 @@ class RAG:
44
 
45
  tokenizer = AutoTokenizer.from_pretrained(rerank_model)
46
  model = AutoModelForSequenceClassification.from_pretrained(rerank_model)
 
47
 
48
  def get_score(query, passage):
49
  """Calculate the relevance score of a passage with respect to a query."""
50
 
51
-
52
  inputs = tokenizer(query, passage, return_tensors='pt', truncation=True, padding=True, max_length=512)
53
-
54
 
55
  with torch.no_grad():
56
  outputs = model(**inputs)
57
 
58
-
59
  logits = outputs.logits
60
-
61
-
62
  score = logits.view(-1, ).float()
63
 
 
64
 
65
  return score
66
 
67
  scores = [get_score(instruction, c[0].page_content) for c in contexts]
 
 
 
68
  combined = list(zip(contexts, scores))
69
  sorted_combined = sorted(combined, key=lambda x: x[1], reverse=True)
70
  sorted_texts, _ = zip(*sorted_combined)
71
 
72
  return sorted_texts[:number_of_contexts]
73
 
74
- def get_context(self, instruction, number_of_contexts=2):
 
75
  """Retrieve the most relevant contexts for a given instruction."""
 
76
  logging.info("RETRIEVE DOCUMENTS")
77
- documentos = self.vectore_store.similarity_search_with_score(instruction, k=self.rerank_number_contexts)
78
- # logging.info(documentos)
79
- logging.info("RERANK DOCUMENTS")
80
- documentos = self.rerank_contexts(instruction, documentos, number_of_contexts=number_of_contexts)
81
- # logging.info(documentos)
82
- print("Reranked documents")
83
- return documentos
 
 
 
 
84
 
 
85
  def predict_dolly(self, instruction, context, model_parameters):
86
 
87
  api_key = os.getenv("HF_TOKEN")
@@ -155,26 +171,35 @@ class RAG:
155
  def get_response(self, prompt: str, model_parameters: dict) -> str:
156
  try:
157
  docs = self.get_context(prompt, model_parameters["NUM_CHUNKS"])
158
- text_context, full_context, source = self.beautiful_context(docs)
159
- print("#"*100)
160
- logging.info("text_context")
161
- logging.info(text_context)
162
-
163
- print("#"*100)
164
- logging.info("full context")
165
- logging.info(full_context)
166
-
167
- print("#"*100)
168
- logging.info("source")
169
- logging.info(source)
170
-
171
- del model_parameters["NUM_CHUNKS"]
172
-
173
- response = self.predict_completion(prompt, text_context, model_parameters)
174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  if not response:
176
  return self.NO_ANSWER_MESSAGE
177
 
178
  return response, full_context, source
 
179
  except Exception as err:
180
  print(err)
 
10
  from langchain_community.embeddings import HuggingFaceEmbeddings
11
 
12
 
13
+ logging.basicConfig(level=logging.INFO, format='[%(asctime)s][%(levelname)s] - %(message)s')
14
+ # logging.getLogger().setLevel(logging.INFO)
15
+
16
+
17
  class RAG:
18
  NO_ANSWER_MESSAGE: str = "Ho sento, no he pogut respondre la teva pregunta."
19
 
 
30
  self.rerank_number_contexts = rerank_number_contexts
31
 
32
  # load vectore store
 
 
33
  embeddings = HuggingFaceEmbeddings(model_name=embeddings_model, model_kwargs={'device': 'cpu'})
34
+
35
+ if vs_hf_repo_path:
36
+ hf_vectorstore = snapshot_download(repo_id=vs_hf_repo_path)
37
+ self.vectore_store = FAISS.load_local(hf_vectorstore, embeddings, allow_dangerous_deserialization=True)
38
+ else:
39
+ self.vectore_store = FAISS.load_local(self.vectorstore_path, embeddings, allow_dangerous_deserialization=True)
40
+
41
+
42
  logging.info("RAG loaded!")
43
  logging.info( self.vectore_store)
44
 
 
52
 
53
  tokenizer = AutoTokenizer.from_pretrained(rerank_model)
54
  model = AutoModelForSequenceClassification.from_pretrained(rerank_model)
55
+ logging.info("Rerank model loaded!")
56
 
57
  def get_score(query, passage):
58
  """Calculate the relevance score of a passage with respect to a query."""
59
 
 
60
  inputs = tokenizer(query, passage, return_tensors='pt', truncation=True, padding=True, max_length=512)
61
+ print("Inputs: ", inputs)
62
 
63
  with torch.no_grad():
64
  outputs = model(**inputs)
65
 
 
66
  logits = outputs.logits
 
 
67
  score = logits.view(-1, ).float()
68
 
69
+ print("Score: ", score)
70
 
71
  return score
72
 
73
  scores = [get_score(instruction, c[0].page_content) for c in contexts]
74
+
75
+ print("Scores: ", scores)
76
+
77
  combined = list(zip(contexts, scores))
78
  sorted_combined = sorted(combined, key=lambda x: x[1], reverse=True)
79
  sorted_texts, _ = zip(*sorted_combined)
80
 
81
  return sorted_texts[:number_of_contexts]
82
 
83
+
84
+ def get_context(self, instruction, number_of_contexts=3):
85
  """Retrieve the most relevant contexts for a given instruction."""
86
+
87
  logging.info("RETRIEVE DOCUMENTS")
88
+ documents_retrieved = self.vectore_store.similarity_search_with_score(instruction, k=self.rerank_number_contexts)
89
+ logging.info(f"Documents retrieved: {len(documents_retrieved)}")
90
+
91
+ if self.rerank_model:
92
+ logging.info("RERANK DOCUMENTS")
93
+ documents_reranked = self.rerank_contexts(instruction, documents_retrieved, number_of_contexts=number_of_contexts)
94
+ else:
95
+ logging.info("NO RERANKING")
96
+ documents_reranked = documents_retrieved[:number_of_contexts]
97
+
98
+ return documents_reranked
99
 
100
+
101
  def predict_dolly(self, instruction, context, model_parameters):
102
 
103
  api_key = os.getenv("HF_TOKEN")
 
171
  def get_response(self, prompt: str, model_parameters: dict) -> str:
172
  try:
173
  docs = self.get_context(prompt, model_parameters["NUM_CHUNKS"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
+ response = ""
176
+
177
+ for i, (doc, score) in enumerate(docs):
178
+
179
+ response += "\n\n" + "="*100
180
+ response += f"\nDocument {i+1}"
181
+ response += "\n" + "="*100
182
+ response += f"\nScore: {score:.5f}"
183
+ response += f"\nTitle: {doc.metadata['title']}"
184
+ response += f"\nURL: {doc.metadata['url']}"
185
+ response += f"\nID: {doc.metadata['id']}"
186
+ response += f"\nStart index: {doc.metadata['start_index']}"
187
+ # response += f"\nSource: {doc.metadata['src']}"
188
+ # response += f"\nRedirected: {doc.metadata['redirected']}"
189
+ # url = doc.metadata['url']
190
+ # response += f"\nRevision ID: {url}"
191
+ # response += f'\nURL: <a href="{url}" target="_blank">{url}</a><br>'
192
+ response += "\n" + "-"*100 + "\n"
193
+ response += f"\nContent:\n"
194
+ response += doc.page_content
195
+
196
+ full_context = ""
197
+ source = []
198
+
199
  if not response:
200
  return self.NO_ANSWER_MESSAGE
201
 
202
  return response, full_context, source
203
+
204
  except Exception as err:
205
  print(err)