File size: 3,409 Bytes
6fc683c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import argparse
import os

import torch
from fairseq.data import (FairseqDataset, PrependTokenDataset,
                          TokenBlockDataset, TruncateDataset, data_utils)
from fairseq.data.indexed_dataset import make_builder
from tqdm import tqdm
from transformers import AutoTokenizer


class IndexDataset(FairseqDataset):
  
  def __init__(self, indices):
    self.indices = indices
    self._sizes = [len(i) for i in indices]

  @property
  def sizes(self):
    return self._sizes
  
  def size(self, index):
    item = self.__getitem__(index)
    return len(item)
  
  def __getitem__(self, index):
    item = self.indices[index]
    item = torch.LongTensor(item)
    return item
  
  def __len__(self):
    return len(self.indices)
  
  def collater(self, samples):
    raise NotImplementedError


def build_tokenizer(args):
  tokenizer = AutoTokenizer.from_pretrained(args.model_name)
  return tokenizer


def get_args():
  parser = argparse.ArgumentParser()
  parser.add_argument("--model_name", type=str, default="CZWin32768/xlm-align")
  parser.add_argument("--input", type=str, default="")
  parser.add_argument("--output", type=str, default="")
  parser.add_argument('--sample-break-mode', default='complete',
                        choices=['none', 'complete', 'complete_doc', 'eos'],
                        help='If omitted or "none", fills each sample with tokens-per-sample '
                        'tokens. If set to "complete", splits samples only at the end '
                        'of sentence, but may include multiple sentences per sample. '
                        '"complete_doc" is similar but respects doc boundaries. '
                        'If set to "eos", includes only one sentence per sample.')
  parser.add_argument('--tokens-per-sample', default=510, type=int,
                      help='max number of total tokens over all segments per sample')
  parser.add_argument('--dataset_impl', default="mmap", type=str)
  args = parser.parse_args()
  return args


def save_items(items, prefix, vocab_size):
  bin_fn = "%s.bin" % prefix
  idx_fn = "%s.idx" % prefix
  builder = make_builder(bin_fn, "mmap", vocab_size=vocab_size)
  print("builder: " + str(builder))
  for item in items: builder.add_item(item)
  builder.finalize(idx_fn)


def main(args):
  tokenizer = build_tokenizer(args)

  indices = []
  with open(args.input) as fp:
    for line in tqdm(fp):
      line = line.strip()
      indices.append(tokenizer.encode(line))
  print("tokenize finished.")
  for i in range(5):
    print("example[%d]:" % i)
    input_ids = indices[i]
    print(input_ids)
    tokens = tokenizer.convert_ids_to_tokens(input_ids)
    print(tokens)

  dataset = IndexDataset(indices)
  dataset = TruncateDataset(dataset, args.tokens_per_sample - 1)
  dataset = TokenBlockDataset(
    dataset,
    dataset.sizes,
    args.tokens_per_sample - 1,  # one less for <s>
    pad=tokenizer.pad_token_id,
    eos=tokenizer.sep_token_id,
    break_mode=args.sample_break_mode,
  )
  print('| loaded {} blocks from: {}'.format(len(dataset), args.input), flush=True)

  dataset = PrependTokenDataset(dataset, tokenizer.cls_token_id)
  print("| get all items ...")
  items = [i for i in tqdm(dataset)]
  print("| writing binary file ...")
  prefix = os.path.join(args.output, "train.0")
  save_items(items, prefix, len(tokenizer))



if __name__ == "__main__":
  args = get_args()
  main(args)