Patent_Retrieval / evaluate_train_rankings.py
darpanaswal's picture
Update evaluate_train_rankings.py
81442ed verified
import os
import json
import argparse
import sys
# Import metrics directly from the local file
from metrics import mean_recall_at_k, mean_average_precision, mean_inv_ranking, mean_ranking
def load_json_file(file_path):
"""Load JSON data from a file"""
with open(file_path, 'r') as f:
return json.load(f)
def main():
base_directory = os.getcwd()
parser = argparse.ArgumentParser(description='Evaluate document ranking performance on training data')
parser.add_argument('--pre_ranking', type=str, default='shuffled_pre_ranking.json',
help='Path to pre-ranking JSON file')
parser.add_argument('--re_ranking', type=str, default='predictions2.json',
help='Path to re-ranked JSON file')
parser.add_argument('--gold', type=str, default='train_gold_mapping.json',
help='Path to gold standard mapping JSON file (training only)')
parser.add_argument('--train_queries', type=str, default='train_queries.json',
help='Path to training queries JSON file')
parser.add_argument('--k_values', type=str, default='3,5,10,20',
help='Comma-separated list of k values for Recall@k')
parser.add_argument('--base_dir', type=str,
default=f'{base_directory}/Patent_Retrieval/datasets',
help='Base directory for data files')
args = parser.parse_args()
# Ensure all paths are relative to base_dir if they're not absolute
def get_full_path(path):
if os.path.isabs(path):
return path
return os.path.join(args.base_dir, path)
# Load the training queries
print("Loading training queries...")
train_queries = load_json_file(get_full_path(args.train_queries))
print(f"Loaded {len(train_queries)} training queries")
# Load the ranking data and gold standard
print("Loading ranking data and gold standard...")
pre_ranking = load_json_file(get_full_path(args.pre_ranking))
re_ranking = load_json_file(get_full_path(args.re_ranking))
gold_mapping = load_json_file(get_full_path(args.gold))
# Filter to include only training queries
pre_ranking = {fan: docs for fan, docs in pre_ranking.items() if fan in train_queries}
re_ranking = {fan: docs for fan, docs in re_ranking.items() if fan in train_queries} # Fixed this line
gold_mapping = {fan: docs for fan, docs in gold_mapping.items() if fan in train_queries}
# Parse k values
k_values = [int(k) for k in args.k_values.split(',')]
# Prepare data for metrics calculation
query_fans = set(gold_mapping.keys()) & set(pre_ranking.keys()) & set(re_ranking.keys())
if not query_fans:
print("Error: No common query FANs found across all datasets!")
return
print(f"Evaluating rankings for {len(query_fans)} training queries...")
# Extract true and predicted labels for both rankings
true_labels = [gold_mapping[fan] for fan in query_fans]
pre_ranking_labels = [pre_ranking[fan] for fan in query_fans]
re_ranking_labels = [re_ranking[fan] for fan in query_fans]
# Calculate metrics for pre-ranking
print("\nPre-ranking performance (training queries only):")
for k in k_values:
recall_at_k = mean_recall_at_k(true_labels, pre_ranking_labels, k=k)
print(f" Recall@{k}: {recall_at_k:.4f}")
map_score = mean_average_precision(true_labels, pre_ranking_labels)
print(f" MAP: {map_score:.4f}")
inv_rank = mean_inv_ranking(true_labels, pre_ranking_labels)
print(f" Mean Inverse Rank: {inv_rank:.4f}")
rank = mean_ranking(true_labels, pre_ranking_labels)
print(f" Mean Rank: {rank:.2f}")
# Calculate metrics for re-ranking
print("\nRe-ranking performance (training queries only):")
for k in k_values:
recall_at_k = mean_recall_at_k(true_labels, re_ranking_labels, k=k)
print(f" Recall@{k}: {recall_at_k:.4f}")
map_score = mean_average_precision(true_labels, re_ranking_labels)
print(f" MAP: {map_score:.4f}")
inv_rank = mean_inv_ranking(true_labels, re_ranking_labels)
print(f" Mean Inverse Rank: {inv_rank:.4f}")
rank = mean_ranking(true_labels, re_ranking_labels)
print(f" Mean Rank: {rank:.2f}")
if __name__ == "__main__":
main()