maya-persistence / src /chromaIntf.py
anubhav77's picture
v0.1.4
2abd82d
raw
history blame
10.7 kB
from langchain.vectorstores import Chroma
from chromadb.api.fastapi import requests
from langchain.schema import Document
from langchain.chains import RetrievalQA
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.retrievers.self_query.base import SelfQueryRetriever
from langchain.chains.query_constructor.base import AttributeInfo
from langchain.retrievers.self_query.chroma import ChromaTranslator
from llm.llmFactory import LLMFactory
from datetime import datetime
import baseInfra.dropbox_handler as dbh
from baseInfra.dbInterface import DbInterface
from uuid import UUID
from langchain.text_splitter import RecursiveCharacterTextSplitter
import logging, asyncio
logger=logging.getLogger("root")
class myChromaTranslator(ChromaTranslator):
allowed_operators = ["$and", "$or"]
"""Subset of allowed logical operators."""
allowed_comparators = [ "$eq","$ne","$gt","$gte","$lt","$lte",
"$contains","$not_contains","$in","$nin"]
class ChromaIntf():
def __init__(self):
self.db_interface=DbInterface()
model_name = "BAAI/bge-large-en-v1.5"
encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity
self.embedding = HuggingFaceBgeEmbeddings(
model_name=model_name,
model_kwargs={'device': 'cpu'},
encode_kwargs=encode_kwargs
)
self.persist_db_directory = 'db'
self.persist_docs_directory = "persistence-docs"
self.logger_file = "persistence.log"
loop=asyncio.get_event_loop()
try:
loop.run_until_complete(dbh.restoreFolder(self.persist_db_directory))
loop.run_until_complete(dbh.restoreFolder(self.persist_docs_directory))
except:
print("Probably folder doesn't exist as it is brand new setup")
docs = [
Document(
page_content="this is test doc",
metadata={"timestamp":1696743148.474055,"ID":"2000-01-01 15:57:11::664165-test","source":"test"},
id="2000-01-01 15:57:11::664165-test"
),
]
self.vectorstore = Chroma.from_documents(documents=docs,
embedding=self.embedding,
persist_directory=self.persist_db_directory)
#self.vectorstore._client.
# timestamp --> time when added
# source --> notes/references/web/youtube/book/conversation, default conversation
# title --> of document , will be conversation when source is conversation, default blank
# author --> will default to blank
# "Year": 2024,
#"Month": 1,
#"Day": 3,
#"Hour": 11,
#"Minute": 29
self.metadata_field_info = [
AttributeInfo(
name="timestamp",
description="Python datetime.timestamp of the document in isoformat, should not be used for query",
type="str",
),
AttributeInfo(
name="Year",
description="Year from the date when the entry was added in YYYY format",
type="int",
),
AttributeInfo(
name="Month",
description="Month from the date when the entry was added it is from 1-12",
type="int",
),
AttributeInfo(
name="Day",
description="Day of month from the date-time stamp when the entry was added, it is from 1-31",
type="int",
),
AttributeInfo(
name="Hour",
description="Hour from the timestamp when the entry was added",
type="int",
),
AttributeInfo(
name="Minute",
description="Minute from the timestamp when the entry was added",
type="int",
),
AttributeInfo(
name="source",
description="Type of entry",
type="string or list[string]",
),
AttributeInfo(
name="title",
description="Title or Subject of the entry",
type="string",
),
AttributeInfo(
name="author",
description="Author of the entry",
type="string",
)
]
self.document_content_description = "Information to store for retrival from LLM based chatbot"
lf=LLMFactory()
#self.llm=lf.get_llm("executor2")
self.llm=lf.get_llm("executor3")
self.retriever = SelfQueryRetriever.from_llm(
self.llm,
self.vectorstore,
self.document_content_description,
self.metadata_field_info,
structured_query_translator=ChromaTranslator(),
verbose=True
)
async def getRelevantDocs(self,query:str,kwargs:dict):
"""This should also post the result to firebase"""
print("retriver state",self.retriever.search_kwargs)
print("retriver state",self.retriever.search_type)
try:
for key in kwargs.keys():
if "search_type" in key:
self.retriever.search_type=kwargs[key]
else:
self.retriever.search_kwargs[key]=kwargs[key]
except:
print("setting search args failed")
print("reaching step2")
try:
#loop=asyncio.get_event_loop()
retVal=self.retriever.get_relevant_documents(query)
except Exception as ex:
logger.exception("Exception occured:",exc_info=True)
value=[]
excludeMeta=True
print("reaching step3")
print(str(len(retVal)))
print("reaching step4")
try:
for item in retVal:
if excludeMeta:
v=item.page_content+" \n"
else:
v="Info:"+item.page_content+" "
for key in item.metadata.keys():
if key != "ID":
v+=key+":"+str(item.metadata[key])+" "
value.append(v)
print("reaching step5")
self.db_interface.add_to_cache(input=query,value=value)
except:
print("reaching step6")
for item in retVal:
if excludeMeta:
v=item['page_content']+" \n"
else:
v="Info:"+item['page_content']+" "
for key in item['metadata'].keys():
if key != "ID":
v+=key+":"+str(item['metadata'][key])+" "
value.append(v)
print("reaching step7")
self.db_interface.add_to_cache(input=query,value=value)
print("reaching step8")
return retVal
async def addText(self,inStr:str,metadata):
# metadata expected is some of following
# timestamp --> time when added
# source --> notes/references/web/youtube/book/conversation, default conversation
# title --> of document , will be conversation when source is conversation, default blank
# author --> will default to blank
##TODO: Preprocess inStr to remove any html, markdown tags etc.
metadata=metadata.dict()
if "timestamp" not in metadata.keys():
metadata['timestamp']=datetime.now().isoformat()
else:
metadata['timestamp']=datetime.fromisoformat(metadata['timestamp'])
pass
if "source" not in metadata.keys():
metadata['source']="conversation"
if "title" not in metadata.keys():
metadata["title"] = ""
if metadata["source"] == "conversation":
metadata["title"] == "conversation"
if "author" not in metadata.keys():
metadata["author"] = ""
#TODO: If url is present in input or when the splitting need to be done, then we'll need to change how we
# formulate the ID and may be filename to store information
metadata['ID']=metadata['timestamp'].strftime("%Y-%m-%d %H-%M-%S")+"-"+metadata['title']
metadata['Year']=metadata['timestamp'].year
metadata['Month']=metadata['timestamp'].month
metadata['Day']=int(metadata['timestamp'].strftime("%d"))
metadata['Hour']=metadata['timestamp'].hour
metadata['Minute']=metadata['timestamp'].minute
metadata['timestamp']=metadata['timestamp'].isoformat()
print("Metadata is:")
print(metadata)
#md.pop("timestamp")
with open("./docs/"+metadata['ID']+".txt","w") as fd:
fd.write(inStr)
print("written to file", inStr)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=800,
chunk_overlap=50,
length_function=len,
is_separator_regex=False)
#docs = [ Document(page_content=inStr, metadata=metadata)]
docs=text_splitter.create_documents([inStr],[metadata])
partNumber=0
for doc in docs:
if partNumber > 0:
doc.metadata['ID']+=f"__{partNumber}"
partNumber+=1
print(f"{partNumber} follows:")
print(doc)
try:
print(metadata['ID'])
ids=[doc.metadata['ID'] for doc in docs]
print("ids are:")
print(ids)
return await self.vectorstore.aadd_documents(docs,ids=ids)
except Exception as ex:
logger.exception("exception in adding",exc_info=True)
print("inside expect of addText")
return await self.vectorstore.aadd_documents(docs,ids=[metadata.ID])
async def listDocs(self):
collection=self.vectorstore._client.get_collection(self.vectorstore._LANGCHAIN_DEFAULT_COLLECTION_NAME,embedding_function=self.embedding)
return collection.get()
#return self.vectorstore._client._get(collection_id=self._uuid(collectionInfo.id))
async def persist(self):
self.vectorstore.persist()
await dbh.backupFile(self.logger_file)
await dbh.backupFolder(self.persist_db_directory)
return await dbh.backupFolder(self.persist_docs_directory)
def _uuid(self,uuid_str: str) -> UUID:
try:
return UUID(uuid_str)
except ValueError:
print("Error generating uuid")
raise ValueError(f"Could not parse {uuid_str} as a UUID")