Spaces:
Configuration error
Configuration error
Update cross_encoder_reranking_train.py
Browse files- cross_encoder_reranking_train.py +131 -48
cross_encoder_reranking_train.py
CHANGED
@@ -13,8 +13,7 @@ from sklearn.metrics.pairwise import cosine_similarity
|
|
13 |
|
14 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
15 |
# Load embedder once
|
16 |
-
embedder = SentenceTransformer("all-
|
17 |
-
embedder = embedder.to(device)
|
18 |
|
19 |
|
20 |
def embed_text_list(texts):
|
@@ -62,6 +61,28 @@ def process_single_patent(patent_dict):
|
|
62 |
"features": rank_by_centrality(top_features),
|
63 |
}
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
def load_json_file(file_path):
|
66 |
"""Load JSON data from a file"""
|
67 |
with open(file_path, 'r') as f:
|
@@ -153,6 +174,22 @@ def extract_text(content_dict, text_type="full"):
|
|
153 |
|
154 |
return " ".join(all_text)
|
155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
|
157 |
return ""
|
158 |
|
@@ -166,67 +203,113 @@ def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tenso
|
|
166 |
batch_size = last_hidden_states.shape[0]
|
167 |
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
|
168 |
|
|
|
|
|
|
|
|
|
169 |
def get_detailed_instruct(task_description: str, query: str) -> str:
|
170 |
-
|
171 |
-
|
|
|
|
|
|
|
172 |
|
173 |
-
def
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
device = next(model.parameters()).device
|
189 |
-
scores = []
|
190 |
|
191 |
-
|
192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
|
194 |
-
|
195 |
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
|
200 |
-
|
201 |
-
|
202 |
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
|
212 |
-
|
213 |
-
|
214 |
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
|
219 |
-
|
220 |
-
|
221 |
|
222 |
-
|
223 |
-
|
224 |
|
225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
return [idx for idx, _ in indexed_scores]
|
227 |
|
228 |
def main():
|
229 |
base_directory = os.getcwd()
|
|
|
230 |
parser = argparse.ArgumentParser(description='Re-rank patents using cross-encoder scoring (training queries only)')
|
231 |
parser.add_argument('--pre_ranking', type=str, default='shuffled_pre_ranking.json',
|
232 |
help='Path to pre-ranking JSON file')
|
@@ -241,7 +324,7 @@ def main():
|
|
241 |
parser.add_argument('--queries_list', type=str, default='test_queries.json',
|
242 |
help='Path to training queries JSON file')
|
243 |
parser.add_argument('--text_type', type=str, default='TA',
|
244 |
-
choices=['TA', 'claims', 'description', 'full', 'tac1', 'smart'],
|
245 |
help='Type of text to use for scoring')
|
246 |
parser.add_argument('--model_name', type=str, default='intfloat/e5-large-v2',
|
247 |
help='Name of the cross-encoder model')
|
@@ -252,7 +335,7 @@ def main():
|
|
252 |
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
|
253 |
help='Device to use (cuda/cpu)')
|
254 |
parser.add_argument('--base_dir', type=str,
|
255 |
-
default=f'{base_directory}/
|
256 |
help='Base directory for data files')
|
257 |
|
258 |
args = parser.parse_args()
|
|
|
13 |
|
14 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
15 |
# Load embedder once
|
16 |
+
embedder = SentenceTransformer("sentence-transformers/all-mpnet-base-v2").to(device)
|
|
|
17 |
|
18 |
|
19 |
def embed_text_list(texts):
|
|
|
61 |
"features": rank_by_centrality(top_features),
|
62 |
}
|
63 |
|
64 |
+
def refined_process_single_patent(patent_dict, top_n=10):
|
65 |
+
abstract = patent_dict.get("pa01", "")
|
66 |
+
title = patent_dict.get("title", "")
|
67 |
+
context = f"{title} {abstract}"
|
68 |
+
context_emb = embed_text_list([context])[0]
|
69 |
+
|
70 |
+
claims = [v for k, v in patent_dict.items() if k.startswith("c-en")]
|
71 |
+
paragraphs = [v for k, v in patent_dict.items() if k.startswith("p")]
|
72 |
+
features = [v for k, v in patent_dict.get("features", {}).items()]
|
73 |
+
|
74 |
+
def semantic_rank(items, context_emb):
|
75 |
+
embeddings = embed_text_list(items)
|
76 |
+
scores = cosine_similarity([context_emb], embeddings)[0]
|
77 |
+
ranked_items = [item for item, _ in sorted(zip(items, scores), key=lambda x: x[1], reverse=True)]
|
78 |
+
return ranked_items
|
79 |
+
|
80 |
+
return {
|
81 |
+
"claims": semantic_rank(claims, context_emb)[:top_n],
|
82 |
+
"paragraphs": semantic_rank(paragraphs, context_emb)[:top_n],
|
83 |
+
"features": semantic_rank(features, context_emb)[:top_n],
|
84 |
+
}
|
85 |
+
|
86 |
def load_json_file(file_path):
|
87 |
"""Load JSON data from a file"""
|
88 |
with open(file_path, 'r') as f:
|
|
|
174 |
|
175 |
return " ".join(all_text)
|
176 |
|
177 |
+
elif text_type == "smart2":
|
178 |
+
filtered_dict = refined_process_single_patent(content_dict)
|
179 |
+
all_text = []
|
180 |
+
# Context with title and abstract
|
181 |
+
if "title" in content_dict:
|
182 |
+
all_text.append(content_dict["title"])
|
183 |
+
if "pa01" in content_dict:
|
184 |
+
all_text.append(content_dict["pa01"])
|
185 |
+
|
186 |
+
# Add claims, paragraphs, and features
|
187 |
+
all_text.extend(filtered_dict["claims"])
|
188 |
+
all_text.extend(filtered_dict["paragraphs"])
|
189 |
+
all_text.extend(filtered_dict["features"])
|
190 |
+
|
191 |
+
return " ".join(all_text)
|
192 |
+
|
193 |
|
194 |
return ""
|
195 |
|
|
|
203 |
batch_size = last_hidden_states.shape[0]
|
204 |
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
|
205 |
|
206 |
+
# def get_detailed_instruct(task_description: str, query: str) -> str:
|
207 |
+
# """Create an instruction-formatted query"""
|
208 |
+
# return f'Instruct: {task_description}\nQuery: {query}'
|
209 |
+
|
210 |
def get_detailed_instruct(task_description: str, query: str) -> str:
|
211 |
+
return (
|
212 |
+
f"Instruct: Evaluate the semantic and technical similarity between two patent documents."
|
213 |
+
f" Prioritize highly similar claims, technical implementations, and shared functionalities."
|
214 |
+
f"\nQuery: {query}"
|
215 |
+
)
|
216 |
|
217 |
+
def hybrid_score(cross_encoder_score, semantic_score, weight_cross=0.7, weight_semantic=0.3):
|
218 |
+
return (weight_cross * cross_encoder_score) + (weight_semantic * semantic_score)
|
219 |
+
|
220 |
+
|
221 |
+
# def cross_encoder_reranking(query_text, doc_texts, model, tokenizer, batch_size=8, max_length=2048):
|
222 |
+
# """
|
223 |
+
# Rerank document texts based on query text using cross-encoder model
|
224 |
+
|
225 |
+
# Parameters:
|
226 |
+
# query_text (str): The query text
|
227 |
+
# doc_texts (list): List of document texts
|
228 |
+
# model: The cross-encoder model
|
229 |
+
# tokenizer: The tokenizer for the model
|
230 |
+
# batch_size (int): Batch size for processing
|
231 |
+
# max_length (int): Maximum sequence length
|
|
|
|
|
232 |
|
233 |
+
# Returns:
|
234 |
+
# list: Indices of documents sorted by relevance score (descending)
|
235 |
+
# """
|
236 |
+
# device = next(model.parameters()).device
|
237 |
+
# scores = []
|
238 |
+
|
239 |
+
# # Format query with instruction
|
240 |
+
# task_description = 'Re-rank a set of retrieved patents based on their relevance to a given query patent. The task aims to refine the order of patents by evaluating their semantic similarity to the query patent, ensuring that the most relevant patents appear at the top of the list.'
|
241 |
|
242 |
+
# instructed_query = get_detailed_instruct(task_description, query_text)
|
243 |
|
244 |
+
# # Process in batches to avoid OOM
|
245 |
+
# for i in tqdm(range(0, len(doc_texts), batch_size), desc="Scoring documents", leave=False):
|
246 |
+
# batch_docs = doc_texts[i:i+batch_size]
|
247 |
|
248 |
+
# # Prepare input pairs for the batch
|
249 |
+
# input_texts = [instructed_query] + batch_docs
|
250 |
|
251 |
+
# # Tokenize
|
252 |
+
# with torch.no_grad():
|
253 |
+
# batch_dict = tokenizer(input_texts, max_length=max_length, padding=True,
|
254 |
+
# truncation=True, return_tensors='pt').to(device)
|
255 |
|
256 |
+
# # Get embeddings
|
257 |
+
# outputs = model(**batch_dict)
|
258 |
+
# embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
|
259 |
|
260 |
+
# # Normalize embeddings
|
261 |
+
# embeddings = F.normalize(embeddings, p=2, dim=1)
|
262 |
|
263 |
+
# # Calculate similarity scores between query and documents
|
264 |
+
# batch_scores = (embeddings[0].unsqueeze(0) @ embeddings[1:].T).squeeze(0) * 100
|
265 |
+
# scores.extend(batch_scores.cpu().tolist())
|
266 |
|
267 |
+
# # Create list of (index, score) tuples for sorting
|
268 |
+
# indexed_scores = list(enumerate(scores))
|
269 |
|
270 |
+
# # Sort by score in descending order
|
271 |
+
# indexed_scores.sort(key=lambda x: x[1], reverse=True)
|
272 |
|
273 |
+
# # Return sorted indices
|
274 |
+
# return [idx for idx, _ in indexed_scores]
|
275 |
+
|
276 |
+
def cross_encoder_reranking(query_text, doc_texts, model, tokenizer, batch_size=8, max_length=2048):
|
277 |
+
device = next(model.parameters()).device
|
278 |
+
cross_scores = []
|
279 |
+
query_emb = embed_text_list([query_text])[0]
|
280 |
+
|
281 |
+
instructed_query = get_detailed_instruct("", query_text)
|
282 |
+
|
283 |
+
for i in tqdm(range(0, len(doc_texts), batch_size), desc="Scoring documents", leave=False):
|
284 |
+
batch_docs = doc_texts[i:i+batch_size]
|
285 |
+
|
286 |
+
input_texts = [instructed_query] + batch_docs
|
287 |
+
|
288 |
+
with torch.no_grad():
|
289 |
+
batch_dict = tokenizer(input_texts, max_length=max_length, padding=True, truncation=True, return_tensors='pt').to(device)
|
290 |
+
|
291 |
+
outputs = model(**batch_dict)
|
292 |
+
embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
|
293 |
+
embeddings = F.normalize(embeddings, p=2, dim=1)
|
294 |
+
|
295 |
+
batch_cross_scores = (embeddings[0].unsqueeze(0) @ embeddings[1:].T).squeeze(0).cpu().numpy()
|
296 |
+
cross_scores.extend(batch_cross_scores)
|
297 |
+
|
298 |
+
# Semantic scores
|
299 |
+
doc_embeddings = embed_text_list(doc_texts)
|
300 |
+
semantic_scores = cosine_similarity([query_emb], doc_embeddings)[0]
|
301 |
+
|
302 |
+
# Hybrid scores
|
303 |
+
hybrid_scores = [hybrid_score(c, s) for c, s in zip(cross_scores, semantic_scores)]
|
304 |
+
|
305 |
+
indexed_scores = list(enumerate(hybrid_scores))
|
306 |
+
indexed_scores.sort(key=lambda x: x[1], reverse=True)
|
307 |
+
|
308 |
return [idx for idx, _ in indexed_scores]
|
309 |
|
310 |
def main():
|
311 |
base_directory = os.getcwd()
|
312 |
+
base_directory += "/Patent_Retrieval"
|
313 |
parser = argparse.ArgumentParser(description='Re-rank patents using cross-encoder scoring (training queries only)')
|
314 |
parser.add_argument('--pre_ranking', type=str, default='shuffled_pre_ranking.json',
|
315 |
help='Path to pre-ranking JSON file')
|
|
|
324 |
parser.add_argument('--queries_list', type=str, default='test_queries.json',
|
325 |
help='Path to training queries JSON file')
|
326 |
parser.add_argument('--text_type', type=str, default='TA',
|
327 |
+
choices=['TA', 'claims', 'description', 'full', 'tac1', 'smart', 'smart2'],
|
328 |
help='Type of text to use for scoring')
|
329 |
parser.add_argument('--model_name', type=str, default='intfloat/e5-large-v2',
|
330 |
help='Name of the cross-encoder model')
|
|
|
335 |
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
|
336 |
help='Device to use (cuda/cpu)')
|
337 |
parser.add_argument('--base_dir', type=str,
|
338 |
+
default=f'{base_directory}/datasets',
|
339 |
help='Base directory for data files')
|
340 |
|
341 |
args = parser.parse_args()
|