Spaces:
Sleeping
Sleeping
File size: 5,122 Bytes
6fc683c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import os
import tqdm
import torch
from contextlib import nullcontext
from torch.utils.data import DataLoader
from functools import partial
from datasets import Dataset
from typing import Dict, List
from transformers.file_utils import PaddingStrategy
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import (
AutoTokenizer,
PreTrainedTokenizerFast,
DataCollatorWithPadding,
HfArgumentParser,
BatchEncoding
)
from config import Arguments
from logger_config import logger
from utils import move_to_cuda
from models import RerankerForInference
from data_utils import load_msmarco_predictions, load_corpus, load_queries, \
merge_rerank_predictions, get_rerank_shard_path
parser = HfArgumentParser((Arguments,))
args: Arguments = parser.parse_args_into_dataclasses()[0]
def _rerank_transform_func(tokenizer: PreTrainedTokenizerFast,
corpus: Dataset,
queries: Dict[str, str],
examples: Dict[str, List]) -> BatchEncoding:
input_docs: List[str] = []
# ATTENTION: this code should be consistent with RerankDataLoader
for doc_id in examples['doc_id']:
doc_id = int(doc_id)
prefix = ''
if corpus[doc_id].get('title', ''):
prefix = corpus[doc_id]['title'] + ': '
input_docs.append(prefix + corpus[doc_id]['contents'])
input_queries = [queries[query_id] for query_id in examples['query_id']]
batch_dict = tokenizer(input_queries,
text_pair=input_docs,
max_length=args.rerank_max_length,
padding=PaddingStrategy.DO_NOT_PAD,
truncation=True)
return batch_dict
@torch.no_grad()
def _worker_compute_reranker_score(gpu_idx: int):
preds = load_msmarco_predictions(args.rerank_in_path)
query_ids = sorted(list(preds.keys()))
qid_pid = []
for query_id in tqdm.tqdm(query_ids, desc='load qid-pid', mininterval=2):
qid_pid += [(scored_doc.qid, scored_doc.pid) for scored_doc in preds[query_id]
if scored_doc.rank <= args.rerank_depth]
dataset = Dataset.from_dict({'query_id': [t[0] for t in qid_pid],
'doc_id': [t[1] for t in qid_pid]})
dataset = dataset.shard(num_shards=torch.cuda.device_count(),
index=gpu_idx,
contiguous=True)
logger.info('GPU {} needs to process {} examples'.format(gpu_idx, len(dataset)))
torch.cuda.set_device(gpu_idx)
query_ids, doc_ids = dataset['query_id'], dataset['doc_id']
assert len(dataset) == len(query_ids)
tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(args.model_name_or_path)
model: RerankerForInference = RerankerForInference.from_pretrained(args.model_name_or_path)
model.eval()
model.cuda()
corpus: Dataset = load_corpus(path=os.path.join(args.data_dir, 'passages.jsonl.gz'))
queries = load_queries(path='{}/{}_queries.tsv'.format(args.data_dir, args.rerank_split),
task_type=args.task_type)
dataset.set_transform(partial(_rerank_transform_func, tokenizer, corpus, queries))
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8 if args.fp16 else None)
data_loader = DataLoader(
dataset,
batch_size=args.rerank_batch_size,
shuffle=False,
drop_last=False,
num_workers=args.dataloader_num_workers,
collate_fn=data_collator,
pin_memory=True)
scores = []
for batch_dict in tqdm.tqdm(data_loader, desc='passage rerank', mininterval=5):
batch_dict = move_to_cuda(batch_dict)
with torch.cuda.amp.autocast() if args.fp16 else nullcontext():
outputs: SequenceClassifierOutput = model(batch_dict)
scores.append(outputs.logits.squeeze(dim=-1).cpu())
assert len(scores[-1].shape) == 1
all_scores = torch.cat(scores, dim=-1)
assert all_scores.shape[0] == len(query_ids), '{} != {}'.format(all_scores.shape[0], len(query_ids))
all_scores = all_scores.tolist()
with open(get_rerank_shard_path(args, gpu_idx), 'w', encoding='utf-8') as writer:
for idx in range(len(query_ids)):
# dummy rank, since a query may be split across different workers
writer.write('{}\t{}\t{}\t{}\n'.format(query_ids[idx], doc_ids[idx], -1, round(all_scores[idx], 5)))
logger.info('Done computing rerank score for worker {}'.format(gpu_idx))
def _batch_compute_reranker_score():
logger.info('Args={}'.format(str(args)))
gpu_count = torch.cuda.device_count()
if gpu_count == 0:
logger.error('No gpu available')
return
logger.info('Use {} gpus'.format(gpu_count))
torch.multiprocessing.spawn(_worker_compute_reranker_score, args=(), nprocs=gpu_count)
logger.info('Done batch compute rerank score')
merge_rerank_predictions(args, gpu_count)
logger.info('Done merge results')
if __name__ == '__main__':
_batch_compute_reranker_score()
|