darpanaswal commited on
Commit
0bc2c49
·
verified ·
1 Parent(s): ba76ed7

Update cross_encoder_reranking_train.py

Browse files
Files changed (1) hide show
  1. cross_encoder_reranking_train.py +2 -1
cross_encoder_reranking_train.py CHANGED
@@ -223,6 +223,7 @@ def cross_encoder_reranking(query_text, doc_texts, model, tokenizer, batch_size=
223
  return [idx for idx, _ in indexed_scores]
224
 
225
  def main():
 
226
  parser = argparse.ArgumentParser(description='Re-rank patents using cross-encoder scoring (training queries only)')
227
  parser.add_argument('--pre_ranking', type=str, default='shuffled_pre_ranking.json',
228
  help='Path to pre-ranking JSON file')
@@ -248,7 +249,7 @@ def main():
248
  parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
249
  help='Device to use (cuda/cpu)')
250
  parser.add_argument('--base_dir', type=str,
251
- default='datasets',
252
  help='Base directory for data files')
253
 
254
  args = parser.parse_args()
 
223
  return [idx for idx, _ in indexed_scores]
224
 
225
  def main():
226
+ base_directory = os.getcwd()
227
  parser = argparse.ArgumentParser(description='Re-rank patents using cross-encoder scoring (training queries only)')
228
  parser.add_argument('--pre_ranking', type=str, default='shuffled_pre_ranking.json',
229
  help='Path to pre-ranking JSON file')
 
249
  parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
250
  help='Device to use (cuda/cpu)')
251
  parser.add_argument('--base_dir', type=str,
252
+ default=f'{base_directory}/datasets',
253
  help='Base directory for data files')
254
 
255
  args = parser.parse_args()