fehmikaya commited on
Commit
35a9a28
·
verified ·
1 Parent(s): 28677c7

Update ragagent.py

Browse files
Files changed (1) hide show
  1. ragagent.py +60 -58
ragagent.py CHANGED
@@ -21,25 +21,6 @@ import os
21
 
22
  class RAGAgent():
23
 
24
- def __init__(self, docs):
25
- docs_list = [item for sublist in docs for item in sublist]
26
-
27
- text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
28
- chunk_size=1000, chunk_overlap=200
29
- )
30
- doc_splits = text_splitter.split_documents(docs_list)
31
-
32
- embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
33
-
34
- # Add to vectorDB
35
- vectorstore = Chroma.from_documents(
36
- documents=doc_splits,
37
- collection_name="rag-chroma",
38
- embedding=embedding_function,
39
- )
40
- RAGAgent.retriever = vectorstore.as_retriever()
41
- RAGAgent.logs = ""
42
-
43
  HF_TOKEN = os.getenv("HF_TOKEN")
44
  TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
45
 
@@ -62,54 +43,75 @@ class RAGAgent():
62
  """,
63
  input_variables=["question", "document"],
64
  )
65
- retrieval_grader_llm = CustomLlama3(bearer_token = HF_TOKEN)
66
- retrieval_grader = retrieval_grader_prompt | retrieval_grader_llm | JsonOutputParser()
67
 
68
  answer_prompt = PromptTemplate(
69
- template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are an assistant for question-answering tasks.
70
- Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know.
71
- Use three sentences maximum and keep the answer concise <|eot_id|><|start_header_id|>user<|end_header_id|>
72
- Question: {question}
73
- Context: {context}
74
- Answer: <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
75
- input_variables=["question", "document"],
76
  )
77
- answer_llm = CustomLlama3(bearer_token = HF_TOKEN)
78
-
79
- # Post-processing
80
- def format_docs(docs):
81
- return "\n\n".join(doc.page_content for doc in docs)
82
-
83
- rag_chain = answer_prompt | answer_llm | StrOutputParser()
84
 
85
  hallucination_prompt = PromptTemplate(
86
- template=""" <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a grader assessing whether
87
- an answer is grounded in / supported by a set of facts. Give a binary 'yes' or 'no' score to indicate
88
- whether the answer is grounded in / supported by a set of facts. Provide the binary score as a JSON with a
89
- single key 'score' and no preamble or explanation. <|eot_id|><|start_header_id|>user<|end_header_id|>
90
- Here are the facts:
91
- \n ------- \n
92
- {documents}
93
- \n ------- \n
94
- Here is the answer: {generation} <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
95
- input_variables=["generation", "documents"],
96
  )
97
- hallucination_llm = CustomLlama3(bearer_token = HF_TOKEN)
98
- hallucination_grader = hallucination_prompt | hallucination_llm | JsonOutputParser()
99
 
100
  answer_grader_prompt = PromptTemplate(
101
- template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a grader assessing whether an
102
- answer is useful to resolve a question. Give a binary score 'yes' or 'no' to indicate whether the answer is
103
- useful to resolve a question. Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.
104
- <|eot_id|><|start_header_id|>user<|end_header_id|> Here is the answer:
105
- \n ------- \n
106
- {generation}
107
- \n ------- \n
108
- Here is the question: {question} <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
109
- input_variables=["generation", "question"],
110
  )
111
- answer_grader_llm = CustomLlama3(bearer_token = HF_TOKEN)
112
- answer_grader = answer_grader_prompt | answer_grader_llm | JsonOutputParser()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  web_search_tool = TavilySearchResults(k=3)
115
 
 
21
 
22
  class RAGAgent():
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  HF_TOKEN = os.getenv("HF_TOKEN")
25
  TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
26
 
 
43
  """,
44
  input_variables=["question", "document"],
45
  )
 
 
46
 
47
  answer_prompt = PromptTemplate(
48
+ template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are an assistant for question-answering tasks.
49
+ Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know.
50
+ Use three sentences maximum and keep the answer concise <|eot_id|><|start_header_id|>user<|end_header_id|>
51
+ Question: {question}
52
+ Context: {context}
53
+ Answer: <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
54
+ input_variables=["question", "document"],
55
  )
 
 
 
 
 
 
 
56
 
57
  hallucination_prompt = PromptTemplate(
58
+ template=""" <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a grader assessing whether
59
+ an answer is grounded in / supported by a set of facts. Give a binary 'yes' or 'no' score to indicate
60
+ whether the answer is grounded in / supported by a set of facts. Provide the binary score as a JSON with a
61
+ single key 'score' and no preamble or explanation. <|eot_id|><|start_header_id|>user<|end_header_id|>
62
+ Here are the facts:
63
+ \n ------- \n
64
+ {documents}
65
+ \n ------- \n
66
+ Here is the answer: {generation} <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
67
+ input_variables=["generation", "documents"],
68
  )
 
 
69
 
70
  answer_grader_prompt = PromptTemplate(
71
+ template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a grader assessing whether an
72
+ answer is useful to resolve a question. Give a binary score 'yes' or 'no' to indicate whether the answer is
73
+ useful to resolve a question. Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.
74
+ <|eot_id|><|start_header_id|>user<|end_header_id|> Here is the answer:
75
+ \n ------- \n
76
+ {generation}
77
+ \n ------- \n
78
+ Here is the question: {question} <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
79
+ input_variables=["generation", "question"],
80
  )
81
+
82
+ retrieval_grader, rag_chain, hallucination_grader, answer_grader = None
83
+
84
+ logs = ""
85
+
86
+ print("RAGAgent()")
87
+
88
+ def reset_agent():
89
+ RAGAgent.retrieval_grader = RAGAgent.retrieval_grader_prompt | CustomLlama3(bearer_token = HF_TOKEN) | JsonOutputParser()
90
+ RAGAgent.rag_chain = RAGAgent.answer_prompt | CustomLlama3(bearer_token = HF_TOKEN) | StrOutputParser()
91
+ RAGAgent.hallucination_grader = RAGAgent.hallucination_prompt | CustomLlama3(bearer_token = HF_TOKEN) | JsonOutputParser()
92
+ RAGAgent.answer_grader = RAGAgent.answer_grader_prompt | CustomLlama3(bearer_token = HF_TOKEN) | JsonOutputParser()
93
+ RAGAgent.logs = ""
94
+ print("reset_agent")
95
+
96
+ def __init__(self, docs):
97
+ print("init")
98
+ docs_list = [item for sublist in docs for item in sublist]
99
+
100
+ text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
101
+ chunk_size=1000, chunk_overlap=200
102
+ )
103
+ doc_splits = text_splitter.split_documents(docs_list)
104
+
105
+ embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
106
+
107
+ # Add to vectorDB
108
+ vectorstore = Chroma.from_documents(
109
+ documents=doc_splits,
110
+ collection_name="rag-chroma",
111
+ embedding=embedding_function,
112
+ )
113
+ RAGAgent.retriever = vectorstore.as_retriever()
114
+ RAGAgent.reset_agent()
115
 
116
  web_search_tool = TavilySearchResults(k=3)
117