darpanaswal commited on
Commit
fd6b733
·
verified ·
1 Parent(s): a0a8763

Update cross_encoder_reranking_train.py

Browse files
Files changed (1) hide show
  1. cross_encoder_reranking_train.py +131 -48
cross_encoder_reranking_train.py CHANGED
@@ -13,8 +13,7 @@ from sklearn.metrics.pairwise import cosine_similarity
13
 
14
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
  # Load embedder once
16
- embedder = SentenceTransformer("all-MiniLM-L6-v2")
17
- embedder = embedder.to(device)
18
 
19
 
20
  def embed_text_list(texts):
@@ -62,6 +61,28 @@ def process_single_patent(patent_dict):
62
  "features": rank_by_centrality(top_features),
63
  }
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  def load_json_file(file_path):
66
  """Load JSON data from a file"""
67
  with open(file_path, 'r') as f:
@@ -153,6 +174,22 @@ def extract_text(content_dict, text_type="full"):
153
 
154
  return " ".join(all_text)
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  return ""
158
 
@@ -166,67 +203,113 @@ def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tenso
166
  batch_size = last_hidden_states.shape[0]
167
  return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
168
 
 
 
 
 
169
  def get_detailed_instruct(task_description: str, query: str) -> str:
170
- """Create an instruction-formatted query"""
171
- return f'Instruct: {task_description}\nQuery: {query}'
 
 
 
172
 
173
- def cross_encoder_reranking(query_text, doc_texts, model, tokenizer, batch_size=8, max_length=2048):
174
- """
175
- Rerank document texts based on query text using cross-encoder model
176
-
177
- Parameters:
178
- query_text (str): The query text
179
- doc_texts (list): List of document texts
180
- model: The cross-encoder model
181
- tokenizer: The tokenizer for the model
182
- batch_size (int): Batch size for processing
183
- max_length (int): Maximum sequence length
184
-
185
- Returns:
186
- list: Indices of documents sorted by relevance score (descending)
187
- """
188
- device = next(model.parameters()).device
189
- scores = []
190
 
191
- # Format query with instruction
192
- task_description = 'Re-rank a set of retrieved patents based on their relevance to a given query patent. The task aims to refine the order of patents by evaluating their semantic similarity to the query patent, ensuring that the most relevant patents appear at the top of the list.'
 
 
 
 
 
 
193
 
194
- instructed_query = get_detailed_instruct(task_description, query_text)
195
 
196
- # Process in batches to avoid OOM
197
- for i in tqdm(range(0, len(doc_texts), batch_size), desc="Scoring documents", leave=False):
198
- batch_docs = doc_texts[i:i+batch_size]
199
 
200
- # Prepare input pairs for the batch
201
- input_texts = [instructed_query] + batch_docs
202
 
203
- # Tokenize
204
- with torch.no_grad():
205
- batch_dict = tokenizer(input_texts, max_length=max_length, padding=True,
206
- truncation=True, return_tensors='pt').to(device)
207
 
208
- # Get embeddings
209
- outputs = model(**batch_dict)
210
- embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
211
 
212
- # Normalize embeddings
213
- embeddings = F.normalize(embeddings, p=2, dim=1)
214
 
215
- # Calculate similarity scores between query and documents
216
- batch_scores = (embeddings[0].unsqueeze(0) @ embeddings[1:].T).squeeze(0) * 100
217
- scores.extend(batch_scores.cpu().tolist())
218
 
219
- # Create list of (index, score) tuples for sorting
220
- indexed_scores = list(enumerate(scores))
221
 
222
- # Sort by score in descending order
223
- indexed_scores.sort(key=lambda x: x[1], reverse=True)
224
 
225
- # Return sorted indices
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  return [idx for idx, _ in indexed_scores]
227
 
228
  def main():
229
  base_directory = os.getcwd()
 
230
  parser = argparse.ArgumentParser(description='Re-rank patents using cross-encoder scoring (training queries only)')
231
  parser.add_argument('--pre_ranking', type=str, default='shuffled_pre_ranking.json',
232
  help='Path to pre-ranking JSON file')
@@ -241,7 +324,7 @@ def main():
241
  parser.add_argument('--queries_list', type=str, default='test_queries.json',
242
  help='Path to training queries JSON file')
243
  parser.add_argument('--text_type', type=str, default='TA',
244
- choices=['TA', 'claims', 'description', 'full', 'tac1', 'smart'],
245
  help='Type of text to use for scoring')
