File size: 4,387 Bytes
3b98ef0
 
 
 
 
 
 
 
 
 
 
 
 
 
d4b4f8f
3b98ef0
 
 
 
 
 
 
 
 
 
 
 
81442ed
3b98ef0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import json
import argparse
import sys

# Import metrics directly from the local file
from metrics import mean_recall_at_k, mean_average_precision, mean_inv_ranking, mean_ranking

def load_json_file(file_path):
    """Load JSON data from a file"""
    with open(file_path, 'r') as f:
        return json.load(f)

def main():
    base_directory = os.getcwd()
    parser = argparse.ArgumentParser(description='Evaluate document ranking performance on training data')
    parser.add_argument('--pre_ranking', type=str, default='shuffled_pre_ranking.json',
                        help='Path to pre-ranking JSON file')
    parser.add_argument('--re_ranking', type=str, default='predictions2.json',
                        help='Path to re-ranked JSON file')
    parser.add_argument('--gold', type=str, default='train_gold_mapping.json',
                        help='Path to gold standard mapping JSON file (training only)')
    parser.add_argument('--train_queries', type=str, default='train_queries.json',
                        help='Path to training queries JSON file')
    parser.add_argument('--k_values', type=str, default='3,5,10,20',
                        help='Comma-separated list of k values for Recall@k')
    parser.add_argument('--base_dir', type=str, 
                        default=f'{base_directory}/Patent_Retrieval/datasets',
                        help='Base directory for data files')
    args = parser.parse_args()

    # Ensure all paths are relative to base_dir if they're not absolute
    def get_full_path(path):
        if os.path.isabs(path):
            return path
        return os.path.join(args.base_dir, path)

    # Load the training queries
    print("Loading training queries...")
    train_queries = load_json_file(get_full_path(args.train_queries))
    print(f"Loaded {len(train_queries)} training queries")

    # Load the ranking data and gold standard
    print("Loading ranking data and gold standard...")
    pre_ranking = load_json_file(get_full_path(args.pre_ranking))
    re_ranking = load_json_file(get_full_path(args.re_ranking))
    gold_mapping = load_json_file(get_full_path(args.gold))
    
    # Filter to include only training queries
    pre_ranking = {fan: docs for fan, docs in pre_ranking.items() if fan in train_queries}
    re_ranking = {fan: docs for fan, docs in re_ranking.items() if fan in train_queries}  # Fixed this line
    gold_mapping = {fan: docs for fan, docs in gold_mapping.items() if fan in train_queries}
    
    # Parse k values
    k_values = [int(k) for k in args.k_values.split(',')]
    
    # Prepare data for metrics calculation
    query_fans = set(gold_mapping.keys()) & set(pre_ranking.keys()) & set(re_ranking.keys())
    
    if not query_fans:
        print("Error: No common query FANs found across all datasets!")
        return
    
    print(f"Evaluating rankings for {len(query_fans)} training queries...")
    
    # Extract true and predicted labels for both rankings
    true_labels = [gold_mapping[fan] for fan in query_fans]
    pre_ranking_labels = [pre_ranking[fan] for fan in query_fans]
    re_ranking_labels = [re_ranking[fan] for fan in query_fans]
    
    # Calculate metrics for pre-ranking
    print("\nPre-ranking performance (training queries only):")
    for k in k_values:
        recall_at_k = mean_recall_at_k(true_labels, pre_ranking_labels, k=k)
        print(f"  Recall@{k}: {recall_at_k:.4f}")
    
    map_score = mean_average_precision(true_labels, pre_ranking_labels)
    print(f"  MAP: {map_score:.4f}")
    
    inv_rank = mean_inv_ranking(true_labels, pre_ranking_labels)
    print(f"  Mean Inverse Rank: {inv_rank:.4f}")
    
    rank = mean_ranking(true_labels, pre_ranking_labels)
    print(f"  Mean Rank: {rank:.2f}")
    
    # Calculate metrics for re-ranking
    print("\nRe-ranking performance (training queries only):")
    for k in k_values:
        recall_at_k = mean_recall_at_k(true_labels, re_ranking_labels, k=k)
        print(f"  Recall@{k}: {recall_at_k:.4f}")
    
    map_score = mean_average_precision(true_labels, re_ranking_labels)
    print(f"  MAP: {map_score:.4f}")
    
    inv_rank = mean_inv_ranking(true_labels, re_ranking_labels)
    print(f"  Mean Inverse Rank: {inv_rank:.4f}")
    
    rank = mean_ranking(true_labels, re_ranking_labels)
    print(f"  Mean Rank: {rank:.2f}")

if __name__ == "__main__":
    main()