File size: 2,941 Bytes
6fc683c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import argparse
import os

import torch
from fairseq.data import (FairseqDataset, PrependTokenDataset,
                          TokenBlockDataset, TruncateDataset, data_utils, StripTokenDataset, ConcatDataset)
from fairseq.data.indexed_dataset import make_builder
from tqdm import tqdm
from transformers import AutoTokenizer

from infoxlm.data.tlm_dataset import TLMDataset


class IndexDataset(FairseqDataset):
  
  def __init__(self, indices):
    self.indices = indices
    self._sizes = [len(i) for i in indices]

  @property
  def sizes(self):
    return self._sizes
  
  def size(self, index):
    item = self.__getitem__(index)
    return len(item)
  
  def __getitem__(self, index):
    item = self.indices[index]
    item = torch.LongTensor(item)
    return item
  
  def __len__(self):
    return len(self.indices)
  
  def collater(self, samples):
    raise NotImplementedError


def build_tokenizer(args):
  tokenizer = AutoTokenizer.from_pretrained(args.model_name)
  return tokenizer


def get_args():
  parser = argparse.ArgumentParser()
  parser.add_argument("--model_name", type=str, default="CZWin32768/xlm-align")
  parser.add_argument("--input_src", type=str, default="")
  parser.add_argument("--input_trg", type=str, default="")
  parser.add_argument("--output", type=str, default="")
  parser.add_argument("--max_pos", type=int, default=256)
  args = parser.parse_args()
  return args


def save_items(items, prefix, vocab_size):
  bin_fn = "%s.bin" % prefix
  idx_fn = "%s.idx" % prefix
  builder = make_builder(bin_fn, "mmap", vocab_size=vocab_size)
  print("builder: " + str(builder))
  for item in items: builder.add_item(item)
  builder.finalize(idx_fn)


def get_indices(input_fn, tokenizer):
  indices = []
  with open(input_fn) as fp:
    for lid, line in tqdm(enumerate(fp)):
      # DEBUG 
      # if lid > 500: break
      line = line.strip()
      indices.append(tokenizer.encode(line))
  print("tokenize finished.")
  return indices


def main(args):
  tokenizer = build_tokenizer(args)
  src_indices = get_indices(args.input_src, tokenizer)
  trg_indices = get_indices(args.input_trg, tokenizer)

  src_dataset = IndexDataset(src_indices)
  trg_dataset = IndexDataset(trg_indices)

  eos = tokenizer.sep_token_id
  bos = tokenizer.cls_token_id
  max_pos = args.max_pos

  datasets = []

  src_dataset = TruncateDataset(
    StripTokenDataset(src_dataset, eos), max_pos - 2,)
  trg_dataset = TruncateDataset(
    StripTokenDataset(trg_dataset, eos), max_pos - 2,)

  datasets.append(
    TLMDataset(src_dataset, trg_dataset, bos, eos))
  datasets.append(
    TLMDataset(trg_dataset, src_dataset, bos, eos))

  dataset = ConcatDataset(datasets)

  print("| get all items ...")
  items = [i for i in tqdm(dataset)]
  print("| writing binary file ...")
  prefix = os.path.join(args.output, "train.0")
  save_items(items, prefix, len(tokenizer))


if __name__ == "__main__":
  args = get_args()
  main(args)