File size: 11,298 Bytes
7f463bd
 
 
10a2cab
 
 
e2f6806
7e132d2
7f463bd
 
 
 
 
 
 
e1a2492
 
7f463bd
 
 
 
1e04250
7f463bd
7e132d2
 
 
7f463bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dcc25d2
7f463bd
 
 
 
 
 
 
 
35a9a28
 
 
 
6a4989e
35a9a28
 
7f463bd
 
 
35a9a28
 
 
 
 
 
 
 
 
 
7f463bd
 
 
35a9a28
 
 
 
 
 
 
 
 
7f463bd
cdb08ae
1ec2e26
 
 
 
 
cdb08ae
e0dfd56
35a9a28
 
 
 
82ce72b
35a9a28
 
e2f6806
10a2cab
92b2187
e0dfd56
cdb08ae
92293bb
bbdc102
92293bb
efa3471
cdb08ae
bbdc102
e192f86
35a9a28
10a2cab
e192f86
efa3471
10a2cab
35a9a28
10a2cab
 
1e04250
1ec2e26
cdb08ae
1ec2e26
914d635
8665cb1
1ec2e26
b7921f3
7f463bd
 
 
 
 
 
 
 
4d74bcb
1ec2e26
7f463bd
 
 
1ec2e26
7f463bd
 
 
 
1ec2e26
7f463bd
 
 
 
 
4c5877c
bbdc102
 
4c5877c
7f463bd
bbdc102
1ec2e26
7f463bd
 
bbdc102
7f463bd
 
 
1ec2e26
7f463bd
4c5877c
7f463bd
 
1ec2e26
4c5877c
7f463bd
 
 
 
1ec2e26
7f463bd
 
 
 
 
 
 
bbdc102
7f463bd
 
 
a8ed2f0
7f463bd
 
e05c932
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f463bd
 
1ec2e26
7f463bd
 
 
 
1ec2e26
7f463bd
 
 
a8ed2f0
7f463bd
 
 
bbdc102
7f463bd
1ec2e26
7f463bd
 
bbdc102
a8ed2f0
7f463bd
bbdc102
a8ed2f0
7f463bd
bbdc102
a8ed2f0
bbdc102
 
a8ed2f0
7f463bd
 
 
 
 
 
 
 
 
 
911fa32
7f463bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e04250
7f463bd
f84926d
7f463bd
 
 
 
85a5770
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279

###   RAG Agent with Langchain and Langgraph, Hallucination and Sanity Checks with Websearch 

from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
import chromadb

from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_community.tools.tavily_search import TavilySearchResults

from langgraph.graph import END, StateGraph

from customllama3 import CustomLlama3

from typing_extensions import TypedDict
from typing import List
from langchain_core.documents import Document
import os
import re


