File size: 3,045 Bytes
6fc683c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import json
import os
import sys
import tqdm
import argparse
sys.path.insert(0, './src')

from typing import List, Dict

from utils import save_json_to_file
from logger_config import logger
from data_utils import load_qrels, load_corpus, load_queries, load_msmarco_predictions, ScoredDoc
from metrics import get_rel_threshold

parser = argparse.ArgumentParser(description='convert ms-marco predictions to a human-readable format')
parser.add_argument('--in-path', default='', type=str, metavar='N',
                    help='path to predictions in msmarco output format')
parser.add_argument('--split', default='dev', type=str, metavar='N',
                    help='which split to use')
parser.add_argument('--data-dir', default='./data/msmarco/', type=str, metavar='N',
                    help='data dir')

args = parser.parse_args()
logger.info('Args={}'.format(json.dumps(args.__dict__, ensure_ascii=False, indent=4)))


def main(topk: int = 10):
    predictions: Dict[str, List[ScoredDoc]] = load_msmarco_predictions(path=args.in_path)
    path_qrels = '{}/{}_qrels.txt'.format(args.data_dir, args.split)
    qrels = load_qrels(path=path_qrels) if os.path.exists(path_qrels) else None
    queries = load_queries(path='{}/{}_queries.tsv'.format(args.data_dir, args.split))
    corpus = load_corpus(path='{}/passages.jsonl.gz'.format(args.data_dir))

    pred_infos = []
    out_path = '{}.details.json'.format(args.in_path)
    rel_threshold = get_rel_threshold(qrels) if qrels else -1
    for qid in tqdm.tqdm(queries):
        pred_docs = []
        for scored_doc in predictions[qid][:topk]:
            correct = qrels is not None and scored_doc.pid in qrels[qid] and qrels[qid][scored_doc.pid] >= rel_threshold
            pred_docs.append({'id': scored_doc.pid,
                              'contents': corpus[int(scored_doc.pid)]['contents'],
                              'title': corpus[int(scored_doc.pid)]['title'],
                              'score': scored_doc.score})
            if qrels is not None:
                pred_docs[-1]['correct'] = correct
            if correct: break

        gold_rank, gold_score = -1, -1
        for idx, scored_doc in enumerate(predictions[qid]):
            if qrels is None:
                break
            if scored_doc.pid in qrels[qid] and qrels[qid][scored_doc.pid] >= rel_threshold:
                gold_rank = idx + 1
                gold_score = scored_doc.score
                break

        pred_info = {'query_id': qid,
                     'query': queries[qid],
                     'pred_docs': pred_docs}
        if qrels is not None:
            pred_info.update({
                'gold_docs': [corpus[int(doc_id)] for doc_id in qrels[qid] if qrels[qid][doc_id] >= rel_threshold],
                'gold_score': gold_score,
                'gold_rank': gold_rank
            })

        pred_infos.append(pred_info)

    save_json_to_file(pred_infos, out_path)
    logger.info('Save prediction details to {}'.format(out_path))


if __name__ == '__main__':
    main()