Spaces:
Configuration error
Configuration error
Update cross_encoder_reranking_train.py
Browse files
cross_encoder_reranking_train.py
CHANGED
@@ -13,10 +13,13 @@ from sklearn.metrics.pairwise import cosine_similarity
|
|
13 |
|
14 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
15 |
# Load embedder once
|
16 |
-
embedder = SentenceTransformer("
|
|
|
|
|
17 |
|
18 |
def embed_text_list(texts):
|
19 |
-
return embedder.encode(texts, convert_to_tensor=False, device=device)
|
|
|
20 |
|
21 |
def rank_by_centrality(texts):
|
22 |
embeddings = embed_text_list(texts)
|
@@ -45,9 +48,12 @@ def cluster_and_rank(texts, threshold=0.75):
|
|
45 |
return representative_texts
|
46 |
|
47 |
def process_single_patent(patent_dict):
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
51 |
|
52 |
# Cluster & rank
|
53 |
top_claims = cluster_and_rank(claims)
|
@@ -225,6 +231,7 @@ def cross_encoder_reranking(query_text, doc_texts, model, tokenizer, batch_size=
|
|
225 |
|
226 |
def main():
|
227 |
base_directory = os.getcwd()
|
|
|
228 |
parser = argparse.ArgumentParser(description='Re-rank patents using cross-encoder scoring (training queries only)')
|
229 |
parser.add_argument('--pre_ranking', type=str, default='shuffled_pre_ranking.json',
|
230 |
help='Path to pre-ranking JSON file')
|
@@ -252,7 +259,7 @@ def main():
|
|
252 |
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
|
253 |
help='Device to use (cuda/cpu)')
|
254 |
parser.add_argument('--base_dir', type=str,
|
255 |
-
default=f'{base_directory}/
|
256 |
help='Base directory for data files')
|
257 |
|
258 |
args = parser.parse_args()
|
|
|
13 |
|
14 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
15 |
# Load embedder once
|
16 |
+
# embedder = SentenceTransformer("all-MiniLM-L6-v2").to(device)
|
17 |
+
embedder = SentenceTransformer("intfloat/e5-large-v2").to(device)
|
18 |
+
|
19 |
|
20 |
def embed_text_list(texts):
|
21 |
+
# return embedder.encode(texts, convert_to_tensor=False, device=device)
|
22 |
+
return embedder.encode(["query: your sentence here"], convert_to_tensor=False, device=device)
|
23 |
|
24 |
def rank_by_centrality(texts):
|
25 |
embeddings = embed_text_list(texts)
|
|
|
48 |
return representative_texts
|
49 |
|
50 |
def process_single_patent(patent_dict):
|
51 |
+
def filter_short_texts(texts, min_tokens=5):
|
52 |
+
return [text for text in texts if len(text.split()) >= min_tokens]
|
53 |
+
|
54 |
+
claims = filter_short_texts([v for k, v in patent_dict.items() if k.startswith("c-en")])
|
55 |
+
paragraphs = filter_short_texts([v for k, v in patent_dict.items() if k.startswith("p")])
|
56 |
+
features = filter_short_texts([v for k, v in patent_dict.get("features", {}).items()])
|
57 |
|
58 |
# Cluster & rank
|
59 |
top_claims = cluster_and_rank(claims)
|
|
|
231 |
|
232 |
def main():
|
233 |
base_directory = os.getcwd()
|
234 |
+
base_directory += "/Patent_Retrieval"
|
235 |
parser = argparse.ArgumentParser(description='Re-rank patents using cross-encoder scoring (training queries only)')
|
236 |
parser.add_argument('--pre_ranking', type=str, default='shuffled_pre_ranking.json',
|
237 |
help='Path to pre-ranking JSON file')
|
|
|
259 |
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
|
260 |
help='Device to use (cuda/cpu)')
|
261 |
parser.add_argument('--base_dir', type=str,
|
262 |
+
default=f'{base_directory}/datasets',
|
263 |
help='Base directory for data files')
|
264 |
|
265 |
args = parser.parse_args()
|