Spaces:
Sleeping
Sleeping
File size: 4,261 Bytes
6fc683c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import os
import tqdm
import torch
from contextlib import nullcontext
from torch.utils.data import DataLoader
from functools import partial
from datasets import load_dataset
from typing import Dict, List
from transformers.file_utils import PaddingStrategy
from transformers import (
AutoTokenizer,
PreTrainedTokenizerFast,
DataCollatorWithPadding,
HfArgumentParser,
BatchEncoding
)
from config import Arguments
from logger_config import logger
from utils import move_to_cuda
from models import BiencoderModelForInference, BiencoderOutput
parser = HfArgumentParser((Arguments,))
args: Arguments = parser.parse_args_into_dataclasses()[0]
def _psg_transform_func(tokenizer: PreTrainedTokenizerFast,
examples: Dict[str, List]) -> BatchEncoding:
batch_dict = tokenizer(examples['title'],
text_pair=examples['contents'],
max_length=args.p_max_len,
padding=PaddingStrategy.DO_NOT_PAD,
truncation=True)
# for co-Condenser reproduction purpose only
if args.model_name_or_path.startswith('Luyu/'):
del batch_dict['token_type_ids']
return batch_dict
@torch.no_grad()
def _worker_encode_passages(gpu_idx: int):
def _get_out_path(shard_idx: int = 0) -> str:
return '{}/shard_{}_{}'.format(args.encode_save_dir, gpu_idx, shard_idx)
if os.path.exists(_get_out_path(0)):
logger.error('{} already exists, will skip encoding'.format(_get_out_path(0)))
return
dataset = load_dataset('json', data_files=args.encode_in_path)['train']
if args.dry_run:
dataset = dataset.select(range(4096))
dataset = dataset.shard(num_shards=torch.cuda.device_count(),
index=gpu_idx,
contiguous=True)
logger.info('GPU {} needs to process {} examples'.format(gpu_idx, len(dataset)))
torch.cuda.set_device(gpu_idx)
tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(args.model_name_or_path)
model: BiencoderModelForInference = BiencoderModelForInference.build(args)
model.eval()
model.cuda()
dataset.set_transform(partial(_psg_transform_func, tokenizer))
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8 if args.fp16 else None)
data_loader = DataLoader(
dataset,
batch_size=args.encode_batch_size,
shuffle=False,
drop_last=False,
num_workers=args.dataloader_num_workers,
collate_fn=data_collator,
pin_memory=True)
num_encoded_docs, encoded_embeds, cur_shard_idx = 0, [], 0
for batch_dict in tqdm.tqdm(data_loader, desc='passage encoding', mininterval=8):
batch_dict = move_to_cuda(batch_dict)
with torch.cuda.amp.autocast() if args.fp16 else nullcontext():
outputs: BiencoderOutput = model(query=None, passage=batch_dict)
encoded_embeds.append(outputs.p_reps.cpu())
num_encoded_docs += outputs.p_reps.shape[0]
if num_encoded_docs >= args.encode_shard_size:
out_path = _get_out_path(cur_shard_idx)
concat_embeds = torch.cat(encoded_embeds, dim=0)
logger.info('GPU {} save {} embeds to {}'.format(gpu_idx, concat_embeds.shape[0], out_path))
torch.save(concat_embeds, out_path)
cur_shard_idx += 1
num_encoded_docs = 0
encoded_embeds.clear()
if num_encoded_docs > 0:
out_path = _get_out_path(cur_shard_idx)
concat_embeds = torch.cat(encoded_embeds, dim=0)
logger.info('GPU {} save {} embeds to {}'.format(gpu_idx, concat_embeds.shape[0], out_path))
torch.save(concat_embeds, out_path)
logger.info('Done computing score for worker {}'.format(gpu_idx))
def _batch_encode_passages():
logger.info('Args={}'.format(str(args)))
gpu_count = torch.cuda.device_count()
if gpu_count == 0:
logger.error('No gpu available')
return
logger.info('Use {} gpus'.format(gpu_count))
torch.multiprocessing.spawn(_worker_encode_passages, args=(), nprocs=gpu_count)
logger.info('Done batch encode passages')
if __name__ == '__main__':
_batch_encode_passages()
|