Spaces:
Configuration error
Configuration error
import os | |
import json | |
import argparse | |
import sys | |
# Import metrics directly from the local file | |
from metrics import mean_recall_at_k, mean_average_precision, mean_inv_ranking, mean_ranking | |
def load_json_file(file_path): | |
"""Load JSON data from a file""" | |
with open(file_path, 'r') as f: | |
return json.load(f) | |
def main(): | |
base_directory = os.getcwd() | |
parser = argparse.ArgumentParser(description='Evaluate document ranking performance on training data') | |
parser.add_argument('--pre_ranking', type=str, default='shuffled_pre_ranking.json', | |
help='Path to pre-ranking JSON file') | |
parser.add_argument('--re_ranking', type=str, default='predictions2.json', | |
help='Path to re-ranked JSON file') | |
parser.add_argument('--gold', type=str, default='train_gold_mapping.json', | |
help='Path to gold standard mapping JSON file (training only)') | |
parser.add_argument('--train_queries', type=str, default='train_queries.json', | |
help='Path to training queries JSON file') | |
parser.add_argument('--k_values', type=str, default='3,5,10,20', | |
help='Comma-separated list of k values for Recall@k') | |
parser.add_argument('--base_dir', type=str, | |
default=f'{base_directory}/Patent_Retrieval/datasets', | |
help='Base directory for data files') | |
args = parser.parse_args() | |
# Ensure all paths are relative to base_dir if they're not absolute | |
def get_full_path(path): | |
if os.path.isabs(path): | |
return path | |
return os.path.join(args.base_dir, path) | |
# Load the training queries | |
print("Loading training queries...") | |
train_queries = load_json_file(get_full_path(args.train_queries)) | |
print(f"Loaded {len(train_queries)} training queries") | |
# Load the ranking data and gold standard | |
print("Loading ranking data and gold standard...") | |
pre_ranking = load_json_file(get_full_path(args.pre_ranking)) | |
re_ranking = load_json_file(get_full_path(args.re_ranking)) | |
gold_mapping = load_json_file(get_full_path(args.gold)) | |
# Filter to include only training queries | |
pre_ranking = {fan: docs for fan, docs in pre_ranking.items() if fan in train_queries} | |
re_ranking = {fan: docs for fan, docs in re_ranking.items() if fan in train_queries} # Fixed this line | |
gold_mapping = {fan: docs for fan, docs in gold_mapping.items() if fan in train_queries} | |
# Parse k values | |
k_values = [int(k) for k in args.k_values.split(',')] | |
# Prepare data for metrics calculation | |
query_fans = set(gold_mapping.keys()) & set(pre_ranking.keys()) & set(re_ranking.keys()) | |
if not query_fans: | |
print("Error: No common query FANs found across all datasets!") | |
return | |
print(f"Evaluating rankings for {len(query_fans)} training queries...") | |
# Extract true and predicted labels for both rankings | |
true_labels = [gold_mapping[fan] for fan in query_fans] | |
pre_ranking_labels = [pre_ranking[fan] for fan in query_fans] | |
re_ranking_labels = [re_ranking[fan] for fan in query_fans] | |
# Calculate metrics for pre-ranking | |
print("\nPre-ranking performance (training queries only):") | |
for k in k_values: | |
recall_at_k = mean_recall_at_k(true_labels, pre_ranking_labels, k=k) | |
print(f" Recall@{k}: {recall_at_k:.4f}") | |
map_score = mean_average_precision(true_labels, pre_ranking_labels) | |
print(f" MAP: {map_score:.4f}") | |
inv_rank = mean_inv_ranking(true_labels, pre_ranking_labels) | |
print(f" Mean Inverse Rank: {inv_rank:.4f}") | |
rank = mean_ranking(true_labels, pre_ranking_labels) | |
print(f" Mean Rank: {rank:.2f}") | |
# Calculate metrics for re-ranking | |
print("\nRe-ranking performance (training queries only):") | |
for k in k_values: | |
recall_at_k = mean_recall_at_k(true_labels, re_ranking_labels, k=k) | |
print(f" Recall@{k}: {recall_at_k:.4f}") | |
map_score = mean_average_precision(true_labels, re_ranking_labels) | |
print(f" MAP: {map_score:.4f}") | |
inv_rank = mean_inv_ranking(true_labels, re_ranking_labels) | |
print(f" Mean Inverse Rank: {inv_rank:.4f}") | |
rank = mean_ranking(true_labels, re_ranking_labels) | |
print(f" Mean Rank: {rank:.2f}") | |
if __name__ == "__main__": | |
main() | |