class RAGAgent():

    HF_TOKEN = os.getenv("HF_TOKEN")
    TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
    
    if HF_TOKEN is None:
        st.error("API key not found. Please set the HF_TOKEN secret in your Hugging Face Space.")
        st.stop()
    if TAVILY_API_KEY is None:
        st.error("API key not found. Please set the TAVILY_API_KEY secret in your Hugging Face Space.")
        st.stop()

    retrieval_grader_prompt = PromptTemplate(
        template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a grader assessing relevance
        of a retrieved document to a user question. If the document contains keywords related to the user question,
        grade it as relevant. It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
        Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question. \n
        Provide the binary score as a JSON with a single key 'score' and no premable or explanation. The JSON format should be exactly: {{"score": "yes"}} or {{"score": "no"}} \n
        <|eot_id|><|start_header_id|>user<|end_header_id|>
        Here is the retrieved document: \n\n {document} \n\n
        Here is the user question: {question} \n <|eot_id|><|start_header_id|>assistant<|end_header_id|>
        """,
        input_variables=["question", "document"],
    )

    answer_prompt = PromptTemplate(
        template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are an assistant for question-answering tasks. 
        Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. 
        Use three sentences maximum and keep the answer concise <|eot_id|><|start_header_id|>user<|end_header_id|>
        Question: {question} 
        Context: {document} 
        Answer: <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
        input_variables=["question", "document"],
    )

    hallucination_prompt = PromptTemplate(
        template=""" <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a grader assessing whether 
        an answer is grounded in / supported by a set of facts. Give a binary 'yes' or 'no' score to indicate 
        whether the answer is grounded in / supported by a set of facts. Provide the binary score as a JSON with a 
        single key 'score' and no preamble or explanation. <|eot_id|><|start_header_id|>user<|end_header_id|>
        Here are the facts:
        \n ------- \n
        {documents} 
        \n ------- \n
        Here is the answer: {generation}  <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
        input_variables=["generation", "documents"],
    )

    answer_grader_prompt = PromptTemplate(
        template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a grader assessing whether an 
        answer is useful to resolve a question. Give a binary score 'yes' or 'no' to indicate whether the answer is 
        useful to resolve a question. Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.
         <|eot_id|><|start_header_id|>user<|end_header_id|> Here is the answer:
        \n ------- \n
        {generation} 
        \n ------- \n
        Here is the question: {question} <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
        input_variables=["generation", "question"],
    )
    
    def reset_chains():
        RAGAgent.retrieval_grader = RAGAgent.retrieval_grader_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
        RAGAgent.rag_chain = RAGAgent.answer_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | StrOutputParser()
        RAGAgent.hallucination_grader = RAGAgent.hallucination_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
        RAGAgent.answer_grader = RAGAgent.answer_grader_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
    
    
    def __init__(self, docs):
        docs_list = [item for sublist in docs for item in sublist]

        text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
            chunk_size=512, chunk_overlap=20
        )
        doc_splits = text_splitter.split_documents(docs_list)
        
        embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
        collection_name = re.sub(r'[^a-zA-Z0-9]', '', doc_splits[0].metadata.get('source'))
        persistent_client = chromadb.PersistentClient()

        if collection_name in [c.name for c in persistent_client.list_collections()]:
            print("\nDELETED COLLECTION: ",collection_name)
            persistent_client.delete_collection(collection_name)
            
        persistent_client.create_collection(collection_name)
        print("\nCREATED COLLECTION: ",collection_name)

        # Add to vectorDB
        vectorstore = Chroma(
            client=persistent_client,
            collection_name=collection_name,
            embedding_function=embedding_function,
        )

        vectorstore.add_documents(doc_splits)
        
        RAGAgent.retriever = vectorstore.as_retriever()
        RAGAgent.reset_chains()
        RAGAgent.logs=""
    
    def add_log(log):
        RAGAgent.logs += log + "\n"
        
    web_search_tool = TavilySearchResults(k=3)

    class GraphState(TypedDict):
      question: str
      generation: str
      web_search: str
      documents: List[str]
    
    def retrieve(state):
        RAGAgent.add_log("---RETRIEVE---")
        question = state["question"]

        # Retrieval
        documents = RAGAgent.retriever.invoke(question)
        return {"documents": documents, "question": question}

    def grade_documents(state):

        RAGAgent.add_log("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
        question = state["question"]
        documents = state["documents"]

        # Score each doc
        filtered_docs = []
        web_search = "Yes"

        print("\n---- QUESTION: ",question)
        
        for d in documents:
            print("\n---- DOCUMENT: ",d.page_content)
            score = RAGAgent.retrieval_grader.invoke(
                {"question": question, "document": d.page_content}
            )
            print("\n---- SCORE: ",score)
            grade = score["score"]
            # Document relevant
            if grade.lower() == "yes":
                RAGAgent.add_log("---GRADE: DOCUMENT RELEVANT---")
                filtered_docs.append(d)
                web_search = "No"
            # Document not relevant
            else:
                RAGAgent.add_log("---GRADE: DOCUMENT NOT RELEVANT---")
                
        return {"documents": filtered_docs, "question": question, "web_search": web_search}

    def decide_to_generate(state):

        RAGAgent.add_log("---ASSESS GRADED DOCUMENTS---")
        question = state["question"]
        web_search = state["web_search"]
        filtered_documents = state["documents"]

        if web_search == "Yes":
            # All documents have been filtered check_relevance
            # We will re-generate a new query
            RAGAgent.add_log("---DOCUMENTS NOT RELEVANT, INCLUDE WEB SEARCH---")
            return "websearch"
        else:
            # We have relevant documents, so generate answer
            RAGAgent.add_log("---DOCUMENTS RELEVANT, GENERATE---")
            return "generate"

    def generate(state):
        RAGAgent.add_log("---GENERATE---")
        question = state["question"]
        documents = state["documents"]

        # RAG generation
        generation = RAGAgent.rag_chain.invoke({"document": documents, "question": question})
        return {"documents": documents, "question": question, "generation": generation}

    def web_search(state):

        RAGAgent.add_log("---WEB SEARCH RUNNING---")
        question = state["question"]
        documents = state["documents"]

        # Web search
        docs = RAGAgent.web_search_tool.invoke({"query": question})
        web_results = "\n".join([d["content"] for d in docs])
        web_results = Document(page_content=web_results)
        if documents is not None:
            documents.append(web_results)
        else:
            documents = [web_results]
        return {"documents": documents, "question": question}
    
    def grade_generation_v_documents_and_question(state):

        RAGAgent.add_log("---CHECK HALLUCINATIONS---")
        question = state["question"]
        documents = state["documents"]
        generation = state["generation"]

        score = RAGAgent.hallucination_grader.invoke(
            {"documents": documents, "generation": generation}
        )
        grade = score["score"]
        result = ""

        # Check hallucination
        if grade == "yes":
            RAGAgent.add_log("---GENERATION IS GROUNDED IN DOCUMENTS---")
            # Check question-answering
            score = RAGAgent.answer_grader.invoke({"question": question, "generation": generation})
            grade = score["score"]
            if grade == "yes":
                RAGAgent.add_log("---GENERATION ADDRESSES QUESTION---")
                result = "useful"
            else:
                RAGAgent.add_log("---GENERATION DOES NOT ADDRESS QUESTION---")
                result = "not useful"
        else:
            RAGAgent.add_log("---GENERATION IS NOT GROUNDED IN DOCUMENTS---")
            result = "not supported"
            
        RAGAgent.add_log("\n--------END--------\n")
        return result

    workflow = StateGraph(GraphState)

    # Define the nodes
    workflow.add_node("websearch", web_search)  # web search
    workflow.add_node("retrieve", retrieve)  # retrieve
    workflow.add_node("grade_documents", grade_documents)  # grade documents
    workflow.add_node("generate", generate)  # generatae

    # Build graph
    workflow.set_entry_point("retrieve")

    workflow.add_edge("retrieve", "grade_documents")
    workflow.add_conditional_edges(
        "grade_documents",
        decide_to_generate,
        {
            "websearch": "websearch",
            "generate": "generate",
        },
    )
    workflow.add_edge("websearch", "generate")
    workflow.add_conditional_edges(
        "generate",
        grade_generation_v_documents_and_question,
        {
            "not supported": END, # "generate",
            "useful": END,
            "not useful": END, #"websearch",
        },
    )

    # Compile
    app = workflow.compile()