fehmikaya commited on
Commit
1ec2e26
·
verified ·
1 Parent(s): aaf2b1c

Update ragagent.py

Browse files
Files changed (1) hide show
  1. ragagent.py +29 -38
ragagent.py CHANGED
@@ -78,6 +78,12 @@ class RAGAgent():
78
  Here is the question: {question} <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
79
  input_variables=["generation", "question"],
80
  )
 
 
 
 
 
 
81
 
82
  def __init__(self, docs):
83
  print("init")
@@ -94,29 +100,14 @@ class RAGAgent():
94
  collection_name="rag-chroma",
95
  embedding=embedding_function,
96
  )
97
- self._retriever = vectorstore.as_retriever()
98
- self._retrieval_grader = RAGAgent.retrieval_grader_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
99
- self._rag_chain = RAGAgent.answer_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | StrOutputParser()
100
- self._hallucination_grader = RAGAgent.hallucination_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
101
- self._answer_grader = RAGAgent.answer_grader_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
102
- self._logs=""
103
-
104
- def get_retriever(self):
105
- return self._retriever
106
- def get_retrieval_grader(self):
107
- return self._retrieval_grader
108
- def get_rag_chain(self):
109
- return self._rag_chain
110
- def get_hallucination_grader(self):
111
- return self._hallucination_grader
112
- def get_answer_grader(self):
113
- return self._answer_grader
114
-
115
 
116
  def get_logs(self):
117
  return self._logs
118
  def add_log(self, log):
119
- self._logs += log + "\n"
120
 
121
  web_search_tool = TavilySearchResults(k=3)
122
 
@@ -127,26 +118,26 @@ class RAGAgent():
127
  documents: List[str]
128
 
129
  def retrieve(self, state):
130
- self.add_log("---RETRIEVE---")
131
  question = state["question"]
132
 
133
  # Retrieval
134
- documents = self.get_retriever.invoke(question)
135
  return {"documents": documents, "question": question}
136
 
137
  def generate(state):
138
- add_log("---GENERATE---")
139
  question = state["question"]
140
  documents = state["documents"]
141
 
142
  # RAG generation
143
- generation = get_rag_chain.invoke({"context": documents, "question": question})
144
  return {"documents": documents, "question": question, "generation": generation}
145
 
146
 
147
  def grade_documents(state):
148
 
