import logging import torch from torch import nn from torch.nn import CrossEntropyLoss from transformers import BertConfig, BertModel, BertPreTrainedModel, RobertaConfig # from transformers.modeling_bert import BertLayerNorm, BertOnlyMLMHead logger = logging.getLogger(__name__) LAYOUTLMV1_PRETRAINED_MODEL_ARCHIVE_MAP = {} LAYOUTLMV1_PRETRAINED_CONFIG_ARCHIVE_MAP = {} class Layoutlmv1Config(RobertaConfig): pretrained_config_archive_map = LAYOUTLMV1_PRETRAINED_CONFIG_ARCHIVE_MAP model_type = "bert" def __init__(self, max_2d_position_embeddings=1024, add_linear=False, **kwargs): super().__init__(**kwargs) pass class Layoutlmv1Embeddings(nn.Module): def __init__(self, config): super(Layoutlmv1Embeddings, self).__init__() self.config = config self.word_embeddings = nn.Embedding( config.vocab_size, config.hidden_size, padding_idx=0 ) self.position_embeddings = nn.Embedding( config.max_position_embeddings, config.hidden_size ) config.max_2d_position_embeddings = 1024 self.x_position_embeddings = nn.Embedding( config.max_2d_position_embeddings, config.hidden_size ) self.y_position_embeddings = nn.Embedding( config.max_2d_position_embeddings, config.hidden_size ) self.h_position_embeddings = nn.Embedding( config.max_2d_position_embeddings, config.hidden_size ) self.w_position_embeddings = nn.Embedding( config.max_2d_position_embeddings, config.hidden_size ) self.token_type_embeddings = nn.Embedding( config.type_vocab_size, config.hidden_size ) self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.doc_linear1 = nn.Linear(config.hidden_size, config.hidden_size) self.doc_linear2 = nn.Linear(config.hidden_size, config.hidden_size) self.doc_linear3 = nn.Linear(config.hidden_size, config.hidden_size) self.doc_linear4 = nn.Linear(config.hidden_size, config.hidden_size) self.relu = nn.ReLU() def forward( self, input_ids, bbox, token_type_ids=None, position_ids=None, inputs_embeds=None, ): seq_length = input_ids.size(1) if position_ids is None: position_ids = torch.arange( seq_length, dtype=torch.long, device=input_ids.device ) position_ids = position_ids.unsqueeze(0).expand_as(input_ids) if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) words_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0]) upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1]) right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2]) lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3]) h_position_embeddings = self.h_position_embeddings( bbox[:, :, 3] - bbox[:, :, 1] ) w_position_embeddings = self.w_position_embeddings( bbox[:, :, 2] - bbox[:, :, 0] ) temp_embeddings = self.doc_linear2(self.relu(self.doc_linear1( left_position_embeddings + upper_position_embeddings + right_position_embeddings + lower_position_embeddings + h_position_embeddings + w_position_embeddings ))) embeddings = ( words_embeddings + position_embeddings + temp_embeddings + token_type_embeddings ) embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings class Layoutlmv1Model(BertModel): config_class = Layoutlmv1Config pretrained_model_archive_map = LAYOUTLMV1_PRETRAINED_MODEL_ARCHIVE_MAP base_model_prefix = "bert" def __init__(self, config): super(Layoutlmv1Model, self).__init__(config) self.embeddings = Layoutlmv1Embeddings(config) self.init_weights() def forward( self, input_ids, bbox, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, ): if attention_mask is None: attention_mask = torch.ones_like(input_ids) if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) # We create a 3D attention mask from a 2D tensor mask. # Sizes are [batch_size, 1, 1, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # this attention mask is more simple than the triangular masking of causal attention # used in OpenAI GPT, we just need to prepare the broadcast dimension here. extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. extended_attention_mask = extended_attention_mask.to( dtype=torch.float32 # dtype=next(self.parameters()).dtype # this will trigger error when using high version torch ) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] if head_mask is not None: if head_mask.dim() == 1: head_mask = ( head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) ) head_mask = head_mask.expand( self.config.num_hidden_layers, -1, -1, -1, -1 ) elif head_mask.dim() == 2: head_mask = ( head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) ) # We can specify head_mask for each layer head_mask = head_mask.to( dtype=next(self.parameters()).dtype ) # switch to fload if need + fp16 compatibility else: head_mask = [None] * self.config.num_hidden_layers embedding_output = self.embeddings( input_ids, bbox, position_ids=position_ids, token_type_ids=token_type_ids ) encoder_outputs = self.encoder( embedding_output, extended_attention_mask, head_mask=head_mask ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) outputs = (sequence_output, pooled_output) + encoder_outputs[ 1: ] # add hidden_states and attentions if they are here return outputs # sequence_output, pooled_output, (hidden_states), (attentions) class Layoutlmv1ForTokenClassification(BertPreTrainedModel): config_class = Layoutlmv1Config pretrained_model_archive_map = LAYOUTLMV1_PRETRAINED_MODEL_ARCHIVE_MAP base_model_prefix = "bert" def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.roberta = Layoutlmv1Model(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.init_weights() def forward( self, input_ids, bbox, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None, ): outputs = self.roberta( input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) outputs = (logits,) + outputs[ 2: ] # add hidden states and attention if they are here if labels is not None: loss_fct = CrossEntropyLoss() # Only keep active parts of the loss if attention_mask is not None: active_loss = attention_mask.view(-1) == 1 active_logits = logits.view(-1, self.num_labels)[active_loss] active_labels = labels.view(-1)[active_loss] loss = loss_fct(active_logits, active_labels) else: loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) outputs = (loss,) + outputs return outputs # (loss), scores, (hidden_states), (attentions) class Layoutlmv1ForMaskedLM(BertPreTrainedModel): config_class = Layoutlmv1Config pretrained_model_archive_map = LAYOUTLMV1_PRETRAINED_MODEL_ARCHIVE_MAP base_model_prefix = "bert" def __init__(self, config): super().__init__(config) self.bert = Layoutlmv1Model(config) self.cls = BertOnlyMLMHead(config) self.init_weights() def get_input_embeddings(self): return self.bert.embeddings.word_embeddings def get_output_embeddings(self): return self.cls.predictions.decoder def forward( self, input_ids, bbox, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, masked_lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None, lm_labels=None, ): outputs = self.layoutlm( input_ids, bbox, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, ) sequence_output = outputs[0] prediction_scores = self.cls(sequence_output) outputs = (prediction_scores,) + outputs[ 2: ] # Add hidden states and attention if they are here # Although this may seem awkward, BertForMaskedLM supports two scenarios: # 1. If a tensor that contains the indices of masked labels is provided, # the cross-entropy is the MLM cross-entropy that measures the likelihood # of predictions for masked words. # 2. If `lm_labels` is provided we are in a causal scenario where we # try to predict the next token for each input in the decoder. if masked_lm_labels is not None: loss_fct = CrossEntropyLoss() masked_lm_loss = loss_fct( prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1), ) outputs = (masked_lm_loss,) + outputs return ( outputs ) # (masked_lm_loss), (ltr_lm_loss), prediction_scores, (hidden_states), (attentions) class Layoutlmv1ForQuestionAnswering(BertPreTrainedModel): config_class = Layoutlmv1Config pretrained_model_archive_map = LAYOUTLMV1_PRETRAINED_MODEL_ARCHIVE_MAP base_model_prefix = "bert" def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.bert = Layoutlmv1Model(config) self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) self.init_weights() def forward( self, input_ids, bbox, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, # inputs_embeds=None, start_positions=None, end_positions=None, # output_attentions=None, # output_hidden_states=None, # return_dict=None, ): # import numpy as np # torch.set_printoptions(threshold=np.inf) # print(bbox[0]) # exit(0) r""" start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): Labels for position (index) of the start of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): Labels for position (index) of the end of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. """ # return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.bert( input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, ) sequence_output = outputs[0] logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1) end_logits = end_logits.squeeze(-1) total_loss = None if start_positions is not None and end_positions is not None: # If we are on multi-GPU, split add a dimension if len(start_positions.size()) > 1: start_positions = start_positions.squeeze(-1) if len(end_positions.size()) > 1: end_positions = end_positions.squeeze(-1) # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) start_positions.clamp_(0, ignored_index) end_positions.clamp_(0, ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index) start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 output = (start_logits, end_logits) + outputs[2:] return ((total_loss,) + output) if total_loss is not None else output