Spaces:
Sleeping
Sleeping
Update ragagent.py
Browse files- ragagent.py +5 -12
ragagent.py
CHANGED
@@ -82,13 +82,13 @@ class RAGAgent():
|
|
82 |
Here is the question: {question} <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
|
83 |
input_variables=["generation", "question"],
|
84 |
)
|
85 |
-
|
86 |
def reset_chains():
|
87 |
RAGAgent.retrieval_grader = RAGAgent.retrieval_grader_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
|
88 |
RAGAgent.rag_chain = RAGAgent.answer_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | StrOutputParser()
|
89 |
RAGAgent.hallucination_grader = RAGAgent.hallucination_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
|
90 |
RAGAgent.answer_grader = RAGAgent.answer_grader_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
|
91 |
-
|
92 |
|
93 |
def __init__(self, docs):
|
94 |
docs_list = [item for sublist in docs for item in sublist]
|
@@ -100,15 +100,13 @@ class RAGAgent():
|
|
100 |
|
101 |
embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
|
102 |
collection_name = re.sub(r'[^a-zA-Z0-9]', '', doc_splits[0].metadata.get('source'))
|
103 |
-
|
104 |
-
# persistent_client = chromadb.PersistentClient(settings=Settings(allow_reset=True))
|
105 |
persistent_client = chromadb.PersistentClient()
|
106 |
-
|
107 |
if collection_name in [c.name for c in persistent_client.list_collections()]:
|
108 |
print("\ndeleted: ",collection_name)
|
109 |
persistent_client.delete_collection(collection_name)
|
110 |
|
111 |
-
|
112 |
print("\ncreated: ",collection_name)
|
113 |
|
114 |
# Add to vectorDB
|
@@ -121,13 +119,8 @@ class RAGAgent():
|
|
121 |
vectorstore.add_documents(doc_splits)
|
122 |
|
123 |
RAGAgent.retriever = vectorstore.as_retriever()
|
124 |
-
|
125 |
RAGAgent.logs=""
|
126 |
-
|
127 |
-
RAGAgent.retrieval_grader = RAGAgent.retrieval_grader_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
|
128 |
-
RAGAgent.rag_chain = RAGAgent.answer_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | StrOutputParser()
|
129 |
-
RAGAgent.hallucination_grader = RAGAgent.hallucination_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
|
130 |
-
RAGAgent.answer_grader = RAGAgent.answer_grader_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
|
131 |
|
132 |
def add_log(log):
|
133 |
RAGAgent.logs += log + "\n"
|
|
|
82 |
Here is the question: {question} <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
|
83 |
input_variables=["generation", "question"],
|
84 |
)
|
85 |
+
|
86 |
def reset_chains():
|
87 |
RAGAgent.retrieval_grader = RAGAgent.retrieval_grader_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
|
88 |
RAGAgent.rag_chain = RAGAgent.answer_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | StrOutputParser()
|
89 |
RAGAgent.hallucination_grader = RAGAgent.hallucination_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
|
90 |
RAGAgent.answer_grader = RAGAgent.answer_grader_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
|
91 |
+
|
92 |
|
93 |
def __init__(self, docs):
|
94 |
docs_list = [item for sublist in docs for item in sublist]
|
|
|
100 |
|
101 |
embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
|
102 |
collection_name = re.sub(r'[^a-zA-Z0-9]', '', doc_splits[0].metadata.get('source'))
|
|
|
|
|
103 |
persistent_client = chromadb.PersistentClient()
|
104 |
+
|
105 |
if collection_name in [c.name for c in persistent_client.list_collections()]:
|
106 |
print("\ndeleted: ",collection_name)
|
107 |
persistent_client.delete_collection(collection_name)
|
108 |
|
109 |
+
persistent_client.create_collection(collection_name)
|
110 |
print("\ncreated: ",collection_name)
|
111 |
|
112 |
# Add to vectorDB
|
|
|
119 |
vectorstore.add_documents(doc_splits)
|
120 |
|
121 |
RAGAgent.retriever = vectorstore.as_retriever()
|
122 |
+
RAGAgent.reset_chains()
|
123 |
RAGAgent.logs=""
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
def add_log(log):
|
126 |
RAGAgent.logs += log + "\n"
|