import os import random import tqdm import json from typing import Dict, List, Any from datasets import load_dataset, Dataset from dataclasses import dataclass, field from logger_config import logger from config import Arguments from utils import save_json_to_file @dataclass class ScoredDoc: qid: str pid: str rank: int score: float = field(default=-1) def load_qrels(path: str) -> Dict[str, Dict[str, int]]: assert path.endswith('.txt') # qid -> pid -> score qrels = {} for line in open(path, 'r', encoding='utf-8'): qid, _, pid, score = line.strip().split('\t') if qid not in qrels: qrels[qid] = {} qrels[qid][pid] = int(score) logger.info('Load {} queries {} qrels from {}'.format(len(qrels), sum(len(v) for v in qrels.values()), path)) return qrels def load_queries(path: str, task_type: str = 'ir') -> Dict[str, str]: assert path.endswith('.tsv') if task_type == 'qa': qid_to_query = load_query_answers(path) qid_to_query = {k: v['query'] for k, v in qid_to_query.items()} elif task_type == 'ir': qid_to_query = {} for line in open(path, 'r', encoding='utf-8'): qid, query = line.strip().split('\t') qid_to_query[qid] = query else: raise ValueError('Unknown task type: {}'.format(task_type)) logger.info('Load {} queries from {}'.format(len(qid_to_query), path)) return qid_to_query def normalize_qa_text(text: str) -> str: # TriviaQA has some weird formats # For example: """What breakfast food gets its name from the German word for """"stirrup""""?""" while text.startswith('"') and text.endswith('"'): text = text[1:-1].replace('""', '"') return text def get_question_key(question: str) -> str: # For QA dataset, we'll use normalized question strings as dict key return question def load_query_answers(path: str) -> Dict[str, Dict[str, Any]]: assert path.endswith('.tsv') qid_to_query = {} for line in open(path, 'r', encoding='utf-8'): query, answers = line.strip().split('\t') query = normalize_qa_text(query) answers = normalize_qa_text(answers) qid = get_question_key(query) if qid in qid_to_query: logger.warning('Duplicate question: {} vs {}'.format(query, qid_to_query[qid]['query'])) continue qid_to_query[qid] = {} qid_to_query[qid]['query'] = query qid_to_query[qid]['answers'] = list(eval(answers)) logger.info('Load {} queries from {}'.format(len(qid_to_query), path)) return qid_to_query def load_corpus(path: str) -> Dataset: assert path.endswith('.jsonl') or path.endswith('.jsonl.gz') # two fields: id, contents corpus = load_dataset('json', data_files=path)['train'] logger.info('Load {} documents from {} with columns {}'.format(len(corpus), path, corpus.column_names)) logger.info('A random document: {}'.format(random.choice(corpus))) return corpus def load_msmarco_predictions(path: str) -> Dict[str, List[ScoredDoc]]: assert path.endswith('.txt') qid_to_scored_doc = {} for line in tqdm.tqdm(open(path, 'r', encoding='utf-8'), desc='load prediction', mininterval=3): fs = line.strip().split('\t') qid, pid, rank = fs[:3] rank = int(rank) score = round(1 / rank, 4) if len(fs) == 3 else float(fs[3]) if qid not in qid_to_scored_doc: qid_to_scored_doc[qid] = [] scored_doc = ScoredDoc(qid=qid, pid=pid, rank=rank, score=score) qid_to_scored_doc[qid].append(scored_doc) qid_to_scored_doc = {qid: sorted(scored_docs, key=lambda sd: sd.rank) for qid, scored_docs in qid_to_scored_doc.items()} logger.info('Load {} query predictions from {}'.format(len(qid_to_scored_doc), path)) return qid_to_scored_doc def save_preds_to_msmarco_format(preds: Dict[str, List[ScoredDoc]], out_path: str): with open(out_path, 'w', encoding='utf-8') as writer: for qid in preds: for idx, scored_doc in enumerate(preds[qid]): writer.write('{}\t{}\t{}\t{}\n'.format(qid, scored_doc.pid, idx + 1, round(scored_doc.score, 3))) logger.info('Successfully saved to {}'.format(out_path)) def save_to_readable_format(in_path: str, corpus: Dataset): out_path = '{}/readable_{}'.format(os.path.dirname(in_path), os.path.basename(in_path)) dataset: Dataset = load_dataset('json', data_files=in_path)['train'] max_to_keep = 5 def _create_readable_field(samples: Dict[str, List]) -> List: readable_ex = [] for idx in range(min(len(samples['doc_id']), max_to_keep)): doc_id = samples['doc_id'][idx] readable_ex.append({'doc_id': doc_id, 'title': corpus[int(doc_id)].get('title', ''), 'contents': corpus[int(doc_id)]['contents'], 'score': samples['score'][idx]}) return readable_ex def _mp_func(ex: Dict) -> Dict: ex['positives'] = _create_readable_field(ex['positives']) ex['negatives'] = _create_readable_field(ex['negatives']) return ex dataset = dataset.map(_mp_func, num_proc=8) dataset.to_json(out_path, force_ascii=False, lines=False, indent=4) logger.info('Done convert {} to readable format in {}'.format(in_path, out_path)) def get_rerank_shard_path(args: Arguments, worker_idx: int) -> str: return '{}_shard_{}'.format(args.rerank_out_path, worker_idx) def merge_rerank_predictions(args: Arguments, gpu_count: int): from metrics import trec_eval, compute_mrr qid_to_scored_doc: Dict[str, List[ScoredDoc]] = {} for worker_idx in range(gpu_count): path = get_rerank_shard_path(args, worker_idx) for line in tqdm.tqdm(open(path, 'r', encoding='utf-8'), 'merge results', mininterval=3): fs = line.strip().split('\t') qid, pid, _, score = fs score = float(score) if qid not in qid_to_scored_doc: qid_to_scored_doc[qid] = [] scored_doc = ScoredDoc(qid=qid, pid=pid, rank=-1, score=score) qid_to_scored_doc[qid].append(scored_doc) qid_to_scored_doc = {k: sorted(v, key=lambda sd: sd.score, reverse=True) for k, v in qid_to_scored_doc.items()} ori_preds = load_msmarco_predictions(path=args.rerank_in_path) for query_id in list(qid_to_scored_doc.keys()): remain_scored_docs = ori_preds[query_id][args.rerank_depth:] for idx, sd in enumerate(remain_scored_docs): # make sure the order is not broken sd.score = qid_to_scored_doc[query_id][-1].score - idx - 1 qid_to_scored_doc[query_id] += remain_scored_docs assert len(set([sd.pid for sd in qid_to_scored_doc[query_id]])) == len(qid_to_scored_doc[query_id]) save_preds_to_msmarco_format(qid_to_scored_doc, out_path=args.rerank_out_path) path_qrels = '{}/{}_qrels.txt'.format(args.data_dir, args.rerank_split) if os.path.exists(path_qrels): qrels = load_qrels(path=path_qrels) all_metrics = trec_eval(qrels=qrels, predictions=qid_to_scored_doc) all_metrics['mrr'] = compute_mrr(qrels=qrels, predictions=qid_to_scored_doc) logger.info('{} trec metrics = {}'.format(args.rerank_split, json.dumps(all_metrics, ensure_ascii=False, indent=4))) metrics_out_path = '{}/metrics_rerank_{}.json'.format(os.path.dirname(args.rerank_out_path), args.rerank_split) save_json_to_file(all_metrics, metrics_out_path) else: logger.warning('No qrels found for {}'.format(args.rerank_split)) # cleanup some intermediate results for worker_idx in range(gpu_count): path = get_rerank_shard_path(args, worker_idx) os.remove(path) if __name__ == '__main__': load_qrels('./data/msmarco/dev_qrels.txt') load_queries('./data/msmarco/dev_queries.tsv') corpus = load_corpus('./data/msmarco/passages.jsonl.gz') preds = load_msmarco_predictions('./data/bm25.msmarco.txt')