Spaces:
Configuration error
Configuration error
File size: 18,073 Bytes
3b98ef0 4d1a2d7 3b98ef0 4d1a2d7 3b98ef0 4d1a2d7 3b98ef0 a0a8763 3b98ef0 19561c9 44d07af 10e48ed 3b98ef0 d807d27 3b98ef0 101821f a0e2313 3b98ef0 92e8f21 3b98ef0 10e48ed 3b98ef0 ec30c35 3b98ef0 f9ccf3d 0a92332 4202987 37169a9 f9ccf3d f8a16e0 37169a9 f8a16e0 3b98ef0 5a33eb6 c9f0444 3b98ef0 ec30c35 3b98ef0 0aa7730 3b98ef0 ec30c35 3c1b709 ec30c35 3b98ef0 b4bffe5 3836d32 b4bffe5 3b98ef0 b4bffe5 3b98ef0 b4bffe5 3b98ef0 b4bffe5 3b98ef0 b4bffe5 3b98ef0 b4bffe5 3b98ef0 b4bffe5 3a53014 3b98ef0 b4bffe5 3b98ef0 b4bffe5 3b98ef0 b4bffe5 3b98ef0 b4bffe5 fd6b733 b4bffe5 3b98ef0 0bc2c49 10e48ed 3b98ef0 19e7ebd 3b98ef0 b4bffe5 f2bf451 3b98ef0 f8a16e0 3b98ef0 10e48ed 3b98ef0 b4bffe5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 |
import os
import json
import torch
import argparse
import numpy as np
from tqdm import tqdm
from torch import Tensor
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from sklearn.cluster import AgglomerativeClustering
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load embedder once
embedder = SentenceTransformer("all-MiniLM-L6-v2").to(device)
# embedder = SentenceTransformer("AI-Growth-Lab/PatentSBERTa").to(device)
def embed_text_list(texts):
return embedder.encode(texts, convert_to_tensor=False, device=device)
# def embed_text_list(texts):
# # E5 models expect "query: " prefix for proper embedding behavior
# formatted_texts = [f"query: {text}" for text in texts]
# return embedder.encode(formatted_texts, convert_to_tensor=False, device=device)
def rank_by_centrality(texts):
embeddings = embed_text_list(texts)
similarity_matrix = cosine_similarity(embeddings)
centrality_scores = similarity_matrix.mean(axis=1)
ranked = sorted(zip(texts, centrality_scores), key=lambda x: x[1], reverse=True)
return [text for text, _ in ranked]
def cluster_and_rank(texts, threshold=0.75):
if len(texts) < 2:
return texts
embeddings = embed_text_list(texts)
clustering = AgglomerativeClustering(n_clusters=None, distance_threshold=1-threshold, metric = "cosine", linkage='average')
labels = clustering.fit_predict(embeddings)
clustered_texts = {}
for label, text in zip(labels, texts):
clustered_texts.setdefault(label, []).append(text)
representative_texts = []
for cluster_texts in clustered_texts.values():
ranked = rank_by_centrality(cluster_texts)
representative_texts.append(ranked[0]) # Choose most central per cluster
return representative_texts
def process_single_patent(patent_dict):
def filter_short_texts(texts, min_tokens=5):
return [text for text in texts if len(text.split()) >= min_tokens]
claims = filter_short_texts([v for k, v in patent_dict.items() if k.startswith("c-en")])
paragraphs = filter_short_texts([v for k, v in patent_dict.items() if k.startswith("p")])
features = filter_short_texts([v for k, v in patent_dict.get("features", {}).items()])
# Cluster & rank
top_claims = cluster_and_rank(claims)
top_paragraphs = cluster_and_rank(paragraphs)
top_features = cluster_and_rank(features)
return {
"claims": rank_by_centrality(top_claims),
"paragraphs": rank_by_centrality(top_paragraphs),
"features": rank_by_centrality(top_features),
}
def process_single_patent2(patent_dict):
def filter_short_texts(texts, min_tokens=5):
return [text for text in texts if len(text.split()) >= min_tokens]
# Filter short texts
claims = filter_short_texts([v for k, v in patent_dict.items() if k.startswith("c-en")])
paragraphs = filter_short_texts([v for k, v in patent_dict.items() if k.startswith("p")])
features = filter_short_texts([v for k, v in patent_dict.get("features", {}).items()])
# Re-rank claims and features directly
ranked_claims = rank_by_centrality(claims)
ranked_features = rank_by_centrality(features)
# Only filter (cluster + rank) for paragraphs
filtered_paragraphs = cluster_and_rank(paragraphs)
ranked_paragraphs = rank_by_centrality(filtered_paragraphs)
return {
"claims": ranked_claims,
"paragraphs": ranked_paragraphs,
"features": ranked_features,
}
def load_json_file(file_path):
"""Load JSON data from a file"""
with open(file_path, 'r') as f:
return json.load(f)
def save_json_file(data, file_path):
"""Save data to a JSON file"""
with open(file_path, 'w') as f:
json.dump(data, f, indent=2)
def load_content_data(file_path):
"""Load content data from a JSON file"""
with open(file_path, 'r') as f:
data = json.load(f)
# Create a dictionary mapping FAN to Content
content_dict = {item['FAN']: item['Content'] for item in data}
return content_dict
def extract_text(content_dict, text_type="full"):
"""Extract text from patent content based on text_type"""
if text_type == "TA" or text_type == "title_abstract":
# Extract title and abstract
title = content_dict.get("title", "")
abstract = content_dict.get("pa01", "")
return f"{title} {abstract}".strip()
elif text_type == "claims":
# Extract all claims (keys starting with 'c')
claims = []
for key, value in content_dict.items():
if key.startswith('c-'):
claims.append(value)
return " ".join(claims)
elif text_type == "claimfeat":
# Extract all claims (keys starting with 'c')
content = []
for key, value in content_dict.items():
if key.startswith('c-'):
content.append(value)
if key == "features":
content += list(content_dict[key].values())
return " ".join(content)
elif text_type == "feat":
# Extract all claims (keys starting with 'c')
content = []
for key, value in content_dict.items():
if key == "features":
content += list(content_dict[key].values())
return " ".join(content)
elif text_type == "tac1":
# Extract title, abstract, and first claim
title = content_dict.get("title", "")
abstract = content_dict.get("pa01", "")
# Find the first claim safely
first_claim = ""
for key, value in content_dict.items():
if key.startswith('c-'):
first_claim = value
break
return f"{title} {abstract} {first_claim}".strip()
elif text_type == "description":
# Extract all paragraphs (keys starting with 'p')
paragraphs = []
for key, value in content_dict.items():
if key.startswith('p'):
paragraphs.append(value)
return " ".join(paragraphs)
elif text_type == "full":
# Extract everything
all_text = []
# Start with title and abstract for better context at the beginning
# if "title" in content_dict:
# all_text.append(content_dict["title"])
# if "pa01" in content_dict:
# all_text.append(content_dict["pa01"])
# Add claims and description
for key, value in content_dict.items():
if key != "title" and key != "pa01":
all_text.append(value)
return " ".join(all_text)
elif text_type == "smart":
filtered_dict = process_single_patent(content_dict)
all_text = []
# Start with abstract for better context at the beginning
if "pa01" in content_dict:
all_text.append(content_dict["pa01"])
# For claims, paragraphs and features, we take only the top-10 most relevant
# Add claims
# for claim in filtered_dict["claims"][:10]:
# all_text.append(claim)
# # Add features
# for feature in filtered_dict["features"][:10]:
# all_text.append(feature)
# Add paragraphs
for paragraph in filtered_dict["paragraphs"][:10]:
all_text.append(paragraph)
return " ".join(all_text)
elif text_type == "smart2":
filtered_dict = process_single_patent2(content_dict)
all_text = []
# Start with abstract for better context at the beginning
if "pa01" in content_dict:
all_text.append(content_dict["pa01"])
# For claims, paragraphs and features, we take only the top-10 most relevant
# Add claims
for claim in filtered_dict["claims"][:10]:
all_text.append(claim)
# Add features
for feature in filtered_dict["features"][:10]:
all_text.append(feature)
# # Add paragraphs
# for paragraph in filtered_dict["paragraphs"][:10]:
# all_text.append(paragraph)
return " ".join(all_text)
return ""
def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
"""Extract the last token representations for pooling"""
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return last_hidden_states[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_states.shape[0]
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
def get_detailed_instruct(task_description: str, query: str) -> str:
"""Create an instruction-formatted query"""
return f'Instruct: {task_description}\nQuery: {query}'
def cross_encoder_reranking(query_text, doc_texts, model, tokenizer, batch_size=64, max_length=2048):
"""
Rerank document texts based on query text using cross-encoder model
Parameters:
query_text (str): The query text
doc_texts (list): List of document texts
model: The cross-encoder model
tokenizer: The tokenizer for the model
batch_size (int): Batch size for processing
max_length (int): Maximum sequence length
Returns:
list: Indices of documents sorted by relevance score (descending)
"""
device = next(model.parameters()).device
scores = []
# Format query with instruction
task_description = 'Re-rank a set of retrieved patents based on their relevance to a given query patent. The task aims to refine the order of patents by evaluating their semantic similarity to the query patent, ensuring that the most relevant patents appear at the top of the list.'
instructed_query = get_detailed_instruct(task_description, query_text)
# Process in batches to avoid OOM
for i in tqdm(range(0, len(doc_texts), batch_size), desc="Scoring documents", leave=False):
batch_docs = doc_texts[i:i+batch_size]
# Prepare input pairs for the batch
input_texts = [instructed_query] + batch_docs
# Tokenize
with torch.no_grad():
batch_dict = tokenizer(input_texts, max_length=max_length, padding=True,
truncation=True, return_tensors='pt').to(device)
# Get embeddings
outputs = model(**batch_dict)
embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
# Normalize embeddings
embeddings = F.normalize(embeddings, p=2, dim=1)
# Calculate similarity scores between query and documents
batch_scores = (embeddings[0].unsqueeze(0) @ embeddings[1:].T).squeeze(0) * 100
scores.extend(batch_scores.cpu().tolist())
# Create list of (index, score) tuples for sorting
indexed_scores = list(enumerate(scores))
# Sort by score in descending order
indexed_scores.sort(key=lambda x: x[1], reverse=True)
# Return sorted indices
return [idx for idx, _ in indexed_scores]
def main():
base_directory = os.getcwd()
base_directory += "/Patent_Retrieval"
parser = argparse.ArgumentParser(description='Re-rank patents using cross-encoder scoring (training queries only)')
parser.add_argument('--pre_ranking', type=str, default='shuffled_pre_ranking.json',
help='Path to pre-ranking JSON file')
parser.add_argument('--output', type=str, default='prediction2.json',
help='Path to output re-ranked JSON file')
parser.add_argument('--queries_content', type=str,
default='./queries_content_with_features.json',
help='Path to queries content JSON file')
parser.add_argument('--documents_content', type=str,
default='./documents_content_with_features.json',
help='Path to documents content JSON file')
# Change here for test or train
parser.add_argument('--queries_list', type=str, default='test_queries.json',
help='Path to training queries JSON file')
parser.add_argument('--text_type', type=str, default='TA',
choices=['TA', 'claims', 'description', 'full', 'tac1', 'smart', 'smart2', 'claimfeat', 'feat'],
help='Type of text to use for scoring')
parser.add_argument('--model_name', type=str, default='intfloat/e5-large-v2',
help='Name of the cross-encoder model')
parser.add_argument('--batch_size', type=int, default=4,
help='Batch size for scoring')
parser.add_argument('--max_length', type=int, default=512,
help='Maximum sequence length')
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
help='Device to use (cuda/cpu)')
parser.add_argument('--base_dir', type=str,
default=f'{base_directory}/datasets',
help='Base directory for data files')
args = parser.parse_args()
# Ensure all paths are relative to base_dir if they're not absolute
def get_full_path(path):
if os.path.isabs(path):
return path
return os.path.join(args.base_dir, path)
# Load training queries
print(f"Loading training queries from {args.queries_list}...")
queries_list = load_json_file(get_full_path(args.queries_list))
print(f"Loaded {len(queries_list)} training queries")
# Load pre-ranking data
print(f"Loading pre-ranking data from {args.pre_ranking}...")
pre_ranking = load_json_file(get_full_path(args.pre_ranking))
# Filter pre-ranking to include only training queries
pre_ranking = {fan: docs for fan, docs in pre_ranking.items() if fan in queries_list}
print(f"Filtered pre-ranking to {len(pre_ranking)} training queries")
# Load content data
print(f"Loading query content from {args.queries_content}...")
queries_content = load_content_data(get_full_path(args.queries_content))
print(f"Loading document content from {args.documents_content}...")
documents_content = load_content_data(get_full_path(args.documents_content))
# Load model and tokenizer
print(f"Loading model {args.model_name}...")
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
model = AutoModel.from_pretrained(args.model_name).to(args.device)
model.eval()
# Process each query and re-rank its documents
print("Starting re-ranking process for training queries...")
re_ranked = {}
missing_query_fans = []
missing_doc_fans = {}
for query_fan, pre_ranked_docs in tqdm(pre_ranking.items(), desc="Processing queries"):
# Check if query FAN exists in our content data
if query_fan not in queries_content:
missing_query_fans.append(query_fan)
continue
# Extract query text
query_text = extract_text(queries_content[query_fan], args.text_type)
if not query_text:
missing_query_fans.append(query_fan)
continue
# Prepare document texts and keep track of their fans
doc_texts = []
doc_fans = []
missing_docs_for_query = []
for doc_fan in pre_ranked_docs:
if doc_fan not in documents_content:
missing_docs_for_query.append(doc_fan)
continue
doc_text = extract_text(documents_content[doc_fan], args.text_type)
if doc_text:
doc_texts.append(doc_text)
doc_fans.append(doc_fan)
# Keep track of missing documents
if missing_docs_for_query:
missing_doc_fans[query_fan] = missing_docs_for_query
# Skip if no valid documents
if not doc_texts:
re_ranked[query_fan] = []
continue
# Re-rank documents
print(f"\nRe-ranking {len(doc_texts)} documents for training query {query_fan}")
# Print some of the original pre-ranking order for debugging
print(f"Original pre-ranking (first 3): {doc_fans[:3]}")
# Use cross-encoder model for reranking
sorted_indices = cross_encoder_reranking(
query_text, doc_texts, model, tokenizer,
batch_size=args.batch_size, max_length=args.max_length
)
re_ranked[query_fan] = [doc_fans[i] for i in sorted_indices]
# Report any missing FANs
if missing_query_fans:
print(f"Warning: {len(missing_query_fans)} query FANs were not found in the content data")
if missing_doc_fans:
total_missing = sum(len(docs) for docs in missing_doc_fans.values())
print(f"Warning: {total_missing} document FANs were not found in the content data")
# Save re-ranked results
output_path = get_full_path(args.output)
print(f"Saving re-ranked results to {output_path}...")
save_json_file(re_ranked, output_path)
print("Re-ranking complete!")
print(f"Number of training queries processed: {len(re_ranked)}")
# Optionally save the missing FANs information for debugging
if missing_query_fans or missing_doc_fans:
missing_info = {
"missing_query_fans": missing_query_fans,
"missing_doc_fans": missing_doc_fans
}
missing_info_path = f"{os.path.splitext(output_path)[0]}_missing_fans.json"
save_json_file(missing_info, missing_info_path)
print(f"Information about missing FANs saved to {missing_info_path}")
if __name__ == "__main__":
main() |