Spaces:
Sleeping
Sleeping
Update ragagent.py
Browse files- ragagent.py +44 -37
ragagent.py
CHANGED
@@ -78,19 +78,7 @@ class RAGAgent():
|
|
78 |
Here is the question: {question} <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
|
79 |
input_variables=["generation", "question"],
|
80 |
)
|
81 |
-
|
82 |
-
logs = ""
|
83 |
-
|
84 |
-
print("RAGAgent()")
|
85 |
-
|
86 |
-
def reset_agent():
|
87 |
-
RAGAgent.retrieval_grader = RAGAgent.retrieval_grader_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
|
88 |
-
RAGAgent.rag_chain = RAGAgent.answer_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | StrOutputParser()
|
89 |
-
RAGAgent.hallucination_grader = RAGAgent.hallucination_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
|
90 |
-
RAGAgent.answer_grader = RAGAgent.answer_grader_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
|
91 |
-
RAGAgent.logs = ""
|
92 |
-
print("reset_agent")
|
93 |
-
|
94 |
def __init__(self, docs):
|
95 |
print("init")
|
96 |
docs_list = [item for sublist in docs for item in sublist]
|
@@ -99,18 +87,37 @@ class RAGAgent():
|
|
99 |
chunk_size=1000, chunk_overlap=200
|
100 |
)
|
101 |
doc_splits = text_splitter.split_documents(docs_list)
|
102 |
-
|
103 |
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
|
104 |
-
|
105 |
# Add to vectorDB
|
106 |
vectorstore = Chroma.from_documents(
|
107 |
documents=doc_splits,
|
108 |
collection_name="rag-chroma",
|
109 |
embedding=embedding_function,
|
110 |
)
|
111 |
-
|
112 |
-
RAGAgent.
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
web_search_tool = TavilySearchResults(k=3)
|
115 |
|
116 |
class GraphState(TypedDict):
|
@@ -120,26 +127,26 @@ class RAGAgent():
|
|
120 |
documents: List[str]
|
121 |
|
122 |
def retrieve(state):
|
123 |
-
|
124 |
question = state["question"]
|
125 |
|
126 |
# Retrieval
|
127 |
-
documents =
|
128 |
return {"documents": documents, "question": question}
|
129 |
|
130 |
def generate(state):
|
131 |
-
|
132 |
question = state["question"]
|
133 |
documents = state["documents"]
|
134 |
|
135 |
# RAG generation
|
136 |
-
generation =
|
137 |
return {"documents": documents, "question": question, "generation": generation}
|
138 |
|
139 |
|
140 |
def grade_documents(state):
|
141 |
|
142 |
-
|
143 |
question = state["question"]
|
144 |
documents = state["documents"]
|
145 |
|
@@ -148,25 +155,25 @@ class RAGAgent():
|
|
148 |
web_search = "Yes"
|
149 |
|
150 |
for d in documents:
|
151 |
-
score =
|
152 |
{"question": question, "document": d.page_content}
|
153 |
)
|
154 |
grade = score["score"]
|
155 |
# Document relevant
|
156 |
if grade.lower() == "yes":
|
157 |
-
|
158 |
filtered_docs.append(d)
|
159 |
web_search = "No"
|
160 |
# Document not relevant
|
161 |
else:
|
162 |
-
|
163 |
|
164 |
return {"documents": filtered_docs, "question": question, "web_search": web_search}
|
165 |
|
166 |
|
167 |
def web_search(state):
|
168 |
|
169 |
-
|
170 |
question = state["question"]
|
171 |
documents = state["documents"]
|
172 |
|
@@ -183,7 +190,7 @@ class RAGAgent():
|
|
183 |
|
184 |
def decide_to_generate(state):
|
185 |
|
186 |
-
|
187 |
question = state["question"]
|
188 |
web_search = state["web_search"]
|
189 |
filtered_documents = state["documents"]
|
@@ -191,40 +198,40 @@ class RAGAgent():
|
|
191 |
if web_search == "Yes":
|
192 |
# All documents have been filtered check_relevance
|
193 |
# We will re-generate a new query
|
194 |
-
|
195 |
return "websearch"
|
196 |
else:
|
197 |
# We have relevant documents, so generate answer
|
198 |
-
|
199 |
return "generate"
|
200 |
|
201 |
def grade_generation_v_documents_and_question(state):
|
202 |
|
203 |
-
|
204 |
question = state["question"]
|
205 |
documents = state["documents"]
|
206 |
generation = state["generation"]
|
207 |
|
208 |
-
score =
|
209 |
{"documents": documents, "generation": generation}
|
210 |
)
|
211 |
grade = score["score"]
|
212 |
|
213 |
# Check hallucination
|
214 |
if grade == "yes":
|
215 |
-
|
216 |
# Check question-answering
|
217 |
print("---GRADE GENERATION vs QUESTION---")
|
218 |
-
score =
|
219 |
grade = score["score"]
|
220 |
if grade == "yes":
|
221 |
-
|
222 |
return "useful"
|
223 |
else:
|
224 |
-
|
225 |
return "not useful"
|
226 |
else:
|
227 |
-
|
228 |
return "not supported"
|
229 |
|
230 |
workflow = StateGraph(GraphState)
|
|
|
78 |
Here is the question: {question} <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
|
79 |
input_variables=["generation", "question"],
|
80 |
)
|
81 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
def __init__(self, docs):
|
83 |
print("init")
|
84 |
docs_list = [item for sublist in docs for item in sublist]
|
|
|
87 |
chunk_size=1000, chunk_overlap=200
|
88 |
)
|
89 |
doc_splits = text_splitter.split_documents(docs_list)
|
|
|
90 |
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
|
|
|
91 |
# Add to vectorDB
|
92 |
vectorstore = Chroma.from_documents(
|
93 |
documents=doc_splits,
|
94 |
collection_name="rag-chroma",
|
95 |
embedding=embedding_function,
|
96 |
)
|
97 |
+
self._retriever = vectorstore.as_retriever()
|
98 |
+
self._retrieval_grader = RAGAgent.retrieval_grader_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
|
99 |
+
self._rag_chain = RAGAgent.answer_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | StrOutputParser()
|
100 |
+
self._hallucination_grader = RAGAgent.hallucination_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
|
101 |
+
self._answer_grader = RAGAgent.answer_grader_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
|
102 |
+
self._logs=""
|
103 |
+
|
104 |
+
def get_retriever(self):
|
105 |
+
return self._retriever
|
106 |
+
def get_retrieval_grader(self):
|
107 |
+
return self._retrieval_grader
|
108 |
+
def get_rag_chain(self):
|
109 |
+
return self._rag_chain
|
110 |
+
def get_hallucination_grader(self):
|
111 |
+
return self._hallucination_grader
|
112 |
+
def get_answer_grader(self):
|
113 |
+
return self._answer_grader
|
114 |
+
|
115 |
+
|
116 |
+
def get_logs(self):
|
117 |
+
return self._logs
|
118 |
+
def add_log(log):
|
119 |
+
self._logs += log + "\n"
|
120 |
+
|
121 |
web_search_tool = TavilySearchResults(k=3)
|
122 |
|
123 |
class GraphState(TypedDict):
|
|
|
127 |
documents: List[str]
|
128 |
|
129 |
def retrieve(state):
|
130 |
+
add_log("---RETRIEVE---")
|
131 |
question = state["question"]
|
132 |
|
133 |
# Retrieval
|
134 |
+
documents = get_retriever.invoke(question)
|
135 |
return {"documents": documents, "question": question}
|
136 |
|
137 |
def generate(state):
|
138 |
+
add_log("---GENERATE---")
|
139 |
question = state["question"]
|
140 |
documents = state["documents"]
|
141 |
|
142 |
# RAG generation
|
143 |
+
generation = get_rag_chain.invoke({"context": documents, "question": question})
|
144 |
return {"documents": documents, "question": question, "generation": generation}
|
145 |
|
146 |
|
147 |
def grade_documents(state):
|
148 |
|
149 |
+
add_log("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
|
150 |
question = state["question"]
|
151 |
documents = state["documents"]
|
152 |
|
|
|
155 |
web_search = "Yes"
|
156 |
|
157 |
for d in documents:
|
158 |
+
score = get_retrieval_grader.invoke(
|
159 |
{"question": question, "document": d.page_content}
|
160 |
)
|
161 |
grade = score["score"]
|
162 |
# Document relevant
|
163 |
if grade.lower() == "yes":
|
164 |
+
add_log("---GRADE: DOCUMENT RELEVANT---")
|
165 |
filtered_docs.append(d)
|
166 |
web_search = "No"
|
167 |
# Document not relevant
|
168 |
else:
|
169 |
+
add_log("---GRADE: DOCUMENT NOT RELEVANT---")
|
170 |
|
171 |
return {"documents": filtered_docs, "question": question, "web_search": web_search}
|
172 |
|
173 |
|
174 |
def web_search(state):
|
175 |
|
176 |
+
add_log("---WEB SEARCH---")
|
177 |
question = state["question"]
|
178 |
documents = state["documents"]
|
179 |
|
|
|
190 |
|
191 |
def decide_to_generate(state):
|
192 |
|
193 |
+
add_log("---ASSESS GRADED DOCUMENTS---")
|
194 |
question = state["question"]
|
195 |
web_search = state["web_search"]
|
196 |
filtered_documents = state["documents"]
|
|
|
198 |
if web_search == "Yes":
|
199 |
# All documents have been filtered check_relevance
|
200 |
# We will re-generate a new query
|
201 |
+
add_log("---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, INCLUDE WEB SEARCH---")
|
202 |
return "websearch"
|
203 |
else:
|
204 |
# We have relevant documents, so generate answer
|
205 |
+
add_log("---DECISION: GENERATE---")
|
206 |
return "generate"
|
207 |
|
208 |
def grade_generation_v_documents_and_question(state):
|
209 |
|
210 |
+
add_log("---CHECK HALLUCINATIONS---")
|
211 |
question = state["question"]
|
212 |
documents = state["documents"]
|
213 |
generation = state["generation"]
|
214 |
|
215 |
+
score = get_hallucination_grader.invoke(
|
216 |
{"documents": documents, "generation": generation}
|
217 |
)
|
218 |
grade = score["score"]
|
219 |
|
220 |
# Check hallucination
|
221 |
if grade == "yes":
|
222 |
+
add_log("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
|
223 |
# Check question-answering
|
224 |
print("---GRADE GENERATION vs QUESTION---")
|
225 |
+
score = get_answer_grader.invoke({"question": question, "generation": generation})
|
226 |
grade = score["score"]
|
227 |
if grade == "yes":
|
228 |
+
add_log("---DECISION: GENERATION ADDRESSES QUESTION---")
|
229 |
return "useful"
|
230 |
else:
|
231 |
+
add_log("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
|
232 |
return "not useful"
|
233 |
else:
|
234 |
+
add_log("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
|
235 |
return "not supported"
|
236 |
|
237 |
workflow = StateGraph(GraphState)
|