246
  parser.add_argument('--model_name', type=str, default='intfloat/e5-large-v2',
247
  help='Name of the cross-encoder model')
@@ -252,7 +335,7 @@ def main():
252
  parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
253
  help='Device to use (cuda/cpu)')
254
  parser.add_argument('--base_dir', type=str,
255
- default=f'{base_directory}/Patent_Retrieval/datasets',
256
  help='Base directory for data files')
257
 
258
  args = parser.parse_args()
 
13
 
14
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
  # Load embedder once
16
+ embedder = SentenceTransformer("sentence-transformers/all-mpnet-base-v2").to(device)
 
17
 
18
 
19
  def embed_text_list(texts):
 
61
  "features": rank_by_centrality(top_features),
62
  }
63
 
64
+ def refined_process_single_patent(patent_dict, top_n=10):
65
+ abstract = patent_dict.get("pa01", "")
66
+ title = patent_dict.get("title", "")
67
+ context = f"{title} {abstract}"
68
+ context_emb = embed_text_list([context])[0]
69
+
70
+ claims = [v for k, v in patent_dict.items() if k.startswith("c-en")]
71
+ paragraphs = [v for k, v in patent_dict.items() if k.startswith("p")]
72
+ features = [v for k, v in patent_dict.get("features", {}).items()]
73
+
74
+ def semantic_rank(items, context_emb):
75
+ embeddings = embed_text_list(items)
76
+ scores = cosine_similarity([context_emb], embeddings)[0]
77
+ ranked_items = [item for item, _ in sorted(zip(items, scores), key=lambda x: x[1], reverse=True)]
78
+ return ranked_items
79
+
80
+ return {
81
+ "claims": semantic_rank(claims, context_emb)[:top_n],
82
+ "paragraphs": semantic_rank(paragraphs, context_emb)[:top_n],
83
+ "features": semantic_rank(features, context_emb)[:top_n],
84
+ }
85
+
86
  def load_json_file(file_path):
87
  """Load JSON data from a file"""
88
  with open(file_path, 'r') as f:
 
174
 
175
  return " ".join(all_text)
176
 
177
+ elif text_type == "smart2":
178
+ filtered_dict = refined_process_single_patent(content_dict)
179
+ all_text = []
180
+ # Context with title and abstract
181
+ if "title" in content_dict:
182
+ all_text.append(content_dict["title"])
183
+ if "pa01" in content_dict:
184
+ all_text.append(content_dict["pa01"])
185
+
186
+ # Add claims, paragraphs, and features
187
+ all_text.extend(filtered_dict["claims"])
188
+ all_text.extend(filtered_dict["paragraphs"])
189
+ all_text.extend(filtered_dict["features"])
190
+
191
+ return " ".join(all_text)
192
+
193
 
194
  return ""
195
 
 
203
  batch_size = last_hidden_states.shape[0]
204
  return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
205
 
206
+ # def get_detailed_instruct(task_description: str, query: str) -> str:
207
+ # """Create an instruction-formatted query"""
208
+ # return f'Instruct: {task_description}\nQuery: {query}'
209
+
210
  def get_detailed_instruct(task_description: str, query: str) -> str:
211
+ return (
212
+ f"Instruct: Evaluate the semantic and technical similarity between two patent documents."
213
+ f" Prioritize highly similar claims, technical implementations, and shared functionalities."
214
+ f"\nQuery: {query}"
215
+ )
216
 
217
+ def hybrid_score(cross_encoder_score, semantic_score, weight_cross=0.7, weight_semantic=0.3):
218
+ return (weight_cross * cross_encoder_score) + (weight_semantic * semantic_score)
219
+
220
+
221
+ # def cross_encoder_reranking(query_text, doc_texts, model, tokenizer, batch_size=8, max_length=2048):
222
+ # """
223
+ # Rerank document texts based on query text using cross-encoder model
224
+
225
+ # Parameters:
226
+ # query_text (str): The query text
227
+ # doc_texts (list): List of document texts
228
+ # model: The cross-encoder model
229
+ # tokenizer: The tokenizer for the model
230
+ # batch_size (int): Batch size for processing
231
+ # max_length (int): Maximum sequence length
 
 
232
 
233
+ # Returns:
234
+ # list: Indices of documents sorted by relevance score (descending)
235
+ # """
236
+ # device = next(model.parameters()).device
237
+ # scores = []
238
+
239
+ # # Format query with instruction
240
+ # task_description = 'Re-rank a set of retrieved patents based on their relevance to a given query patent. The task aims to refine the order of patents by evaluating their semantic similarity to the query patent, ensuring that the most relevant patents appear at the top of the list.'
241
 
