sergeipetrov commited on
Commit
36d8b50
·
verified ·
1 Parent(s): 16d11d9

Update src/vector_db.py

Browse files
Files changed (1) hide show
  1. src/vector_db.py +11 -13
src/vector_db.py CHANGED
@@ -17,9 +17,7 @@ class VectorDB:
17
  db_location = ''
18
 
19
  def __init__(self, emb_model, db_location, actions_list_file_path, num_sub_vectors, batch_size):
20
- self.emb_model = emb_model
21
- self.db_location = db_location
22
-
23
  emb_config = AutoConfig.from_pretrained(emb_model)
24
  emb_dimension = emb_config.hidden_size
25
 
@@ -50,7 +48,7 @@ class VectorDB:
50
  pa.field(self.name_column, pa.string())
51
  ]
52
  )
53
- tbl = db.create_table(self.table_name, schema=schema, mode="overwrite")
54
 
55
 
56
  df = pd.read_csv(actions_list_file_path)
@@ -76,23 +74,23 @@ class VectorDB:
76
  tbl.add(df)
77
  except:
78
  print(f"batch {i} was skipped")
 
 
 
79
  print("Vector generation done.")
80
 
81
 
82
- def get_embedding_db_as_pandas(self):
83
- db = lancedb.connect(self.db_location)
84
- tbl = db.open_table(self.table_name)
85
- return tbl.to_pandas()
86
 
87
 
88
 
89
  def retrieve_prefiltered_hits(self, query, k):
90
- db = lancedb.connect(".lancedb")
91
- table = db.open_table(self.table_name)
92
- retriever = SentenceTransformer(self.emb_model)
93
 
94
- query_vec = retriever.encode(query)
95
- documents = table.search(query_vec, vector_column_name=self.vector_column).limit(k).to_list()
96
  names = [doc[self.name_column] for doc in documents]
97
  descriptions = [doc[self.description_column] for doc in documents]
98
 
 
17
  db_location = ''
18
 
19
  def __init__(self, emb_model, db_location, actions_list_file_path, num_sub_vectors, batch_size):
20
+ self.retriever = SentenceTransformer(emb_model)
 
 
21
  emb_config = AutoConfig.from_pretrained(emb_model)
22
  emb_dimension = emb_config.hidden_size
23
 
 
48
  pa.field(self.name_column, pa.string())
49
  ]
50
  )
51
+ tbl = db.create_table(table_name, schema=schema, mode="overwrite")
52
 
53
 
54
  df = pd.read_csv(actions_list_file_path)
 
74
  tbl.add(df)
75
  except:
76
  print(f"batch {i} was skipped")
77
+
78
+ self.db = db
79
+ self.table = tbl
80
  print("Vector generation done.")
81
 
82
 
83
+ # def get_embedding_db_as_pandas(self):
84
+ # db = lancedb.connect(self.db_location)
85
+ # tbl = db.open_table(self.table_name)
86
+ # return tbl.to_pandas()
87
 
88
 
89
 
90
  def retrieve_prefiltered_hits(self, query, k):
 
 
 
91
 
92
+ query_vec = self.retriever.encode(query)
93
+ documents = self.table.search(query_vec, vector_column_name=self.vector_column).limit(k).to_list()
94
  names = [doc[self.name_column] for doc in documents]
95
  descriptions = [doc[self.description_column] for doc in documents]
96