Spaces:
Sleeping
Sleeping
File size: 3,409 Bytes
6fc683c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import argparse
import os
import torch
from fairseq.data import (FairseqDataset, PrependTokenDataset,
TokenBlockDataset, TruncateDataset, data_utils)
from fairseq.data.indexed_dataset import make_builder
from tqdm import tqdm
from transformers import AutoTokenizer
class IndexDataset(FairseqDataset):
def __init__(self, indices):
self.indices = indices
self._sizes = [len(i) for i in indices]
@property
def sizes(self):
return self._sizes
def size(self, index):
item = self.__getitem__(index)
return len(item)
def __getitem__(self, index):
item = self.indices[index]
item = torch.LongTensor(item)
return item
def __len__(self):
return len(self.indices)
def collater(self, samples):
raise NotImplementedError
def build_tokenizer(args):
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
return tokenizer
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default="CZWin32768/xlm-align")
parser.add_argument("--input", type=str, default="")
parser.add_argument("--output", type=str, default="")
parser.add_argument('--sample-break-mode', default='complete',
choices=['none', 'complete', 'complete_doc', 'eos'],
help='If omitted or "none", fills each sample with tokens-per-sample '
'tokens. If set to "complete", splits samples only at the end '
'of sentence, but may include multiple sentences per sample. '
'"complete_doc" is similar but respects doc boundaries. '
'If set to "eos", includes only one sentence per sample.')
parser.add_argument('--tokens-per-sample', default=510, type=int,
help='max number of total tokens over all segments per sample')
parser.add_argument('--dataset_impl', default="mmap", type=str)
args = parser.parse_args()
return args
def save_items(items, prefix, vocab_size):
bin_fn = "%s.bin" % prefix
idx_fn = "%s.idx" % prefix
builder = make_builder(bin_fn, "mmap", vocab_size=vocab_size)
print("builder: " + str(builder))
for item in items: builder.add_item(item)
builder.finalize(idx_fn)
def main(args):
tokenizer = build_tokenizer(args)
indices = []
with open(args.input) as fp:
for line in tqdm(fp):
line = line.strip()
indices.append(tokenizer.encode(line))
print("tokenize finished.")
for i in range(5):
print("example[%d]:" % i)
input_ids = indices[i]
print(input_ids)
tokens = tokenizer.convert_ids_to_tokens(input_ids)
print(tokens)
dataset = IndexDataset(indices)
dataset = TruncateDataset(dataset, args.tokens_per_sample - 1)
dataset = TokenBlockDataset(
dataset,
dataset.sizes,
args.tokens_per_sample - 1, # one less for <s>
pad=tokenizer.pad_token_id,
eos=tokenizer.sep_token_id,
break_mode=args.sample_break_mode,
)
print('| loaded {} blocks from: {}'.format(len(dataset), args.input), flush=True)
dataset = PrependTokenDataset(dataset, tokenizer.cls_token_id)
print("| get all items ...")
items = [i for i in tqdm(dataset)]
print("| writing binary file ...")
prefix = os.path.join(args.output, "train.0")
save_items(items, prefix, len(tokenizer))
if __name__ == "__main__":
args = get_args()
main(args)
|