import os import io import gzip import json import random import argparse import ir_datasets import numpy as np import sys sys.path.insert(0, 'src/') from tqdm import tqdm from typing import Dict, List from datasets import Dataset from logger_config import logger from utils import save_json_to_file from data_utils import load_msmarco_predictions, load_queries, load_qrels, load_corpus, \ ScoredDoc, save_to_readable_format parser = argparse.ArgumentParser(description='data preprocessing') parser.add_argument('--out-dir', default='./data/msmarco/', type=str, metavar='N', help='output directory') parser.add_argument('--train-pred-path', default='./preds/official/train.msmarco.txt', type=str, metavar='N', help='path to train predictions to construct negatives') parser.add_argument('--dev-pred-path', default='./preds/official/dev.msmarco.txt', type=str, metavar='N', help='path to dev predictions to construct negatives') parser.add_argument('--num-negatives', default=210, type=int, metavar='N', help='number of negative passages') parser.add_argument('--num-random-neg', default=10, type=int, metavar='N', help='number of random negatives to use') parser.add_argument('--depth', default=200, type=int, metavar='N', help='depth to choose negative passages from') parser.add_argument('--title-path', default='./data/msmarco/para.title.txt', type=str, metavar='N', help='path to titles data') parser.add_argument('--create-train-dev-only', action='store_true', help='path to titles data') parser.add_argument('--filter-noisy-positives', action='store_true', help='filter noisy positives or not') args = parser.parse_args() os.makedirs(args.out_dir, exist_ok=True) logger.info('Args: {}'.format(json.dumps(args.__dict__, ensure_ascii=False, indent=4))) def _write_corpus_to_disk(): dataset = ir_datasets.load('msmarco-passage/train') titles = [] if os.path.exists(args.title_path): titles = [line.strip().split('\t')[1] for line in tqdm(open(args.title_path).readlines(), desc='load title')] logger.info('Load {} titles from {}'.format(len(titles), args.title_path)) else: logger.warning('No title data found: {}'.format(args.title_path)) title_idx = 0 out_path = os.path.join(args.out_dir, 'passages.jsonl.gz') with gzip.open(out_path, 'wb') as output: with io.TextIOWrapper(output, encoding='utf-8') as writer: for doc in tqdm(dataset.docs_iter()): ex = {'id': doc.doc_id, 'contents': doc.text} if titles: ex['title'] = titles[title_idx] title_idx += 1 writer.write(json.dumps(ex, ensure_ascii=False, separators=(',', ':'))) writer.write('\n') if titles: assert title_idx == len(titles), '{} != {}'.format(title_idx, len(titles)) def _write_queries_to_disk(split: str, out_path: str): dataset = ir_datasets.load("msmarco-passage/{}".format(split)) with open(out_path, 'w', encoding='utf-8') as writer: for query in dataset.queries_iter(): writer.write('{}\t{}\n'.format(query.query_id, query.text)) logger.info('Write {} queries to {}'.format(split, out_path)) def _write_qrels_to_disk(split: str, out_path: str): dataset = ir_datasets.load("msmarco-passage/{}".format(split)) with open(out_path, 'w', encoding='utf-8') as writer: for qrel in dataset.qrels_iter(): # query_id, iteration, doc_id, relevance writer.write('{}\t{}\t{}\t{}\n' .format(qrel.query_id, qrel.iteration, qrel.doc_id, qrel.relevance)) logger.info('Write {} qrels to {}'.format(split, out_path)) def _write_prepared_data_to_disk(out_path: str, corpus: Dataset, queries: Dict[str, str], qrels: Dict[str, Dict[str, int]], preds: Dict[str, List[ScoredDoc]], is_train: bool = False): cnt_noisy_positive = 0 cnt_output = 0 with open(out_path, 'w', encoding='utf-8') as writer: for query_id in tqdm(qrels, mininterval=2): positive_doc_ids: Dict = qrels.get(query_id) if not positive_doc_ids: logger.warning('No positive found for query_id={}'.format(query_id)) continue if is_train and args.filter_noisy_positives \ and all(sd.pid not in positive_doc_ids for sd in preds.get(query_id, [])): cnt_noisy_positive += 1 continue # For official triples, only use those with negative doc ids if not preds.get(query_id, []): continue doc_id_to_score = {scored_doc.pid: scored_doc.score for scored_doc in preds.get(query_id, [])} negative_scored_docs = [scored_doc for scored_doc in preds.get(query_id, []) if scored_doc.pid not in positive_doc_ids][:args.depth] np.random.shuffle(negative_scored_docs) negative_scored_docs = negative_scored_docs[:(args.num_negatives - args.num_random_neg)] if len(negative_scored_docs) < args.num_negatives: if not negative_scored_docs: logger.warning('No negatives found for query_id={} ({}), will use random negatives' .format(len(negative_scored_docs), queries[query_id], query_id)) while len(negative_scored_docs) < args.num_negatives: sd = ScoredDoc(qid=query_id, pid=str(random.randint(0, len(corpus) - 1)), rank=args.depth) if sd.pid not in positive_doc_ids and sd.pid not in doc_id_to_score: negative_scored_docs.append(sd) np.random.shuffle(negative_scored_docs) example = {'query_id': query_id, 'query': queries[query_id], 'positives': {'doc_id': list(positive_doc_ids), 'score': [doc_id_to_score.get(doc_id, -1.) for doc_id in positive_doc_ids] }, 'negatives': {'doc_id': [scored_doc.pid for scored_doc in negative_scored_docs], 'score': [scored_doc.score for scored_doc in negative_scored_docs] }, } writer.write(json.dumps(example, ensure_ascii=False, separators=(',', ':'))) writer.write('\n') cnt_output += 1 if is_train and args.filter_noisy_positives: logger.info('Filter {} noisy positives'.format(cnt_noisy_positive)) logger.info('Write {} examples to {}'.format(cnt_output, out_path)) if __name__ == '__main__': if not args.create_train_dev_only: _write_queries_to_disk(split='dev/small', out_path=os.path.join(args.out_dir, 'dev_queries.tsv')) _write_queries_to_disk(split='eval/small', out_path=os.path.join(args.out_dir, 'test_queries.tsv')) _write_queries_to_disk(split='trec-dl-2019/judged', out_path=os.path.join(args.out_dir, 'trec_dl2019_queries.tsv')) _write_queries_to_disk(split='trec-dl-2020/judged', out_path=os.path.join(args.out_dir, 'trec_dl2020_queries.tsv')) _write_queries_to_disk(split='train/judged', out_path=os.path.join(args.out_dir, 'train_queries.tsv')) _write_qrels_to_disk(split='dev/small', out_path=os.path.join(args.out_dir, 'dev_qrels.txt')) _write_qrels_to_disk(split='trec-dl-2019/judged', out_path=os.path.join(args.out_dir, 'trec_dl2019_qrels.txt')) _write_qrels_to_disk(split='trec-dl-2020/judged', out_path=os.path.join(args.out_dir, 'trec_dl2020_qrels.txt')) _write_qrels_to_disk(split='train/judged', out_path=os.path.join(args.out_dir, 'train_qrels.txt')) _write_corpus_to_disk() corpus = load_corpus(path=os.path.join(args.out_dir, 'passages.jsonl.gz')) _write_prepared_data_to_disk(out_path=os.path.join(args.out_dir, 'dev.jsonl'), corpus=corpus, queries=load_queries(path=os.path.join(args.out_dir, 'dev_queries.tsv')), qrels=load_qrels(path=os.path.join(args.out_dir, 'dev_qrels.txt')), preds=load_msmarco_predictions(path=args.dev_pred_path)) _write_prepared_data_to_disk(out_path=os.path.join(args.out_dir, 'train.jsonl'), corpus=corpus, queries=load_queries(path=os.path.join(args.out_dir, 'train_queries.tsv')), qrels=load_qrels(path=os.path.join(args.out_dir, 'train_qrels.txt')), preds=load_msmarco_predictions(path=args.train_pred_path), is_train=True) save_to_readable_format(in_path=os.path.join(args.out_dir, 'dev.jsonl'), corpus=corpus) save_to_readable_format(in_path=os.path.join(args.out_dir, 'train.jsonl'), corpus=corpus) save_json_to_file(args.__dict__, path=os.path.join(args.out_dir, 'train_dev_create_args.json')) src_path = args.dev_pred_path dst_path = '{}/{}'.format(args.out_dir, os.path.basename(args.dev_pred_path)) logger.info('copy {} to {}'.format(src_path, dst_path)) os.system('cp {} {}'.format(src_path, dst_path)) for trec_split in ['trec_dl2019', 'trec_dl2020', 'test']: trec_pred_path = '{}/{}.msmarco.txt'.format(os.path.dirname(args.dev_pred_path), trec_split) dst_path = '{}/{}'.format(args.out_dir, os.path.basename(trec_pred_path)) if not os.path.exists(trec_pred_path): logger.warning('{} does not exist'.format(trec_pred_path)) continue logger.info('copy {} to {}'.format(trec_pred_path, dst_path)) os.system('cp {} {}'.format(trec_pred_path, dst_path))