File size: 5,122 Bytes
6fc683c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
import tqdm
import torch

from contextlib import nullcontext
from torch.utils.data import DataLoader
from functools import partial
from datasets import Dataset
from typing import Dict, List
from transformers.file_utils import PaddingStrategy
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import (
    AutoTokenizer,
    PreTrainedTokenizerFast,
    DataCollatorWithPadding,
    HfArgumentParser,
    BatchEncoding
)

from config import Arguments
from logger_config import logger
from utils import move_to_cuda
from models import RerankerForInference
from data_utils import load_msmarco_predictions, load_corpus, load_queries, \
    merge_rerank_predictions, get_rerank_shard_path

parser = HfArgumentParser((Arguments,))
args: Arguments = parser.parse_args_into_dataclasses()[0]


def _rerank_transform_func(tokenizer: PreTrainedTokenizerFast,
                           corpus: Dataset,
                           queries: Dict[str, str],
                           examples: Dict[str, List]) -> BatchEncoding:
    input_docs: List[str] = []

    # ATTENTION: this code should be consistent with RerankDataLoader
    for doc_id in examples['doc_id']:
        doc_id = int(doc_id)
        prefix = ''
        if corpus[doc_id].get('title', ''):
            prefix = corpus[doc_id]['title'] + ': '
        input_docs.append(prefix + corpus[doc_id]['contents'])

    input_queries = [queries[query_id] for query_id in examples['query_id']]
    batch_dict = tokenizer(input_queries,
                           text_pair=input_docs,
                           max_length=args.rerank_max_length,
                           padding=PaddingStrategy.DO_NOT_PAD,
                           truncation=True)

    return batch_dict


@torch.no_grad()
def _worker_compute_reranker_score(gpu_idx: int):
    preds = load_msmarco_predictions(args.rerank_in_path)
    query_ids = sorted(list(preds.keys()))
    qid_pid = []
    for query_id in tqdm.tqdm(query_ids, desc='load qid-pid', mininterval=2):
        qid_pid += [(scored_doc.qid, scored_doc.pid) for scored_doc in preds[query_id]
                    if scored_doc.rank <= args.rerank_depth]

    dataset = Dataset.from_dict({'query_id': [t[0] for t in qid_pid],
                                 'doc_id': [t[1] for t in qid_pid]})
    dataset = dataset.shard(num_shards=torch.cuda.device_count(),
                            index=gpu_idx,
                            contiguous=True)

    logger.info('GPU {} needs to process {} examples'.format(gpu_idx, len(dataset)))
    torch.cuda.set_device(gpu_idx)

    query_ids, doc_ids = dataset['query_id'], dataset['doc_id']
    assert len(dataset) == len(query_ids)

    tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(args.model_name_or_path)
    model: RerankerForInference = RerankerForInference.from_pretrained(args.model_name_or_path)
    model.eval()
    model.cuda()

    corpus: Dataset = load_corpus(path=os.path.join(args.data_dir, 'passages.jsonl.gz'))
    queries = load_queries(path='{}/{}_queries.tsv'.format(args.data_dir, args.rerank_split),
                           task_type=args.task_type)
    dataset.set_transform(partial(_rerank_transform_func, tokenizer, corpus, queries))

    data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8 if args.fp16 else None)
    data_loader = DataLoader(
        dataset,
        batch_size=args.rerank_batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=args.dataloader_num_workers,
        collate_fn=data_collator,
        pin_memory=True)

    scores = []
    for batch_dict in tqdm.tqdm(data_loader, desc='passage rerank', mininterval=5):
        batch_dict = move_to_cuda(batch_dict)

        with torch.cuda.amp.autocast() if args.fp16 else nullcontext():
            outputs: SequenceClassifierOutput = model(batch_dict)
        scores.append(outputs.logits.squeeze(dim=-1).cpu())
        assert len(scores[-1].shape) == 1

    all_scores = torch.cat(scores, dim=-1)
    assert all_scores.shape[0] == len(query_ids), '{} != {}'.format(all_scores.shape[0], len(query_ids))
    all_scores = all_scores.tolist()

    with open(get_rerank_shard_path(args, gpu_idx), 'w', encoding='utf-8') as writer:
        for idx in range(len(query_ids)):
            # dummy rank, since a query may be split across different workers
            writer.write('{}\t{}\t{}\t{}\n'.format(query_ids[idx], doc_ids[idx], -1, round(all_scores[idx], 5)))

    logger.info('Done computing rerank score for worker {}'.format(gpu_idx))


def _batch_compute_reranker_score():
    logger.info('Args={}'.format(str(args)))
    gpu_count = torch.cuda.device_count()
    if gpu_count == 0:
        logger.error('No gpu available')
        return

    logger.info('Use {} gpus'.format(gpu_count))
    torch.multiprocessing.spawn(_worker_compute_reranker_score, args=(), nprocs=gpu_count)
    logger.info('Done batch compute rerank score')

    merge_rerank_predictions(args, gpu_count)
    logger.info('Done merge results')


if __name__ == '__main__':
    _batch_compute_reranker_score()