File size: 3,755 Bytes
6fc683c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import argparse
import json
import sys
sys.path.insert(0, 'src/')

from tqdm import tqdm
from typing import Dict, Any
from datasets import Dataset

from evaluate_dpr_retrieval import has_answers, SimpleTokenizer, evaluate_retrieval
from data_utils import load_query_answers, load_corpus
from utils import save_json_to_file
from logger_config import logger


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Convert an TREC run to DPR retrieval result json.')
    parser.add_argument('--data-dir', required=True, help='data dir')
    parser.add_argument('--topics', required=True, help='topic name')
    parser.add_argument('--topk', type=int, nargs='+', help="topk to evaluate")
    parser.add_argument('--input', required=True, help='Input TREC run file.')
    parser.add_argument('--store-raw', action='store_true', help='Store raw text of passage')
    parser.add_argument('--regex', action='store_true', default=False, help="regex match")
    parser.add_argument('--output', required=True, help='Output DPR Retrieval json file.')
    args = parser.parse_args()

    qas = load_query_answers(path=args.topics)
    corpus = load_corpus(path=os.path.join(args.data_dir, 'passages.jsonl.gz'))

    retrieval = {}
    tokenizer = SimpleTokenizer()

    predictions = []
    for line in tqdm(open(args.input), mininterval=1):
        question_id, doc_idx, _, score = line.strip().split('\t')[:4]

        predictions.append({'question_id': question_id,
                            'doc_idx': int(doc_idx),
                            'score': score})

    dataset = Dataset.from_dict({'question_id': [ex['question_id'] for ex in predictions],
                                 'doc_idx': [ex['doc_idx'] for ex in predictions],
                                 'score': [ex['score'] for ex in predictions]})
    logger.info('Get {} predictions in total'.format(len(dataset)))

    def _map_func(example: Dict[str, Any]) -> dict:
        question_id, doc_idx, score = example['question_id'], example['doc_idx'], example['score']

        question = qas[question_id]['query']
        answers = qas[question_id]['answers']
        title, text = corpus[doc_idx]['title'], corpus[doc_idx]['contents']
        ctx = '{}\n{}'.format(title, text)

        answer_exist = has_answers(text, answers, tokenizer, args.regex)

        example['question'] = question
        example['answers'] = answers
        example['docid'] = doc_idx
        example['has_answer'] = answer_exist
        if args.store_raw:
            example['text'] = ctx

        return example

    dataset = dataset.map(_map_func,
                          num_proc=min(os.cpu_count(), 16))

    retrieval = {}
    for ex in tqdm(dataset, mininterval=2, desc='convert to dpr format'):
        question_id, question, answers = ex['question_id'], ex['question'], ex['answers']
        if question_id not in retrieval:
            retrieval[question_id] = {'question': question, 'answers': answers, 'contexts': []}
        retrieval[question_id]['contexts'].append(
            {k: ex[k] for k in ['docid', 'score', 'text', 'has_answer'] if k in ex}
        )

    save_json_to_file(retrieval, path=args.output)
    logger.info('Convert {} to {} done'.format(args.input, args.output))

    metrics = evaluate_retrieval(retrieval_file=args.output,
                                 topk=args.topk,
                                 regex=args.regex)
    logger.info('{} recall metrics: {}'.format(
        os.path.basename(args.output),
        json.dumps(metrics, ensure_ascii=False, indent=4)))

    base_dir, base_name = os.path.dirname(args.output), os.path.basename(args.output)
    save_json_to_file(metrics, path='{}/metrics_{}'.format(base_dir, base_name))