Spaces:
Sleeping
Sleeping
Update ragagent.py
Browse files- ragagent.py +60 -58
ragagent.py
CHANGED
@@ -21,25 +21,6 @@ import os
|
|
21 |
|
22 |
class RAGAgent():
|
23 |
|
24 |
-
def __init__(self, docs):
|
25 |
-
docs_list = [item for sublist in docs for item in sublist]
|
26 |
-
|
27 |
-
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
28 |
-
chunk_size=1000, chunk_overlap=200
|
29 |
-
)
|
30 |
-
doc_splits = text_splitter.split_documents(docs_list)
|
31 |
-
|
32 |
-
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
|
33 |
-
|
34 |
-
# Add to vectorDB
|
35 |
-
vectorstore = Chroma.from_documents(
|
36 |
-
documents=doc_splits,
|
37 |
-
collection_name="rag-chroma",
|
38 |
-
embedding=embedding_function,
|
39 |
-
)
|
40 |
-
RAGAgent.retriever = vectorstore.as_retriever()
|
41 |
-
RAGAgent.logs = ""
|
42 |
-
|
43 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
44 |
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
|
45 |
|
@@ -62,54 +43,75 @@ class RAGAgent():
|
|
62 |
""",
|
63 |
input_variables=["question", "document"],
|
64 |
)
|
65 |
-
retrieval_grader_llm = CustomLlama3(bearer_token = HF_TOKEN)
|
66 |
-
retrieval_grader = retrieval_grader_prompt | retrieval_grader_llm | JsonOutputParser()
|
67 |
|
68 |
answer_prompt = PromptTemplate(
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
)
|
77 |
-
answer_llm = CustomLlama3(bearer_token = HF_TOKEN)
|
78 |
-
|
79 |
-
# Post-processing
|
80 |
-
def format_docs(docs):
|
81 |
-
return "\n\n".join(doc.page_content for doc in docs)
|
82 |
-
|
83 |
-
rag_chain = answer_prompt | answer_llm | StrOutputParser()
|
84 |
|
85 |
hallucination_prompt = PromptTemplate(
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
)
|
97 |
-
hallucination_llm = CustomLlama3(bearer_token = HF_TOKEN)
|
98 |
-
hallucination_grader = hallucination_prompt | hallucination_llm | JsonOutputParser()
|
99 |
|
100 |
answer_grader_prompt = PromptTemplate(
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
)
|
111 |
-
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
web_search_tool = TavilySearchResults(k=3)
|
115 |
|
|
|
21 |
|
22 |
class RAGAgent():
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
25 |
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
|
26 |
|
|
|
43 |
""",
|
44 |
input_variables=["question", "document"],
|
45 |
)
|
|
|
|
|
46 |
|
47 |
answer_prompt = PromptTemplate(
|
48 |
+
template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are an assistant for question-answering tasks.
|
49 |
+
Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know.
|
50 |
+
Use three sentences maximum and keep the answer concise <|eot_id|><|start_header_id|>user<|end_header_id|>
|
51 |
+
Question: {question}
|
52 |
+
Context: {context}
|
53 |
+
Answer: <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
|
54 |
+
input_variables=["question", "document"],
|
55 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
hallucination_prompt = PromptTemplate(
|
58 |
+
template=""" <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a grader assessing whether
|
59 |
+
an answer is grounded in / supported by a set of facts. Give a binary 'yes' or 'no' score to indicate
|
60 |
+
whether the answer is grounded in / supported by a set of facts. Provide the binary score as a JSON with a
|
61 |
+
single key 'score' and no preamble or explanation. <|eot_id|><|start_header_id|>user<|end_header_id|>
|
62 |
+
Here are the facts:
|
63 |
+
\n ------- \n
|
64 |
+
{documents}
|
65 |
+
\n ------- \n
|
66 |
+
Here is the answer: {generation} <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
|
67 |
+
input_variables=["generation", "documents"],
|
68 |
)
|
|
|
|
|
69 |
|
70 |
answer_grader_prompt = PromptTemplate(
|
71 |
+
template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a grader assessing whether an
|
72 |
+
answer is useful to resolve a question. Give a binary score 'yes' or 'no' to indicate whether the answer is
|
73 |
+
useful to resolve a question. Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.
|
74 |
+
<|eot_id|><|start_header_id|>user<|end_header_id|> Here is the answer:
|
75 |
+
\n ------- \n
|
76 |
+
{generation}
|
77 |
+
\n ------- \n
|
78 |
+
Here is the question: {question} <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
|
79 |
+
input_variables=["generation", "question"],
|
80 |
)
|
81 |
+
|
82 |
+
retrieval_grader, rag_chain, hallucination_grader, answer_grader = None
|
83 |
+
|
84 |
+
logs = ""
|
85 |
+
|
86 |
+
print("RAGAgent()")
|
87 |
+
|
88 |
+
def reset_agent():
|
89 |
+
RAGAgent.retrieval_grader = RAGAgent.retrieval_grader_prompt | CustomLlama3(bearer_token = HF_TOKEN) | JsonOutputParser()
|
90 |
+
RAGAgent.rag_chain = RAGAgent.answer_prompt | CustomLlama3(bearer_token = HF_TOKEN) | StrOutputParser()
|
91 |
+
RAGAgent.hallucination_grader = RAGAgent.hallucination_prompt | CustomLlama3(bearer_token = HF_TOKEN) | JsonOutputParser()
|
92 |
+
RAGAgent.answer_grader = RAGAgent.answer_grader_prompt | CustomLlama3(bearer_token = HF_TOKEN) | JsonOutputParser()
|
93 |
+
RAGAgent.logs = ""
|
94 |
+
print("reset_agent")
|
95 |
+
|
96 |
+
def __init__(self, docs):
|
97 |
+
print("init")
|
98 |
+
docs_list = [item for sublist in docs for item in sublist]
|
99 |
+
|
100 |
+
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
101 |
+
chunk_size=1000, chunk_overlap=200
|
102 |
+
)
|
103 |
+
doc_splits = text_splitter.split_documents(docs_list)
|
104 |
+
|
105 |
+
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
|
106 |
+
|
107 |
+
# Add to vectorDB
|
108 |
+
vectorstore = Chroma.from_documents(
|
109 |
+
documents=doc_splits,
|
110 |
+
collection_name="rag-chroma",
|
111 |
+
embedding=embedding_function,
|
112 |
+
)
|
113 |
+
RAGAgent.retriever = vectorstore.as_retriever()
|
114 |
+
RAGAgent.reset_agent()
|
115 |
|
116 |
web_search_tool = TavilySearchResults(k=3)
|
117 |
|