Tzktz's picture
Upload 7664 files
6fc683c verified
import os
import argparse
import json
import sys
sys.path.insert(0, 'src/')
from tqdm import tqdm
from typing import Dict, Any
from datasets import Dataset
from evaluate_dpr_retrieval import has_answers, SimpleTokenizer, evaluate_retrieval
from data_utils import load_query_answers, load_corpus
from utils import save_json_to_file
from logger_config import logger
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Convert an TREC run to DPR retrieval result json.')
parser.add_argument('--data-dir', required=True, help='data dir')
parser.add_argument('--topics', required=True, help='topic name')
parser.add_argument('--topk', type=int, nargs='+', help="topk to evaluate")
parser.add_argument('--input', required=True, help='Input TREC run file.')
parser.add_argument('--store-raw', action='store_true', help='Store raw text of passage')
parser.add_argument('--regex', action='store_true', default=False, help="regex match")
parser.add_argument('--output', required=True, help='Output DPR Retrieval json file.')
args = parser.parse_args()
qas = load_query_answers(path=args.topics)
corpus = load_corpus(path=os.path.join(args.data_dir, 'passages.jsonl.gz'))
retrieval = {}
tokenizer = SimpleTokenizer()
predictions = []
for line in tqdm(open(args.input), mininterval=1):
question_id, doc_idx, _, score = line.strip().split('\t')[:4]
predictions.append({'question_id': question_id,
'doc_idx': int(doc_idx),
'score': score})
dataset = Dataset.from_dict({'question_id': [ex['question_id'] for ex in predictions],
'doc_idx': [ex['doc_idx'] for ex in predictions],
'score': [ex['score'] for ex in predictions]})
logger.info('Get {} predictions in total'.format(len(dataset)))
def _map_func(example: Dict[str, Any]) -> dict:
question_id, doc_idx, score = example['question_id'], example['doc_idx'], example['score']
question = qas[question_id]['query']
answers = qas[question_id]['answers']
title, text = corpus[doc_idx]['title'], corpus[doc_idx]['contents']
ctx = '{}\n{}'.format(title, text)
answer_exist = has_answers(text, answers, tokenizer, args.regex)
example['question'] = question
example['answers'] = answers
example['docid'] = doc_idx
example['has_answer'] = answer_exist
if args.store_raw:
example['text'] = ctx
return example
dataset = dataset.map(_map_func,
num_proc=min(os.cpu_count(), 16))
retrieval = {}
for ex in tqdm(dataset, mininterval=2, desc='convert to dpr format'):
question_id, question, answers = ex['question_id'], ex['question'], ex['answers']
if question_id not in retrieval:
retrieval[question_id] = {'question': question, 'answers': answers, 'contexts': []}
retrieval[question_id]['contexts'].append(
{k: ex[k] for k in ['docid', 'score', 'text', 'has_answer'] if k in ex}
)
save_json_to_file(retrieval, path=args.output)
logger.info('Convert {} to {} done'.format(args.input, args.output))
metrics = evaluate_retrieval(retrieval_file=args.output,
topk=args.topk,
regex=args.regex)
logger.info('{} recall metrics: {}'.format(
os.path.basename(args.output),
json.dumps(metrics, ensure_ascii=False, indent=4)))
base_dir, base_name = os.path.dirname(args.output), os.path.basename(args.output)
save_json_to_file(metrics, path='{}/metrics_{}'.format(base_dir, base_name))