Spaces:
Sleeping
Sleeping
# coding=utf-8 | |
"""PyTorch BERT model.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import copy | |
import json | |
import logging | |
import math | |
import os | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from torch.nn.modules.loss import _Loss | |
class LabelSmoothingLoss(_Loss): | |
""" | |
With label smoothing, | |
KL-divergence between q_{smoothed ground truth prob.}(w) | |
and p_{prob. computed by model}(w) is minimized. | |
""" | |
def __init__(self, label_smoothing=0, tgt_vocab_size=0, ignore_index=0, size_average=None, reduce=None, | |
reduction='mean'): | |
assert 0.0 < label_smoothing <= 1.0 | |
self.ignore_index = ignore_index | |
super(LabelSmoothingLoss, self).__init__( | |
size_average=size_average, reduce=reduce, reduction=reduction) | |
assert label_smoothing > 0 | |
assert tgt_vocab_size > 0 | |
smoothing_value = label_smoothing / (tgt_vocab_size - 2) | |
one_hot = torch.full((tgt_vocab_size,), smoothing_value) | |
one_hot[self.ignore_index] = 0 | |
self.register_buffer('one_hot', one_hot.unsqueeze(0)) | |
self.confidence = 1.0 - label_smoothing | |
self.tgt_vocab_size = tgt_vocab_size | |
def forward(self, output, target): | |
""" | |
output (FloatTensor): batch_size * num_pos * n_classes | |
target (LongTensor): batch_size * num_pos | |
""" | |
assert self.tgt_vocab_size == output.size(2) | |
batch_size, num_pos = target.size(0), target.size(1) | |
output = output.view(-1, self.tgt_vocab_size) | |
target = target.view(-1) | |
model_prob = self.one_hot.repeat(target.size(0), 1) | |
model_prob.scatter_(1, target.unsqueeze(1), self.confidence) | |
model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0) | |
return F.kl_div(output, model_prob, reduction='none').view(batch_size, num_pos, -1).sum(2) | |
logger = logging.getLogger(__name__) | |
PRETRAINED_MODEL_ARCHIVE_MAP = { | |
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", | |
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", | |
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", | |
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", | |
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", | |
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", | |
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", | |
'unilm-base-cased': "https://conversationhub.blob.core.windows.net/beit-share-public/ckpt/unilm1-base-cased.bin?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D", | |
'unilm-large-cased': "https://conversationhub.blob.core.windows.net/beit-share-public/ckpt/unilm1-large-cased.bin?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D", | |
'unilm1-base-cased': "https://conversationhub.blob.core.windows.net/beit-share-public/ckpt/unilm1-base-cased.bin?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D", | |
'unilm1-large-cased': "https://conversationhub.blob.core.windows.net/beit-share-public/ckpt/unilm1-large-cased.bin?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D", | |
'unilm1.2-base-uncased': "https://conversationhub.blob.core.windows.net/beit-share-public/ckpt/unilm1.2-base-uncased.bin?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D" | |
} | |
CONFIG_NAME = 'config.json' | |
WEIGHTS_NAME = 'pytorch_model.bin' | |
def gelu(x): | |
"""Implementation of the gelu activation function. | |
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): | |
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) | |
""" | |
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) | |
def swish(x): | |
return x * torch.sigmoid(x) | |
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} | |
class BertConfig(object): | |
"""Configuration class to store the configuration of a `BertModel`. | |
""" | |
def __init__(self, | |
vocab_size_or_config_json_file, | |
hidden_size=768, | |
num_hidden_layers=12, | |
num_attention_heads=12, | |
intermediate_size=3072, | |
hidden_act="gelu", | |
hidden_dropout_prob=0.1, | |
attention_probs_dropout_prob=0.1, | |
max_position_embeddings=512, | |
type_vocab_size=2, | |
relax_projection=0, | |
new_pos_ids=False, | |
initializer_range=0.02, | |
task_idx=None, | |
fp32_embedding=False, | |
ffn_type=0, | |
label_smoothing=None, | |
num_qkv=0, | |
seg_emb=False, | |
source_type_id=0, | |
target_type_id=1, | |
no_segment_embedding=False, **kwargs): | |
"""Constructs BertConfig. | |
Args: | |
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. | |
hidden_size: Size of the encoder layers and the pooler layer. | |
num_hidden_layers: Number of hidden layers in the Transformer encoder. | |
num_attention_heads: Number of attention heads for each attention layer in | |
the Transformer encoder. | |
intermediate_size: The size of the "intermediate" (i.e., feed-forward) | |
layer in the Transformer encoder. | |
hidden_act: The non-linear activation function (function or string) in the | |
encoder and pooler. If string, "gelu", "relu" and "swish" are supported. | |
hidden_dropout_prob: The dropout probabilitiy for all fully connected | |
layers in the embeddings, encoder, and pooler. | |
attention_probs_dropout_prob: The dropout ratio for the attention | |
probabilities. | |
max_position_embeddings: The maximum sequence length that this model might | |
ever be used with. Typically set this to something large just in case | |
(e.g., 512 or 1024 or 2048). | |
type_vocab_size: The vocabulary size of the `token_type_ids` passed into | |
`BertModel`. | |
initializer_range: The sttdev of the truncated_normal_initializer for | |
initializing all weight matrices. | |
""" | |
if isinstance(vocab_size_or_config_json_file, str): | |
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: | |
json_config = json.loads(reader.read()) | |
for key, value in json_config.items(): | |
self.__dict__[key] = value | |
elif isinstance(vocab_size_or_config_json_file, int): | |
self.vocab_size = vocab_size_or_config_json_file | |
self.hidden_size = hidden_size | |
self.num_hidden_layers = num_hidden_layers | |
self.num_attention_heads = num_attention_heads | |
self.hidden_act = hidden_act | |
self.intermediate_size = intermediate_size | |
self.hidden_dropout_prob = hidden_dropout_prob | |
self.attention_probs_dropout_prob = attention_probs_dropout_prob | |
self.max_position_embeddings = max_position_embeddings | |
self.type_vocab_size = type_vocab_size | |
self.relax_projection = relax_projection | |
self.new_pos_ids = new_pos_ids | |
self.initializer_range = initializer_range | |
self.task_idx = task_idx | |
self.fp32_embedding = fp32_embedding | |
self.ffn_type = ffn_type | |
self.label_smoothing = label_smoothing | |
self.num_qkv = num_qkv | |
self.seg_emb = seg_emb | |
self.no_segment_embedding = no_segment_embedding | |
self.source_type_id = source_type_id | |
self.target_type_id = target_type_id | |
if type_vocab_size == 0: | |
self.no_segment_embedding = True | |
else: | |
raise ValueError("First argument must be either a vocabulary size (int)" | |
"or the path to a pretrained model config file (str)") | |
def from_dict(cls, json_object): | |
"""Constructs a `BertConfig` from a Python dictionary of parameters.""" | |
config = BertConfig(vocab_size_or_config_json_file=-1) | |
for key, value in json_object.items(): | |
config.__dict__[key] = value | |
return config | |
def from_json_file(cls, json_file, **kwargs): | |
"""Constructs a `BertConfig` from a json file of parameters.""" | |
with open(json_file, "r", encoding='utf-8') as reader: | |
text = reader.read() | |
json_info = json.loads(text) | |
for k, v in kwargs.items(): | |
json_info[k] = v | |
return cls.from_dict(json_info) | |
def __repr__(self): | |
return str(self.to_json_string()) | |
def to_dict(self): | |
"""Serializes this instance to a Python dictionary.""" | |
output = copy.deepcopy(self.__dict__) | |
return output | |
def to_json_string(self): | |
"""Serializes this instance to a JSON string.""" | |
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" | |
try: | |
from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm | |
except ImportError: | |
print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.") | |
class BertLayerNorm(nn.Module): | |
def __init__(self, hidden_size, eps=1e-5): | |
"""Construct a layernorm module in the TF style (epsilon inside the square root). | |
""" | |
super(BertLayerNorm, self).__init__() | |
self.weight = nn.Parameter(torch.ones(hidden_size)) | |
self.bias = nn.Parameter(torch.zeros(hidden_size)) | |
self.variance_epsilon = eps | |
def forward(self, x): | |
u = x.mean(-1, keepdim=True) | |
s = (x - u).pow(2).mean(-1, keepdim=True) | |
x = (x - u) / torch.sqrt(s + self.variance_epsilon) | |
return self.weight * x + self.bias | |
class BertEmbeddings(nn.Module): | |
"""Construct the embeddings from word, position and token_type embeddings. | |
""" | |
def __init__(self, config): | |
super(BertEmbeddings, self).__init__() | |
self.word_embeddings = nn.Embedding( | |
config.vocab_size, config.hidden_size) | |
if config.no_segment_embedding: | |
self.token_type_embeddings = None | |
else: | |
self.token_type_embeddings = nn.Embedding( | |
config.type_vocab_size, config.hidden_size) | |
if hasattr(config, 'fp32_embedding'): | |
self.fp32_embedding = config.fp32_embedding | |
else: | |
self.fp32_embedding = False | |
if hasattr(config, 'new_pos_ids') and config.new_pos_ids: | |
self.num_pos_emb = 4 | |
else: | |
self.num_pos_emb = 1 | |
self.position_embeddings = nn.Embedding( | |
config.max_position_embeddings, config.hidden_size * self.num_pos_emb) | |
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load | |
# any TensorFlow checkpoint file | |
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5) | |
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |
def forward(self, input_ids, token_type_ids=None, position_ids=None, task_idx=None): | |
seq_length = input_ids.size(1) | |
if position_ids is None: | |
position_ids = torch.arange( | |
seq_length, dtype=torch.long, device=input_ids.device) | |
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) | |
if token_type_ids is None: | |
token_type_ids = torch.zeros_like(input_ids) | |
words_embeddings = self.word_embeddings(input_ids) | |
position_embeddings = self.position_embeddings(position_ids) | |
if self.num_pos_emb > 1: | |
num_batch = position_embeddings.size(0) | |
num_pos = position_embeddings.size(1) | |
position_embeddings = position_embeddings.view( | |
num_batch, num_pos, self.num_pos_emb, -1)[torch.arange(0, num_batch).long(), :, task_idx, :] | |
embeddings = words_embeddings + position_embeddings | |
if self.token_type_embeddings is not None: | |
embeddings = embeddings + self.token_type_embeddings(token_type_ids) | |
if self.fp32_embedding: | |
embeddings = embeddings.half() | |
embeddings = self.LayerNorm(embeddings) | |
embeddings = self.dropout(embeddings) | |
return embeddings | |
class LayoutlmEmbeddings(nn.Module): | |
"""Construct the embeddings from word, position and token_type embeddings. | |
""" | |
def __init__(self, config): | |
super(LayoutlmEmbeddings, self).__init__() | |
# self.word_embeddings = nn.Embedding( | |
# config.vocab_size, config.hidden_size) | |
self.only_layout = config.layoutlm_only_layout_flag | |
if not self.only_layout: | |
self.word_embeddings = nn.Embedding( | |
config.vocab_size, config.hidden_size, padding_idx=0 | |
) | |
else: | |
self.word_embeddings = None | |
self.x_position_embeddings = nn.Embedding( | |
config.max_2d_position_embeddings, config.hidden_size | |
) | |
self.y_position_embeddings = nn.Embedding( | |
config.max_2d_position_embeddings, config.hidden_size | |
) | |
self.h_position_embeddings = nn.Embedding( | |
config.max_2d_position_embeddings, config.hidden_size | |
) | |
self.w_position_embeddings = nn.Embedding( | |
config.max_2d_position_embeddings, config.hidden_size | |
) | |
if config.no_segment_embedding: | |
self.token_type_embeddings = None | |
else: | |
self.token_type_embeddings = nn.Embedding( | |
config.type_vocab_size, config.hidden_size) | |
if hasattr(config, 'fp32_embedding'): | |
self.fp32_embedding = config.fp32_embedding | |
else: | |
self.fp32_embedding = False | |
if hasattr(config, 'new_pos_ids') and config.new_pos_ids: | |
self.num_pos_emb = 4 | |
else: | |
self.num_pos_emb = 1 | |
self.position_embeddings = nn.Embedding( | |
config.max_position_embeddings, config.hidden_size * self.num_pos_emb) | |
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load | |
# any TensorFlow checkpoint file | |
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5) | |
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |
def forward(self, input_ids, bbox, token_type_ids=None, position_ids=None, task_idx=None): | |
seq_length = input_ids.size(1) | |
if position_ids is None: | |
position_ids = torch.arange( | |
seq_length, dtype=torch.long, device=input_ids.device) | |
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) | |
if token_type_ids is None: | |
token_type_ids = torch.zeros_like(input_ids) | |
position_embeddings = self.position_embeddings(position_ids) | |
if self.num_pos_emb > 1: | |
num_batch = position_embeddings.size(0) | |
num_pos = position_embeddings.size(1) | |
position_embeddings = position_embeddings.view( | |
num_batch, num_pos, self.num_pos_emb, -1)[torch.arange(0, num_batch).long(), :, task_idx, :] | |
left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0]) | |
upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1]) | |
right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2]) | |
lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3]) | |
h_position_embeddings = self.h_position_embeddings( | |
bbox[:, :, 3] - bbox[:, :, 1] | |
) | |
w_position_embeddings = self.w_position_embeddings( | |
bbox[:, :, 2] - bbox[:, :, 0] | |
) | |
# token_type_embeddings = self.token_type_embeddings(token_type_ids) | |
# words_embeddings = self.word_embeddings(input_ids) | |
# position_embeddings = self.position_embeddings(position_ids) | |
embeddings = ( | |
# words_embeddings | |
position_embeddings | |
+ left_position_embeddings | |
+ upper_position_embeddings | |
+ right_position_embeddings | |
+ lower_position_embeddings | |
+ h_position_embeddings | |
+ w_position_embeddings | |
) | |
if not self.only_layout: | |
words_embeddings = self.word_embeddings(input_ids) | |
embeddings = embeddings + words_embeddings | |
if self.token_type_embeddings is not None: | |
embeddings = embeddings + self.token_type_embeddings(token_type_ids) | |
if self.fp32_embedding: | |
embeddings = embeddings.half() | |
embeddings = self.LayerNorm(embeddings) | |
embeddings = self.dropout(embeddings) | |
return embeddings | |
class BertSelfAttention(nn.Module): | |
def __init__(self, config): | |
super(BertSelfAttention, self).__init__() | |
if config.hidden_size % config.num_attention_heads != 0: | |
raise ValueError( | |
"The hidden size (%d) is not a multiple of the number of attention " | |
"heads (%d)" % (config.hidden_size, config.num_attention_heads)) | |
self.num_attention_heads = config.num_attention_heads | |
self.attention_head_size = int( | |
config.hidden_size / config.num_attention_heads) | |
self.all_head_size = self.num_attention_heads * self.attention_head_size | |
if hasattr(config, 'num_qkv') and (config.num_qkv > 1): | |
self.num_qkv = config.num_qkv | |
else: | |
self.num_qkv = 1 | |
self.query = nn.Linear( | |
config.hidden_size, self.all_head_size * self.num_qkv) | |
self.key = nn.Linear(config.hidden_size, | |
self.all_head_size * self.num_qkv) | |
self.value = nn.Linear( | |
config.hidden_size, self.all_head_size * self.num_qkv) | |
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) | |
self.uni_debug_flag = True if os.getenv( | |
'UNI_DEBUG_FLAG', '') else False | |
if self.uni_debug_flag: | |
self.register_buffer('debug_attention_probs', | |
torch.zeros((512, 512))) | |
if hasattr(config, 'seg_emb') and config.seg_emb: | |
self.b_q_s = nn.Parameter(torch.zeros( | |
1, self.num_attention_heads, 1, self.attention_head_size)) | |
self.seg_emb = nn.Embedding( | |
config.type_vocab_size, self.all_head_size) | |
else: | |
self.b_q_s = None | |
self.seg_emb = None | |
def transpose_for_scores(self, x, mask_qkv=None): | |
if self.num_qkv > 1: | |
sz = x.size()[:-1] + (self.num_qkv, | |
self.num_attention_heads, self.all_head_size) | |
# (batch, pos, num_qkv, head, head_hid) | |
x = x.view(*sz) | |
if mask_qkv is None: | |
x = x[:, :, 0, :, :] | |
elif isinstance(mask_qkv, int): | |
x = x[:, :, mask_qkv, :, :] | |
else: | |
# mask_qkv: (batch, pos) | |
if mask_qkv.size(1) > sz[1]: | |
mask_qkv = mask_qkv[:, :sz[1]] | |
# -> x: (batch, pos, head, head_hid) | |
x = x.gather(2, mask_qkv.view(sz[0], sz[1], 1, 1, 1).expand( | |
sz[0], sz[1], 1, sz[3], sz[4])).squeeze(2) | |
else: | |
sz = x.size()[:-1] + (self.num_attention_heads, | |
self.attention_head_size) | |
# (batch, pos, head, head_hid) | |
x = x.view(*sz) | |
# (batch, head, pos, head_hid) | |
return x.permute(0, 2, 1, 3) | |
def forward(self, hidden_states, attention_mask, history_states=None, | |
mask_qkv=None, seg_ids=None, key_history=None, value_history=None, | |
key_cache=None, value_cache=None, | |
): | |
if history_states is None: | |
mixed_query_layer = self.query(hidden_states) | |
# possible issue: https://github.com/NVIDIA/apex/issues/131 | |
mixed_key_layer = F.linear(hidden_states, self.key.weight) | |
mixed_value_layer = self.value(hidden_states) | |
else: | |
x_states = torch.cat((history_states, hidden_states), dim=1) | |
mixed_query_layer = self.query(hidden_states) | |
# possible issue: https://github.com/NVIDIA/apex/issues/131 | |
mixed_key_layer = F.linear(x_states, self.key.weight) | |
mixed_value_layer = self.value(x_states) | |
if key_cache is not None and isinstance(key_cache, list): | |
key_cache.append(mixed_key_layer) | |
mixed_key_layer = torch.cat(key_cache, dim=1) | |
if value_cache is not None and isinstance(value_cache, list): | |
value_cache.append(mixed_value_layer) | |
mixed_value_layer = torch.cat(value_cache, dim=1) | |
query_layer = self.transpose_for_scores(mixed_query_layer, mask_qkv) | |
key_layer = self.transpose_for_scores(mixed_key_layer, mask_qkv) | |
value_layer = self.transpose_for_scores(mixed_value_layer, mask_qkv) | |
if key_history is not None and not isinstance(key_history, list): | |
key_layer = torch.cat((key_history, key_layer), dim=-2) | |
value_layer = torch.cat((value_history, value_layer), dim=-2) | |
# Take the dot product between "query" and "key" to get the raw attention scores. | |
# (batch, head, pos, pos) | |
attention_scores = torch.matmul( | |
query_layer / math.sqrt(self.attention_head_size), key_layer.transpose(-1, -2)) | |
if self.seg_emb is not None: | |
seg_rep = self.seg_emb(seg_ids) | |
# (batch, pos, head, head_hid) | |
seg_rep = seg_rep.view(seg_rep.size(0), seg_rep.size( | |
1), self.num_attention_heads, self.attention_head_size) | |
qs = torch.einsum('bnih,bjnh->bnij', | |
query_layer + self.b_q_s, seg_rep) | |
attention_scores = attention_scores + qs | |
# attention_scores = attention_scores / math.sqrt(self.attention_head_size) | |
# Apply the attention mask is (precomputed for all layers in BertModel forward() function) | |
attention_scores = attention_scores + attention_mask | |
# Normalize the attention scores to probabilities. | |
attention_probs = nn.Softmax(dim=-1)(attention_scores) | |
if self.uni_debug_flag: | |
_pos = attention_probs.size(-1) | |
self.debug_attention_probs[:_pos, :_pos].copy_( | |
attention_probs[0].mean(0).view(_pos, _pos)) | |
# This is actually dropping out entire tokens to attend to, which might | |
# seem a bit unusual, but is taken from the original Transformer paper. | |
attention_probs = self.dropout(attention_probs) | |
context_layer = torch.matmul(attention_probs, value_layer) | |
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() | |
new_context_layer_shape = context_layer.size()[ | |
:-2] + (self.all_head_size,) | |
context_layer = context_layer.view(*new_context_layer_shape) | |
if isinstance(key_history, list): | |
key_history.append(key_layer) | |
if isinstance(value_history, list): | |
value_history.append(value_layer) | |
return context_layer | |
class BertSelfOutput(nn.Module): | |
def __init__(self, config): | |
super(BertSelfOutput, self).__init__() | |
self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5) | |
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |
def forward(self, hidden_states, input_tensor): | |
hidden_states = self.dense(hidden_states) | |
hidden_states = self.dropout(hidden_states) | |
hidden_states = self.LayerNorm(hidden_states + input_tensor) | |
return hidden_states | |
class BertAttention(nn.Module): | |
def __init__(self, config): | |
super(BertAttention, self).__init__() | |
self.self = BertSelfAttention(config) | |
self.output = BertSelfOutput(config) | |
def forward(self, input_tensor, attention_mask, history_states=None, | |
mask_qkv=None, seg_ids=None, key_history=None, value_history=None): | |
self_output = self.self( | |
input_tensor, attention_mask, history_states=history_states, | |
mask_qkv=mask_qkv, seg_ids=seg_ids, key_history=key_history, value_history=value_history) | |
attention_output = self.output(self_output, input_tensor) | |
return attention_output | |
class BertIntermediate(nn.Module): | |
def __init__(self, config): | |
super(BertIntermediate, self).__init__() | |
self.dense = nn.Linear(config.hidden_size, config.intermediate_size) | |
self.intermediate_act_fn = ACT2FN[config.hidden_act] \ | |
if isinstance(config.hidden_act, str) else config.hidden_act | |
def forward(self, hidden_states): | |
hidden_states = self.dense(hidden_states) | |
hidden_states = self.intermediate_act_fn(hidden_states) | |
return hidden_states | |
class BertOutput(nn.Module): | |
def __init__(self, config): | |
super(BertOutput, self).__init__() | |
self.dense = nn.Linear(config.intermediate_size, config.hidden_size) | |
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5) | |
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |
def forward(self, hidden_states, input_tensor): | |
hidden_states = self.dense(hidden_states) | |
hidden_states = self.dropout(hidden_states) | |
hidden_states = self.LayerNorm(hidden_states + input_tensor) | |
return hidden_states | |
class TransformerFFN(nn.Module): | |
def __init__(self, config): | |
super(TransformerFFN, self).__init__() | |
self.ffn_type = config.ffn_type | |
assert self.ffn_type in (1, 2) | |
if self.ffn_type in (1, 2): | |
self.wx0 = nn.Linear(config.hidden_size, config.hidden_size) | |
if self.ffn_type in (2,): | |
self.wx1 = nn.Linear(config.hidden_size, config.hidden_size) | |
if self.ffn_type in (1, 2): | |
self.output = nn.Linear(config.hidden_size, config.hidden_size) | |
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5) | |
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |
def forward(self, x): | |
if self.ffn_type in (1, 2): | |
x0 = self.wx0(x) | |
if self.ffn_type == 1: | |
x1 = x | |
elif self.ffn_type == 2: | |
x1 = self.wx1(x) | |
out = self.output(x0 * x1) | |
out = self.dropout(out) | |
out = self.LayerNorm(out + x) | |
return out | |
class BertLayer(nn.Module): | |
def __init__(self, config): | |
super(BertLayer, self).__init__() | |
self.attention = BertAttention(config) | |
self.ffn_type = config.ffn_type | |
if self.ffn_type: | |
self.ffn = TransformerFFN(config) | |
else: | |
self.intermediate = BertIntermediate(config) | |
self.output = BertOutput(config) | |
def forward(self, hidden_states, attention_mask, history_states=None, | |
mask_qkv=None, seg_ids=None, key_history=None, value_history=None): | |
attention_output = self.attention( | |
hidden_states, attention_mask, history_states=history_states, | |
mask_qkv=mask_qkv, seg_ids=seg_ids, key_history=key_history, value_history=value_history) | |
if self.ffn_type: | |
layer_output = self.ffn(attention_output) | |
else: | |
intermediate_output = self.intermediate(attention_output) | |
layer_output = self.output(intermediate_output, attention_output) | |
return layer_output | |
class BertEncoder(nn.Module): | |
def __init__(self, config): | |
super(BertEncoder, self).__init__() | |
layer = BertLayer(config) | |
self.layer = nn.ModuleList([copy.deepcopy(layer) | |
for _ in range(config.num_hidden_layers)]) | |
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, prev_embedding=None, | |
prev_encoded_layers=None, mask_qkv=None, seg_ids=None, key_history=None, value_history=None): | |
# history embedding and encoded layer must be simultanously given | |
assert (prev_embedding is None) == (prev_encoded_layers is None) | |
all_encoder_layers = [] | |
if (prev_embedding is not None) and (prev_encoded_layers is not None): | |
history_states = prev_embedding | |
for i, layer_module in enumerate(self.layer): | |
hidden_states = layer_module( | |
hidden_states, attention_mask, history_states=history_states, mask_qkv=mask_qkv, seg_ids=seg_ids) | |
if output_all_encoded_layers: | |
all_encoder_layers.append(hidden_states) | |
if prev_encoded_layers is not None: | |
history_states = prev_encoded_layers[i] | |
else: | |
for i, layer_module in enumerate(self.layer): | |
set_key = None | |
if isinstance(key_history, list): | |
set_key = key_history if len(key_history) < len(self.layer) else key_history[i] | |
set_value = None | |
if isinstance(value_history, list): | |
set_value = value_history if len(key_history) < len(self.layer) else value_history[i] | |
hidden_states = layer_module( | |
hidden_states, attention_mask, mask_qkv=mask_qkv, seg_ids=seg_ids, | |
key_history=set_key, value_history=set_value) | |
if output_all_encoded_layers: | |
all_encoder_layers.append(hidden_states) | |
if not output_all_encoded_layers: | |
all_encoder_layers.append(hidden_states) | |
return all_encoder_layers | |
class BertPooler(nn.Module): | |
def __init__(self, config): | |
super(BertPooler, self).__init__() | |
self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |
self.activation = nn.Tanh() | |
def forward(self, hidden_states): | |
# We "pool" the model by simply taking the hidden state corresponding | |
# to the first token. | |
first_token_tensor = hidden_states[:, 0] | |
pooled_output = self.dense(first_token_tensor) | |
pooled_output = self.activation(pooled_output) | |
return pooled_output | |
class BertPredictionHeadTransform(nn.Module): | |
def __init__(self, config): | |
super(BertPredictionHeadTransform, self).__init__() | |
self.transform_act_fn = ACT2FN[config.hidden_act] \ | |
if isinstance(config.hidden_act, str) else config.hidden_act | |
hid_size = config.hidden_size | |
if hasattr(config, 'relax_projection') and (config.relax_projection > 1): | |
hid_size *= config.relax_projection | |
self.dense = nn.Linear(config.hidden_size, hid_size) | |
self.LayerNorm = BertLayerNorm(hid_size, eps=1e-5) | |
def forward(self, hidden_states): | |
hidden_states = self.dense(hidden_states) | |
hidden_states = self.transform_act_fn(hidden_states) | |
hidden_states = self.LayerNorm(hidden_states) | |
return hidden_states | |
class LayoutlmSPLMPredictionHead(nn.Module): | |
def __init__(self, config, src_len): | |
super(LayoutlmSPLMPredictionHead, self).__init__() | |
self.transform = BertPredictionHeadTransform(config) | |
# The output weights are the same as the input embeddings, but there is | |
# an output-only bias for each token. | |
self.bias = nn.Parameter(torch.zeros(src_len)) | |
if hasattr(config, 'relax_projection') and (config.relax_projection > 1): | |
self.relax_projection = config.relax_projection | |
else: | |
self.relax_projection = 0 | |
self.fp32_embedding = config.fp32_embedding | |
def convert_to_type(tensor): | |
if self.fp32_embedding: | |
return tensor.half() | |
else: | |
return tensor | |
self.type_converter = convert_to_type | |
self.converted = False | |
def forward(self, hidden_states, src_emb, task_idx=None): | |
if not self.converted: | |
self.converted = True | |
if self.fp32_embedding: | |
self.transform.half() | |
hidden_states = self.transform(self.type_converter(hidden_states)) | |
if self.relax_projection > 1: | |
num_batch = hidden_states.size(0) | |
num_pos = hidden_states.size(1) | |
# (batch, num_pos, relax_projection*hid) -> (batch, num_pos, relax_projection, hid) -> (batch, num_pos, hid) | |
hidden_states = hidden_states.view( | |
num_batch, num_pos, self.relax_projection, -1)[torch.arange(0, num_batch).long(), :, task_idx, :] | |
if self.fp32_embedding: | |
hidden_states = torch.einsum('btf,bsf->bts', | |
self.type_converter(hidden_states), self.type_converter(src_emb)) + \ | |
self.type_converter(self.bias) | |
# hidden_states = F.linear(self.type_converter(hidden_states), self.type_converter( | |
# self.decoder.weight), self.type_converter(self.bias)) | |
else: | |
hidden_states = torch.einsum('btf,bsf->bts', hidden_states, src_emb) + self.bias | |
return hidden_states | |
class LayoutlmSPPreTrainingHeads(nn.Module): | |
def __init__(self, config, src_len, num_labels=2): | |
super(LayoutlmSPPreTrainingHeads, self).__init__() | |
self.predictions = LayoutlmSPLMPredictionHead(config, src_len) | |
self.seq_relationship = nn.Linear(config.hidden_size, num_labels) | |
def forward(self, sequence_output, pooled_output, src_emb, task_idx=None): | |
prediction_scores = self.predictions(sequence_output, src_emb, task_idx) | |
if pooled_output is None: | |
seq_relationship_score = None | |
else: | |
seq_relationship_score = self.seq_relationship(pooled_output) | |
return prediction_scores, seq_relationship_score | |
class PreTrainedBertModel(nn.Module): | |
""" An abstract class to handle weights initialization and | |
a simple interface for dowloading and loading pretrained models. | |
""" | |
def __init__(self, config, *inputs, **kwargs): | |
super(PreTrainedBertModel, self).__init__() | |
if not isinstance(config, BertConfig): | |
raise ValueError( | |
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. " | |
"To create a model from a Google pretrained model use " | |
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( | |
self.__class__.__name__, self.__class__.__name__ | |
)) | |
self.config = config | |
def init_bert_weights(self, module): | |
""" Initialize the weights. | |
""" | |
if isinstance(module, (nn.Linear, nn.Embedding)): | |
# Slightly different from the TF version which uses truncated_normal for initialization | |
# cf https://github.com/pytorch/pytorch/pull/5617 | |
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) | |
# module.weight.data.copy_(torch.Tensor( | |
# truncnorm.rvs(-1, 1, size=list(module.weight.data.shape)) * self.config.initializer_range)) | |
elif isinstance(module, BertLayerNorm): | |
module.bias.data.zero_() | |
module.weight.data.fill_(1.0) | |
if isinstance(module, nn.Linear) and module.bias is not None: | |
module.bias.data.zero_() | |
def from_pretrained(cls, pretrained_model_name, config, state_dict=None, cache_dir=None, *inputs, **kwargs): | |
""" | |
Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict. | |
Download and cache the pre-trained model file if needed. | |
Params: | |
pretrained_model_name: either: | |
- a str with the name of a pre-trained model to load selected in the list of: | |
. `bert-base-uncased` | |
. `bert-large-uncased` | |
. `bert-base-cased` | |
. `bert-base-multilingual` | |
. `bert-base-chinese` | |
- a path or url to a pretrained model archive containing: | |
. `bert_config.json` a configuration file for the model | |
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance | |
cache_dir: an optional path to a folder in which the pre-trained models will be cached. | |
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models | |
*inputs, **kwargs: additional input for the specific Bert class | |
(ex: num_labels for BertForSequenceClassification) | |
""" | |
logger.info("Model config {}".format(config)) | |
# clean the arguments in kwargs | |
for arg_clean in ('config_path', 'type_vocab_size', 'relax_projection', 'new_pos_ids', 'task_idx', | |
'max_position_embeddings', 'fp32_embedding', 'ffn_type', 'label_smoothing', | |
'hidden_dropout_prob', 'attention_probs_dropout_prob', 'num_qkv', 'seg_emb', | |
'word_emb_map', 'num_labels', 'num_rel', 'num_sentlvl_labels'): | |
if arg_clean in kwargs: | |
del kwargs[arg_clean] | |
# Instantiate model. | |
model = cls(config, *inputs, **kwargs) | |
if state_dict is None: | |
weights_path = os.path.join(pretrained_model_name, WEIGHTS_NAME) | |
state_dict = torch.load(weights_path, map_location='cpu') | |
old_keys = [] | |
new_keys = [] | |
for key in state_dict.keys(): | |
new_key = None | |
if 'gamma' in key: | |
new_key = key.replace('gamma', 'weight') | |
if 'beta' in key: | |
new_key = key.replace('beta', 'bias') | |
if new_key: | |
old_keys.append(key) | |
new_keys.append(new_key) | |
for old_key, new_key in zip(old_keys, new_keys): | |
state_dict[new_key] = state_dict.pop(old_key) | |
missing_keys = [] | |
unexpected_keys = [] | |
error_msgs = [] | |
# copy state_dict so _load_from_state_dict can modify it | |
metadata = getattr(state_dict, '_metadata', None) | |
state_dict = state_dict.copy() | |
if metadata is not None: | |
state_dict._metadata = metadata | |
def load(module, prefix=''): | |
local_metadata = {} if metadata is None else metadata.get( | |
prefix[:-1], {}) | |
module._load_from_state_dict( | |
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) | |
for name, child in module._modules.items(): | |
if child is not None: | |
load(child, prefix + name + '.') | |
load(model, prefix='' if hasattr(model, 'bert') else 'bert.') | |
model.missing_keys = missing_keys | |
if len(missing_keys) > 0: | |
logger.info("Weights of {} not initialized from pretrained model: {}".format( | |
model.__class__.__name__, missing_keys)) | |
if len(unexpected_keys) > 0: | |
logger.info("Weights from pretrained model not used in {}: {}".format( | |
model.__class__.__name__, unexpected_keys)) | |
if len(error_msgs) > 0: | |
logger.info('\n'.join(error_msgs)) | |
return model | |
class BertModel(PreTrainedBertModel): | |
"""BERT model ("Bidirectional Embedding Representations from a Transformer"). | |
Params: | |
config: a BertConfig class instance with the configuration to build a new model | |
Inputs: | |
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] | |
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts | |
`extract_features.py`, `run_classifier.py` and `run_squad.py`) | |
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token | |
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to | |
a `sentence B` token (see BERT paper for more details). | |
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices | |
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max | |
input sequence length in the current batch. It's the mask that we typically use for attention when | |
a batch has varying length sentences. | |
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. | |
Outputs: Tuple of (encoded_layers, pooled_output) | |
`encoded_layers`: controled by `output_all_encoded_layers` argument: | |
- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end | |
of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each | |
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], | |
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding | |
to the last attention block of shape [batch_size, sequence_length, hidden_size], | |
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a | |
classifier pretrained on top of the hidden state associated to the first character of the | |
input (`CLF`) to train on the Next-Sentence task (see BERT's paper). | |
``` | |
""" | |
def __init__(self, config): | |
super(BertModel, self).__init__(config) | |
self.embeddings = BertEmbeddings(config) | |
self.encoder = BertEncoder(config) | |
self.pooler = BertPooler(config) | |
self.apply(self.init_bert_weights) | |
def rescale_some_parameters(self): | |
for layer_id, layer in enumerate(self.encoder.layer): | |
layer.attention.output.dense.weight.data.div_( | |
math.sqrt(2.0 * (layer_id + 1))) | |
layer.output.dense.weight.data.div_(math.sqrt(2.0 * (layer_id + 1))) | |
def get_extended_attention_mask(self, input_ids, token_type_ids, attention_mask): | |
if attention_mask is None: | |
attention_mask = torch.ones_like(input_ids) | |
if token_type_ids is None: | |
token_type_ids = torch.zeros_like(input_ids) | |
# We create a 3D attention mask from a 2D tensor mask. | |
# Sizes are [batch_size, 1, 1, to_seq_length] | |
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] | |
# this attention mask is more simple than the triangular masking of causal attention | |
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. | |
if attention_mask.dim() == 2: | |
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) | |
elif attention_mask.dim() == 3: | |
extended_attention_mask = attention_mask.unsqueeze(1) | |
else: | |
raise NotImplementedError | |
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for | |
# masked positions, this operation will create a tensor which is 0.0 for | |
# positions we want to attend and -10000.0 for masked positions. | |
# Since we are adding it to the raw scores before the softmax, this is | |
# effectively the same as removing these entirely. | |
extended_attention_mask = extended_attention_mask.to( | |
dtype=next(self.parameters()).dtype) # fp16 compatibility | |
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 | |
return extended_attention_mask | |
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True, | |
mask_qkv=None, task_idx=None, key_history=None, value_history=None, position_ids=None): | |
extended_attention_mask = self.get_extended_attention_mask( | |
input_ids, token_type_ids, attention_mask) | |
embedding_output = self.embeddings( | |
input_ids, token_type_ids, task_idx=task_idx, position_ids=position_ids) | |
encoded_layers = self.encoder(embedding_output, extended_attention_mask, | |
output_all_encoded_layers=output_all_encoded_layers, | |
mask_qkv=mask_qkv, seg_ids=token_type_ids, | |
key_history=key_history, value_history=value_history) | |
sequence_output = encoded_layers[-1] | |
pooled_output = self.pooler(sequence_output) | |
if not output_all_encoded_layers: | |
encoded_layers = encoded_layers[-1] | |
return encoded_layers, pooled_output | |
class LayoutlmModel(PreTrainedBertModel): | |
"""BERT model ("Bidirectional Embedding Representations from a Transformer"). | |
Params: | |
config: a BertConfig class instance with the configuration to build a new model | |
Inputs: | |
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] | |
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts | |
`extract_features.py`, `run_classifier.py` and `run_squad.py`) | |
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token | |
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to | |
a `sentence B` token (see BERT paper for more details). | |
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices | |
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max | |
input sequence length in the current batch. It's the mask that we typically use for attention when | |
a batch has varying length sentences. | |
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. | |
Outputs: Tuple of (encoded_layers, pooled_output) | |
`encoded_layers`: controled by `output_all_encoded_layers` argument: | |
- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end | |
of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each | |
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], | |
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding | |
to the last attention block of shape [batch_size, sequence_length, hidden_size], | |
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a | |
classifier pretrained on top of the hidden state associated to the first character of the | |
input (`CLF`) to train on the Next-Sentence task (see BERT's paper). | |
``` | |
""" | |
def __init__(self, config): | |
super(LayoutlmModel, self).__init__(config) | |
self.embeddings = LayoutlmEmbeddings(config) | |
self.encoder = BertEncoder(config) | |
self.pooler = BertPooler(config) | |
self.apply(self.init_bert_weights) | |
def rescale_some_parameters(self): | |
for layer_id, layer in enumerate(self.encoder.layer): | |
layer.attention.output.dense.weight.data.div_( | |
math.sqrt(2.0 * (layer_id + 1))) | |
layer.output.dense.weight.data.div_(math.sqrt(2.0 * (layer_id + 1))) | |
def get_extended_attention_mask(self, input_ids, token_type_ids, attention_mask): | |
if attention_mask is None: | |
attention_mask = torch.ones_like(input_ids) | |
if token_type_ids is None: | |
token_type_ids = torch.zeros_like(input_ids) | |
# We create a 3D attention mask from a 2D tensor mask. | |
# Sizes are [batch_size, 1, 1, to_seq_length] | |
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] | |
# this attention mask is more simple than the triangular masking of causal attention | |
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. | |
if attention_mask.dim() == 2: | |
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) | |
elif attention_mask.dim() == 3: | |
extended_attention_mask = attention_mask.unsqueeze(1) | |
else: | |
raise NotImplementedError | |
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for | |
# masked positions, this operation will create a tensor which is 0.0 for | |
# positions we want to attend and -10000.0 for masked positions. | |
# Since we are adding it to the raw scores before the softmax, this is | |
# effectively the same as removing these entirely. | |
extended_attention_mask = extended_attention_mask.to( | |
dtype=next(self.parameters()).dtype) # fp16 compatibility | |
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 | |
return extended_attention_mask | |
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True, | |
mask_qkv=None, task_idx=None, key_history=None, value_history=None, position_ids=None): | |
extended_attention_mask = self.get_extended_attention_mask( | |
input_ids[:, :, 0], token_type_ids, attention_mask) | |
embedding_output = self.embeddings( | |
input_ids[:, :, 0], input_ids[:, :, 1:], token_type_ids, task_idx=task_idx, position_ids=position_ids) | |
encoded_layers = self.encoder(embedding_output, extended_attention_mask, | |
output_all_encoded_layers=output_all_encoded_layers, | |
mask_qkv=mask_qkv, seg_ids=token_type_ids, | |
key_history=key_history, value_history=value_history) | |
sequence_output = encoded_layers[-1] | |
pooled_output = self.pooler(sequence_output) | |
if not output_all_encoded_layers: | |
encoded_layers = encoded_layers[-1] | |
return encoded_layers, pooled_output | |
class BertModelIncr(BertModel): | |
def __init__(self, config): | |
super(BertModelIncr, self).__init__(config) | |
def forward(self, input_ids, token_type_ids, position_ids, attention_mask, output_all_encoded_layers=True, | |
prev_embedding=None, prev_encoded_layers=None, mask_qkv=None, task_idx=None): | |
extended_attention_mask = self.get_extended_attention_mask( | |
input_ids, token_type_ids, attention_mask) | |
embedding_output = self.embeddings( | |
input_ids, token_type_ids, position_ids, task_idx=task_idx) | |
encoded_layers = self.encoder(embedding_output, | |
extended_attention_mask, | |
output_all_encoded_layers=output_all_encoded_layers, | |
prev_embedding=prev_embedding, | |
prev_encoded_layers=prev_encoded_layers, mask_qkv=mask_qkv, | |
seg_ids=token_type_ids) | |
sequence_output = encoded_layers[-1] | |
pooled_output = self.pooler(sequence_output) | |
if not output_all_encoded_layers: | |
encoded_layers = encoded_layers[-1] | |
return embedding_output, encoded_layers, pooled_output | |
class LayoutlmModelIncr(LayoutlmModel): | |
def __init__(self, config): | |
super(LayoutlmModelIncr, self).__init__(config) | |
def forward(self, input_ids, token_type_ids, position_ids, attention_mask, output_all_encoded_layers=True, | |
prev_embedding=None, prev_encoded_layers=None, mask_qkv=None, task_idx=None): | |
extended_attention_mask = self.get_extended_attention_mask( | |
input_ids[:, :, 0], token_type_ids, attention_mask) | |
embedding_output = self.embeddings( | |
input_ids[:, :, 0], input_ids[:, :, 1:], token_type_ids, position_ids, task_idx=task_idx) | |
encoded_layers = self.encoder(embedding_output, | |
extended_attention_mask, | |
output_all_encoded_layers=output_all_encoded_layers, | |
prev_embedding=prev_embedding, | |
prev_encoded_layers=prev_encoded_layers, mask_qkv=mask_qkv, | |
seg_ids=token_type_ids) | |
sequence_output = encoded_layers[-1] | |
pooled_output = self.pooler(sequence_output) | |
if not output_all_encoded_layers: | |
encoded_layers = encoded_layers[-1] | |
return embedding_output, encoded_layers, pooled_output | |
class LayoutlmForSeq2SeqDecoder(PreTrainedBertModel): | |
"""refer to BertForPreTraining""" | |
def __init__(self, config, mask_word_id=0, num_labels=2, num_rel=0, | |
search_beam_size=1, length_penalty=1.0, eos_id=0, sos_id=0, | |
forbid_duplicate_ngrams=False, forbid_ignore_set=None, ngram_size=3, min_len=0, mode="s2s", | |
pos_shift=False): | |
super(LayoutlmForSeq2SeqDecoder, self).__init__(config) | |
self.layout_flag = config.base_model_type == 'layoutlm' | |
if config.base_model_type == 'layoutlm': | |
self.bert = LayoutlmModelIncr(config) | |
else: | |
self.bert = BertModelIncr(config) | |
# self.bert = BertModelIncr(config) | |
# note: the max source length is the max src seq length during fine tuning which includes the cls and sep | |
# NOTE: we don't remove anything. the 0 is for padding | |
self.cls = LayoutlmSPPreTrainingHeads( | |
config, src_len=config.max_source_length, num_labels=num_labels) | |
self.apply(self.init_bert_weights) | |
self.crit_mask_lm = nn.CrossEntropyLoss(reduction='none') | |
self.crit_next_sent = nn.CrossEntropyLoss(ignore_index=-1) | |
self.mask_word_id = mask_word_id | |
self.num_labels = num_labels | |
self.num_rel = num_rel | |
self.search_beam_size = search_beam_size | |
self.length_penalty = length_penalty | |
self.eos_id = eos_id | |
self.sos_id = sos_id | |
self.forbid_duplicate_ngrams = forbid_duplicate_ngrams | |
self.forbid_ignore_set = forbid_ignore_set | |
self.ngram_size = ngram_size | |
self.min_len = min_len | |
assert mode in ("s2s", "l2r") | |
self.mode = mode | |
self.pos_shift = pos_shift | |
def forward(self, input_ids, token_type_ids, position_ids, attention_mask, task_idx=None, mask_qkv=None): | |
if self.search_beam_size > 1: | |
return self.beam_search(input_ids, token_type_ids, position_ids, attention_mask, task_idx=task_idx, | |
mask_qkv=mask_qkv) | |
input_shape = list(input_ids.size()) | |
batch_size = input_shape[0] | |
input_length = input_shape[1] | |
output_shape = list(token_type_ids.size()) | |
output_length = output_shape[1] | |
output_ids = [] | |
prev_embedding = None | |
prev_encoded_layers = None | |
curr_ids = input_ids | |
if not self.layout_flag: | |
mask_ids = input_ids.new(batch_size, 1).fill_(self.mask_word_id) | |
else: | |
mask_ids = input_ids.new_zeros(batch_size, 1, 5) | |
mask_ids[:, :, 0] = self.mask_word_id | |
next_pos = input_length | |
if self.pos_shift: | |
if not self.layout_flag: | |
sos_ids = input_ids.new(batch_size, 1).fill_(self.sos_id) | |
else: | |
sos_ids = input_ids.new_zeros(batch_size, 1, 5) | |
sos_ids[:, :, 0] = self.sos_id | |
src_embedding = None | |
while next_pos < output_length: | |
curr_length = list(curr_ids.size())[1] | |
if self.pos_shift: | |
if next_pos == input_length: | |
x_input_ids = torch.cat((curr_ids, sos_ids), dim=1) | |
start_pos = 0 | |
else: | |
x_input_ids = curr_ids | |
start_pos = next_pos | |
else: | |
start_pos = next_pos - curr_length | |
# if self.layout_flag: | |
# mask_ids[:, -1, 1:] = curr_ids[:, , 1:] | |
x_input_ids = torch.cat((curr_ids, mask_ids), dim=1) | |
curr_token_type_ids = token_type_ids[:, start_pos:next_pos + 1] | |
curr_attention_mask = attention_mask[:, | |
start_pos:next_pos + 1, :next_pos + 1] | |
curr_position_ids = position_ids[:, start_pos:next_pos + 1] | |
new_embedding, new_encoded_layers, _ = \ | |
self.bert(x_input_ids, curr_token_type_ids, curr_position_ids, curr_attention_mask, | |
output_all_encoded_layers=True, prev_embedding=prev_embedding, | |
prev_encoded_layers=prev_encoded_layers, mask_qkv=mask_qkv) | |
if src_embedding is None: | |
# note: cut three embedding: CLS (1st), ..., SEP (-2nd), next to pred (-1st) | |
# note: (NEW) the sep is kept for ignore index in loss func (for padding's index) | |
# NOTE: only remove the next to pred token | |
src_embedding = new_embedding[:, :-1, :] | |
last_hidden = new_encoded_layers[-1][:, -1:, :] | |
prediction_scores, _ = self.cls(last_hidden, None, src_embedding, task_idx=task_idx) | |
_, max_ids = torch.max(prediction_scores, dim=-1) | |
output_ids.append(max_ids) | |
if self.pos_shift: | |
if prev_embedding is None: | |
prev_embedding = new_embedding | |
else: | |
prev_embedding = torch.cat( | |
(prev_embedding, new_embedding), dim=1) | |
if prev_encoded_layers is None: | |
prev_encoded_layers = [x for x in new_encoded_layers] | |
else: | |
prev_encoded_layers = [torch.cat((x[0], x[1]), dim=1) for x in zip( | |
prev_encoded_layers, new_encoded_layers)] | |
else: | |
if prev_embedding is None: | |
prev_embedding = new_embedding[:, :-1, :] | |
else: | |
prev_embedding = torch.cat( | |
(prev_embedding, new_embedding[:, :-1, :]), dim=1) | |
if prev_encoded_layers is None: | |
prev_encoded_layers = [x[:, :-1, :] | |
for x in new_encoded_layers] | |
else: | |
prev_encoded_layers = [torch.cat((x[0], x[1][:, :-1, :]), dim=1) | |
for x in zip(prev_encoded_layers, new_encoded_layers)] | |
if not self.layout_flag: | |
index = max_ids | |
curr_ids = torch.gather(input_ids, 1, index) | |
else: | |
_, _, dim = input_ids.shape | |
index = max_ids.unsqueeze(-1) | |
index = index.expand(index.shape[0], index.shape[1], dim) | |
# index = index.repeat(1, 1, dim) | |
curr_ids = torch.gather(input_ids, 1, index) | |
# if len(input_ids.shape) == 2: | |
# real_input_ids = input_ids[:, 1:] | |
# index = max_ids | |
# curr_ids = torch.gather(real_input_ids, 1, index) | |
# else: | |
# real_input_ids = input_ids[:, 1:, :] | |
# _, _, dim = real_input_ids.shape | |
# index = max_ids.unsqueeze(-1) | |
# index = index.expand(index.shape[0], index.shape[1], dim) | |
# curr_ids = torch.gather(real_input_ids, 1, index) | |
# # note: real input ids only include the ids for real data (remove the cls and sep) | |
# real_input_ids = input_ids[:, 1: -1, :] | |
# | |
# _, _, dim = real_input_ids.shape | |
# index = max_ids.unsqueeze(-1) | |
# index = index.expand(index.shape[0], index.shape[1], dim) | |
# | |
# curr_ids = torch.gather(real_input_ids, 1, index) | |
# curr_ids = real_input_ids[:, max_ids, :] | |
# curr_ids = max_ids | |
next_pos += 1 | |
return torch.cat(output_ids, dim=1) | |
# TODO: do the same with beam search as forward() | |
def beam_search(self, input_ids, token_type_ids, position_ids, attention_mask, task_idx=None, mask_qkv=None): | |
input_shape = list(input_ids.size()) | |
batch_size = input_shape[0] | |
input_length = input_shape[1] | |
output_shape = list(token_type_ids.size()) | |
output_length = output_shape[1] | |
output_ids = [] | |
prev_embedding = None | |
prev_encoded_layers = None | |
curr_ids = input_ids | |
# mask_ids = input_ids.new(batch_size, 1).fill_(self.mask_word_id) | |
if not self.layout_flag: | |
mask_ids = input_ids.new(batch_size, 1).fill_(self.mask_word_id) | |
else: | |
mask_ids = input_ids.new_zeros(batch_size, 1, 5) | |
mask_ids[:, :, 0] = self.mask_word_id | |
next_pos = input_length | |
if self.pos_shift: | |
if not self.layout_flag: | |
sos_ids = input_ids.new(batch_size, 1).fill_(self.sos_id) | |
else: | |
sos_ids = input_ids.new_zeros(batch_size, 1, 5) | |
sos_ids[:, :, 0] = self.sos_id | |
K = self.search_beam_size | |
total_scores = [] | |
beam_masks = [] | |
step_ids = [] | |
step_back_ptrs = [] | |
partial_seqs = [] | |
forbid_word_mask = None | |
buf_matrix = None | |
src_embedding = None | |
while next_pos < output_length: | |
curr_length = list(curr_ids.size())[1] | |
if self.pos_shift: | |
if next_pos == input_length: | |
x_input_ids = torch.cat((curr_ids, sos_ids), dim=1) | |
start_pos = 0 | |
else: | |
x_input_ids = curr_ids | |
start_pos = next_pos | |
else: | |
start_pos = next_pos - curr_length | |
x_input_ids = torch.cat((curr_ids, mask_ids), dim=1) | |
curr_token_type_ids = token_type_ids[:, start_pos:next_pos + 1] | |
curr_attention_mask = attention_mask[:, | |
start_pos:next_pos + 1, :next_pos + 1] | |
curr_position_ids = position_ids[:, start_pos:next_pos + 1] | |
new_embedding, new_encoded_layers, _ = \ | |
self.bert(x_input_ids, curr_token_type_ids, curr_position_ids, curr_attention_mask, | |
output_all_encoded_layers=True, prev_embedding=prev_embedding, | |
prev_encoded_layers=prev_encoded_layers, mask_qkv=mask_qkv) | |
def first_expand(x): | |
input_shape = list(x.size()) | |
expanded_shape = input_shape[:1] + [1] + input_shape[1:] | |
x = torch.reshape(x, expanded_shape) | |
repeat_count = [1, K] + [1] * (len(input_shape) - 1) | |
x = x.repeat(*repeat_count) | |
x = torch.reshape(x, [input_shape[0] * K] + input_shape[1:]) | |
return x | |
if src_embedding is None: | |
src_embedding = new_embedding[:, :-1, :] | |
if src_embedding.shape[0] != new_embedding.shape[0]: | |
src_embedding = first_expand(src_embedding) | |
last_hidden = new_encoded_layers[-1][:, -1:, :] | |
prediction_scores, _ = self.cls(last_hidden, None, src_embedding, task_idx=task_idx) | |
log_scores = torch.nn.functional.log_softmax( | |
prediction_scores, dim=-1) | |
# if forbid_word_mask is not None: | |
# log_scores += (forbid_word_mask * -10000.0) | |
# if self.min_len and (next_pos - input_length + 1 <= self.min_len): | |
# log_scores[:, :, self.eos_id].fill_(-10000.0) | |
kk_scores, kk_ids = torch.topk(log_scores, k=K) | |
if len(total_scores) == 0: | |
k_ids = torch.reshape(kk_ids, [batch_size, K]) | |
back_ptrs = torch.zeros(batch_size, K, dtype=torch.long) | |
k_scores = torch.reshape(kk_scores, [batch_size, K]) | |
else: | |
last_eos = torch.reshape( | |
beam_masks[-1], [batch_size * K, 1, 1]) | |
last_seq_scores = torch.reshape( | |
total_scores[-1], [batch_size * K, 1, 1]) | |
kk_scores += last_eos * (-10000.0) + last_seq_scores | |
kk_scores = torch.reshape(kk_scores, [batch_size, K * K]) | |
k_scores, k_ids = torch.topk(kk_scores, k=K) | |
back_ptrs = torch.floor_divide(k_ids, K) | |
kk_ids = torch.reshape(kk_ids, [batch_size, K * K]) | |
k_ids = torch.gather(kk_ids, 1, k_ids) | |
step_back_ptrs.append(back_ptrs) | |
step_ids.append(k_ids) | |
beam_masks.append(torch.eq(k_ids, self.eos_id).type_as(kk_scores)) | |
total_scores.append(k_scores) | |
# def first_expand(x): | |
# input_shape = list(x.size()) | |
# expanded_shape = input_shape[:1] + [1] + input_shape[1:] | |
# x = torch.reshape(x, expanded_shape) | |
# repeat_count = [1, K] + [1] * (len(input_shape) - 1) | |
# x = x.repeat(*repeat_count) | |
# x = torch.reshape(x, [input_shape[0] * K] + input_shape[1:]) | |
# return x | |
def select_beam_items(x, ids): | |
id_shape = list(ids.size()) | |
id_rank = len(id_shape) | |
assert len(id_shape) == 2 | |
x_shape = list(x.size()) | |
x = torch.reshape(x, [batch_size, K] + x_shape[1:]) | |
x_rank = len(x_shape) + 1 | |
assert x_rank >= 2 | |
if id_rank < x_rank: | |
ids = torch.reshape( | |
ids, id_shape + [1] * (x_rank - id_rank)) | |
ids = ids.expand(id_shape + x_shape[1:]) | |
y = torch.gather(x, 1, ids) | |
y = torch.reshape(y, x_shape) | |
return y | |
is_first = (prev_embedding is None) | |
if self.pos_shift: | |
if prev_embedding is None: | |
prev_embedding = first_expand(new_embedding) | |
else: | |
prev_embedding = torch.cat( | |
(prev_embedding, new_embedding), dim=1) | |
prev_embedding = select_beam_items( | |
prev_embedding, back_ptrs) | |
if prev_encoded_layers is None: | |
prev_encoded_layers = [first_expand( | |
x) for x in new_encoded_layers] | |
else: | |
prev_encoded_layers = [torch.cat((x[0], x[1]), dim=1) for x in zip( | |
prev_encoded_layers, new_encoded_layers)] | |
prev_encoded_layers = [select_beam_items( | |
x, back_ptrs) for x in prev_encoded_layers] | |
else: | |
if prev_embedding is None: | |
prev_embedding = first_expand(new_embedding[:, :-1, :]) | |
else: | |
prev_embedding = torch.cat( | |
(prev_embedding, new_embedding[:, :-1, :]), dim=1) | |
prev_embedding = select_beam_items( | |
prev_embedding, back_ptrs) | |
if prev_encoded_layers is None: | |
prev_encoded_layers = [first_expand( | |
x[:, :-1, :]) for x in new_encoded_layers] | |
else: | |
prev_encoded_layers = [torch.cat((x[0], x[1][:, :-1, :]), dim=1) | |
for x in zip(prev_encoded_layers, new_encoded_layers)] | |
prev_encoded_layers = [select_beam_items( | |
x, back_ptrs) for x in prev_encoded_layers] | |
max_ids = torch.reshape(k_ids, [batch_size * K, 1]) | |
if len(input_ids.shape) == 2: | |
expand_input_ids = first_expand(input_ids) | |
index = max_ids | |
curr_ids = torch.gather(expand_input_ids, 1, index) | |
else: | |
expand_input_ids = first_expand(input_ids) | |
_, _, dim = expand_input_ids.shape | |
index = max_ids.unsqueeze(-1) | |
index = index.expand(index.shape[0], index.shape[1], dim) | |
curr_ids = torch.gather(expand_input_ids, 1, index) | |
if is_first: | |
token_type_ids = first_expand(token_type_ids) | |
position_ids = first_expand(position_ids) | |
attention_mask = first_expand(attention_mask) | |
mask_ids = first_expand(mask_ids) | |
if mask_qkv is not None: | |
mask_qkv = first_expand(mask_qkv) | |
if self.forbid_duplicate_ngrams: | |
wids = step_ids[-1].tolist() | |
ptrs = step_back_ptrs[-1].tolist() | |
if is_first: | |
partial_seqs = [] | |
for b in range(batch_size): | |
for k in range(K): | |
partial_seqs.append([wids[b][k]]) | |
else: | |
new_partial_seqs = [] | |
for b in range(batch_size): | |
for k in range(K): | |
new_partial_seqs.append( | |
partial_seqs[ptrs[b][k] + b * K] + [wids[b][k]]) | |
partial_seqs = new_partial_seqs | |
def get_dup_ngram_candidates(seq, n): | |
cands = set() | |
if len(seq) < n: | |
return [] | |
tail = seq[-(n - 1):] | |
if self.forbid_ignore_set and any(tk in self.forbid_ignore_set for tk in tail): | |
return [] | |
for i in range(len(seq) - (n - 1)): | |
mismatch = False | |
for j in range(n - 1): | |
if tail[j] != seq[i + j]: | |
mismatch = True | |
break | |
if (not mismatch) and not ( | |
self.forbid_ignore_set and (seq[i + n - 1] in self.forbid_ignore_set)): | |
cands.add(seq[i + n - 1]) | |
return list(sorted(cands)) | |
if len(partial_seqs[0]) >= self.ngram_size: | |
dup_cands = [] | |
for seq in partial_seqs: | |
dup_cands.append( | |
get_dup_ngram_candidates(seq, self.ngram_size)) | |
if max(len(x) for x in dup_cands) > 0: | |
if buf_matrix is None: | |
vocab_size = list(log_scores.size())[-1] | |
buf_matrix = np.zeros( | |
(batch_size * K, vocab_size), dtype=float) | |
else: | |
buf_matrix.fill(0) | |
for bk, cands in enumerate(dup_cands): | |
for i, wid in enumerate(cands): | |
buf_matrix[bk, wid] = 1.0 | |
forbid_word_mask = torch.tensor( | |
buf_matrix, dtype=log_scores.dtype) | |
forbid_word_mask = torch.reshape( | |
forbid_word_mask, [batch_size * K, 1, vocab_size]).to(input_ids.device) | |
else: | |
forbid_word_mask = None | |
next_pos += 1 | |
# [(batch, beam)] | |
total_scores = [x.tolist() for x in total_scores] | |
step_ids = [x.tolist() for x in step_ids] | |
step_back_ptrs = [x.tolist() for x in step_back_ptrs] | |
# back tracking | |
traces = {'pred_seq': [], 'scores': [], 'wids': [], 'ptrs': []} | |
for b in range(batch_size): | |
# [(beam,)] | |
scores = [x[b] for x in total_scores] | |
wids_list = [x[b] for x in step_ids] | |
ptrs = [x[b] for x in step_back_ptrs] | |
traces['scores'].append(scores) | |
traces['wids'].append(wids_list) | |
traces['ptrs'].append(ptrs) | |
# first we need to find the eos frame where all symbols are eos | |
# any frames after the eos frame are invalid | |
last_frame_id = len(scores) - 1 | |
for i, wids in enumerate(wids_list): | |
if all(wid == self.eos_id for wid in wids): | |
last_frame_id = i | |
break | |
max_score = -math.inf | |
frame_id = -1 | |
pos_in_frame = -1 | |
for fid in range(last_frame_id + 1): | |
for i, wid in enumerate(wids_list[fid]): | |
if wid == self.eos_id or fid == last_frame_id: | |
s = scores[fid][i] | |
if self.length_penalty > 0: | |
s /= math.pow((5 + fid + 1) / 6.0, | |
self.length_penalty) | |
if s > max_score: | |
max_score = s | |
frame_id = fid | |
pos_in_frame = i | |
if frame_id == -1: | |
traces['pred_seq'].append([0]) | |
else: | |
seq = [wids_list[frame_id][pos_in_frame]] | |
for fid in range(frame_id, 0, -1): | |
pos_in_frame = ptrs[fid][pos_in_frame] | |
seq.append(wids_list[fid - 1][pos_in_frame]) | |
seq.reverse() | |
traces['pred_seq'].append(seq) | |
def _pad_sequence(sequences, max_len, padding_value=0): | |
trailing_dims = sequences[0].size()[1:] | |
out_dims = (len(sequences), max_len) + trailing_dims | |
out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value) | |
for i, tensor in enumerate(sequences): | |
length = tensor.size(0) | |
# use index notation to prevent duplicate references to the tensor | |
out_tensor[i, :length, ...] = tensor | |
return out_tensor | |
# convert to tensors for DataParallel | |
for k in ('pred_seq', 'scores', 'wids', 'ptrs'): | |
ts_list = traces[k] | |
if not isinstance(ts_list[0], torch.Tensor): | |
dt = torch.float if k == 'scores' else torch.long | |
ts_list = [torch.tensor(it, dtype=dt) for it in ts_list] | |
traces[k] = _pad_sequence( | |
ts_list, output_length, padding_value=0).to(input_ids.device) | |
return traces | |