import numpy as np from random import randint, shuffle, choice from random import random as rand import math import logging import torch import torch.utils.data logger = logging.getLogger(__name__) def get_random_word(vocab_words): i = randint(0, len(vocab_words)-1) return vocab_words[i] def batch_list_to_batch_tensors(batch): batch_tensors = [] for x in zip(*batch): if x[0] is None: batch_tensors.append(None) elif isinstance(x[0], torch.Tensor): batch_tensors.append(torch.stack(x)) else: batch_tensors.append(torch.tensor(x, dtype=torch.long)) return batch_tensors def _get_word_split_index(tokens, st, end): split_idx = [] i = st while i < end: if (not tokens[i].startswith('##')) or (i == st): split_idx.append(i) i += 1 split_idx.append(end) return split_idx def _expand_whole_word(tokens, st, end): new_st, new_end = st, end while (new_st >= 0) and tokens[new_st].startswith('##'): new_st -= 1 while (new_end < len(tokens)) and tokens[new_end].startswith('##'): new_end += 1 return new_st, new_end class Pipeline(): """ Pre-process Pipeline Class : callable """ def __init__(self): super().__init__() self.skipgram_prb = None self.skipgram_size = None self.pre_whole_word = None self.mask_whole_word = None self.word_subsample_prb = None self.sp_prob = None self.pieces_dir = None self.vocab_words = None self.pieces_threshold = 10 self.call_count = 0 self.offline_mode = False self.skipgram_size_geo_list = None self.span_same_mask = False def __call__(self, instance): raise NotImplementedError class Preprocess4Seq2seqDecoder(Pipeline): """ Pre-processing steps for pretraining transformer """ def __init__(self, vocab_words, indexer, max_len=512, max_tgt_length=128, mode="s2s", pos_shift=False, source_type_id=0, target_type_id=1, cls_token='[CLS]', sep_token='[SEP]', pad_token='[PAD]', layout_flag=False): super().__init__() self.max_len = max_len self.vocab_words = vocab_words # vocabulary (sub)words self.indexer = indexer # function from token to token index self.max_len = max_len self._tril_matrix = torch.tril(torch.ones((max_len, max_len), dtype=torch.long)) self.task_idx = 3 # relax projection layer for different tasks assert mode in ("s2s", "l2r") self.mode = mode self.max_tgt_length = max_tgt_length self.pos_shift = pos_shift self.layout_flag = layout_flag if layout_flag: self.cls_token = [cls_token, 0, 0, 0, 0] self.sep_token = [sep_token, 1000, 1000, 1000, 1000] self.pad_token = [pad_token, 0, 0, 0, 0] else: self.cls_token = cls_token self.sep_token = sep_token self.pad_token = pad_token self.source_type_id = source_type_id self.target_type_id = target_type_id self.cc = 0 def __call__(self, instance): tokens_a, max_a_len = instance # NOTE: must pad to the max src length max_a_len = 511 padded_tokens_a = [self.cls_token] + tokens_a + [self.sep_token] assert len(padded_tokens_a) <= max_a_len + 2 if max_a_len + 2 > len(padded_tokens_a): padded_tokens_a += [self.pad_token] * \ (max_a_len + 2 - len(padded_tokens_a)) assert len(padded_tokens_a) == max_a_len + 2 max_len_in_batch = min(self.max_tgt_length + max_a_len + 2, self.max_len) tokens = padded_tokens_a segment_ids = [self.source_type_id] * (len(padded_tokens_a)) \ + [self.target_type_id] * (max_len_in_batch - len(padded_tokens_a)) mask_qkv = None position_ids = [] for i in range(len(tokens_a) + 2): position_ids.append(i) for i in range(len(tokens_a) + 2, max_a_len + 2): position_ids.append(0) for i in range(max_a_len + 2, max_len_in_batch): position_ids.append(i - (max_a_len + 2) + len(tokens_a) + 2) # Token Indexing if not self.layout_flag: input_ids = self.indexer(tokens) else: raw_text = [x[0] for x in tokens] raw_text_ids = self.indexer(raw_text) input_ids = [[i] + x[1:] for i, x in zip(raw_text_ids, tokens)] self.cc += 1 if self.cc < 5: if not self.layout_flag: logger.info("Input src = %s" % " ".join(self.vocab_words[tk_id] for tk_id in input_ids)) else: logger.info("Input src = %s" % " ".join(self.vocab_words[tk_id[0]] for tk_id in input_ids)) # Zero Padding input_mask = torch.zeros( max_len_in_batch, max_len_in_batch, dtype=torch.long) if self.mode == "s2s": input_mask[:, :len(tokens_a)+2].fill_(1) else: st, end = 0, len(tokens_a) + 2 input_mask[st:end, st:end].copy_( self._tril_matrix[:end, :end]) input_mask[end:, :len(tokens_a)+2].fill_(1) second_st, second_end = len(padded_tokens_a), max_len_in_batch input_mask[second_st:second_end, second_st:second_end].copy_( self._tril_matrix[:second_end-second_st, :second_end-second_st]) return input_ids, segment_ids, position_ids, input_mask, mask_qkv, self.task_idx