fehmikaya commited on
Commit
b7921f3
·
verified ·
1 Parent(s): 35a6421

Update ragagent.py

Browse files
Files changed (1) hide show
  1. ragagent.py +44 -37
ragagent.py CHANGED
@@ -78,19 +78,7 @@ class RAGAgent():
78
  Here is the question: {question} <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
79
  input_variables=["generation", "question"],
80
  )
81
-
82
- logs = ""
83
-
84
- print("RAGAgent()")
85
-
86
- def reset_agent():
87
- RAGAgent.retrieval_grader = RAGAgent.retrieval_grader_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
88
- RAGAgent.rag_chain = RAGAgent.answer_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | StrOutputParser()
89
- RAGAgent.hallucination_grader = RAGAgent.hallucination_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
90
- RAGAgent.answer_grader = RAGAgent.answer_grader_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
91
- RAGAgent.logs = ""
92
- print("reset_agent")
93
-
94
  def __init__(self, docs):
95
  print("init")
96
  docs_list = [item for sublist in docs for item in sublist]
@@ -99,18 +87,37 @@ class RAGAgent():
99
  chunk_size=1000, chunk_overlap=200
100
  )
101
  doc_splits = text_splitter.split_documents(docs_list)
102
-
103
  embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
104
-
105
  # Add to vectorDB
106
  vectorstore = Chroma.from_documents(
107
  documents=doc_splits,
108
  collection_name="rag-chroma",
109
  embedding=embedding_function,
110
  )
111
- RAGAgent.retriever = vectorstore.as_retriever()
112
- RAGAgent.reset_agent()
113
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  web_search_tool = TavilySearchResults(k=3)
115
 
116
  class GraphState(TypedDict):
@@ -120,26 +127,26 @@ class RAGAgent():
120
  documents: List[str]
121
 
122
  def retrieve(state):
123
- RAGAgent.logs += "---RETRIEVE---\n"
124
  question = state["question"]
125
 
126
  # Retrieval
127
- documents = RAGAgent.retriever.invoke(question)
128
  return {"documents": documents, "question": question}
129
 
130
  def generate(state):
131
- RAGAgent.logs += "---GENERATE---\n"
132
  question = state["question"]
133
  documents = state["documents"]
134
 
135
  # RAG generation
136
- generation = RAGAgent.rag_chain.invoke({"context": documents, "question": question})
137
  return {"documents": documents, "question": question, "generation": generation}
138
 
139
 
140
  def grade_documents(state):
141
 
142
- RAGAgent.logs += "---CHECK DOCUMENT RELEVANCE TO QUESTION---\n"
143
  question = state["question"]
144
  documents = state["documents"]
145
 
@@ -148,25 +155,25 @@ class RAGAgent():
148
  web_search = "Yes"
149
 
150
  for d in documents:
151
- score = RAGAgent.retrieval_grader.invoke(
152
  {"question": question, "document": d.page_content}
153
  )
154
  grade = score["score"]
155
  # Document relevant
156
  if grade.lower() == "yes":
157
- RAGAgent.logs += "---GRADE: DOCUMENT RELEVANT---\n"
158
  filtered_docs.append(d)
159
  web_search = "No"
160
  # Document not relevant
161
  else:
162
- RAGAgent.logs += "---GRADE: DOCUMENT NOT RELEVANT---\n"
163
 
164
  return {"documents": filtered_docs, "question": question, "web_search": web_search}
165
 
166
 
167
  def web_search(state):
168
 
169
- RAGAgent.logs += "---WEB SEARCH---\n"
170
  question = state["question"]
171
  documents = state["documents"]
172
 
@@ -183,7 +190,7 @@ class RAGAgent():
183
 
184
  def decide_to_generate(state):
185
 
186
- RAGAgent.logs += "---ASSESS GRADED DOCUMENTS---\n"
187
  question = state["question"]
188
  web_search = state["web_search"]
189
  filtered_documents = state["documents"]
@@ -191,40 +198,40 @@ class RAGAgent():
191
  if web_search == "Yes":
192
  # All documents have been filtered check_relevance
193
  # We will re-generate a new query
194
- RAGAgent.logs += "---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, INCLUDE WEB SEARCH---\n"
195
  return "websearch"
196
  else:
197
  # We have relevant documents, so generate answer
198
- RAGAgent.logs += "---DECISION: GENERATE---\n"
199
  return "generate"
200
 
201
  def grade_generation_v_documents_and_question(state):
202
 
203
- RAGAgent.logs += "---CHECK HALLUCINATIONS---\n"
204
  question = state["question"]
205
  documents = state["documents"]
206
  generation = state["generation"]
207
 
208
- score = RAGAgent.hallucination_grader.invoke(
209
  {"documents": documents, "generation": generation}
210
  )
211
  grade = score["score"]
212
 
213
  # Check hallucination
214
  if grade == "yes":
215
- RAGAgent.logs += "---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---\n"
216
  # Check question-answering
217
  print("---GRADE GENERATION vs QUESTION---")
218
- score = RAGAgent.answer_grader.invoke({"question": question, "generation": generation})
219
  grade = score["score"]
220
  if grade == "yes":
221
- RAGAgent.logs += "---DECISION: GENERATION ADDRESSES QUESTION---\n"
222
  return "useful"
