Spaces:
Sleeping
Sleeping
import json | |
import sys | |
import argparse | |
sys.path.insert(0, './src') | |
from logger_config import logger | |
from metrics import compute_mrr, trec_eval | |
from utils import save_json_to_file | |
from data_utils import load_qrels, load_msmarco_predictions | |
parser = argparse.ArgumentParser(description='compute metrics for ms-marco predictions') | |
parser.add_argument('--in-path', default='', type=str, metavar='N', | |
help='path to predictions in msmarco output format') | |
parser.add_argument('--qrels', default='./data/msmarco/dev_qrels.txt', type=str, metavar='N', | |
help='path to qrels') | |
args = parser.parse_args() | |
logger.info('Args={}'.format(json.dumps(args.__dict__, ensure_ascii=False, indent=4))) | |
def main(): | |
qrels = load_qrels(path=args.qrels) | |
predictions = load_msmarco_predictions(args.in_path) | |
all_metrics = trec_eval(qrels=qrels, predictions=predictions) | |
all_metrics['mrr'] = compute_mrr(qrels=qrels, predictions=predictions) | |
logger.info(json.dumps(all_metrics, ensure_ascii=False, indent=4)) | |
save_json_to_file(all_metrics, '{}.metrics.json'.format(args.in_path)) | |
if __name__ == '__main__': | |
main() | |