import chromadb from .BaseDB import BaseDB import random import string import os from tqdm import tqdm class ChromaDB(BaseDB): def __init__(self,embedding,save_type = "persistent"): self.collections = {} self.embedding = embedding base_dir = os.path.dirname(os.path.abspath(__file__)) if save_type == "persistent": self.path = os.path.join(base_dir, "./chromadb_saves/") self.client = chromadb.PersistentClient(path = self.path) else: self.client = chromadb.Client() def init_from_data(self, data, db_name): if db_name in [c.name for c in self.client.list_collections()]: self.collections[db_name] = self.client.get_collection(name=db_name,embedding_function=self.embedding) else: self.collections[db_name] = self.client.create_collection(name=db_name,embedding_function=self.embedding) if len(data) != 0: for i in tqdm(list(range(self.collections[db_name].count()+1,len(data)))): self.collections[db_name].update( documents=[data[i]], ids=[str(i)] ) return def search(self, query, n_results, db_name): if db_name not in self.collections:return [] n_results = min(self.collections[db_name].count(), n_results) if n_results < 1: return [] results = self.collections[db_name].query(query_texts=[query], n_results=n_results) return results['documents'][0] def add(self,text,idx, db_name=""): if db_name not in self.collections: self.collections[db_name] = self.client.create_collection( name=db_name, embedding_function=self.embedding ) self.collections[db_name].add( documents=[text], ids=[idx] ) return collection = self.collections[db_name] existing_doc = collection.get(ids=[idx]) if existing_doc and existing_doc['ids']: collection.update( documents=[text], ids=[idx] ) else: collection.add( documents=[text], ids=[idx] ) def delete(self,idx,db_name): self.collections[db_name].delete(ids=[idx])