AjayP13 commited on
Commit
2eb9217
·
verified ·
1 Parent(s): db4efbd

Upload instruction_template_retriever.py

Browse files
Files changed (1) hide show
  1. instruction_template_retriever.py +155 -0
instruction_template_retriever.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import json
3
+
4
+ from datasets import load_dataset
5
+ import faiss
6
+ import pandas as pd
7
+ import numpy as np
8
+ import torch
9
+
10
+ from huggingface_hub import hf_hub_download
11
+ from sentence_transformers import SentenceTransformer
12
+ from pooling_coverage import use_gaussian_coverage_pooling
13
+
14
+
15
+ class InstructionTemplateRetriever:
16
+ FINETEMPLATES_REVISION = "831ab22c90f9da011bd972585afdf609f40fa54b"
17
+ RETRIEVAL_EMBEDDING_NAME = "fineinstructions/matching_embedding"
18
+ RETRIEVAL_EMBEDDING_REVISION = "db4efbde126216250ffa5a356663fc7da3bf7856"
19
+
20
+ def __init__(
21
+ self,
22
+ coverage_chunks=10,
23
+ sigma=0.05,
24
+ alpha=1.0,
25
+ nprobe=150,
26
+ ):
27
+ """
28
+ Computes embeddings that cover a document to find relevant
29
+ instruction templates using Gaussian-weighted embeddings that cover
30
+ different parts of the document.
31
+
32
+ Args:
33
+ coverage_chunks (int): The number of equally sized chunks/sections
34
+ to get coverage over the entire document.
35
+ sigma (float): Standard deviation for Gaussian weighting, this
36
+ will essentially control how "wide" / "focused" each chunk is.
37
+ alpha (float): A weighting factor to control how much to balance
38
+ the representation of a single chunk, versus the representation of
39
+ the entire document.
40
+ nprobe (int): The number of probes to use when searching the FAISS
41
+ index (larger is more accurate, but slower).
42
+ """
43
+ self.d = load_dataset(
44
+ "fineinstructions/finetemplates",
45
+ revision=InstructionTemplateRetriever.FINETEMPLATES_REVISION,
46
+ split="full",
47
+ )
48
+ self.m = SentenceTransformer(
49
+ InstructionTemplateRetriever.RETRIEVAL_EMBEDDING_NAME,
50
+ revision=InstructionTemplateRetriever.RETRIEVAL_EMBEDDING_REVISION,
51
+ device="cpu",
52
+ )
53
+ self.m = use_gaussian_coverage_pooling(
54
+ self.m, coverage_chunks=coverage_chunks, sigma=sigma, alpha=alpha
55
+ )
56
+ self.index = faiss.read_index(
57
+ hf_hub_download(
58
+ "fineinstructions/finetemplates",
59
+ "faiss_index/finetemplates.index",
60
+ revision=InstructionTemplateRetriever.FINETEMPLATES_REVISION,
61
+ repo_type="dataset",
62
+ ),
63
+ faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY,
64
+ )
65
+ self.index.nprobe = nprobe
66
+ if torch.cuda.is_available():
67
+ self.m = self.m.to("cuda")
68
+ elif torch.backends.mps.is_available():
69
+ self.m = self.m.to("mps")
70
+
71
+ def _filter_rows(self, rows, filter_string):
72
+ if not rows:
73
+ return []
74
+ df = pd.DataFrame(rows)
75
+ try:
76
+ filtered_df = df.query(filter_string)
77
+ return filtered_df.to_dict(orient="records")
78
+ except Exception as e:
79
+ return rows
80
+
81
+ def search(
82
+ self, document, filters="", search_k=20000, max_results=250, deduplicate=True
83
+ ):
84
+ """
85
+ Given a document
86
+
87
+ Args:
88
+ document (str): The document to retrieve relevant instruction templates for.
89
+ filters (str): A query string in the format of pandas.DataFrame.query()
90
+ search_k (int): The number of search results to pull when retrieving from FAISS.
91
+ max_results (int): The max number of results to return.
92
+ deduplicate (bool): Deduplicate results between coverage sections.
93
+ """
94
+
95
+ # Search FAISS index
96
+ vecs = self.m.encode([document], normalize_embeddings=False).reshape(
97
+ -1, self.m[0].auto_model.config.hidden_size
98
+ )
99
+ scores_batch, indices_batch = self.index.search(np.vstack(vecs), k=search_k)
100
+
101
+ # Pull in FineTemplates rows into memory
102
+ to_select = [i.item() for i in itertools.chain.from_iterable(indices_batch)]
103
+ d_in_mem = {
104
+ i: row for i, row in zip(to_select, self.d.select(to_select).to_list())
105
+ }
106
+
107
+ # Group by coverage chunk
108
+ true_coverage_chunks = self.m[1].coverage_chunks + 1
109
+ scores_per_input, indices_per_input = (
110
+ [
111
+ scores_batch[i : i + true_coverage_chunks]
112
+ for i in range(0, len(scores_batch), true_coverage_chunks)
113
+ ],
114
+ [
115
+ indices_batch[i : i + true_coverage_chunks]
116
+ for i in range(0, len(indices_batch), true_coverage_chunks)
117
+ ],
118
+ )
119
+
120
+ # Get the results for the first result in the batch (assuming bz=1)
121
+ scores_per_input, indices_per_input = scores_per_input[0], indices_per_input[0]
122
+
123
+ # Create result rows
124
+ rows = [
125
+ [
126
+ {
127
+ "coverage_section": f"{chunk_idx}/{self.m[1].coverage_chunks}"
128
+ if chunk_idx > 0
129
+ else "Entire Document",
130
+ "score": s.item(),
131
+ **d_in_mem[i.item()],
132
+ }
133
+ for i, s in zip(indices, scores)
134
+ ]
135
+ for chunk_idx, (indices, scores) in enumerate(
136
+ zip(indices_per_input, scores_per_input)
137
+ )
138
+ ]
139
+
140
+ # Deduplicate
141
+ if deduplicate:
142
+ seen = set()
143
+ rows = [
144
+ r
145
+ for r in itertools.chain.from_iterable(zip(*rows))
146
+ if (len(seen) != len(seen.add(r["template_id"]) or seen))
147
+ ]
148
+ else:
149
+ rows = list(itertools.chain.from_iterable(zip(*rows)))
150
+
151
+ # Filter
152
+ rows = self._filter_rows(rows, filters)[:max_results]
153
+
154
+ # Return rows
155
+ return rows