darpanaswal commited on
Commit
10e48ed
·
verified ·
1 Parent(s): e68549b

Update cross_encoder_reranking_train.py

Browse files
Files changed (1) hide show
  1. cross_encoder_reranking_train.py +13 -6
cross_encoder_reranking_train.py CHANGED
@@ -13,10 +13,13 @@ from sklearn.metrics.pairwise import cosine_similarity
13
 
14
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
  # Load embedder once
16
- embedder = SentenceTransformer("sentence-transformers/all-mpnet-base-v2").to(device)
 
 
17
 
18
  def embed_text_list(texts):
19
- return embedder.encode(texts, convert_to_tensor=False, device=device)
 
20
 
21
  def rank_by_centrality(texts):
22
  embeddings = embed_text_list(texts)
@@ -45,9 +48,12 @@ def cluster_and_rank(texts, threshold=0.75):
45
  return representative_texts
46
 
47
  def process_single_patent(patent_dict):
48
- claims = [v for k, v in patent_dict.items() if k.startswith("c-en")]
49
- paragraphs = [v for k, v in patent_dict.items() if k.startswith("p")]
50
- features = [v for k, v in patent_dict.get("features", {}).items()]
 
 
 
51
 
52
  # Cluster & rank
53
  top_claims = cluster_and_rank(claims)
@@ -225,6 +231,7 @@ def cross_encoder_reranking(query_text, doc_texts, model, tokenizer, batch_size=
225
 
226
  def main():
227
  base_directory = os.getcwd()
 
228
  parser = argparse.ArgumentParser(description='Re-rank patents using cross-encoder scoring (training queries only)')
229
  parser.add_argument('--pre_ranking', type=str, default='shuffled_pre_ranking.json',
230
  help='Path to pre-ranking JSON file')
@@ -252,7 +259,7 @@ def main():
252
  parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
253
  help='Device to use (cuda/cpu)')
254
  parser.add_argument('--base_dir', type=str,
255
- default=f'{base_directory}/Patent_Retrieval/datasets',
256
  help='Base directory for data files')
257
 
258
  args = parser.parse_args()
 
13
 
14
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
  # Load embedder once
16
+ # embedder = SentenceTransformer("all-MiniLM-L6-v2").to(device)
17
+ embedder = SentenceTransformer("intfloat/e5-large-v2").to(device)
18
+
19
 
20
  def embed_text_list(texts):
21
+ # return embedder.encode(texts, convert_to_tensor=False, device=device)
22
+ return embedder.encode(["query: your sentence here"], convert_to_tensor=False, device=device)
23
 
24
  def rank_by_centrality(texts):
25
  embeddings = embed_text_list(texts)
 
48
  return representative_texts
49
 
50
  def process_single_patent(patent_dict):
51
+ def filter_short_texts(texts, min_tokens=5):
52
+ return [text for text in texts if len(text.split()) >= min_tokens]
53
+
54
+ claims = filter_short_texts([v for k, v in patent_dict.items() if k.startswith("c-en")])
55
+ paragraphs = filter_short_texts([v for k, v in patent_dict.items() if k.startswith("p")])
56
+ features = filter_short_texts([v for k, v in patent_dict.get("features", {}).items()])
57
 
58
  # Cluster & rank
59
  top_claims = cluster_and_rank(claims)
 
231
 
232
  def main():
233
  base_directory = os.getcwd()
234
+ base_directory += "/Patent_Retrieval"
235
  parser = argparse.ArgumentParser(description='Re-rank patents using cross-encoder scoring (training queries only)')
236
  parser.add_argument('--pre_ranking', type=str, default='shuffled_pre_ranking.json',
237
  help='Path to pre-ranking JSON file')
 
259
  parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
260
  help='Device to use (cuda/cpu)')
261
  parser.add_argument('--base_dir', type=str,
262
+ default=f'{base_directory}/datasets',
263
  help='Base directory for data files')
264
 
265
  args = parser.parse_args()