Spaces:
Sleeping
Sleeping
File size: 7,958 Bytes
6fc683c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
import json
import os
import glob
import tqdm
import torch
from contextlib import nullcontext
from torch.utils.data import DataLoader
from functools import partial
from collections import defaultdict
from datasets import Dataset
from typing import Dict, List, Tuple
from transformers.file_utils import PaddingStrategy
from transformers import (
AutoTokenizer,
PreTrainedTokenizerFast,
DataCollatorWithPadding,
HfArgumentParser,
BatchEncoding
)
from config import Arguments
from logger_config import logger
from utils import move_to_cuda, save_json_to_file
from metrics import compute_mrr, trec_eval, ScoredDoc
from data_utils import load_queries, load_qrels, load_msmarco_predictions, save_preds_to_msmarco_format
from models import BiencoderModelForInference, BiencoderOutput
parser = HfArgumentParser((Arguments,))
args: Arguments = parser.parse_args_into_dataclasses()[0]
assert os.path.exists(args.encode_save_dir)
def _get_all_shards_path() -> List[str]:
path_list = glob.glob('{}/shard_*_*'.format(args.encode_save_dir))
assert len(path_list) > 0
def _parse_worker_idx_shard_idx(p: str) -> Tuple:
worker_idx, shard_idx = [int(f) for f in os.path.basename(p).split('_')[-2:]]
return worker_idx, shard_idx
path_list = sorted(path_list, key=lambda path: _parse_worker_idx_shard_idx(path))
logger.info('Embeddings path list: {}'.format(path_list))
return path_list
def _get_topk_result_save_path(worker_idx: int) -> str:
return '{}/top{}_{}_{}.txt'.format(args.search_out_dir, args.search_topk, args.search_split, worker_idx)
def _query_transform_func(tokenizer: PreTrainedTokenizerFast,
examples: Dict[str, List]) -> BatchEncoding:
batch_dict = tokenizer(examples['query'],
max_length=args.q_max_len,
padding=PaddingStrategy.DO_NOT_PAD,
truncation=True)
return batch_dict
@torch.no_grad()
def _worker_encode_queries(gpu_idx: int) -> Tuple:
# fail fast if shard does not exist
_get_all_shards_path()
query_id_to_text = load_queries(path=os.path.join(args.data_dir, '{}_queries.tsv'.format(args.search_split)),
task_type=args.task_type)
query_ids = sorted(list(query_id_to_text.keys()))
queries = [query_id_to_text[query_id] for query_id in query_ids]
dataset = Dataset.from_dict({'query_id': query_ids,
'query': queries})
dataset = dataset.shard(num_shards=torch.cuda.device_count(),
index=gpu_idx,
contiguous=True)
# only keep data for current shard
query_ids = dataset['query_id']
query_id_to_text = {qid: query_id_to_text[qid] for qid in query_ids}
logger.info('GPU {} needs to process {} examples'.format(gpu_idx, len(dataset)))
torch.cuda.set_device(gpu_idx)
tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(args.model_name_or_path)
model: BiencoderModelForInference = BiencoderModelForInference.build(args)
model.eval()
model.cuda()
dataset.set_transform(partial(_query_transform_func, tokenizer))
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
data_loader = DataLoader(
dataset,
batch_size=512,
shuffle=False,
drop_last=False,
num_workers=args.dataloader_num_workers,
collate_fn=data_collator,
pin_memory=True)
encoded_embeds = []
for batch_dict in tqdm.tqdm(data_loader, desc='query encoding', mininterval=5):
batch_dict = move_to_cuda(batch_dict)
with torch.cuda.amp.autocast() if args.fp16 else nullcontext():
outputs: BiencoderOutput = model(query=batch_dict, passage=None)
encoded_embeds.append(outputs.q_reps)
query_embeds = torch.cat(encoded_embeds, dim=0)
logger.info('Done query encoding for worker {}'.format(gpu_idx))
return query_embeds, query_ids, query_id_to_text
@torch.no_grad()
def _worker_batch_search(gpu_idx: int):
embeds_path_list = _get_all_shards_path()
query_embeds, query_ids, query_id_to_text = _worker_encode_queries(gpu_idx)
assert query_embeds.shape[0] == len(query_ids), '{} != {}'.format(query_embeds.shape[0], len(query_ids))
query_id_to_topk = defaultdict(list)
psg_idx_offset = 0
for shard_idx, shard_path in enumerate(embeds_path_list):
shard_psg_embed = torch.load(shard_path, map_location=lambda storage, loc: storage).to(query_embeds.device)
logger.info('Load {} passage embeddings from {}'.format(shard_psg_embed.shape[0], shard_path))
for start in tqdm.tqdm(range(0, len(query_ids), args.search_batch_size),
desc="search shard {}".format(shard_idx),
mininterval=5):
batch_query_embed = query_embeds[start:(start + args.search_batch_size)]
batch_query_ids = query_ids[start:(start + args.search_batch_size)]
batch_score = torch.mm(batch_query_embed, shard_psg_embed.t())
batch_sorted_score, batch_sorted_indices = torch.topk(batch_score, k=args.search_topk, dim=-1, largest=True)
for batch_idx, query_id in enumerate(batch_query_ids):
cur_scores = batch_sorted_score[batch_idx].cpu().tolist()
cur_indices = [idx + psg_idx_offset for idx in batch_sorted_indices[batch_idx].cpu().tolist()]
query_id_to_topk[query_id] += list(zip(cur_scores, cur_indices))
query_id_to_topk[query_id] = sorted(query_id_to_topk[query_id], key=lambda t: (-t[0], t[1]))
query_id_to_topk[query_id] = query_id_to_topk[query_id][:args.search_topk]
psg_idx_offset += shard_psg_embed.shape[0]
out_path = _get_topk_result_save_path(worker_idx=gpu_idx)
with open(out_path, 'w', encoding='utf-8') as writer:
for query_id in query_id_to_text:
for rank, (score, doc_id) in enumerate(query_id_to_topk[query_id]):
writer.write('{}\t{}\t{}\t{}\n'.format(query_id, doc_id, rank + 1, round(score, 4)))
logger.info('Write scores to {} done'.format(out_path))
def _compute_and_save_metrics(worker_cnt: int):
preds: Dict[str, List[ScoredDoc]] = {}
for worker_idx in range(worker_cnt):
path = _get_topk_result_save_path(worker_idx)
preds.update(load_msmarco_predictions(path))
out_path = os.path.join(args.search_out_dir, '{}.msmarco.txt'.format(args.search_split))
save_preds_to_msmarco_format(preds, out_path)
logger.info('Merge done: save {} predictions to {}'.format(len(preds), out_path))
path_qrels = os.path.join(args.data_dir, '{}_qrels.txt'.format(args.search_split))
if os.path.exists(path_qrels):
qrels = load_qrels(path=path_qrels)
all_metrics = trec_eval(qrels=qrels, predictions=preds)
all_metrics['mrr'] = compute_mrr(qrels=qrels, predictions=preds)
logger.info('{} trec metrics = {}'.format(args.search_split, json.dumps(all_metrics, ensure_ascii=False, indent=4)))
save_json_to_file(all_metrics, os.path.join(args.search_out_dir, 'metrics_{}.json'.format(args.search_split)))
else:
logger.warning('No qrels found for {}'.format(args.search_split))
# do some cleanup
for worker_idx in range(worker_cnt):
path = _get_topk_result_save_path(worker_idx)
os.remove(path)
def _batch_search_queries():
logger.info('Args={}'.format(str(args)))
gpu_count = torch.cuda.device_count()
if gpu_count == 0:
logger.error('No gpu available')
return
logger.info('Use {} gpus'.format(gpu_count))
torch.multiprocessing.spawn(_worker_batch_search, args=(), nprocs=gpu_count)
logger.info('Done batch search queries')
_compute_and_save_metrics(gpu_count)
if __name__ == '__main__':
_batch_search_queries()
|