fehmikaya commited on
Commit
cdb08ae
·
verified ·
1 Parent(s): 914d635

Update ragagent.py

Browse files
Files changed (1) hide show
  1. ragagent.py +5 -12
ragagent.py CHANGED
@@ -82,13 +82,13 @@ class RAGAgent():
82
  Here is the question: {question} <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
83
  input_variables=["generation", "question"],
84
  )
85
- '''
86
  def reset_chains():
87
  RAGAgent.retrieval_grader = RAGAgent.retrieval_grader_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
88
  RAGAgent.rag_chain = RAGAgent.answer_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | StrOutputParser()
89
  RAGAgent.hallucination_grader = RAGAgent.hallucination_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
90
  RAGAgent.answer_grader = RAGAgent.answer_grader_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
91
- '''
92
 
93
  def __init__(self, docs):
94
  docs_list = [item for sublist in docs for item in sublist]
@@ -100,15 +100,13 @@ class RAGAgent():
100
 
101
  embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
102
  collection_name = re.sub(r'[^a-zA-Z0-9]', '', doc_splits[0].metadata.get('source'))
103
-
104
- # persistent_client = chromadb.PersistentClient(settings=Settings(allow_reset=True))
105
  persistent_client = chromadb.PersistentClient()
106
- # persistent_client.reset()
107
  if collection_name in [c.name for c in persistent_client.list_collections()]:
108
  print("\ndeleted: ",collection_name)
109
  persistent_client.delete_collection(collection_name)
110
 
111
- collection = persistent_client.create_collection(collection_name)
112
  print("\ncreated: ",collection_name)
113
 
114
  # Add to vectorDB
@@ -121,13 +119,8 @@ class RAGAgent():
121
  vectorstore.add_documents(doc_splits)
122
 
123
  RAGAgent.retriever = vectorstore.as_retriever()
124
- # RAGAgent.reset_chains()
125
  RAGAgent.logs=""
126
-
127
- RAGAgent.retrieval_grader = RAGAgent.retrieval_grader_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
128
- RAGAgent.rag_chain = RAGAgent.answer_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | StrOutputParser()
129
- RAGAgent.hallucination_grader = RAGAgent.hallucination_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
130
- RAGAgent.answer_grader = RAGAgent.answer_grader_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
131
 
132
  def add_log(log):
133
  RAGAgent.logs += log + "\n"
 
82
  Here is the question: {question} <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
83
  input_variables=["generation", "question"],
84
  )
85
+
86
  def reset_chains():
87
  RAGAgent.retrieval_grader = RAGAgent.retrieval_grader_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
88
  RAGAgent.rag_chain = RAGAgent.answer_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | StrOutputParser()
89
  RAGAgent.hallucination_grader = RAGAgent.hallucination_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
90
  RAGAgent.answer_grader = RAGAgent.answer_grader_prompt | CustomLlama3(bearer_token = RAGAgent.HF_TOKEN) | JsonOutputParser()
91
+
92
 
93
  def __init__(self, docs):
94
  docs_list = [item for sublist in docs for item in sublist]
 
100
 
101
  embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
102
  collection_name = re.sub(r'[^a-zA-Z0-9]', '', doc_splits[0].metadata.get('source'))
 
 
103
  persistent_client = chromadb.PersistentClient()
104
+
105
  if collection_name in [c.name for c in persistent_client.list_collections()]:
106
  print("\ndeleted: ",collection_name)
107
  persistent_client.delete_collection(collection_name)
108
 
109
+ persistent_client.create_collection(collection_name)
110
  print("\ncreated: ",collection_name)
111
 
112
  # Add to vectorDB
 
119
  vectorstore.add_documents(doc_splits)
120
 
121
  RAGAgent.retriever = vectorstore.as_retriever()
122
+ RAGAgent.reset_chains()
123
  RAGAgent.logs=""
 
 
 
 
 
124
 
125
  def add_log(log):
126
  RAGAgent.logs += log + "\n"