import json import os import glob import tqdm import torch from contextlib import nullcontext from torch.utils.data import DataLoader from functools import partial from collections import defaultdict from datasets import Dataset from typing import Dict, List, Tuple from transformers.file_utils import PaddingStrategy from transformers import ( AutoTokenizer, PreTrainedTokenizerFast, DataCollatorWithPadding, HfArgumentParser, BatchEncoding ) from config import Arguments from logger_config import logger from utils import move_to_cuda, save_json_to_file from metrics import compute_mrr, trec_eval, ScoredDoc from data_utils import load_queries, load_qrels, load_msmarco_predictions, save_preds_to_msmarco_format from models import BiencoderModelForInference, BiencoderOutput parser = HfArgumentParser((Arguments,)) args: Arguments = parser.parse_args_into_dataclasses()[0] assert os.path.exists(args.encode_save_dir) def _get_all_shards_path() -> List[str]: path_list = glob.glob('{}/shard_*_*'.format(args.encode_save_dir)) assert len(path_list) > 0 def _parse_worker_idx_shard_idx(p: str) -> Tuple: worker_idx, shard_idx = [int(f) for f in os.path.basename(p).split('_')[-2:]] return worker_idx, shard_idx path_list = sorted(path_list, key=lambda path: _parse_worker_idx_shard_idx(path)) logger.info('Embeddings path list: {}'.format(path_list)) return path_list def _get_topk_result_save_path(worker_idx: int) -> str: return '{}/top{}_{}_{}.txt'.format(args.search_out_dir, args.search_topk, args.search_split, worker_idx) def _query_transform_func(tokenizer: PreTrainedTokenizerFast, examples: Dict[str, List]) -> BatchEncoding: batch_dict = tokenizer(examples['query'], max_length=args.q_max_len, padding=PaddingStrategy.DO_NOT_PAD, truncation=True) return batch_dict @torch.no_grad() def _worker_encode_queries(gpu_idx: int) -> Tuple: # fail fast if shard does not exist _get_all_shards_path() query_id_to_text = load_queries(path=os.path.join(args.data_dir, '{}_queries.tsv'.format(args.search_split)), task_type=args.task_type) query_ids = sorted(list(query_id_to_text.keys())) queries = [query_id_to_text[query_id] for query_id in query_ids] dataset = Dataset.from_dict({'query_id': query_ids, 'query': queries}) dataset = dataset.shard(num_shards=torch.cuda.device_count(), index=gpu_idx, contiguous=True) # only keep data for current shard query_ids = dataset['query_id'] query_id_to_text = {qid: query_id_to_text[qid] for qid in query_ids} logger.info('GPU {} needs to process {} examples'.format(gpu_idx, len(dataset))) torch.cuda.set_device(gpu_idx) tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(args.model_name_or_path) model: BiencoderModelForInference = BiencoderModelForInference.build(args) model.eval() model.cuda() dataset.set_transform(partial(_query_transform_func, tokenizer)) data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) data_loader = DataLoader( dataset, batch_size=512, shuffle=False, drop_last=False, num_workers=args.dataloader_num_workers, collate_fn=data_collator, pin_memory=True) encoded_embeds = [] for batch_dict in tqdm.tqdm(data_loader, desc='query encoding', mininterval=5): batch_dict = move_to_cuda(batch_dict) with torch.cuda.amp.autocast() if args.fp16 else nullcontext(): outputs: BiencoderOutput = model(query=batch_dict, passage=None) encoded_embeds.append(outputs.q_reps) query_embeds = torch.cat(encoded_embeds, dim=0) logger.info('Done query encoding for worker {}'.format(gpu_idx)) return query_embeds, query_ids, query_id_to_text @torch.no_grad() def _worker_batch_search(gpu_idx: int): embeds_path_list = _get_all_shards_path() query_embeds, query_ids, query_id_to_text = _worker_encode_queries(gpu_idx) assert query_embeds.shape[0] == len(query_ids), '{} != {}'.format(query_embeds.shape[0], len(query_ids)) query_id_to_topk = defaultdict(list) psg_idx_offset = 0 for shard_idx, shard_path in enumerate(embeds_path_list): shard_psg_embed = torch.load(shard_path, map_location=lambda storage, loc: storage).to(query_embeds.device) logger.info('Load {} passage embeddings from {}'.format(shard_psg_embed.shape[0], shard_path)) for start in tqdm.tqdm(range(0, len(query_ids), args.search_batch_size), desc="search shard {}".format(shard_idx), mininterval=5): batch_query_embed = query_embeds[start:(start + args.search_batch_size)] batch_query_ids = query_ids[start:(start + args.search_batch_size)] batch_score = torch.mm(batch_query_embed, shard_psg_embed.t()) batch_sorted_score, batch_sorted_indices = torch.topk(batch_score, k=args.search_topk, dim=-1, largest=True) for batch_idx, query_id in enumerate(batch_query_ids): cur_scores = batch_sorted_score[batch_idx].cpu().tolist() cur_indices = [idx + psg_idx_offset for idx in batch_sorted_indices[batch_idx].cpu().tolist()] query_id_to_topk[query_id] += list(zip(cur_scores, cur_indices)) query_id_to_topk[query_id] = sorted(query_id_to_topk[query_id], key=lambda t: (-t[0], t[1])) query_id_to_topk[query_id] = query_id_to_topk[query_id][:args.search_topk] psg_idx_offset += shard_psg_embed.shape[0] out_path = _get_topk_result_save_path(worker_idx=gpu_idx) with open(out_path, 'w', encoding='utf-8') as writer: for query_id in query_id_to_text: for rank, (score, doc_id) in enumerate(query_id_to_topk[query_id]): writer.write('{}\t{}\t{}\t{}\n'.format(query_id, doc_id, rank + 1, round(score, 4))) logger.info('Write scores to {} done'.format(out_path)) def _compute_and_save_metrics(worker_cnt: int): preds: Dict[str, List[ScoredDoc]] = {} for worker_idx in range(worker_cnt): path = _get_topk_result_save_path(worker_idx) preds.update(load_msmarco_predictions(path)) out_path = os.path.join(args.search_out_dir, '{}.msmarco.txt'.format(args.search_split)) save_preds_to_msmarco_format(preds, out_path) logger.info('Merge done: save {} predictions to {}'.format(len(preds), out_path)) path_qrels = os.path.join(args.data_dir, '{}_qrels.txt'.format(args.search_split)) if os.path.exists(path_qrels): qrels = load_qrels(path=path_qrels) all_metrics = trec_eval(qrels=qrels, predictions=preds) all_metrics['mrr'] = compute_mrr(qrels=qrels, predictions=preds) logger.info('{} trec metrics = {}'.format(args.search_split, json.dumps(all_metrics, ensure_ascii=False, indent=4))) save_json_to_file(all_metrics, os.path.join(args.search_out_dir, 'metrics_{}.json'.format(args.search_split))) else: logger.warning('No qrels found for {}'.format(args.search_split)) # do some cleanup for worker_idx in range(worker_cnt): path = _get_topk_result_save_path(worker_idx) os.remove(path) def _batch_search_queries(): logger.info('Args={}'.format(str(args))) gpu_count = torch.cuda.device_count() if gpu_count == 0: logger.error('No gpu available') return logger.info('Use {} gpus'.format(gpu_count)) torch.multiprocessing.spawn(_worker_batch_search, args=(), nprocs=gpu_count) logger.info('Done batch search queries') _compute_and_save_metrics(gpu_count) if __name__ == '__main__': _batch_search_queries()