Spaces:
Sleeping
Sleeping
import os | |
import tqdm | |
import torch | |
from contextlib import nullcontext | |
from torch.utils.data import DataLoader | |
from functools import partial | |
from datasets import Dataset | |
from typing import Dict, List | |
from transformers.file_utils import PaddingStrategy | |
from transformers.modeling_outputs import SequenceClassifierOutput | |
from transformers import ( | |
AutoTokenizer, | |
PreTrainedTokenizerFast, | |
DataCollatorWithPadding, | |
HfArgumentParser, | |
BatchEncoding | |
) | |
from config import Arguments | |
from logger_config import logger | |
from utils import move_to_cuda | |
from models import RerankerForInference | |
from data_utils import load_msmarco_predictions, load_corpus, load_queries, \ | |
merge_rerank_predictions, get_rerank_shard_path | |
parser = HfArgumentParser((Arguments,)) | |
args: Arguments = parser.parse_args_into_dataclasses()[0] | |
def _rerank_transform_func(tokenizer: PreTrainedTokenizerFast, | |
corpus: Dataset, | |
queries: Dict[str, str], | |
examples: Dict[str, List]) -> BatchEncoding: | |
input_docs: List[str] = [] | |
# ATTENTION: this code should be consistent with RerankDataLoader | |
for doc_id in examples['doc_id']: | |
doc_id = int(doc_id) | |
prefix = '' | |
if corpus[doc_id].get('title', ''): | |
prefix = corpus[doc_id]['title'] + ': ' | |
input_docs.append(prefix + corpus[doc_id]['contents']) | |
input_queries = [queries[query_id] for query_id in examples['query_id']] | |
batch_dict = tokenizer(input_queries, | |
text_pair=input_docs, | |
max_length=args.rerank_max_length, | |
padding=PaddingStrategy.DO_NOT_PAD, | |
truncation=True) | |
return batch_dict | |
def _worker_compute_reranker_score(gpu_idx: int): | |
preds = load_msmarco_predictions(args.rerank_in_path) | |
query_ids = sorted(list(preds.keys())) | |
qid_pid = [] | |
for query_id in tqdm.tqdm(query_ids, desc='load qid-pid', mininterval=2): | |
qid_pid += [(scored_doc.qid, scored_doc.pid) for scored_doc in preds[query_id] | |
if scored_doc.rank <= args.rerank_depth] | |
dataset = Dataset.from_dict({'query_id': [t[0] for t in qid_pid], | |
'doc_id': [t[1] for t in qid_pid]}) | |
dataset = dataset.shard(num_shards=torch.cuda.device_count(), | |
index=gpu_idx, | |
contiguous=True) | |
logger.info('GPU {} needs to process {} examples'.format(gpu_idx, len(dataset))) | |
torch.cuda.set_device(gpu_idx) | |
query_ids, doc_ids = dataset['query_id'], dataset['doc_id'] | |
assert len(dataset) == len(query_ids) | |
tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(args.model_name_or_path) | |
model: RerankerForInference = RerankerForInference.from_pretrained(args.model_name_or_path) | |
model.eval() | |
model.cuda() | |
corpus: Dataset = load_corpus(path=os.path.join(args.data_dir, 'passages.jsonl.gz')) | |
queries = load_queries(path='{}/{}_queries.tsv'.format(args.data_dir, args.rerank_split), | |
task_type=args.task_type) | |
dataset.set_transform(partial(_rerank_transform_func, tokenizer, corpus, queries)) | |
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8 if args.fp16 else None) | |
data_loader = DataLoader( | |
dataset, | |
batch_size=args.rerank_batch_size, | |
shuffle=False, | |
drop_last=False, | |
num_workers=args.dataloader_num_workers, | |
collate_fn=data_collator, | |
pin_memory=True) | |
scores = [] | |
for batch_dict in tqdm.tqdm(data_loader, desc='passage rerank', mininterval=5): | |
batch_dict = move_to_cuda(batch_dict) | |
with torch.cuda.amp.autocast() if args.fp16 else nullcontext(): | |
outputs: SequenceClassifierOutput = model(batch_dict) | |
scores.append(outputs.logits.squeeze(dim=-1).cpu()) | |
assert len(scores[-1].shape) == 1 | |
all_scores = torch.cat(scores, dim=-1) | |
assert all_scores.shape[0] == len(query_ids), '{} != {}'.format(all_scores.shape[0], len(query_ids)) | |
all_scores = all_scores.tolist() | |
with open(get_rerank_shard_path(args, gpu_idx), 'w', encoding='utf-8') as writer: | |
for idx in range(len(query_ids)): | |
# dummy rank, since a query may be split across different workers | |
writer.write('{}\t{}\t{}\t{}\n'.format(query_ids[idx], doc_ids[idx], -1, round(all_scores[idx], 5))) | |
logger.info('Done computing rerank score for worker {}'.format(gpu_idx)) | |
def _batch_compute_reranker_score(): | |
logger.info('Args={}'.format(str(args))) | |
gpu_count = torch.cuda.device_count() | |
if gpu_count == 0: | |
logger.error('No gpu available') | |
return | |
logger.info('Use {} gpus'.format(gpu_count)) | |
torch.multiprocessing.spawn(_worker_compute_reranker_score, args=(), nprocs=gpu_count) | |
logger.info('Done batch compute rerank score') | |
merge_rerank_predictions(args, gpu_count) | |
logger.info('Done merge results') | |
if __name__ == '__main__': | |
_batch_compute_reranker_score() | |