File size: 4,261 Bytes
6fc683c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import tqdm
import torch

from contextlib import nullcontext
from torch.utils.data import DataLoader
from functools import partial
from datasets import load_dataset
from typing import Dict, List
from transformers.file_utils import PaddingStrategy
from transformers import (
    AutoTokenizer,
    PreTrainedTokenizerFast,
    DataCollatorWithPadding,
    HfArgumentParser,
    BatchEncoding
)

from config import Arguments
from logger_config import logger
from utils import move_to_cuda
from models import BiencoderModelForInference, BiencoderOutput

parser = HfArgumentParser((Arguments,))
args: Arguments = parser.parse_args_into_dataclasses()[0]


def _psg_transform_func(tokenizer: PreTrainedTokenizerFast,
                        examples: Dict[str, List]) -> BatchEncoding:
    batch_dict = tokenizer(examples['title'],
                           text_pair=examples['contents'],
                           max_length=args.p_max_len,
                           padding=PaddingStrategy.DO_NOT_PAD,
                           truncation=True)
    # for co-Condenser reproduction purpose only
    if args.model_name_or_path.startswith('Luyu/'):
        del batch_dict['token_type_ids']

    return batch_dict


@torch.no_grad()
def _worker_encode_passages(gpu_idx: int):
    def _get_out_path(shard_idx: int = 0) -> str:
        return '{}/shard_{}_{}'.format(args.encode_save_dir, gpu_idx, shard_idx)

    if os.path.exists(_get_out_path(0)):
        logger.error('{} already exists, will skip encoding'.format(_get_out_path(0)))
        return

    dataset = load_dataset('json', data_files=args.encode_in_path)['train']
    if args.dry_run:
        dataset = dataset.select(range(4096))
    dataset = dataset.shard(num_shards=torch.cuda.device_count(),
                            index=gpu_idx,
                            contiguous=True)

    logger.info('GPU {} needs to process {} examples'.format(gpu_idx, len(dataset)))
    torch.cuda.set_device(gpu_idx)

    tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(args.model_name_or_path)
    model: BiencoderModelForInference = BiencoderModelForInference.build(args)
    model.eval()
    model.cuda()

    dataset.set_transform(partial(_psg_transform_func, tokenizer))

    data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8 if args.fp16 else None)
    data_loader = DataLoader(
        dataset,
        batch_size=args.encode_batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=args.dataloader_num_workers,
        collate_fn=data_collator,
        pin_memory=True)

    num_encoded_docs, encoded_embeds, cur_shard_idx = 0, [], 0
    for batch_dict in tqdm.tqdm(data_loader, desc='passage encoding', mininterval=8):
        batch_dict = move_to_cuda(batch_dict)

        with torch.cuda.amp.autocast() if args.fp16 else nullcontext():
            outputs: BiencoderOutput = model(query=None, passage=batch_dict)
        encoded_embeds.append(outputs.p_reps.cpu())
        num_encoded_docs += outputs.p_reps.shape[0]

        if num_encoded_docs >= args.encode_shard_size:
            out_path = _get_out_path(cur_shard_idx)
            concat_embeds = torch.cat(encoded_embeds, dim=0)
            logger.info('GPU {} save {} embeds to {}'.format(gpu_idx, concat_embeds.shape[0], out_path))
            torch.save(concat_embeds, out_path)

            cur_shard_idx += 1
            num_encoded_docs = 0
            encoded_embeds.clear()

    if num_encoded_docs > 0:
        out_path = _get_out_path(cur_shard_idx)
        concat_embeds = torch.cat(encoded_embeds, dim=0)
        logger.info('GPU {} save {} embeds to {}'.format(gpu_idx, concat_embeds.shape[0], out_path))
        torch.save(concat_embeds, out_path)

    logger.info('Done computing score for worker {}'.format(gpu_idx))


def _batch_encode_passages():
    logger.info('Args={}'.format(str(args)))
    gpu_count = torch.cuda.device_count()
    if gpu_count == 0:
        logger.error('No gpu available')
        return

    logger.info('Use {} gpus'.format(gpu_count))
    torch.multiprocessing.spawn(_worker_encode_passages, args=(), nprocs=gpu_count)
    logger.info('Done batch encode passages')


if __name__ == '__main__':
    _batch_encode_passages()