from __future__ import absolute_import, division, print_function import logging import os import json import random import glob import re import torch import tqdm import torch.utils.data logger = logging.getLogger(__name__) class Seq2seqDatasetForBert(torch.utils.data.Dataset): def __init__( self, features, max_source_len, max_target_len, vocab_size, cls_id, sep_id, pad_id, mask_id, random_prob, keep_prob, offset, num_training_instances, span_len=1, span_prob=1.0): self.features = features self.max_source_len = max_source_len self.max_target_len = max_target_len self.offset = offset if offset > 0: logger.info(" **** Set offset %d in Seq2seqDatasetForBert **** ", offset) self.cls_id = cls_id self.sep_id = sep_id self.pad_id = pad_id self.random_prob = random_prob self.keep_prob = keep_prob self.mask_id = mask_id self.vocab_size = vocab_size self.num_training_instances = num_training_instances self.span_len = span_len self.span_prob = span_prob def __len__(self): return int(self.num_training_instances) def __trunk(self, ids, max_len): if len(ids) > max_len - 1: ids = ids[:max_len - 1] ids = ids + [self.sep_id] return ids def __pad(self, ids, max_len): if len(ids) < max_len: return ids + [self.pad_id] * (max_len - len(ids)) else: assert len(ids) == max_len return ids def __getitem__(self, idx): idx = (self.offset + idx) % len(self.features) feature = self.features[idx] source_ids = self.__trunk([self.cls_id] + feature["source_ids"], self.max_source_len) target_ids = self.__trunk(feature["target_ids"], self.max_target_len) pseudo_ids = [] for tk_id in target_ids: p = random.random() if p < self.keep_prob: pseudo_ids.append(tk_id) elif p < self.keep_prob + self.random_prob: pseudo_ids.append(random.randint(0, self.vocab_size - 1)) else: pseudo_ids.append(self.mask_id) num_source_tokens = len(source_ids) num_target_tokens = len(target_ids) source_ids = self.__pad(source_ids, self.max_source_len) target_ids = self.__pad(target_ids, self.max_target_len) pseudo_ids = self.__pad(pseudo_ids, self.max_target_len) if self.span_len > 1: span_ids = [] span_id = 1 while len(span_ids) < num_target_tokens: p = random.random() if p < self.span_prob: span_len = random.randint(2, self.span_len) span_len = min(span_len, num_target_tokens - len(span_ids)) else: span_len = 1 span_ids.extend([span_id] * span_len) span_id += 1 span_ids = self.__pad(span_ids, self.max_target_len) return source_ids, target_ids, pseudo_ids, num_source_tokens, num_target_tokens, span_ids else: return source_ids, target_ids, pseudo_ids, num_source_tokens, num_target_tokens # DONE: finish this!!! the 2D input id settings. class Seq2seqDatasetForLayoutlm(torch.utils.data.Dataset): def __init__( self, features, max_source_len, max_target_len, vocab_size, cls_id, sep_id, pad_id, mask_id, random_prob, keep_prob, offset, num_training_instances, layout_flag=True, span_len=1, span_prob=1.0): self.layout_flag = layout_flag self.features = features self.max_source_len = max_source_len self.max_target_len = max_target_len self.offset = offset if offset > 0: logger.info(" **** Set offset %d in Seq2seqDatasetForBert **** ", offset) self.cls_id = cls_id self.sep_id = sep_id self.pad_id = pad_id self.random_prob = random_prob self.keep_prob = keep_prob self.mask_id = mask_id self.vocab_size = vocab_size self.num_training_instances = num_training_instances self.span_len = span_len self.span_prob = span_prob self.index_sp_id = 0 def __len__(self): return int(self.num_training_instances) def __clip_index(self, ids): replace_value = 0 for i in range(len(ids)): if ids[i] > self.max_source_len - 1: ids[i] = replace_value return ids def __trunk(self, ids, max_len, simple=False, value=None): trunk_value = value if value is not None else self.sep_id if len(ids) > max_len - 1: ids = ids[:max_len - 1] if simple: ids = ids + [trunk_value] else: ids = ids + [[trunk_value, 1000, 1000, 1000, 1000]] return ids def __pad(self, ids, max_len, simple=False, value=None): pad_value = value if value is not None else self.pad_id if len(ids) < max_len: if simple: return ids + [pad_value] * (max_len - len(ids)) else: return ids + [[pad_value, 0, 0, 0, 0]] * (max_len - len(ids)) else: assert len(ids) == max_len return ids def __getitem__(self, idx): if self.layout_flag: return self.__getitem_layout__(idx) else: return self.__getitem_bert__(idx) def __getitem_bert__(self, idx): idx = (self.offset + idx) % len(self.features) feature = self.features[idx] source_ids = self.__trunk([self.cls_id] + feature["source_ids"], self.max_source_len, simple=True) target_ids = self.__trunk(feature["target_ids"], self.max_target_len, simple=True) target_index = self.__trunk(feature['target_index'], self.max_target_len, simple=True, value=self.index_sp_id) pseudo_ids = [] for tk_id in target_ids: p = random.random() if p < self.keep_prob: pseudo_ids.append(tk_id) elif p < self.keep_prob + self.random_prob: pseudo_ids.append(random.randint(0, self.vocab_size - 1)) else: pseudo_ids.append(self.mask_id) num_source_tokens = len(source_ids) num_target_tokens = len(target_ids) source_ids = self.__pad(source_ids, self.max_source_len, simple=True) target_ids = self.__pad(target_ids, self.max_target_len, simple=True) pseudo_ids = self.__pad(pseudo_ids, self.max_target_len, simple=True) target_index = self.__pad(target_index, self.max_target_len, simple=True, value=self.index_sp_id) target_index = self.__clip_index(target_index) if self.span_len > 1: span_ids = [] span_id = 1 while len(span_ids) < num_target_tokens: p = random.random() if p < self.span_prob: span_len = random.randint(2, self.span_len) span_len = min(span_len, num_target_tokens - len(span_ids)) else: span_len = 1 span_ids.extend([span_id] * span_len) span_id += 1 span_ids = self.__pad(span_ids, self.max_target_len) return source_ids, target_ids, pseudo_ids, num_source_tokens, num_target_tokens, span_ids, target_index else: return source_ids, target_ids, pseudo_ids, num_source_tokens, num_target_tokens, target_index def __getitem_layout__(self, idx): # TODO: how to initialize the random and masked tokens' pos emb # Simple Solution: only mask the text idx = (self.offset + idx) % len(self.features) feature = self.features[idx] source_ids = self.__trunk([[self.cls_id, 0, 0, 0, 0]] + feature["source_ids"], self.max_source_len) target_ids = self.__trunk(feature["target_ids"], self.max_target_len) target_index = self.__trunk(feature['target_index'], self.max_target_len, simple=True, value=self.index_sp_id) pseudo_ids = [] for tk_id in target_ids: p = random.random() if p < self.keep_prob: pseudo_ids.append(tk_id) elif p < self.keep_prob + self.random_prob: pseudo_ids.append([random.randint(0, self.vocab_size - 1)] + [0, 0, 0, 0]) # tk_id[1:]) else: pseudo_ids.append([self.mask_id] + [0, 0, 0, 0]) # tk_id[1:]) num_source_tokens = len(source_ids) num_target_tokens = len(target_ids) source_ids = self.__pad(source_ids, self.max_source_len) target_ids = self.__pad(target_ids, self.max_target_len) pseudo_ids = self.__pad(pseudo_ids, self.max_target_len) target_index = self.__pad(target_index, self.max_target_len, simple=True, value=self.index_sp_id) target_index = self.__clip_index(target_index) if self.span_len > 1: span_ids = [] span_id = 1 while len(span_ids) < num_target_tokens: p = random.random() if p < self.span_prob: span_len = random.randint(2, self.span_len) span_len = min(span_len, num_target_tokens - len(span_ids)) else: span_len = 1 span_ids.extend([span_id] * span_len) span_id += 1 span_ids = self.__pad(span_ids, self.max_target_len) return source_ids, target_ids, pseudo_ids, num_source_tokens, num_target_tokens, span_ids, target_index else: return source_ids, target_ids, pseudo_ids, num_source_tokens, num_target_tokens, target_index def batch_list_to_batch_tensors(batch): batch_tensors = [] for x in zip(*batch): if isinstance(x[0], torch.Tensor): batch_tensors.append(torch.stack(x)) else: batch_tensors.append(torch.tensor(x, dtype=torch.long)) return batch_tensors def get_max_epoch_model(output_dir): fn_model_list = glob.glob(os.path.join(output_dir, "model.*.bin")) fn_optim_list = glob.glob(os.path.join(output_dir, "optim.*.bin")) if (not fn_model_list) or (not fn_optim_list): return None os.path.basename(output_dir) both_set = set([int(os.path.basename(fn).split('.')[1]) for fn in fn_model_list] ) & set([int(os.path.basename(fn).split('.')[1]) for fn in fn_optim_list]) if both_set: return max(both_set) else: return None def load_and_cache_examples( example_file, tokenizer, local_rank, cached_features_file, shuffle=True): # Make sure only the first process in distributed training process the dataset, and the others will use the cache if local_rank not in [-1, 0]: torch.distributed.barrier() if cached_features_file is not None and os.path.exists(cached_features_file): logger.info("Loading features from cached file %s", cached_features_file) features = torch.load(cached_features_file) else: logger.info("Creating features from dataset file at %s", example_file) examples = [] with open(example_file, mode="r", encoding="utf-8") as reader: for i, line in enumerate(reader): if i == 100: break examples.append(json.loads(line)) features = [] for example in tqdm.tqdm(examples): if isinstance(example["src"], list): source_tokens = example["src"] target_tokens = example["tgt"] else: source_tokens = tokenizer.tokenize(example["src"]) target_tokens = tokenizer.tokenize(example["tgt"]) features.append({ "source_ids": tokenizer.convert_tokens_to_ids(source_tokens), "target_ids": tokenizer.convert_tokens_to_ids(target_tokens), }) if shuffle: random.shuffle(features) if local_rank in [-1, 0] and cached_features_file is not None: logger.info("Saving features into cached file %s", cached_features_file) torch.save(features, cached_features_file) # Make sure only the first process in distributed training process the dataset, and the others will use the cache if local_rank == 0: torch.distributed.barrier() return features def load_and_cache_line_order_examples( example_path, tokenizer, local_rank, cached_features_file, max_src_length=1024, layout_flag=True, shuffle=True, src_shuffle_rate=0, file_info_flag=False, ): # Make sure only the first process in distributed training process the dataset, and the others will use the cache if local_rank not in [-1, 0]: torch.distributed.barrier() if cached_features_file is not None and os.path.exists(cached_features_file) and False: logger.info("Loading features from cached file %s", cached_features_file) features = torch.load(cached_features_file) else: logger.info("Creating features from dataset at %s", example_path) examples = [] with open(example_path, 'r') as layout_reader: logger.info(f'Start loading {example_path}') for i, line in enumerate(layout_reader): examples.append(json.loads(line)) features = [] for layout in tqdm.tqdm(examples): bleu = layout['bleu'] if random.random() < src_shuffle_rate: # print('Random!!!') # DONE: the random src! here has bug! index also need shuffle src_layout = layout['src'] tgt_index = layout['tgt_index'] source_length = len(src_layout) shuffle_index = list(range(source_length)) random.shuffle(shuffle_index) shuffle_layout = ['' for _ in range(source_length)] for i, j in enumerate(shuffle_index): # NOTE: map i-th token to j-th token shuffle_layout[j] = src_layout[i] shuffle_target_index = [shuffle_index[i] for i in tgt_index] layout['tgt_index'] = shuffle_target_index layout['src'] = shuffle_layout mask = tokenizer.mask_token_id src_ids = [tokenizer.convert_tokens_to_ids([str(tmp_i)])[:1] + src_layout for tmp_i, src_layout in enumerate(layout['src'])] tgt_ids = [tokenizer.convert_tokens_to_ids([str(tmp_i)])[:1] + tgt_layout for tmp_i, tgt_layout in enumerate(layout['tgt'])] tgt_index = layout['tgt_index'] feature = { "source_ids": src_ids, "target_ids": tgt_ids, "target_index": tgt_index, 'bleu': bleu } if file_info_flag: file_info = {'original_filename': layout['filename'], 'filename': layout['filename'], 'page_idx': 0} feature['file_info'] = file_info features.append(feature) if shuffle: random.shuffle(features) if local_rank in [-1, 0] and cached_features_file is not None: logger.info("Saving features into cached file %s", cached_features_file) torch.save(features, cached_features_file) # Make sure only the first process in distributed training process the dataset, and the others will use the cache if local_rank == 0: torch.distributed.barrier() return features def load_and_cache_layoutlm_examples( example_path, tokenizer, local_rank, cached_features_file, max_src_length=1024, layout_flag=True, shuffle=True, src_shuffle_rate=0, file_info_flag=False ): # Make sure only the first process in distributed training process the dataset, and the others will use the cache if local_rank not in [-1, 0]: torch.distributed.barrier() if cached_features_file is not None and os.path.exists(cached_features_file): logger.info("Loading features from cached file %s", cached_features_file) features = torch.load(cached_features_file) else: logger.info("Creating features from dataset at %s", example_path) examples = [] if os.path.isdir(example_path): text_files = glob.glob(f'{example_path}/*text*.json') layout_files = [re.sub('text|txt', 'layout', x, 1) for x in text_files] else: text_files = [example_path] layout_files = [re.sub('text|txt', 'layout', example_path, 1)] for text_file, layout_file in zip(text_files, layout_files): with open(text_file, mode='r', encoding='utf-8') as text_reader, \ open(layout_file, mode='r', encoding='utf-8') as layout_reader: logger.info(f'Start loading {text_file}') for i, (text_line, layout_line) in enumerate(zip(text_reader, layout_reader)): if (i + 1) % 10000 == 0: logger.info(f'{i + 1} lines ...') examples.append((json.loads(text_line), json.loads(layout_line))) features = [] def tokenize_text_and_layout_src(_text, _layout, _layout_flag): ret = [] index_split = {} words = _text.split() # note: (OLD) the index should start from 1: 0-the cls token in src # note: (NEW) we need to remove the src embedding's CLS SEP token so we can still start from 0 # note: (NEWER) we need to at least one blank pos for ignore index in loss function (we use sep's index) # NOTE: (NEWER-ER) 1 for all padding tgt index new_token_index = 1 # first ordinary index for i, (word, box) in enumerate(zip(words, _layout)): if (not box[2] >= box[0]) or (not box[3] >= box[1]): continue tokens = tokenizer.tokenize(word) tokens = tokenizer.convert_tokens_to_ids(tokens) new_token_ids = [] for token in tokens: if _layout_flag: ret.append([token] + box) else: ret.append(token) new_token_ids.append(new_token_index) new_token_index += 1 index_split[i] = new_token_ids return ret, index_split def tokenize_text_and_layout_tgt(_text, _layout, _index, _index_split, _layout_flag): ret = [] ret_index = [] words = _text.split() for word, box, i in zip(words, _layout, _index): if (not box[2] >= box[0]) or (not box[3] >= box[1]): continue tokens = tokenizer.tokenize(word) tokens = tokenizer.convert_tokens_to_ids(tokens) for token, ii in zip(tokens, _index_split[i]): if _layout_flag: ret.append([token] + box) else: ret.append(token) ii = min(ii, max_src_length - 1) ret_index.append(ii) return ret, ret_index for text, layout in tqdm.tqdm(examples): if 'bleu' in text: bleu = text['bleu'] else: bleu = 0 if random.random() < src_shuffle_rate: # print('Random!!!') # DONE: the random src! here has bug! index also need shuffle src_text = text['src'] src_layout = layout['src'] tgt_index = text['tgt_index'] src_text = src_text.split() source_length = len(src_text) shuffle_index = list(range(source_length)) random.shuffle(shuffle_index) shuffle_text = ['' for _ in range(source_length)] shuffle_layout = ['' for _ in range(source_length)] for i, j in enumerate(shuffle_index): # NOTE: map i-th token to j-th token shuffle_text[j] = src_text[i] shuffle_layout[j] = src_layout[i] shuffle_target_index = [shuffle_index[i] for i in tgt_index] text['src'] = ' '.join(shuffle_text) text['tgt_index'] = shuffle_target_index layout['src'] = shuffle_layout src_ids, src_index_split = tokenize_text_and_layout_src(text['src'], layout['src'], _layout_flag=layout_flag) tgt_ids, tgt_index = tokenize_text_and_layout_tgt(text['tgt'], layout['tgt'], text['tgt_index'], src_index_split, _layout_flag=layout_flag) feature = { "source_ids": src_ids, "target_ids": tgt_ids, "target_index": tgt_index, 'bleu': bleu } if file_info_flag: file_info = {'original_filename': text['original_filename'], 'filename': text['filename'], 'page_idx': text['page_idx']} feature['file_info'] = file_info features.append(feature) if shuffle: random.shuffle(features) if local_rank in [-1, 0] and cached_features_file is not None: if not os.path.exists(os.path.dirname(cached_features_file)): os.makedirs(os.path.dirname(cached_features_file)) logger.info("Saving features into cached file %s", cached_features_file) torch.save(features, cached_features_file) # Make sure only the first process in distributed training process the dataset, and the others will use the cache if local_rank == 0: torch.distributed.barrier() return features def convert_src_layout_inputs_to_tokens(inputs, converter, max_src_length, layout_flag=True): ret = [] if not layout_flag: for line in inputs: ret.append(converter(line["source_ids"])[: max_src_length]) else: for line in inputs: raw_text_ids = [x[0] for x in line['source_ids']] raw_text = converter(raw_text_ids) new_line = [[t] + x[1:] for t, x in zip(raw_text, line['source_ids'])][: max_src_length] ret.append(new_line) return ret def convert_tgt_layout_inputs_to_tokens(inputs, converter, max_tgt_length, layout_flag=True): ret = [] if not layout_flag: for line in inputs: ret.append(converter(line["target_ids"])[: max_tgt_length]) else: for line in inputs: raw_text_ids = [x[0] for x in line['target_ids']] ret.append(converter(raw_text_ids)[: max_tgt_length]) return ret def get_tokens_from_src_and_index(src, index, modifier=None): result = [] for i in index: i = modifier(i) i = min(i, len(src) - 1) if isinstance(src[i], list): result.append(src[i][0]) else: result.append(src[i]) return result def get_layout_from_src_and_index(src, index, modifier=None): result = [] s = set() for i in index: i = modifier(i) i = min(i, len(src) - 1) layout = src[i][1:] if repr(layout) not in s: result.append(layout) s.add(repr(layout)) return result def get_everything_from_src_and_index(src, index, modifier=None): result = [] for i in index: i = modifier(i) i = min(i, len(src) - 1) result.append(src[i]) return result