149
- add_log("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
150
  question = state["question"]
151
  documents = state["documents"]
152
 
@@ -155,25 +146,25 @@ class RAGAgent():
155
  web_search = "Yes"
156
 
157
  for d in documents:
158
- score = get_retrieval_grader.invoke(
159
  {"question": question, "document": d.page_content}
160
  )
161
  grade = score["score"]
162
  # Document relevant
163
  if grade.lower() == "yes":
164
- add_log("---GRADE: DOCUMENT RELEVANT---")
165
  filtered_docs.append(d)
166
  web_search = "No"
167
  # Document not relevant
168
  else:
169
- add_log("---GRADE: DOCUMENT NOT RELEVANT---")
170
 
171
  return {"documents": filtered_docs, "question": question, "web_search": web_search}
172
 
173
 
174
  def web_search(state):
175
 
176
- add_log("---WEB SEARCH---")
177
  question = state["question"]
178
  documents = state["documents"]
179
 
@@ -190,7 +181,7 @@ class RAGAgent():
190
 
191
  def decide_to_generate(state):
192
 
193
- add_log("---ASSESS GRADED DOCUMENTS---")
194
  question = state["question"]
195
  web_search = state["web_search"]
196
  filtered_documents = state["documents"]
@@ -198,40 +189,40 @@ class RAGAgent():
198
  if web_search == "Yes":
199
  # All documents have been filtered check_relevance
200
  # We will re-generate a new query
201
- add_log("---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, INCLUDE WEB SEARCH---")
202
  return "websearch"
203
  else:
204
  # We have relevant documents, so generate answer
205
- add_log("---DECISION: GENERATE---")
206
  return "generate"
207
 
208
  def grade_generation_v_documents_and_question(state):
209
 
210
- add_log("---CHECK HALLUCINATIONS---")
211
  question = state["question"]
212
  documents = state["documents"]
213
  generation = state["generation"]
214
 
215
- score = get_hallucination_grader.invoke(
216
  {"documents": documents, "generation": generation}
217
  )
218
  grade = score["score"]
219
 
220
  # Check hallucination
221
  if grade == "yes":
222
- add_log("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
223
  # Check question-answering
224
  print("---GRADE GENERATION vs QUESTION---")
225
- score = get_answer_grader.invoke({"question": question, "generation": generation})
226
  grade = score["score"]
227
  if grade == "yes":
228
- add_log("---DECISION: GENERATION ADDRESSES QUESTION---")
229
  return "useful"
230
  else:
231
- add_log("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
232
  return "not useful"
233
  else:
234
- add_log("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
235
  return "not supported"
236
 
237
  workflow = StateGraph(GraphState)
 
78
  Here is the question: {question} <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
79
  input_variables=["generation", "question"],
80
  )
81
+
82
+ def reset_chains():
83
+ RAGAgent.retrieval_grader = RAGAgent.retrieval_grader_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
84
+ RAGAgent.rag_chain = RAGAgent.answer_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | StrOutputParser()
85
+ RAGAgent.hallucination_grader = RAGAgent.hallucination_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
86
+ RAGAgent.answer_grader = RAGAgent.answer_grader_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
87
 
88
  def __init__(self, docs):
89
  print("init")
 
100
  collection_name="rag-chroma",
101
  embedding=embedding_function,
102
  )
103
+ RAGAgent.retriever = vectorstore.as_retriever()
104
+ RAGAgent.reset_chains()
105
+ RAGAgent.logs=""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  def get_logs(self):
108
  return self._logs
109
  def add_log(self, log):
110
+ RAGAgent.logs += log + "\n"
111
 
112
  web_search_tool = TavilySearchResults(k=3)
113
 
 
118
  documents: List[str]
119
 
120
  def retrieve(self, state):
121
+ RAGAgent.add_log("---RETRIEVE---")
122
  question = state["question"]
123
 
124
  # Retrieval
125
+ documents = RAGAgent.retriever.invoke(question)
126
  return {"documents": documents, "question": question}
127
 
128
  def generate(state):
129
+ RAGAgent.add_log("---GENERATE---")
130
  question = state["question"]
131
  documents = state["documents"]
132
 
133
  # RAG generation
134
+ generation = RAGAgent.rag_chain.invoke({"context": documents, "question": question})
135
  return {"documents": documents, "question": question, "generation": generation}
136
 
137
 
138
  def grade_documents(state):
139
 
140
+ RAGAgent.add_log("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
141
  question = state["question"]
142
  documents = state["documents"]
143
 
 
146
  web_search = "Yes"
147
 
148
  for d in documents:
149
+ score = RAGAgent.retrieval_grader.invoke(
150
  {"question": question, "document": d.page_content}
151
  )
152
  grade = score["score"]
153
  # Document relevant
154
  if grade.lower() == "yes":
155
+ RAGAgent.add_log("---GRADE: DOCUMENT RELEVANT---")
156
  filtered_docs.append(d)
157
  web_search = "No"
158
  # Document not relevant
159
  else:
160
+ RAGAgent.add_log("---GRADE: DOCUMENT NOT RELEVANT---")
161
 
162
  return {"documents": filtered_docs, "question": question, "web_search": web_search}
163
 
164
 
165
  def web_search(state):
166
 
167
+ RAGAgent.add_log("---WEB SEARCH---")
168
  question = state["question"]
169
  documents = state["documents"]
170
 
 
181
 
182
  def decide_to_generate(state):
183
 
184
+ RAGAgent.add_log("---ASSESS GRADED DOCUMENTS---")
185
  question = state["question"]
186
  web_search = state["web_search"]
187
  filtered_documents = state["documents"]
 
189
  if web_search == "Yes":
190
  # All documents have been filtered check_relevance
191
  # We will re-generate a new query
192
+ RAGAgent.add_log("---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, INCLUDE WEB SEARCH---")
193
  return "websearch"
194
  else:
195
  # We have relevant documents, so generate answer
196
+ RAGAgent.add_log("---DECISION: GENERATE---")
197
  return "generate"
198
 
199
  def grade_generation_v_documents_and_question(state):
200
 
201
+ RAGAgent.add_log("---CHECK HALLUCINATIONS---")
202
  question = state["question"]
203
  documents = state["documents"]
204
  generation = state["generation"]
205
 
206
+ score = RAGAgent.hallucination_grader.invoke(
207
  {"documents": documents, "generation": generation}
208
  )
209
  grade = score["score"]
210
 
211
  # Check hallucination
212
  if grade == "yes":
213
+ RAGAgent.add_log("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
214
  # Check question-answering
215
  print("---GRADE GENERATION vs QUESTION---")
216
+ score = RAGAgent.answer_grader.invoke({"question": question, "generation": generation})
217
  grade = score["score"]
218
  if grade == "yes":
219
+ RAGAgent.add_log("---DECISION: GENERATION ADDRESSES QUESTION---")
220
  return "useful"
221
  else:
222
+ RAGAgent.add_log("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
223
  return "not useful"
224
  else:
225
+ RAGAgent.add_log("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
226
  return "not supported"
227
 
228
  workflow = StateGraph(GraphState)