223
  else:
224
- RAGAgent.logs += "---DECISION: GENERATION DOES NOT ADDRESS QUESTION---\n"
225
  return "not useful"
226
  else:
227
- RAGAgent.logs += "---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---\n"
228
  return "not supported"
229
 
230
  workflow = StateGraph(GraphState)
 
78
  Here is the question: {question} <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
79
  input_variables=["generation", "question"],
80
  )
81
+
 
 
 
 
 
 
 
 
 
 
 
 
82
  def __init__(self, docs):
83
  print("init")
84
  docs_list = [item for sublist in docs for item in sublist]
 
87
  chunk_size=1000, chunk_overlap=200
88
  )
89
  doc_splits = text_splitter.split_documents(docs_list)
 
90
  embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
 
91
  # Add to vectorDB
92
  vectorstore = Chroma.from_documents(
93
  documents=doc_splits,
94
  collection_name="rag-chroma",
95
  embedding=embedding_function,
96
  )
97
+ self._retriever = vectorstore.as_retriever()
98
+ self._retrieval_grader = RAGAgent.retrieval_grader_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
99
+ self._rag_chain = RAGAgent.answer_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | StrOutputParser()
100
+ self._hallucination_grader = RAGAgent.hallucination_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
101
+ self._answer_grader = RAGAgent.answer_grader_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
102
+ self._logs=""
103
+
104
+ def get_retriever(self):
105
+ return self._retriever
106
+ def get_retrieval_grader(self):
107
+ return self._retrieval_grader
108
+ def get_rag_chain(self):
109
+ return self._rag_chain
110
+ def get_hallucination_grader(self):
111
+ return self._hallucination_grader
112
+ def get_answer_grader(self):
113
+ return self._answer_grader
114
+
115
+
116
+ def get_logs(self):
117
+ return self._logs
118
+ def add_log(log):
119
+ self._logs += log + "\n"
120
+
121
  web_search_tool = TavilySearchResults(k=3)
122
 
123
  class GraphState(TypedDict):
 
127
  documents: List[str]
128
 
129
  def retrieve(state):
130
+ add_log("---RETRIEVE---")
131
  question = state["question"]
132
 
133
  # Retrieval
134
+ documents = get_retriever.invoke(question)
135
  return {"documents": documents, "question": question}
136
 
137
  def generate(state):
138
+ add_log("---GENERATE---")
139
  question = state["question"]
140
  documents = state["documents"]
141
 
142
  # RAG generation
143
+ generation = get_rag_chain.invoke({"context": documents, "question": question})
144
  return {"documents": documents, "question": question, "generation": generation}
145
 
146
 
147
  def grade_documents(state):
148
 
149
+ add_log("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
150
  question = state["question"]
151
  documents = state["documents"]
152
 
 
155
  web_search = "Yes"
156
 
157
  for d in documents:
158
+ score = get_retrieval_grader.invoke(
159
  {"question": question, "document": d.page_content}
160
  )
161
  grade = score["score"]
162
  # Document relevant
163
  if grade.lower() == "yes":
164
+ add_log("---GRADE: DOCUMENT RELEVANT---")
165
  filtered_docs.append(d)
166
  web_search = "No"
167
  # Document not relevant
168
  else:
169
+ add_log("---GRADE: DOCUMENT NOT RELEVANT---")
170
 
171
  return {"documents": filtered_docs, "question": question, "web_search": web_search}
172
 
173
 
174
  def web_search(state):
175
 
176
+ add_log("---WEB SEARCH---")
177
  question = state["question"]
178
  documents = state["documents"]
179
 
 
190
 
191
  def decide_to_generate(state):
192
 
193
+ add_log("---ASSESS GRADED DOCUMENTS---")
194
  question = state["question"]
195
  web_search = state["web_search"]
196
  filtered_documents = state["documents"]
 
198
  if web_search == "Yes":
199
  # All documents have been filtered check_relevance
200
  # We will re-generate a new query
201
+ add_log("---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, INCLUDE WEB SEARCH---")
202
  return "websearch"
203
  else:
204
  # We have relevant documents, so generate answer
205
+ add_log("---DECISION: GENERATE---")
206
  return "generate"
207
 
208
  def grade_generation_v_documents_and_question(state):
209
 
210
+ add_log("---CHECK HALLUCINATIONS---")
211
  question = state["question"]
212
  documents = state["documents"]
213
  generation = state["generation"]
214
 
215
+ score = get_hallucination_grader.invoke(
216
  {"documents": documents, "generation": generation}
217
  )
218
  grade = score["score"]
219
 
220
  # Check hallucination
221
  if grade == "yes":
222
+ add_log("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
223
  # Check question-answering
224
  print("---GRADE GENERATION vs QUESTION---")
225
+ score = get_answer_grader.invoke({"question": question, "generation": generation})
226
  grade = score["score"]
227
  if grade == "yes":
228
+ add_log("---DECISION: GENERATION ADDRESSES QUESTION---")
229
  return "useful"
230
  else:
231
+ add_log("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
232
  return "not useful"
233
  else:
234
+ add_log("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
235
  return "not supported"
236
 
237
  workflow = StateGraph(GraphState)