File size: 2,415 Bytes
e636070
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import chromadb
from .BaseDB import BaseDB
import random
import string
import os
from tqdm import tqdm

class ChromaDB(BaseDB):
    
    def __init__(self,embedding,save_type = "persistent"):
        self.collections = {}
        self.embedding = embedding

        base_dir = os.path.dirname(os.path.abspath(__file__))
        if save_type == "persistent":
            self.path = os.path.join(base_dir, "./chromadb_saves/")
            self.client = chromadb.PersistentClient(path = self.path)
        else:
            self.client = chromadb.Client()

    def init_from_data(self, data, db_name):
        if db_name in [c.name for c in self.client.list_collections()]:
            self.collections[db_name] = self.client.get_collection(name=db_name,embedding_function=self.embedding)
        else:
            self.collections[db_name] = self.client.create_collection(name=db_name,embedding_function=self.embedding)
            if len(data) != 0:
                for i in  tqdm(list(range(self.collections[db_name].count()+1,len(data)))):
                    self.collections[db_name].update(
                        documents=[data[i]],
                        ids=[str(i)]
                    )
        return 

    def search(self, query, n_results, db_name):
        if db_name not in self.collections:return []
        n_results = min(self.collections[db_name].count(), n_results)
        if n_results < 1:
            return []
        results = self.collections[db_name].query(query_texts=[query], n_results=n_results)
        return results['documents'][0]
    
    def add(self,text,idx, db_name=""):
        if db_name not in self.collections:
            self.collections[db_name] = self.client.create_collection(
                name=db_name,
                embedding_function=self.embedding
            )
            self.collections[db_name].add(
                documents=[text],
                ids=[idx]
            )
            return

        collection = self.collections[db_name]

        existing_doc = collection.get(ids=[idx])

        if existing_doc and existing_doc['ids']:
            collection.update(
                documents=[text],
                ids=[idx]
            )
        else:
            collection.add(
                documents=[text],
                ids=[idx]
            )
    def delete(self,idx,db_name):
        self.collections[db_name].delete(ids=[idx])