Spaces:
Sleeping
Sleeping
Update ragagent.py
Browse files- ragagent.py +29 -38
ragagent.py
CHANGED
@@ -78,6 +78,12 @@ class RAGAgent():
|
|
78 |
Here is the question: {question} <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
|
79 |
input_variables=["generation", "question"],
|
80 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
def __init__(self, docs):
|
83 |
print("init")
|
@@ -94,29 +100,14 @@ class RAGAgent():
|
|
94 |
collection_name="rag-chroma",
|
95 |
embedding=embedding_function,
|
96 |
)
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
self._hallucination_grader = RAGAgent.hallucination_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
|
101 |
-
self._answer_grader = RAGAgent.answer_grader_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
|
102 |
-
self._logs=""
|
103 |
-
|
104 |
-
def get_retriever(self):
|
105 |
-
return self._retriever
|
106 |
-
def get_retrieval_grader(self):
|
107 |
-
return self._retrieval_grader
|
108 |
-
def get_rag_chain(self):
|
109 |
-
return self._rag_chain
|
110 |
-
def get_hallucination_grader(self):
|
111 |
-
return self._hallucination_grader
|
112 |
-
def get_answer_grader(self):
|
113 |
-
return self._answer_grader
|
114 |
-
|
115 |
|
116 |
def get_logs(self):
|
117 |
return self._logs
|
118 |
def add_log(self, log):
|
119 |
-
|
120 |
|
121 |
web_search_tool = TavilySearchResults(k=3)
|
122 |
|
@@ -127,26 +118,26 @@ class RAGAgent():
|
|
127 |
documents: List[str]
|
128 |
|
129 |
def retrieve(self, state):
|
130 |
-
|
131 |
question = state["question"]
|
132 |
|
133 |
# Retrieval
|
134 |
-
documents =
|
135 |
return {"documents": documents, "question": question}
|
136 |
|
137 |
def generate(state):
|
138 |
-
add_log("---GENERATE---")
|
139 |
question = state["question"]
|
140 |
documents = state["documents"]
|
141 |
|
142 |
# RAG generation
|
143 |
-
generation =
|
144 |
return {"documents": documents, "question": question, "generation": generation}
|
145 |
|
146 |
|
147 |
def grade_documents(state):
|
148 |
|
149 |
-
add_log("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
|
150 |
question = state["question"]
|
151 |
documents = state["documents"]
|
152 |
|
@@ -155,25 +146,25 @@ class RAGAgent():
|
|
155 |
web_search = "Yes"
|
156 |
|
157 |
for d in documents:
|
158 |
-
score =
|
159 |
{"question": question, "document": d.page_content}
|
160 |
)
|
161 |
grade = score["score"]
|
162 |
# Document relevant
|
163 |
if grade.lower() == "yes":
|
164 |
-
add_log("---GRADE: DOCUMENT RELEVANT---")
|
165 |
filtered_docs.append(d)
|
166 |
web_search = "No"
|
167 |
# Document not relevant
|
168 |
else:
|
169 |
-
add_log("---GRADE: DOCUMENT NOT RELEVANT---")
|
170 |
|
171 |
return {"documents": filtered_docs, "question": question, "web_search": web_search}
|
172 |
|
173 |
|
174 |
def web_search(state):
|
175 |
|
176 |
-
add_log("---WEB SEARCH---")
|
177 |
question = state["question"]
|
178 |
documents = state["documents"]
|
179 |
|
@@ -190,7 +181,7 @@ class RAGAgent():
|
|
190 |
|
191 |
def decide_to_generate(state):
|
192 |
|
193 |
-
add_log("---ASSESS GRADED DOCUMENTS---")
|
194 |
question = state["question"]
|
195 |
web_search = state["web_search"]
|
196 |
filtered_documents = state["documents"]
|
@@ -198,40 +189,40 @@ class RAGAgent():
|
|
198 |
if web_search == "Yes":
|
199 |
# All documents have been filtered check_relevance
|
200 |
# We will re-generate a new query
|
201 |
-
add_log("---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, INCLUDE WEB SEARCH---")
|
202 |
return "websearch"
|
203 |
else:
|
204 |
# We have relevant documents, so generate answer
|
205 |
-
add_log("---DECISION: GENERATE---")
|
206 |
return "generate"
|
207 |
|
208 |
def grade_generation_v_documents_and_question(state):
|
209 |
|
210 |
-
add_log("---CHECK HALLUCINATIONS---")
|
211 |
question = state["question"]
|
212 |
documents = state["documents"]
|
213 |
generation = state["generation"]
|
214 |
|
215 |
-
score =
|
216 |
{"documents": documents, "generation": generation}
|
217 |
)
|
218 |
grade = score["score"]
|
219 |
|
220 |
# Check hallucination
|
221 |
if grade == "yes":
|
222 |
-
add_log("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
|
223 |
# Check question-answering
|
224 |
print("---GRADE GENERATION vs QUESTION---")
|
225 |
-
score =
|
226 |
grade = score["score"]
|
227 |
if grade == "yes":
|
228 |
-
add_log("---DECISION: GENERATION ADDRESSES QUESTION---")
|
229 |
return "useful"
|
230 |
else:
|
231 |
-
add_log("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
|
232 |
return "not useful"
|
233 |
else:
|
234 |
-
add_log("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
|
235 |
return "not supported"
|
236 |
|
237 |
workflow = StateGraph(GraphState)
|
|
|
78 |
Here is the question: {question} <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
|
79 |
input_variables=["generation", "question"],
|
80 |
)
|
81 |
+
|
82 |
+
def reset_chains():
|
83 |
+
RAGAgent.retrieval_grader = RAGAgent.retrieval_grader_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
|
84 |
+
RAGAgent.rag_chain = RAGAgent.answer_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | StrOutputParser()
|
85 |
+
RAGAgent.hallucination_grader = RAGAgent.hallucination_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
|
86 |
+
RAGAgent.answer_grader = RAGAgent.answer_grader_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
|
87 |
|
88 |
def __init__(self, docs):
|
89 |
print("init")
|
|
|
100 |
collection_name="rag-chroma",
|
101 |
embedding=embedding_function,
|
102 |
)
|
103 |
+
RAGAgent.retriever = vectorstore.as_retriever()
|
104 |
+
RAGAgent.reset_chains()
|
105 |
+
RAGAgent.logs=""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
def get_logs(self):
|
108 |
return self._logs
|
109 |
def add_log(self, log):
|
110 |
+
RAGAgent.logs += log + "\n"
|
111 |
|
112 |
web_search_tool = TavilySearchResults(k=3)
|
113 |
|
|
|
118 |
documents: List[str]
|
119 |
|
120 |
def retrieve(self, state):
|
121 |
+
RAGAgent.add_log("---RETRIEVE---")
|
122 |
question = state["question"]
|
123 |
|
124 |
# Retrieval
|
125 |
+
documents = RAGAgent.retriever.invoke(question)
|
126 |
return {"documents": documents, "question": question}
|
127 |
|
128 |
def generate(state):
|
129 |
+
RAGAgent.add_log("---GENERATE---")
|
130 |
question = state["question"]
|
131 |
documents = state["documents"]
|
132 |
|
133 |
# RAG generation
|
134 |
+
generation = RAGAgent.rag_chain.invoke({"context": documents, "question": question})
|
135 |
return {"documents": documents, "question": question, "generation": generation}
|
136 |
|
137 |
|
138 |
def grade_documents(state):
|
139 |
|
140 |
+
RAGAgent.add_log("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
|
141 |
question = state["question"]
|
142 |
documents = state["documents"]
|
143 |
|
|
|
146 |
web_search = "Yes"
|
147 |
|
148 |
for d in documents:
|
149 |
+
score = RAGAgent.retrieval_grader.invoke(
|
150 |
{"question": question, "document": d.page_content}
|
151 |
)
|
152 |
grade = score["score"]
|
153 |
# Document relevant
|
154 |
if grade.lower() == "yes":
|
155 |
+
RAGAgent.add_log("---GRADE: DOCUMENT RELEVANT---")
|
156 |
filtered_docs.append(d)
|
157 |
web_search = "No"
|
158 |
# Document not relevant
|
159 |
else:
|
160 |
+
RAGAgent.add_log("---GRADE: DOCUMENT NOT RELEVANT---")
|
161 |
|
162 |
return {"documents": filtered_docs, "question": question, "web_search": web_search}
|
163 |
|
164 |
|
165 |
def web_search(state):
|
166 |
|
167 |
+
RAGAgent.add_log("---WEB SEARCH---")
|
168 |
question = state["question"]
|
169 |
documents = state["documents"]
|
170 |
|
|
|
181 |
|
182 |
def decide_to_generate(state):
|
183 |
|
184 |
+
RAGAgent.add_log("---ASSESS GRADED DOCUMENTS---")
|
185 |
question = state["question"]
|
186 |
web_search = state["web_search"]
|
187 |
filtered_documents = state["documents"]
|
|
|
189 |
if web_search == "Yes":
|
190 |
# All documents have been filtered check_relevance
|
191 |
# We will re-generate a new query
|
192 |
+
RAGAgent.add_log("---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, INCLUDE WEB SEARCH---")
|
193 |
return "websearch"
|
194 |
else:
|
195 |
# We have relevant documents, so generate answer
|
196 |
+
RAGAgent.add_log("---DECISION: GENERATE---")
|
197 |
return "generate"
|
198 |
|
199 |
def grade_generation_v_documents_and_question(state):
|
200 |
|
201 |
+
RAGAgent.add_log("---CHECK HALLUCINATIONS---")
|
202 |
question = state["question"]
|
203 |
documents = state["documents"]
|
204 |
generation = state["generation"]
|
205 |
|
206 |
+
score = RAGAgent.hallucination_grader.invoke(
|
207 |
{"documents": documents, "generation": generation}
|
208 |
)
|
209 |
grade = score["score"]
|
210 |
|
211 |
# Check hallucination
|
212 |
if grade == "yes":
|
213 |
+
RAGAgent.add_log("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
|
214 |
# Check question-answering
|
215 |
print("---GRADE GENERATION vs QUESTION---")
|
216 |
+
score = RAGAgent.answer_grader.invoke({"question": question, "generation": generation})
|
217 |
grade = score["score"]
|
218 |
if grade == "yes":
|
219 |
+
RAGAgent.add_log("---DECISION: GENERATION ADDRESSES QUESTION---")
|
220 |
return "useful"
|
221 |
else:
|
222 |
+
RAGAgent.add_log("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
|
223 |
return "not useful"
|
224 |
else:
|
225 |
+
RAGAgent.add_log("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
|
226 |
return "not supported"
|
227 |
|
228 |
workflow = StateGraph(GraphState)
|