File size: 3,130 Bytes
6fc683c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import argparse
import os

import torch
from fairseq.data import (FairseqDataset, PrependTokenDataset,
                          TokenBlockDataset, TruncateDataset, data_utils, StripTokenDataset, ConcatDataset, PrependTokenDataset, AppendTokenDataset)
from fairseq.data.indexed_dataset import make_builder
from tqdm import tqdm
from transformers import AutoTokenizer

from infoxlm.data.tlm_dataset import TLMDataset


class IndexDataset(FairseqDataset):
  
  def __init__(self, indices):
    self.indices = indices
    self._sizes = [len(i) for i in indices]

  @property
  def sizes(self):
    return self._sizes
  
  def size(self, index):
    item = self.__getitem__(index)
    return len(item)
  
  def __getitem__(self, index):
    item = self.indices[index]
    item = torch.LongTensor(item)
    return item
  
  def __len__(self):
    return len(self.indices)
  
  def collater(self, samples):
    raise NotImplementedError


def build_tokenizer(args):
  tokenizer = AutoTokenizer.from_pretrained(args.model_name)
  return tokenizer


def get_args():
  parser = argparse.ArgumentParser()
  parser.add_argument("--model_name", type=str, default="CZWin32768/xlm-align")
  parser.add_argument("--input_src", type=str, default="")
  parser.add_argument("--input_trg", type=str, default="")
  parser.add_argument("--output", type=str, default="")
  parser.add_argument("--max_pos", type=int, default=256)
  args = parser.parse_args()
  return args


def save_items(items, prefix, vocab_size):
  bin_fn = "%s.bin" % prefix
  idx_fn = "%s.idx" % prefix
  builder = make_builder(bin_fn, "mmap", vocab_size=vocab_size)
  print("builder: " + str(builder))
  for item in items: builder.add_item(item)
  builder.finalize(idx_fn)


def get_indices(input_fn, tokenizer):
  indices = []
  with open(input_fn) as fp:
    for lid, line in tqdm(enumerate(fp)):
      # DEBUG 
      # if lid > 500: break
      line = line.strip()
      indices.append(tokenizer.encode(line))
  print("tokenize finished.")
  return indices



def main(args):
  tokenizer = build_tokenizer(args)
  src_indices = get_indices(args.input_src, tokenizer)
  trg_indices = get_indices(args.input_trg, tokenizer)

  src_dataset = IndexDataset(src_indices)
  trg_dataset = IndexDataset(trg_indices)

  eos = tokenizer.sep_token_id
  bos = tokenizer.cls_token_id
  max_pos = args.max_pos

  datasets = []

  src_dataset = TruncateDataset(
    StripTokenDataset(src_dataset, eos), max_pos - 2,)
  trg_dataset = TruncateDataset(
    StripTokenDataset(trg_dataset, eos), max_pos - 2,)

  src_dataset = PrependTokenDataset(src_dataset, bos)
  trg_dataset = PrependTokenDataset(trg_dataset, bos)

  src_dataset = AppendTokenDataset(src_dataset, eos)
  trg_dataset = AppendTokenDataset(trg_dataset, eos)

  print("| get all items ...")
  # items = [i for i in tqdm(dataset)]
  items = []
  for t1, t2 in tqdm(zip(src_dataset, trg_dataset)):
    items.append(t1)
    items.append(t2)

  print("| writing binary file ...")
  prefix = os.path.join(args.output, "train.0")
  save_items(items, prefix, len(tokenizer))


if __name__ == "__main__":
  args = get_args()
  main(args)