242
+ # instructed_query = get_detailed_instruct(task_description, query_text)
243
 
244
+ # # Process in batches to avoid OOM
245
+ # for i in tqdm(range(0, len(doc_texts), batch_size), desc="Scoring documents", leave=False):
246
+ # batch_docs = doc_texts[i:i+batch_size]
247
 
248
+ # # Prepare input pairs for the batch
249
+ # input_texts = [instructed_query] + batch_docs
250
 
251
+ # # Tokenize
252
+ # with torch.no_grad():
253
+ # batch_dict = tokenizer(input_texts, max_length=max_length, padding=True,
254
+ # truncation=True, return_tensors='pt').to(device)
255
 
256
+ # # Get embeddings
257
+ # outputs = model(**batch_dict)
258
+ # embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
259
 
260
+ # # Normalize embeddings
261
+ # embeddings = F.normalize(embeddings, p=2, dim=1)
262
 
263
+ # # Calculate similarity scores between query and documents
264
+ # batch_scores = (embeddings[0].unsqueeze(0) @ embeddings[1:].T).squeeze(0) * 100
265
+ # scores.extend(batch_scores.cpu().tolist())
266
 
267
+ # # Create list of (index, score) tuples for sorting
268
+ # indexed_scores = list(enumerate(scores))
269
 
270
+ # # Sort by score in descending order
271
+ # indexed_scores.sort(key=lambda x: x[1], reverse=True)
272
 
273
+ # # Return sorted indices
274
+ # return [idx for idx, _ in indexed_scores]
275
+
276
+ def cross_encoder_reranking(query_text, doc_texts, model, tokenizer, batch_size=8, max_length=2048):
277
+ device = next(model.parameters()).device
278
+ cross_scores = []
279
+ query_emb = embed_text_list([query_text])[0]
280
+
281
+ instructed_query = get_detailed_instruct("", query_text)
282
+
283
+ for i in tqdm(range(0, len(doc_texts), batch_size), desc="Scoring documents", leave=False):
284
+ batch_docs = doc_texts[i:i+batch_size]
285
+
286
+ input_texts = [instructed_query] + batch_docs
287
+
288
+ with torch.no_grad():
289
+ batch_dict = tokenizer(input_texts, max_length=max_length, padding=True, truncation=True, return_tensors='pt').to(device)
290
+
291
+ outputs = model(**batch_dict)
292
+ embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
293
+ embeddings = F.normalize(embeddings, p=2, dim=1)
294
+
295
+ batch_cross_scores = (embeddings[0].unsqueeze(0) @ embeddings[1:].T).squeeze(0).cpu().numpy()
296
+ cross_scores.extend(batch_cross_scores)
297
+
298
+ # Semantic scores
299
+ doc_embeddings = embed_text_list(doc_texts)
300
+ semantic_scores = cosine_similarity([query_emb], doc_embeddings)[0]
301
+
302
+ # Hybrid scores
303
+ hybrid_scores = [hybrid_score(c, s) for c, s in zip(cross_scores, semantic_scores)]
304
+
305
+ indexed_scores = list(enumerate(hybrid_scores))
306
+ indexed_scores.sort(key=lambda x: x[1], reverse=True)
307
+
308
  return [idx for idx, _ in indexed_scores]
309
 
310
  def main():
311
  base_directory = os.getcwd()
312
+ base_directory += "/Patent_Retrieval"
313
  parser = argparse.ArgumentParser(description='Re-rank patents using cross-encoder scoring (training queries only)')
314
  parser.add_argument('--pre_ranking', type=str, default='shuffled_pre_ranking.json',
315
  help='Path to pre-ranking JSON file')
 
324
  parser.add_argument('--queries_list', type=str, default='test_queries.json',
325
  help='Path to training queries JSON file')
326
  parser.add_argument('--text_type', type=str, default='TA',
327
+ choices=['TA', 'claims', 'description', 'full', 'tac1', 'smart', 'smart2'],
328
  help='Type of text to use for scoring')
329
  parser.add_argument('--model_name', type=str, default='intfloat/e5-large-v2',
330
  help='Name of the cross-encoder model')
 
335
  parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
336
  help='Device to use (cuda/cpu)')
337
  parser.add_argument('--base_dir', type=str,
338
+ default=f'{base_directory}/datasets',
339
  help='Base directory for data files')
340
 
341
  args = parser.parse_args()