mehmet0001 commited on
Commit
f52dfc6
·
verified ·
1 Parent(s): b7e76a9

Update create_database.py

Browse files
Files changed (1) hide show
  1. create_database.py +29 -27
create_database.py CHANGED
@@ -1,28 +1,30 @@
1
- from sentence_transformers import SentenceTransformer
2
- import chromadb
3
-
4
- def split_list(list_,chunk_size):
5
- return [list_[i:i+chunk_size] for i in range(0,len(list_),chunk_size)]
6
-
7
- def create_database(txt):
8
- class EmbeddingFn:
9
- def __init__(self,model_name):
10
- self.model = SentenceTransformer(model_name)
11
-
12
- def __call__(self,input):
13
- return self.model.encode(input).tolist()
14
-
15
- embedding_fn = EmbeddingFn("sentence-transformers/all-mpnet-base-v2")
16
-
17
- ids = [str(i) for i in range(len(txt))]
18
-
19
- chroma_cli = chromadb.Client()
20
- collection = chroma_cli.create_collection("chat-with-docs",embedding_function=embedding_fn)
21
-
22
- txt = split_list(txt,5000)
23
- ids = split_list(ids,5000)
24
-
25
- for txt_chunk,ids_chunk in zip(txt,ids):
26
- collection.add(documents=txt_chunk,ids=ids_chunk)
27
-
 
 
28
  return collection
 
1
+ from sentence_transformers import SentenceTransformer
2
+ import chromadb
3
+
4
+ def split_list(list_,chunk_size):
5
+ return [list_[i:i+chunk_size] for i in range(0,len(list_),chunk_size)]
6
+
7
+ def create_database(txt):
8
+ class EmbeddingFn:
9
+ def __init__(self,model_name):
10
+ self.model = SentenceTransformer(model_name)
11
+
12
+ def __call__(self,input):
13
+ return self.model.encode(input).tolist()
14
+
15
+ embedding_fn = EmbeddingFn("sentence-transformers/all-mpnet-base-v2")
16
+
17
+ ids = [str(i) for i in range(len(txt))]
18
+
19
+ chromadb.api.client.SharedSystemClient.clear_system_cache()
20
+
21
+ chroma_cli = chromadb.Client()
22
+ collection = chroma_cli.create_collection("chat-with-docs",embedding_function=embedding_fn)
23
+
24
+ txt = split_list(txt,5000)
25
+ ids = split_list(ids,5000)
26
+
27
+ for txt_chunk,ids_chunk in zip(txt,ids):
28
+ collection.add(documents=txt_chunk,ids=ids_chunk)
29
+
30
  return collection