Tzktz's picture
Upload 7664 files
6fc683c verified
import os
import tqdm
import torch
from contextlib import nullcontext
from torch.utils.data import DataLoader
from functools import partial
from datasets import Dataset
from typing import Dict, List
from transformers.file_utils import PaddingStrategy
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import (
AutoTokenizer,
PreTrainedTokenizerFast,
DataCollatorWithPadding,
HfArgumentParser,
BatchEncoding
)
from config import Arguments
from logger_config import logger
from utils import move_to_cuda
from models import RerankerForInference
from data_utils import load_msmarco_predictions, load_corpus, load_queries, \
merge_rerank_predictions, get_rerank_shard_path
parser = HfArgumentParser((Arguments,))
args: Arguments = parser.parse_args_into_dataclasses()[0]
def _rerank_transform_func(tokenizer: PreTrainedTokenizerFast,
corpus: Dataset,
queries: Dict[str, str],
examples: Dict[str, List]) -> BatchEncoding:
input_docs: List[str] = []
# ATTENTION: this code should be consistent with RerankDataLoader
for doc_id in examples['doc_id']:
doc_id = int(doc_id)
prefix = ''
if corpus[doc_id].get('title', ''):
prefix = corpus[doc_id]['title'] + ': '
input_docs.append(prefix + corpus[doc_id]['contents'])
input_queries = [queries[query_id] for query_id in examples['query_id']]
batch_dict = tokenizer(input_queries,
text_pair=input_docs,
max_length=args.rerank_max_length,
padding=PaddingStrategy.DO_NOT_PAD,
truncation=True)
return batch_dict
@torch.no_grad()
def _worker_compute_reranker_score(gpu_idx: int):
preds = load_msmarco_predictions(args.rerank_in_path)
query_ids = sorted(list(preds.keys()))
qid_pid = []
for query_id in tqdm.tqdm(query_ids, desc='load qid-pid', mininterval=2):
qid_pid += [(scored_doc.qid, scored_doc.pid) for scored_doc in preds[query_id]
if scored_doc.rank <= args.rerank_depth]
dataset = Dataset.from_dict({'query_id': [t[0] for t in qid_pid],
'doc_id': [t[1] for t in qid_pid]})
dataset = dataset.shard(num_shards=torch.cuda.device_count(),
index=gpu_idx,
contiguous=True)
logger.info('GPU {} needs to process {} examples'.format(gpu_idx, len(dataset)))
torch.cuda.set_device(gpu_idx)
query_ids, doc_ids = dataset['query_id'], dataset['doc_id']
assert len(dataset) == len(query_ids)
tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(args.model_name_or_path)
model: RerankerForInference = RerankerForInference.from_pretrained(args.model_name_or_path)
model.eval()
model.cuda()
corpus: Dataset = load_corpus(path=os.path.join(args.data_dir, 'passages.jsonl.gz'))
queries = load_queries(path='{}/{}_queries.tsv'.format(args.data_dir, args.rerank_split),
task_type=args.task_type)
dataset.set_transform(partial(_rerank_transform_func, tokenizer, corpus, queries))
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8 if args.fp16 else None)
data_loader = DataLoader(
dataset,
batch_size=args.rerank_batch_size,
shuffle=False,
drop_last=False,
num_workers=args.dataloader_num_workers,
collate_fn=data_collator,
pin_memory=True)
scores = []
for batch_dict in tqdm.tqdm(data_loader, desc='passage rerank', mininterval=5):
batch_dict = move_to_cuda(batch_dict)
with torch.cuda.amp.autocast() if args.fp16 else nullcontext():
outputs: SequenceClassifierOutput = model(batch_dict)
scores.append(outputs.logits.squeeze(dim=-1).cpu())
assert len(scores[-1].shape) == 1
all_scores = torch.cat(scores, dim=-1)
assert all_scores.shape[0] == len(query_ids), '{} != {}'.format(all_scores.shape[0], len(query_ids))
all_scores = all_scores.tolist()
with open(get_rerank_shard_path(args, gpu_idx), 'w', encoding='utf-8') as writer:
for idx in range(len(query_ids)):
# dummy rank, since a query may be split across different workers
writer.write('{}\t{}\t{}\t{}\n'.format(query_ids[idx], doc_ids[idx], -1, round(all_scores[idx], 5)))
logger.info('Done computing rerank score for worker {}'.format(gpu_idx))
def _batch_compute_reranker_score():
logger.info('Args={}'.format(str(args)))
gpu_count = torch.cuda.device_count()
if gpu_count == 0:
logger.error('No gpu available')
return
logger.info('Use {} gpus'.format(gpu_count))
torch.multiprocessing.spawn(_worker_compute_reranker_score, args=(), nprocs=gpu_count)
logger.info('Done batch compute rerank score')
merge_rerank_predictions(args, gpu_count)
logger.info('Done merge results')
if __name__ == '__main__':
_batch_compute_reranker_score()