Tzktz's picture
Upload 7664 files
6fc683c verified
import json
import os
import sys
import tqdm
import argparse
sys.path.insert(0, './src')
from typing import List, Dict
from utils import save_json_to_file
from logger_config import logger
from data_utils import load_qrels, load_corpus, load_queries, load_msmarco_predictions, ScoredDoc
from metrics import get_rel_threshold
parser = argparse.ArgumentParser(description='convert ms-marco predictions to a human-readable format')
parser.add_argument('--in-path', default='', type=str, metavar='N',
help='path to predictions in msmarco output format')
parser.add_argument('--split', default='dev', type=str, metavar='N',
help='which split to use')
parser.add_argument('--data-dir', default='./data/msmarco/', type=str, metavar='N',
help='data dir')
args = parser.parse_args()
logger.info('Args={}'.format(json.dumps(args.__dict__, ensure_ascii=False, indent=4)))
def main(topk: int = 10):
predictions: Dict[str, List[ScoredDoc]] = load_msmarco_predictions(path=args.in_path)
path_qrels = '{}/{}_qrels.txt'.format(args.data_dir, args.split)
qrels = load_qrels(path=path_qrels) if os.path.exists(path_qrels) else None
queries = load_queries(path='{}/{}_queries.tsv'.format(args.data_dir, args.split))
corpus = load_corpus(path='{}/passages.jsonl.gz'.format(args.data_dir))
pred_infos = []
out_path = '{}.details.json'.format(args.in_path)
rel_threshold = get_rel_threshold(qrels) if qrels else -1
for qid in tqdm.tqdm(queries):
pred_docs = []
for scored_doc in predictions[qid][:topk]:
correct = qrels is not None and scored_doc.pid in qrels[qid] and qrels[qid][scored_doc.pid] >= rel_threshold
pred_docs.append({'id': scored_doc.pid,
'contents': corpus[int(scored_doc.pid)]['contents'],
'title': corpus[int(scored_doc.pid)]['title'],
'score': scored_doc.score})
if qrels is not None:
pred_docs[-1]['correct'] = correct
if correct: break
gold_rank, gold_score = -1, -1
for idx, scored_doc in enumerate(predictions[qid]):
if qrels is None:
break
if scored_doc.pid in qrels[qid] and qrels[qid][scored_doc.pid] >= rel_threshold:
gold_rank = idx + 1
gold_score = scored_doc.score
break
pred_info = {'query_id': qid,
'query': queries[qid],
'pred_docs': pred_docs}
if qrels is not None:
pred_info.update({
'gold_docs': [corpus[int(doc_id)] for doc_id in qrels[qid] if qrels[qid][doc_id] >= rel_threshold],
'gold_score': gold_score,
'gold_rank': gold_rank
})
pred_infos.append(pred_info)
save_json_to_file(pred_infos, out_path)
logger.info('Save prediction details to {}'.format(out_path))
if __name__ == '__main__':
main()