Spaces:
Configuration error
Configuration error
Update cross_encoder_reranking_train.py
Browse files
cross_encoder_reranking_train.py
CHANGED
@@ -223,6 +223,7 @@ def cross_encoder_reranking(query_text, doc_texts, model, tokenizer, batch_size=
|
|
223 |
return [idx for idx, _ in indexed_scores]
|
224 |
|
225 |
def main():
|
|
|
226 |
parser = argparse.ArgumentParser(description='Re-rank patents using cross-encoder scoring (training queries only)')
|
227 |
parser.add_argument('--pre_ranking', type=str, default='shuffled_pre_ranking.json',
|
228 |
help='Path to pre-ranking JSON file')
|
@@ -248,7 +249,7 @@ def main():
|
|
248 |
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
|
249 |
help='Device to use (cuda/cpu)')
|
250 |
parser.add_argument('--base_dir', type=str,
|
251 |
-
default='datasets',
|
252 |
help='Base directory for data files')
|
253 |
|
254 |
args = parser.parse_args()
|
|
|
223 |
return [idx for idx, _ in indexed_scores]
|
224 |
|
225 |
def main():
|
226 |
+
base_directory = os.getcwd()
|
227 |
parser = argparse.ArgumentParser(description='Re-rank patents using cross-encoder scoring (training queries only)')
|
228 |
parser.add_argument('--pre_ranking', type=str, default='shuffled_pre_ranking.json',
|
229 |
help='Path to pre-ranking JSON file')
|
|
|
249 |
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
|
250 |
help='Device to use (cuda/cpu)')
|
251 |
parser.add_argument('--base_dir', type=str,
|
252 |
+
default=f'{base_directory}/datasets',
|
253 |
help='Base directory for data files')
|
254 |
|
255 |
args = parser.parse_args()
|