import argparse import os import torch from fairseq.data import (FairseqDataset, PrependTokenDataset, TokenBlockDataset, TruncateDataset, data_utils, StripTokenDataset, ConcatDataset) from fairseq.data.indexed_dataset import make_builder from tqdm import tqdm from transformers import AutoTokenizer from infoxlm.data.tlm_dataset import TLMDataset class IndexDataset(FairseqDataset): def __init__(self, indices): self.indices = indices self._sizes = [len(i) for i in indices] @property def sizes(self): return self._sizes def size(self, index): item = self.__getitem__(index) return len(item) def __getitem__(self, index): item = self.indices[index] item = torch.LongTensor(item) return item def __len__(self): return len(self.indices) def collater(self, samples): raise NotImplementedError def build_tokenizer(args): tokenizer = AutoTokenizer.from_pretrained(args.model_name) return tokenizer def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--model_name", type=str, default="CZWin32768/xlm-align") parser.add_argument("--input_src", type=str, default="") parser.add_argument("--input_trg", type=str, default="") parser.add_argument("--output", type=str, default="") parser.add_argument("--max_pos", type=int, default=256) args = parser.parse_args() return args def save_items(items, prefix, vocab_size): bin_fn = "%s.bin" % prefix idx_fn = "%s.idx" % prefix builder = make_builder(bin_fn, "mmap", vocab_size=vocab_size) print("builder: " + str(builder)) for item in items: builder.add_item(item) builder.finalize(idx_fn) def get_indices(input_fn, tokenizer): indices = [] with open(input_fn) as fp: for lid, line in tqdm(enumerate(fp)): # DEBUG # if lid > 500: break line = line.strip() indices.append(tokenizer.encode(line)) print("tokenize finished.") return indices def main(args): tokenizer = build_tokenizer(args) src_indices = get_indices(args.input_src, tokenizer) trg_indices = get_indices(args.input_trg, tokenizer) src_dataset = IndexDataset(src_indices) trg_dataset = IndexDataset(trg_indices) eos = tokenizer.sep_token_id bos = tokenizer.cls_token_id max_pos = args.max_pos datasets = [] src_dataset = TruncateDataset( StripTokenDataset(src_dataset, eos), max_pos - 2,) trg_dataset = TruncateDataset( StripTokenDataset(trg_dataset, eos), max_pos - 2,) datasets.append( TLMDataset(src_dataset, trg_dataset, bos, eos)) datasets.append( TLMDataset(trg_dataset, src_dataset, bos, eos)) dataset = ConcatDataset(datasets) print("| get all items ...") items = [i for i in tqdm(dataset)] print("| writing binary file ...") prefix = os.path.join(args.output, "train.0") save_items(items, prefix, len(tokenizer)) if __name__ == "__main__": args = get_args() main(args)