File size: 1,162 Bytes
6fc683c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import json
import sys
import argparse
sys.path.insert(0, './src')

from logger_config import logger
from metrics import compute_mrr, trec_eval
from utils import save_json_to_file
from data_utils import load_qrels, load_msmarco_predictions

parser = argparse.ArgumentParser(description='compute metrics for ms-marco predictions')
parser.add_argument('--in-path', default='', type=str, metavar='N',
                    help='path to predictions in msmarco output format')
parser.add_argument('--qrels', default='./data/msmarco/dev_qrels.txt', type=str, metavar='N',
                    help='path to qrels')

args = parser.parse_args()
logger.info('Args={}'.format(json.dumps(args.__dict__, ensure_ascii=False, indent=4)))


def main():
    qrels = load_qrels(path=args.qrels)
    predictions = load_msmarco_predictions(args.in_path)
    all_metrics = trec_eval(qrels=qrels, predictions=predictions)
    all_metrics['mrr'] = compute_mrr(qrels=qrels, predictions=predictions)

    logger.info(json.dumps(all_metrics, ensure_ascii=False, indent=4))

    save_json_to_file(all_metrics, '{}.metrics.json'.format(args.in_path))


if __name__ == '__main__':
    main()