AjayP13 commited on
Commit
5a402b7
·
verified ·
1 Parent(s): 06f9264

Update instruction_template_retriever.py

Browse files
Files changed (1) hide show
  1. instruction_template_retriever.py +54 -3
instruction_template_retriever.py CHANGED
@@ -1,5 +1,7 @@
1
  import itertools
2
  import json
 
 
3
 
4
  from datasets import load_dataset
5
  import faiss
@@ -168,7 +170,9 @@ def unuse_gaussian_coverage_pooling(m):
168
 
169
  class InstructionTemplateRetriever:
170
  FINETEMPLATES_REVISION = "831ab22c90f9da011bd972585afdf609f40fa54b"
171
- RETRIEVAL_EMBEDDING_NAME = "fineinstructions/instruction_template_retrieval_embedding"
 
 
172
  RETRIEVAL_EMBEDDING_REVISION = "db4efbde126216250ffa5a356663fc7da3bf7856"
173
 
174
  def __init__(
@@ -222,6 +226,21 @@ class InstructionTemplateRetriever:
222
  elif torch.backends.mps.is_available():
223
  self.m = self.m.to("mps")
224
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  def _filter_rows(self, rows, filter_string):
226
  if not rows:
227
  return []
@@ -233,7 +252,14 @@ class InstructionTemplateRetriever:
233
  return rows
234
 
235
  def search(
236
- self, document, filters="", search_k=20000, max_results=250, deduplicate=True
 
 
 
 
 
 
 
237
  ):
238
  """
239
  Given a document
@@ -246,6 +272,31 @@ class InstructionTemplateRetriever:
246
  deduplicate (bool): Deduplicate results between coverage sections.
247
  """
248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  # Search FAISS index
250
  vecs = self.m.encode([document], normalize_embeddings=False).reshape(
251
  -1, self.m[0].auto_model.config.hidden_size
@@ -284,7 +335,7 @@ class InstructionTemplateRetriever:
284
  "score": s.item(),
285
  **d_in_mem[i.item()],
286
  }
287
- for i, s in zip(indices, scores)
288
  ]
289
  for chunk_idx, (indices, scores) in enumerate(
290
  zip(indices_per_input, scores_per_input)
 
1
  import itertools
2
  import json
3
+ import pickle
4
+ from random import Random
5
 
6
  from datasets import load_dataset
7
  import faiss
 
170
 
171
  class InstructionTemplateRetriever:
172
  FINETEMPLATES_REVISION = "831ab22c90f9da011bd972585afdf609f40fa54b"
173
+ RETRIEVAL_EMBEDDING_NAME = (
174
+ "fineinstructions/instruction_template_retrieval_embedding"
175
+ )
176
  RETRIEVAL_EMBEDDING_REVISION = "db4efbde126216250ffa5a356663fc7da3bf7856"
177
 
178
  def __init__(
 
226
  elif torch.backends.mps.is_available():
227
  self.m = self.m.to("mps")
228
 
229
+ with open(
230
+ hf_hub_download(
231
+ "fineinstructions/finetemplates",
232
+ "faiss_index/reweighting_stats.pkl",
233
+ revision=FINETEMPLATES_REVISION,
234
+ repo_type="dataset",
235
+ ),
236
+ "rb",
237
+ ) as reweighting_stats_fp:
238
+ reweighting_stats = pickle.load(reweighting_stats_fp)
239
+ self.resampling_weights = reweighting_stats["resampling_weights"]
240
+ self.template_variable_count_mapping = reweighting_stats[
241
+ "template_variable_count_mapping"
242
+ ]
243
+
244
  def _filter_rows(self, rows, filter_string):
245
  if not rows:
246
  return []
 
252
  return rows
253
 
254
  def search(
255
+ self,
256
+ document,
257
+ filters="",
258
+ search_k=20000,
259
+ max_results=250,
260
+ deduplicate=True,
261
+ reweight=False,
262
+ reweighting_epsilon=True,
263
  ):
264
  """
265
  Given a document
 
272
  deduplicate (bool): Deduplicate results between coverage sections.
273
  """
274
 
275
+ def _reweight(inp, k=None):
276
+ if reweight:
277
+ inp0, inp = itertools.tee(inp)
278
+ first_row = next(inp0)
279
+ r = Random(first_row[1].item())
280
+ epsilon = reweighting_epsilon
281
+ bucket = first_row[1]
282
+ items = []
283
+ weights = []
284
+ for i, s in inp:
285
+ if abs(bucket - s.item()) <= epsilon:
286
+ items.append((i, s))
287
+ weights.append(
288
+ self.resampling_weights[
289
+ self.template_variable_count_mapping[i.item()]
290
+ ]
291
+ )
292
+ else:
293
+ break
294
+ return r.choices(
295
+ items, weights=weights, k=(len(items) if k is None else k)
296
+ )
297
+ else:
298
+ return inp
299
+
300
  # Search FAISS index
301
  vecs = self.m.encode([document], normalize_embeddings=False).reshape(
302
  -1, self.m[0].auto_model.config.hidden_size
 
335
  "score": s.item(),
336
  **d_in_mem[i.item()],
337
  }
338
+ for i, s in _reweight(zip(indices, scores), k=None)
339
  ]
340
  for chunk_idx, (indices, scores) in enumerate(
341
  zip(indices_per_input, scores_per_input)