# coding=utf-8 """PyTorch BERT model.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import copy import json import logging import math import os import numpy as np import torch import torch.nn.functional as F from torch import nn from torch.nn.modules.loss import _Loss class LabelSmoothingLoss(_Loss): """ With label smoothing, KL-divergence between q_{smoothed ground truth prob.}(w) and p_{prob. computed by model}(w) is minimized. """ def __init__(self, label_smoothing=0, tgt_vocab_size=0, ignore_index=0, size_average=None, reduce=None, reduction='mean'): assert 0.0 < label_smoothing <= 1.0 self.ignore_index = ignore_index super(LabelSmoothingLoss, self).__init__( size_average=size_average, reduce=reduce, reduction=reduction) assert label_smoothing > 0 assert tgt_vocab_size > 0 smoothing_value = label_smoothing / (tgt_vocab_size - 2) one_hot = torch.full((tgt_vocab_size,), smoothing_value) one_hot[self.ignore_index] = 0 self.register_buffer('one_hot', one_hot.unsqueeze(0)) self.confidence = 1.0 - label_smoothing self.tgt_vocab_size = tgt_vocab_size def forward(self, output, target): """ output (FloatTensor): batch_size * num_pos * n_classes target (LongTensor): batch_size * num_pos """ assert self.tgt_vocab_size == output.size(2) batch_size, num_pos = target.size(0), target.size(1) output = output.view(-1, self.tgt_vocab_size) target = target.view(-1) model_prob = self.one_hot.repeat(target.size(0), 1) model_prob.scatter_(1, target.unsqueeze(1), self.confidence) model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0) return F.kl_div(output, model_prob, reduction='none').view(batch_size, num_pos, -1).sum(2) logger = logging.getLogger(__name__) PRETRAINED_MODEL_ARCHIVE_MAP = { 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", 'unilm-base-cased': "https://conversationhub.blob.core.windows.net/beit-share-public/ckpt/unilm1-base-cased.bin?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D", 'unilm-large-cased': "https://conversationhub.blob.core.windows.net/beit-share-public/ckpt/unilm1-large-cased.bin?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D", 'unilm1-base-cased': "https://conversationhub.blob.core.windows.net/beit-share-public/ckpt/unilm1-base-cased.bin?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D", 'unilm1-large-cased': "https://conversationhub.blob.core.windows.net/beit-share-public/ckpt/unilm1-large-cased.bin?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D", 'unilm1.2-base-uncased': "https://conversationhub.blob.core.windows.net/beit-share-public/ckpt/unilm1.2-base-uncased.bin?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D" } CONFIG_NAME = 'config.json' WEIGHTS_NAME = 'pytorch_model.bin' def gelu(x): """Implementation of the gelu activation function. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) """ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) def swish(x): return x * torch.sigmoid(x) ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} class BertConfig(object): """Configuration class to store the configuration of a `BertModel`. """ def __init__(self, vocab_size_or_config_json_file, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, relax_projection=0, new_pos_ids=False, initializer_range=0.02, task_idx=None, fp32_embedding=False, ffn_type=0, label_smoothing=None, num_qkv=0, seg_emb=False, source_type_id=0, target_type_id=1, no_segment_embedding=False, **kwargs): """Constructs BertConfig. Args: vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. hidden_size: Size of the encoder layers and the pooler layer. num_hidden_layers: Number of hidden layers in the Transformer encoder. num_attention_heads: Number of attention heads for each attention layer in the Transformer encoder. intermediate_size: The size of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. hidden_act: The non-linear activation function (function or string) in the encoder and pooler. If string, "gelu", "relu" and "swish" are supported. hidden_dropout_prob: The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. attention_probs_dropout_prob: The dropout ratio for the attention probabilities. max_position_embeddings: The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048). type_vocab_size: The vocabulary size of the `token_type_ids` passed into `BertModel`. initializer_range: The sttdev of the truncated_normal_initializer for initializing all weight matrices. """ if isinstance(vocab_size_or_config_json_file, str): with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: json_config = json.loads(reader.read()) for key, value in json_config.items(): self.__dict__[key] = value elif isinstance(vocab_size_or_config_json_file, int): self.vocab_size = vocab_size_or_config_json_file self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.hidden_act = hidden_act self.intermediate_size = intermediate_size self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.max_position_embeddings = max_position_embeddings self.type_vocab_size = type_vocab_size self.relax_projection = relax_projection self.new_pos_ids = new_pos_ids self.initializer_range = initializer_range self.task_idx = task_idx self.fp32_embedding = fp32_embedding self.ffn_type = ffn_type self.label_smoothing = label_smoothing self.num_qkv = num_qkv self.seg_emb = seg_emb self.no_segment_embedding = no_segment_embedding self.source_type_id = source_type_id self.target_type_id = target_type_id if type_vocab_size == 0: self.no_segment_embedding = True else: raise ValueError("First argument must be either a vocabulary size (int)" "or the path to a pretrained model config file (str)") @classmethod def from_dict(cls, json_object): """Constructs a `BertConfig` from a Python dictionary of parameters.""" config = BertConfig(vocab_size_or_config_json_file=-1) for key, value in json_object.items(): config.__dict__[key] = value return config @classmethod def from_json_file(cls, json_file, **kwargs): """Constructs a `BertConfig` from a json file of parameters.""" with open(json_file, "r", encoding='utf-8') as reader: text = reader.read() json_info = json.loads(text) for k, v in kwargs.items(): json_info[k] = v return cls.from_dict(json_info) def __repr__(self): return str(self.to_json_string()) def to_dict(self): """Serializes this instance to a Python dictionary.""" output = copy.deepcopy(self.__dict__) return output def to_json_string(self): """Serializes this instance to a JSON string.""" return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" try: from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm except ImportError: print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.") class BertLayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-5): """Construct a layernorm module in the TF style (epsilon inside the square root). """ super(BertLayerNorm, self).__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.bias = nn.Parameter(torch.zeros(hidden_size)) self.variance_epsilon = eps def forward(self, x): u = x.mean(-1, keepdim=True) s = (x - u).pow(2).mean(-1, keepdim=True) x = (x - u) / torch.sqrt(s + self.variance_epsilon) return self.weight * x + self.bias class BertEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings. """ def __init__(self, config): super(BertEmbeddings, self).__init__() self.word_embeddings = nn.Embedding( config.vocab_size, config.hidden_size) if config.no_segment_embedding: self.token_type_embeddings = None else: self.token_type_embeddings = nn.Embedding( config.type_vocab_size, config.hidden_size) if hasattr(config, 'fp32_embedding'): self.fp32_embedding = config.fp32_embedding else: self.fp32_embedding = False if hasattr(config, 'new_pos_ids') and config.new_pos_ids: self.num_pos_emb = 4 else: self.num_pos_emb = 1 self.position_embeddings = nn.Embedding( config.max_position_embeddings, config.hidden_size * self.num_pos_emb) # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load # any TensorFlow checkpoint file self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, input_ids, token_type_ids=None, position_ids=None, task_idx=None): seq_length = input_ids.size(1) if position_ids is None: position_ids = torch.arange( seq_length, dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand_as(input_ids) if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) words_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) if self.num_pos_emb > 1: num_batch = position_embeddings.size(0) num_pos = position_embeddings.size(1) position_embeddings = position_embeddings.view( num_batch, num_pos, self.num_pos_emb, -1)[torch.arange(0, num_batch).long(), :, task_idx, :] embeddings = words_embeddings + position_embeddings if self.token_type_embeddings is not None: embeddings = embeddings + self.token_type_embeddings(token_type_ids) if self.fp32_embedding: embeddings = embeddings.half() embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings class LayoutlmEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings. """ def __init__(self, config): super(LayoutlmEmbeddings, self).__init__() # self.word_embeddings = nn.Embedding( # config.vocab_size, config.hidden_size) self.only_layout = config.layoutlm_only_layout_flag if not self.only_layout: self.word_embeddings = nn.Embedding( config.vocab_size, config.hidden_size, padding_idx=0 ) else: self.word_embeddings = None self.x_position_embeddings = nn.Embedding( config.max_2d_position_embeddings, config.hidden_size ) self.y_position_embeddings = nn.Embedding( config.max_2d_position_embeddings, config.hidden_size ) self.h_position_embeddings = nn.Embedding( config.max_2d_position_embeddings, config.hidden_size ) self.w_position_embeddings = nn.Embedding( config.max_2d_position_embeddings, config.hidden_size ) if config.no_segment_embedding: self.token_type_embeddings = None else: self.token_type_embeddings = nn.Embedding( config.type_vocab_size, config.hidden_size) if hasattr(config, 'fp32_embedding'): self.fp32_embedding = config.fp32_embedding else: self.fp32_embedding = False if hasattr(config, 'new_pos_ids') and config.new_pos_ids: self.num_pos_emb = 4 else: self.num_pos_emb = 1 self.position_embeddings = nn.Embedding( config.max_position_embeddings, config.hidden_size * self.num_pos_emb) # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load # any TensorFlow checkpoint file self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, input_ids, bbox, token_type_ids=None, position_ids=None, task_idx=None): seq_length = input_ids.size(1) if position_ids is None: position_ids = torch.arange( seq_length, dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand_as(input_ids) if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) position_embeddings = self.position_embeddings(position_ids) if self.num_pos_emb > 1: num_batch = position_embeddings.size(0) num_pos = position_embeddings.size(1) position_embeddings = position_embeddings.view( num_batch, num_pos, self.num_pos_emb, -1)[torch.arange(0, num_batch).long(), :, task_idx, :] left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0]) upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1]) right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2]) lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3]) h_position_embeddings = self.h_position_embeddings( bbox[:, :, 3] - bbox[:, :, 1] ) w_position_embeddings = self.w_position_embeddings( bbox[:, :, 2] - bbox[:, :, 0] ) # token_type_embeddings = self.token_type_embeddings(token_type_ids) # words_embeddings = self.word_embeddings(input_ids) # position_embeddings = self.position_embeddings(position_ids) embeddings = ( # words_embeddings position_embeddings + left_position_embeddings + upper_position_embeddings + right_position_embeddings + lower_position_embeddings + h_position_embeddings + w_position_embeddings ) if not self.only_layout: words_embeddings = self.word_embeddings(input_ids) embeddings = embeddings + words_embeddings if self.token_type_embeddings is not None: embeddings = embeddings + self.token_type_embeddings(token_type_ids) if self.fp32_embedding: embeddings = embeddings.half() embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings class BertSelfAttention(nn.Module): def __init__(self, config): super(BertSelfAttention, self).__init__() if config.hidden_size % config.num_attention_heads != 0: raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (config.hidden_size, config.num_attention_heads)) self.num_attention_heads = config.num_attention_heads self.attention_head_size = int( config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size if hasattr(config, 'num_qkv') and (config.num_qkv > 1): self.num_qkv = config.num_qkv else: self.num_qkv = 1 self.query = nn.Linear( config.hidden_size, self.all_head_size * self.num_qkv) self.key = nn.Linear(config.hidden_size, self.all_head_size * self.num_qkv) self.value = nn.Linear( config.hidden_size, self.all_head_size * self.num_qkv) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.uni_debug_flag = True if os.getenv( 'UNI_DEBUG_FLAG', '') else False if self.uni_debug_flag: self.register_buffer('debug_attention_probs', torch.zeros((512, 512))) if hasattr(config, 'seg_emb') and config.seg_emb: self.b_q_s = nn.Parameter(torch.zeros( 1, self.num_attention_heads, 1, self.attention_head_size)) self.seg_emb = nn.Embedding( config.type_vocab_size, self.all_head_size) else: self.b_q_s = None self.seg_emb = None def transpose_for_scores(self, x, mask_qkv=None): if self.num_qkv > 1: sz = x.size()[:-1] + (self.num_qkv, self.num_attention_heads, self.all_head_size) # (batch, pos, num_qkv, head, head_hid) x = x.view(*sz) if mask_qkv is None: x = x[:, :, 0, :, :] elif isinstance(mask_qkv, int): x = x[:, :, mask_qkv, :, :] else: # mask_qkv: (batch, pos) if mask_qkv.size(1) > sz[1]: mask_qkv = mask_qkv[:, :sz[1]] # -> x: (batch, pos, head, head_hid) x = x.gather(2, mask_qkv.view(sz[0], sz[1], 1, 1, 1).expand( sz[0], sz[1], 1, sz[3], sz[4])).squeeze(2) else: sz = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) # (batch, pos, head, head_hid) x = x.view(*sz) # (batch, head, pos, head_hid) return x.permute(0, 2, 1, 3) def forward(self, hidden_states, attention_mask, history_states=None, mask_qkv=None, seg_ids=None, key_history=None, value_history=None, key_cache=None, value_cache=None, ): if history_states is None: mixed_query_layer = self.query(hidden_states) # possible issue: https://github.com/NVIDIA/apex/issues/131 mixed_key_layer = F.linear(hidden_states, self.key.weight) mixed_value_layer = self.value(hidden_states) else: x_states = torch.cat((history_states, hidden_states), dim=1) mixed_query_layer = self.query(hidden_states) # possible issue: https://github.com/NVIDIA/apex/issues/131 mixed_key_layer = F.linear(x_states, self.key.weight) mixed_value_layer = self.value(x_states) if key_cache is not None and isinstance(key_cache, list): key_cache.append(mixed_key_layer) mixed_key_layer = torch.cat(key_cache, dim=1) if value_cache is not None and isinstance(value_cache, list): value_cache.append(mixed_value_layer) mixed_value_layer = torch.cat(value_cache, dim=1) query_layer = self.transpose_for_scores(mixed_query_layer, mask_qkv) key_layer = self.transpose_for_scores(mixed_key_layer, mask_qkv) value_layer = self.transpose_for_scores(mixed_value_layer, mask_qkv) if key_history is not None and not isinstance(key_history, list): key_layer = torch.cat((key_history, key_layer), dim=-2) value_layer = torch.cat((value_history, value_layer), dim=-2) # Take the dot product between "query" and "key" to get the raw attention scores. # (batch, head, pos, pos) attention_scores = torch.matmul( query_layer / math.sqrt(self.attention_head_size), key_layer.transpose(-1, -2)) if self.seg_emb is not None: seg_rep = self.seg_emb(seg_ids) # (batch, pos, head, head_hid) seg_rep = seg_rep.view(seg_rep.size(0), seg_rep.size( 1), self.num_attention_heads, self.attention_head_size) qs = torch.einsum('bnih,bjnh->bnij', query_layer + self.b_q_s, seg_rep) attention_scores = attention_scores + qs # attention_scores = attention_scores / math.sqrt(self.attention_head_size) # Apply the attention mask is (precomputed for all layers in BertModel forward() function) attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities. attention_probs = nn.Softmax(dim=-1)(attention_scores) if self.uni_debug_flag: _pos = attention_probs.size(-1) self.debug_attention_probs[:_pos, :_pos].copy_( attention_probs[0].mean(0).view(_pos, _pos)) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = self.dropout(attention_probs) context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[ :-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) if isinstance(key_history, list): key_history.append(key_layer) if isinstance(value_history, list): value_history.append(value_layer) return context_layer class BertSelfOutput(nn.Module): def __init__(self, config): super(BertSelfOutput, self).__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states class BertAttention(nn.Module): def __init__(self, config): super(BertAttention, self).__init__() self.self = BertSelfAttention(config) self.output = BertSelfOutput(config) def forward(self, input_tensor, attention_mask, history_states=None, mask_qkv=None, seg_ids=None, key_history=None, value_history=None): self_output = self.self( input_tensor, attention_mask, history_states=history_states, mask_qkv=mask_qkv, seg_ids=seg_ids, key_history=key_history, value_history=value_history) attention_output = self.output(self_output, input_tensor) return attention_output class BertIntermediate(nn.Module): def __init__(self, config): super(BertIntermediate, self).__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) self.intermediate_act_fn = ACT2FN[config.hidden_act] \ if isinstance(config.hidden_act, str) else config.hidden_act def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states class BertOutput(nn.Module): def __init__(self, config): super(BertOutput, self).__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states class TransformerFFN(nn.Module): def __init__(self, config): super(TransformerFFN, self).__init__() self.ffn_type = config.ffn_type assert self.ffn_type in (1, 2) if self.ffn_type in (1, 2): self.wx0 = nn.Linear(config.hidden_size, config.hidden_size) if self.ffn_type in (2,): self.wx1 = nn.Linear(config.hidden_size, config.hidden_size) if self.ffn_type in (1, 2): self.output = nn.Linear(config.hidden_size, config.hidden_size) self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, x): if self.ffn_type in (1, 2): x0 = self.wx0(x) if self.ffn_type == 1: x1 = x elif self.ffn_type == 2: x1 = self.wx1(x) out = self.output(x0 * x1) out = self.dropout(out) out = self.LayerNorm(out + x) return out class BertLayer(nn.Module): def __init__(self, config): super(BertLayer, self).__init__() self.attention = BertAttention(config) self.ffn_type = config.ffn_type if self.ffn_type: self.ffn = TransformerFFN(config) else: self.intermediate = BertIntermediate(config) self.output = BertOutput(config) def forward(self, hidden_states, attention_mask, history_states=None, mask_qkv=None, seg_ids=None, key_history=None, value_history=None): attention_output = self.attention( hidden_states, attention_mask, history_states=history_states, mask_qkv=mask_qkv, seg_ids=seg_ids, key_history=key_history, value_history=value_history) if self.ffn_type: layer_output = self.ffn(attention_output) else: intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) return layer_output class BertEncoder(nn.Module): def __init__(self, config): super(BertEncoder, self).__init__() layer = BertLayer(config) self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, prev_embedding=None, prev_encoded_layers=None, mask_qkv=None, seg_ids=None, key_history=None, value_history=None): # history embedding and encoded layer must be simultanously given assert (prev_embedding is None) == (prev_encoded_layers is None) all_encoder_layers = [] if (prev_embedding is not None) and (prev_encoded_layers is not None): history_states = prev_embedding for i, layer_module in enumerate(self.layer): hidden_states = layer_module( hidden_states, attention_mask, history_states=history_states, mask_qkv=mask_qkv, seg_ids=seg_ids) if output_all_encoded_layers: all_encoder_layers.append(hidden_states) if prev_encoded_layers is not None: history_states = prev_encoded_layers[i] else: for i, layer_module in enumerate(self.layer): set_key = None if isinstance(key_history, list): set_key = key_history if len(key_history) < len(self.layer) else key_history[i] set_value = None if isinstance(value_history, list): set_value = value_history if len(key_history) < len(self.layer) else value_history[i] hidden_states = layer_module( hidden_states, attention_mask, mask_qkv=mask_qkv, seg_ids=seg_ids, key_history=set_key, value_history=set_value) if output_all_encoded_layers: all_encoder_layers.append(hidden_states) if not output_all_encoded_layers: all_encoder_layers.append(hidden_states) return all_encoder_layers class BertPooler(nn.Module): def __init__(self, config): super(BertPooler, self).__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = nn.Tanh() def forward(self, hidden_states): # We "pool" the model by simply taking the hidden state corresponding # to the first token. first_token_tensor = hidden_states[:, 0] pooled_output = self.dense(first_token_tensor) pooled_output = self.activation(pooled_output) return pooled_output class BertPredictionHeadTransform(nn.Module): def __init__(self, config): super(BertPredictionHeadTransform, self).__init__() self.transform_act_fn = ACT2FN[config.hidden_act] \ if isinstance(config.hidden_act, str) else config.hidden_act hid_size = config.hidden_size if hasattr(config, 'relax_projection') and (config.relax_projection > 1): hid_size *= config.relax_projection self.dense = nn.Linear(config.hidden_size, hid_size) self.LayerNorm = BertLayerNorm(hid_size, eps=1e-5) def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.LayerNorm(hidden_states) return hidden_states class LayoutlmSPLMPredictionHead(nn.Module): def __init__(self, config, src_len): super(LayoutlmSPLMPredictionHead, self).__init__() self.transform = BertPredictionHeadTransform(config) # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. self.bias = nn.Parameter(torch.zeros(src_len)) if hasattr(config, 'relax_projection') and (config.relax_projection > 1): self.relax_projection = config.relax_projection else: self.relax_projection = 0 self.fp32_embedding = config.fp32_embedding def convert_to_type(tensor): if self.fp32_embedding: return tensor.half() else: return tensor self.type_converter = convert_to_type self.converted = False def forward(self, hidden_states, src_emb, task_idx=None): if not self.converted: self.converted = True if self.fp32_embedding: self.transform.half() hidden_states = self.transform(self.type_converter(hidden_states)) if self.relax_projection > 1: num_batch = hidden_states.size(0) num_pos = hidden_states.size(1) # (batch, num_pos, relax_projection*hid) -> (batch, num_pos, relax_projection, hid) -> (batch, num_pos, hid) hidden_states = hidden_states.view( num_batch, num_pos, self.relax_projection, -1)[torch.arange(0, num_batch).long(), :, task_idx, :] if self.fp32_embedding: hidden_states = torch.einsum('btf,bsf->bts', self.type_converter(hidden_states), self.type_converter(src_emb)) + \ self.type_converter(self.bias) # hidden_states = F.linear(self.type_converter(hidden_states), self.type_converter( # self.decoder.weight), self.type_converter(self.bias)) else: hidden_states = torch.einsum('btf,bsf->bts', hidden_states, src_emb) + self.bias return hidden_states class LayoutlmSPPreTrainingHeads(nn.Module): def __init__(self, config, src_len, num_labels=2): super(LayoutlmSPPreTrainingHeads, self).__init__() self.predictions = LayoutlmSPLMPredictionHead(config, src_len) self.seq_relationship = nn.Linear(config.hidden_size, num_labels) def forward(self, sequence_output, pooled_output, src_emb, task_idx=None): prediction_scores = self.predictions(sequence_output, src_emb, task_idx) if pooled_output is None: seq_relationship_score = None else: seq_relationship_score = self.seq_relationship(pooled_output) return prediction_scores, seq_relationship_score class PreTrainedBertModel(nn.Module): """ An abstract class to handle weights initialization and a simple interface for dowloading and loading pretrained models. """ def __init__(self, config, *inputs, **kwargs): super(PreTrainedBertModel, self).__init__() if not isinstance(config, BertConfig): raise ValueError( "Parameter config in `{}(config)` should be an instance of class `BertConfig`. " "To create a model from a Google pretrained model use " "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( self.__class__.__name__, self.__class__.__name__ )) self.config = config def init_bert_weights(self, module): """ Initialize the weights. """ if isinstance(module, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) # module.weight.data.copy_(torch.Tensor( # truncnorm.rvs(-1, 1, size=list(module.weight.data.shape)) * self.config.initializer_range)) elif isinstance(module, BertLayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() @classmethod def from_pretrained(cls, pretrained_model_name, config, state_dict=None, cache_dir=None, *inputs, **kwargs): """ Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict. Download and cache the pre-trained model file if needed. Params: pretrained_model_name: either: - a str with the name of a pre-trained model to load selected in the list of: . `bert-base-uncased` . `bert-large-uncased` . `bert-base-cased` . `bert-base-multilingual` . `bert-base-chinese` - a path or url to a pretrained model archive containing: . `bert_config.json` a configuration file for the model . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance cache_dir: an optional path to a folder in which the pre-trained models will be cached. state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models *inputs, **kwargs: additional input for the specific Bert class (ex: num_labels for BertForSequenceClassification) """ logger.info("Model config {}".format(config)) # clean the arguments in kwargs for arg_clean in ('config_path', 'type_vocab_size', 'relax_projection', 'new_pos_ids', 'task_idx', 'max_position_embeddings', 'fp32_embedding', 'ffn_type', 'label_smoothing', 'hidden_dropout_prob', 'attention_probs_dropout_prob', 'num_qkv', 'seg_emb', 'word_emb_map', 'num_labels', 'num_rel', 'num_sentlvl_labels'): if arg_clean in kwargs: del kwargs[arg_clean] # Instantiate model. model = cls(config, *inputs, **kwargs) if state_dict is None: weights_path = os.path.join(pretrained_model_name, WEIGHTS_NAME) state_dict = torch.load(weights_path, map_location='cpu') old_keys = [] new_keys = [] for key in state_dict.keys(): new_key = None if 'gamma' in key: new_key = key.replace('gamma', 'weight') if 'beta' in key: new_key = key.replace('beta', 'bias') if new_key: old_keys.append(key) new_keys.append(new_key) for old_key, new_key in zip(old_keys, new_keys): state_dict[new_key] = state_dict.pop(old_key) missing_keys = [] unexpected_keys = [] error_msgs = [] # copy state_dict so _load_from_state_dict can modify it metadata = getattr(state_dict, '_metadata', None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata def load(module, prefix=''): local_metadata = {} if metadata is None else metadata.get( prefix[:-1], {}) module._load_from_state_dict( state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + '.') load(model, prefix='' if hasattr(model, 'bert') else 'bert.') model.missing_keys = missing_keys if len(missing_keys) > 0: logger.info("Weights of {} not initialized from pretrained model: {}".format( model.__class__.__name__, missing_keys)) if len(unexpected_keys) > 0: logger.info("Weights from pretrained model not used in {}: {}".format( model.__class__.__name__, unexpected_keys)) if len(error_msgs) > 0: logger.info('\n'.join(error_msgs)) return model class BertModel(PreTrainedBertModel): """BERT model ("Bidirectional Embedding Representations from a Transformer"). Params: config: a BertConfig class instance with the configuration to build a new model Inputs: `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts `extract_features.py`, `run_classifier.py` and `run_squad.py`) `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details). `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max input sequence length in the current batch. It's the mask that we typically use for attention when a batch has varying length sentences. `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. Outputs: Tuple of (encoded_layers, pooled_output) `encoded_layers`: controled by `output_all_encoded_layers` argument: - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding to the last attention block of shape [batch_size, sequence_length, hidden_size], `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a classifier pretrained on top of the hidden state associated to the first character of the input (`CLF`) to train on the Next-Sentence task (see BERT's paper). ``` """ def __init__(self, config): super(BertModel, self).__init__(config) self.embeddings = BertEmbeddings(config) self.encoder = BertEncoder(config) self.pooler = BertPooler(config) self.apply(self.init_bert_weights) def rescale_some_parameters(self): for layer_id, layer in enumerate(self.encoder.layer): layer.attention.output.dense.weight.data.div_( math.sqrt(2.0 * (layer_id + 1))) layer.output.dense.weight.data.div_(math.sqrt(2.0 * (layer_id + 1))) def get_extended_attention_mask(self, input_ids, token_type_ids, attention_mask): if attention_mask is None: attention_mask = torch.ones_like(input_ids) if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) # We create a 3D attention mask from a 2D tensor mask. # Sizes are [batch_size, 1, 1, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # this attention mask is more simple than the triangular masking of causal attention # used in OpenAI GPT, we just need to prepare the broadcast dimension here. if attention_mask.dim() == 2: extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) elif attention_mask.dim() == 3: extended_attention_mask = attention_mask.unsqueeze(1) else: raise NotImplementedError # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. extended_attention_mask = extended_attention_mask.to( dtype=next(self.parameters()).dtype) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 return extended_attention_mask def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True, mask_qkv=None, task_idx=None, key_history=None, value_history=None, position_ids=None): extended_attention_mask = self.get_extended_attention_mask( input_ids, token_type_ids, attention_mask) embedding_output = self.embeddings( input_ids, token_type_ids, task_idx=task_idx, position_ids=position_ids) encoded_layers = self.encoder(embedding_output, extended_attention_mask, output_all_encoded_layers=output_all_encoded_layers, mask_qkv=mask_qkv, seg_ids=token_type_ids, key_history=key_history, value_history=value_history) sequence_output = encoded_layers[-1] pooled_output = self.pooler(sequence_output) if not output_all_encoded_layers: encoded_layers = encoded_layers[-1] return encoded_layers, pooled_output class LayoutlmModel(PreTrainedBertModel): """BERT model ("Bidirectional Embedding Representations from a Transformer"). Params: config: a BertConfig class instance with the configuration to build a new model Inputs: `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts `extract_features.py`, `run_classifier.py` and `run_squad.py`) `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details). `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max input sequence length in the current batch. It's the mask that we typically use for attention when a batch has varying length sentences. `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. Outputs: Tuple of (encoded_layers, pooled_output) `encoded_layers`: controled by `output_all_encoded_layers` argument: - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding to the last attention block of shape [batch_size, sequence_length, hidden_size], `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a classifier pretrained on top of the hidden state associated to the first character of the input (`CLF`) to train on the Next-Sentence task (see BERT's paper). ``` """ def __init__(self, config): super(LayoutlmModel, self).__init__(config) self.embeddings = LayoutlmEmbeddings(config) self.encoder = BertEncoder(config) self.pooler = BertPooler(config) self.apply(self.init_bert_weights) def rescale_some_parameters(self): for layer_id, layer in enumerate(self.encoder.layer): layer.attention.output.dense.weight.data.div_( math.sqrt(2.0 * (layer_id + 1))) layer.output.dense.weight.data.div_(math.sqrt(2.0 * (layer_id + 1))) def get_extended_attention_mask(self, input_ids, token_type_ids, attention_mask): if attention_mask is None: attention_mask = torch.ones_like(input_ids) if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) # We create a 3D attention mask from a 2D tensor mask. # Sizes are [batch_size, 1, 1, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # this attention mask is more simple than the triangular masking of causal attention # used in OpenAI GPT, we just need to prepare the broadcast dimension here. if attention_mask.dim() == 2: extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) elif attention_mask.dim() == 3: extended_attention_mask = attention_mask.unsqueeze(1) else: raise NotImplementedError # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. extended_attention_mask = extended_attention_mask.to( dtype=next(self.parameters()).dtype) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 return extended_attention_mask def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True, mask_qkv=None, task_idx=None, key_history=None, value_history=None, position_ids=None): extended_attention_mask = self.get_extended_attention_mask( input_ids[:, :, 0], token_type_ids, attention_mask) embedding_output = self.embeddings( input_ids[:, :, 0], input_ids[:, :, 1:], token_type_ids, task_idx=task_idx, position_ids=position_ids) encoded_layers = self.encoder(embedding_output, extended_attention_mask, output_all_encoded_layers=output_all_encoded_layers, mask_qkv=mask_qkv, seg_ids=token_type_ids, key_history=key_history, value_history=value_history) sequence_output = encoded_layers[-1] pooled_output = self.pooler(sequence_output) if not output_all_encoded_layers: encoded_layers = encoded_layers[-1] return encoded_layers, pooled_output class BertModelIncr(BertModel): def __init__(self, config): super(BertModelIncr, self).__init__(config) def forward(self, input_ids, token_type_ids, position_ids, attention_mask, output_all_encoded_layers=True, prev_embedding=None, prev_encoded_layers=None, mask_qkv=None, task_idx=None): extended_attention_mask = self.get_extended_attention_mask( input_ids, token_type_ids, attention_mask) embedding_output = self.embeddings( input_ids, token_type_ids, position_ids, task_idx=task_idx) encoded_layers = self.encoder(embedding_output, extended_attention_mask, output_all_encoded_layers=output_all_encoded_layers, prev_embedding=prev_embedding, prev_encoded_layers=prev_encoded_layers, mask_qkv=mask_qkv, seg_ids=token_type_ids) sequence_output = encoded_layers[-1] pooled_output = self.pooler(sequence_output) if not output_all_encoded_layers: encoded_layers = encoded_layers[-1] return embedding_output, encoded_layers, pooled_output class LayoutlmModelIncr(LayoutlmModel): def __init__(self, config): super(LayoutlmModelIncr, self).__init__(config) def forward(self, input_ids, token_type_ids, position_ids, attention_mask, output_all_encoded_layers=True, prev_embedding=None, prev_encoded_layers=None, mask_qkv=None, task_idx=None): extended_attention_mask = self.get_extended_attention_mask( input_ids[:, :, 0], token_type_ids, attention_mask) embedding_output = self.embeddings( input_ids[:, :, 0], input_ids[:, :, 1:], token_type_ids, position_ids, task_idx=task_idx) encoded_layers = self.encoder(embedding_output, extended_attention_mask, output_all_encoded_layers=output_all_encoded_layers, prev_embedding=prev_embedding, prev_encoded_layers=prev_encoded_layers, mask_qkv=mask_qkv, seg_ids=token_type_ids) sequence_output = encoded_layers[-1] pooled_output = self.pooler(sequence_output) if not output_all_encoded_layers: encoded_layers = encoded_layers[-1] return embedding_output, encoded_layers, pooled_output class LayoutlmForSeq2SeqDecoder(PreTrainedBertModel): """refer to BertForPreTraining""" def __init__(self, config, mask_word_id=0, num_labels=2, num_rel=0, search_beam_size=1, length_penalty=1.0, eos_id=0, sos_id=0, forbid_duplicate_ngrams=False, forbid_ignore_set=None, ngram_size=3, min_len=0, mode="s2s", pos_shift=False): super(LayoutlmForSeq2SeqDecoder, self).__init__(config) self.layout_flag = config.base_model_type == 'layoutlm' if config.base_model_type == 'layoutlm': self.bert = LayoutlmModelIncr(config) else: self.bert = BertModelIncr(config) # self.bert = BertModelIncr(config) # note: the max source length is the max src seq length during fine tuning which includes the cls and sep # NOTE: we don't remove anything. the 0 is for padding self.cls = LayoutlmSPPreTrainingHeads( config, src_len=config.max_source_length, num_labels=num_labels) self.apply(self.init_bert_weights) self.crit_mask_lm = nn.CrossEntropyLoss(reduction='none') self.crit_next_sent = nn.CrossEntropyLoss(ignore_index=-1) self.mask_word_id = mask_word_id self.num_labels = num_labels self.num_rel = num_rel self.search_beam_size = search_beam_size self.length_penalty = length_penalty self.eos_id = eos_id self.sos_id = sos_id self.forbid_duplicate_ngrams = forbid_duplicate_ngrams self.forbid_ignore_set = forbid_ignore_set self.ngram_size = ngram_size self.min_len = min_len assert mode in ("s2s", "l2r") self.mode = mode self.pos_shift = pos_shift def forward(self, input_ids, token_type_ids, position_ids, attention_mask, task_idx=None, mask_qkv=None): if self.search_beam_size > 1: return self.beam_search(input_ids, token_type_ids, position_ids, attention_mask, task_idx=task_idx, mask_qkv=mask_qkv) input_shape = list(input_ids.size()) batch_size = input_shape[0] input_length = input_shape[1] output_shape = list(token_type_ids.size()) output_length = output_shape[1] output_ids = [] prev_embedding = None prev_encoded_layers = None curr_ids = input_ids if not self.layout_flag: mask_ids = input_ids.new(batch_size, 1).fill_(self.mask_word_id) else: mask_ids = input_ids.new_zeros(batch_size, 1, 5) mask_ids[:, :, 0] = self.mask_word_id next_pos = input_length if self.pos_shift: if not self.layout_flag: sos_ids = input_ids.new(batch_size, 1).fill_(self.sos_id) else: sos_ids = input_ids.new_zeros(batch_size, 1, 5) sos_ids[:, :, 0] = self.sos_id src_embedding = None while next_pos < output_length: curr_length = list(curr_ids.size())[1] if self.pos_shift: if next_pos == input_length: x_input_ids = torch.cat((curr_ids, sos_ids), dim=1) start_pos = 0 else: x_input_ids = curr_ids start_pos = next_pos else: start_pos = next_pos - curr_length # if self.layout_flag: # mask_ids[:, -1, 1:] = curr_ids[:, , 1:] x_input_ids = torch.cat((curr_ids, mask_ids), dim=1) curr_token_type_ids = token_type_ids[:, start_pos:next_pos + 1] curr_attention_mask = attention_mask[:, start_pos:next_pos + 1, :next_pos + 1] curr_position_ids = position_ids[:, start_pos:next_pos + 1] new_embedding, new_encoded_layers, _ = \ self.bert(x_input_ids, curr_token_type_ids, curr_position_ids, curr_attention_mask, output_all_encoded_layers=True, prev_embedding=prev_embedding, prev_encoded_layers=prev_encoded_layers, mask_qkv=mask_qkv) if src_embedding is None: # note: cut three embedding: CLS (1st), ..., SEP (-2nd), next to pred (-1st) # note: (NEW) the sep is kept for ignore index in loss func (for padding's index) # NOTE: only remove the next to pred token src_embedding = new_embedding[:, :-1, :] last_hidden = new_encoded_layers[-1][:, -1:, :] prediction_scores, _ = self.cls(last_hidden, None, src_embedding, task_idx=task_idx) _, max_ids = torch.max(prediction_scores, dim=-1) output_ids.append(max_ids) if self.pos_shift: if prev_embedding is None: prev_embedding = new_embedding else: prev_embedding = torch.cat( (prev_embedding, new_embedding), dim=1) if prev_encoded_layers is None: prev_encoded_layers = [x for x in new_encoded_layers] else: prev_encoded_layers = [torch.cat((x[0], x[1]), dim=1) for x in zip( prev_encoded_layers, new_encoded_layers)] else: if prev_embedding is None: prev_embedding = new_embedding[:, :-1, :] else: prev_embedding = torch.cat( (prev_embedding, new_embedding[:, :-1, :]), dim=1) if prev_encoded_layers is None: prev_encoded_layers = [x[:, :-1, :] for x in new_encoded_layers] else: prev_encoded_layers = [torch.cat((x[0], x[1][:, :-1, :]), dim=1) for x in zip(prev_encoded_layers, new_encoded_layers)] if not self.layout_flag: index = max_ids curr_ids = torch.gather(input_ids, 1, index) else: _, _, dim = input_ids.shape index = max_ids.unsqueeze(-1) index = index.expand(index.shape[0], index.shape[1], dim) # index = index.repeat(1, 1, dim) curr_ids = torch.gather(input_ids, 1, index) # if len(input_ids.shape) == 2: # real_input_ids = input_ids[:, 1:] # index = max_ids # curr_ids = torch.gather(real_input_ids, 1, index) # else: # real_input_ids = input_ids[:, 1:, :] # _, _, dim = real_input_ids.shape # index = max_ids.unsqueeze(-1) # index = index.expand(index.shape[0], index.shape[1], dim) # curr_ids = torch.gather(real_input_ids, 1, index) # # note: real input ids only include the ids for real data (remove the cls and sep) # real_input_ids = input_ids[:, 1: -1, :] # # _, _, dim = real_input_ids.shape # index = max_ids.unsqueeze(-1) # index = index.expand(index.shape[0], index.shape[1], dim) # # curr_ids = torch.gather(real_input_ids, 1, index) # curr_ids = real_input_ids[:, max_ids, :] # curr_ids = max_ids next_pos += 1 return torch.cat(output_ids, dim=1) # TODO: do the same with beam search as forward() def beam_search(self, input_ids, token_type_ids, position_ids, attention_mask, task_idx=None, mask_qkv=None): input_shape = list(input_ids.size()) batch_size = input_shape[0] input_length = input_shape[1] output_shape = list(token_type_ids.size()) output_length = output_shape[1] output_ids = [] prev_embedding = None prev_encoded_layers = None curr_ids = input_ids # mask_ids = input_ids.new(batch_size, 1).fill_(self.mask_word_id) if not self.layout_flag: mask_ids = input_ids.new(batch_size, 1).fill_(self.mask_word_id) else: mask_ids = input_ids.new_zeros(batch_size, 1, 5) mask_ids[:, :, 0] = self.mask_word_id next_pos = input_length if self.pos_shift: if not self.layout_flag: sos_ids = input_ids.new(batch_size, 1).fill_(self.sos_id) else: sos_ids = input_ids.new_zeros(batch_size, 1, 5) sos_ids[:, :, 0] = self.sos_id K = self.search_beam_size total_scores = [] beam_masks = [] step_ids = [] step_back_ptrs = [] partial_seqs = [] forbid_word_mask = None buf_matrix = None src_embedding = None while next_pos < output_length: curr_length = list(curr_ids.size())[1] if self.pos_shift: if next_pos == input_length: x_input_ids = torch.cat((curr_ids, sos_ids), dim=1) start_pos = 0 else: x_input_ids = curr_ids start_pos = next_pos else: start_pos = next_pos - curr_length x_input_ids = torch.cat((curr_ids, mask_ids), dim=1) curr_token_type_ids = token_type_ids[:, start_pos:next_pos + 1] curr_attention_mask = attention_mask[:, start_pos:next_pos + 1, :next_pos + 1] curr_position_ids = position_ids[:, start_pos:next_pos + 1] new_embedding, new_encoded_layers, _ = \ self.bert(x_input_ids, curr_token_type_ids, curr_position_ids, curr_attention_mask, output_all_encoded_layers=True, prev_embedding=prev_embedding, prev_encoded_layers=prev_encoded_layers, mask_qkv=mask_qkv) def first_expand(x): input_shape = list(x.size()) expanded_shape = input_shape[:1] + [1] + input_shape[1:] x = torch.reshape(x, expanded_shape) repeat_count = [1, K] + [1] * (len(input_shape) - 1) x = x.repeat(*repeat_count) x = torch.reshape(x, [input_shape[0] * K] + input_shape[1:]) return x if src_embedding is None: src_embedding = new_embedding[:, :-1, :] if src_embedding.shape[0] != new_embedding.shape[0]: src_embedding = first_expand(src_embedding) last_hidden = new_encoded_layers[-1][:, -1:, :] prediction_scores, _ = self.cls(last_hidden, None, src_embedding, task_idx=task_idx) log_scores = torch.nn.functional.log_softmax( prediction_scores, dim=-1) # if forbid_word_mask is not None: # log_scores += (forbid_word_mask * -10000.0) # if self.min_len and (next_pos - input_length + 1 <= self.min_len): # log_scores[:, :, self.eos_id].fill_(-10000.0) kk_scores, kk_ids = torch.topk(log_scores, k=K) if len(total_scores) == 0: k_ids = torch.reshape(kk_ids, [batch_size, K]) back_ptrs = torch.zeros(batch_size, K, dtype=torch.long) k_scores = torch.reshape(kk_scores, [batch_size, K]) else: last_eos = torch.reshape( beam_masks[-1], [batch_size * K, 1, 1]) last_seq_scores = torch.reshape( total_scores[-1], [batch_size * K, 1, 1]) kk_scores += last_eos * (-10000.0) + last_seq_scores kk_scores = torch.reshape(kk_scores, [batch_size, K * K]) k_scores, k_ids = torch.topk(kk_scores, k=K) back_ptrs = torch.floor_divide(k_ids, K) kk_ids = torch.reshape(kk_ids, [batch_size, K * K]) k_ids = torch.gather(kk_ids, 1, k_ids) step_back_ptrs.append(back_ptrs) step_ids.append(k_ids) beam_masks.append(torch.eq(k_ids, self.eos_id).type_as(kk_scores)) total_scores.append(k_scores) # def first_expand(x): # input_shape = list(x.size()) # expanded_shape = input_shape[:1] + [1] + input_shape[1:] # x = torch.reshape(x, expanded_shape) # repeat_count = [1, K] + [1] * (len(input_shape) - 1) # x = x.repeat(*repeat_count) # x = torch.reshape(x, [input_shape[0] * K] + input_shape[1:]) # return x def select_beam_items(x, ids): id_shape = list(ids.size()) id_rank = len(id_shape) assert len(id_shape) == 2 x_shape = list(x.size()) x = torch.reshape(x, [batch_size, K] + x_shape[1:]) x_rank = len(x_shape) + 1 assert x_rank >= 2 if id_rank < x_rank: ids = torch.reshape( ids, id_shape + [1] * (x_rank - id_rank)) ids = ids.expand(id_shape + x_shape[1:]) y = torch.gather(x, 1, ids) y = torch.reshape(y, x_shape) return y is_first = (prev_embedding is None) if self.pos_shift: if prev_embedding is None: prev_embedding = first_expand(new_embedding) else: prev_embedding = torch.cat( (prev_embedding, new_embedding), dim=1) prev_embedding = select_beam_items( prev_embedding, back_ptrs) if prev_encoded_layers is None: prev_encoded_layers = [first_expand( x) for x in new_encoded_layers] else: prev_encoded_layers = [torch.cat((x[0], x[1]), dim=1) for x in zip( prev_encoded_layers, new_encoded_layers)] prev_encoded_layers = [select_beam_items( x, back_ptrs) for x in prev_encoded_layers] else: if prev_embedding is None: prev_embedding = first_expand(new_embedding[:, :-1, :]) else: prev_embedding = torch.cat( (prev_embedding, new_embedding[:, :-1, :]), dim=1) prev_embedding = select_beam_items( prev_embedding, back_ptrs) if prev_encoded_layers is None: prev_encoded_layers = [first_expand( x[:, :-1, :]) for x in new_encoded_layers] else: prev_encoded_layers = [torch.cat((x[0], x[1][:, :-1, :]), dim=1) for x in zip(prev_encoded_layers, new_encoded_layers)] prev_encoded_layers = [select_beam_items( x, back_ptrs) for x in prev_encoded_layers] max_ids = torch.reshape(k_ids, [batch_size * K, 1]) if len(input_ids.shape) == 2: expand_input_ids = first_expand(input_ids) index = max_ids curr_ids = torch.gather(expand_input_ids, 1, index) else: expand_input_ids = first_expand(input_ids) _, _, dim = expand_input_ids.shape index = max_ids.unsqueeze(-1) index = index.expand(index.shape[0], index.shape[1], dim) curr_ids = torch.gather(expand_input_ids, 1, index) if is_first: token_type_ids = first_expand(token_type_ids) position_ids = first_expand(position_ids) attention_mask = first_expand(attention_mask) mask_ids = first_expand(mask_ids) if mask_qkv is not None: mask_qkv = first_expand(mask_qkv) if self.forbid_duplicate_ngrams: wids = step_ids[-1].tolist() ptrs = step_back_ptrs[-1].tolist() if is_first: partial_seqs = [] for b in range(batch_size): for k in range(K): partial_seqs.append([wids[b][k]]) else: new_partial_seqs = [] for b in range(batch_size): for k in range(K): new_partial_seqs.append( partial_seqs[ptrs[b][k] + b * K] + [wids[b][k]]) partial_seqs = new_partial_seqs def get_dup_ngram_candidates(seq, n): cands = set() if len(seq) < n: return [] tail = seq[-(n - 1):] if self.forbid_ignore_set and any(tk in self.forbid_ignore_set for tk in tail): return [] for i in range(len(seq) - (n - 1)): mismatch = False for j in range(n - 1): if tail[j] != seq[i + j]: mismatch = True break if (not mismatch) and not ( self.forbid_ignore_set and (seq[i + n - 1] in self.forbid_ignore_set)): cands.add(seq[i + n - 1]) return list(sorted(cands)) if len(partial_seqs[0]) >= self.ngram_size: dup_cands = [] for seq in partial_seqs: dup_cands.append( get_dup_ngram_candidates(seq, self.ngram_size)) if max(len(x) for x in dup_cands) > 0: if buf_matrix is None: vocab_size = list(log_scores.size())[-1] buf_matrix = np.zeros( (batch_size * K, vocab_size), dtype=float) else: buf_matrix.fill(0) for bk, cands in enumerate(dup_cands): for i, wid in enumerate(cands): buf_matrix[bk, wid] = 1.0 forbid_word_mask = torch.tensor( buf_matrix, dtype=log_scores.dtype) forbid_word_mask = torch.reshape( forbid_word_mask, [batch_size * K, 1, vocab_size]).to(input_ids.device) else: forbid_word_mask = None next_pos += 1 # [(batch, beam)] total_scores = [x.tolist() for x in total_scores] step_ids = [x.tolist() for x in step_ids] step_back_ptrs = [x.tolist() for x in step_back_ptrs] # back tracking traces = {'pred_seq': [], 'scores': [], 'wids': [], 'ptrs': []} for b in range(batch_size): # [(beam,)] scores = [x[b] for x in total_scores] wids_list = [x[b] for x in step_ids] ptrs = [x[b] for x in step_back_ptrs] traces['scores'].append(scores) traces['wids'].append(wids_list) traces['ptrs'].append(ptrs) # first we need to find the eos frame where all symbols are eos # any frames after the eos frame are invalid last_frame_id = len(scores) - 1 for i, wids in enumerate(wids_list): if all(wid == self.eos_id for wid in wids): last_frame_id = i break max_score = -math.inf frame_id = -1 pos_in_frame = -1 for fid in range(last_frame_id + 1): for i, wid in enumerate(wids_list[fid]): if wid == self.eos_id or fid == last_frame_id: s = scores[fid][i] if self.length_penalty > 0: s /= math.pow((5 + fid + 1) / 6.0, self.length_penalty) if s > max_score: max_score = s frame_id = fid pos_in_frame = i if frame_id == -1: traces['pred_seq'].append([0]) else: seq = [wids_list[frame_id][pos_in_frame]] for fid in range(frame_id, 0, -1): pos_in_frame = ptrs[fid][pos_in_frame] seq.append(wids_list[fid - 1][pos_in_frame]) seq.reverse() traces['pred_seq'].append(seq) def _pad_sequence(sequences, max_len, padding_value=0): trailing_dims = sequences[0].size()[1:] out_dims = (len(sequences), max_len) + trailing_dims out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value) for i, tensor in enumerate(sequences): length = tensor.size(0) # use index notation to prevent duplicate references to the tensor out_tensor[i, :length, ...] = tensor return out_tensor # convert to tensors for DataParallel for k in ('pred_seq', 'scores', 'wids', 'ptrs'): ts_list = traces[k] if not isinstance(ts_list[0], torch.Tensor): dt = torch.float if k == 'scores' else torch.long ts_list = [torch.tensor(it, dtype=dt) for it in ts_list] traces[k] = _pad_sequence( ts_list, output_length, padding_value=0).to(input_ids.device) return traces