Tzktz's picture
Upload 7664 files
6fc683c verified
# coding=utf-8
"""PyTorch BERT model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import json
import logging
import math
import os
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.modules.loss import _Loss
class LabelSmoothingLoss(_Loss):
"""
With label smoothing,
KL-divergence between q_{smoothed ground truth prob.}(w)
and p_{prob. computed by model}(w) is minimized.
"""
def __init__(self, label_smoothing=0, tgt_vocab_size=0, ignore_index=0, size_average=None, reduce=None,
reduction='mean'):
assert 0.0 < label_smoothing <= 1.0
self.ignore_index = ignore_index
super(LabelSmoothingLoss, self).__init__(
size_average=size_average, reduce=reduce, reduction=reduction)
assert label_smoothing > 0
assert tgt_vocab_size > 0
smoothing_value = label_smoothing / (tgt_vocab_size - 2)
one_hot = torch.full((tgt_vocab_size,), smoothing_value)
one_hot[self.ignore_index] = 0
self.register_buffer('one_hot', one_hot.unsqueeze(0))
self.confidence = 1.0 - label_smoothing
self.tgt_vocab_size = tgt_vocab_size
def forward(self, output, target):
"""
output (FloatTensor): batch_size * num_pos * n_classes
target (LongTensor): batch_size * num_pos
"""
assert self.tgt_vocab_size == output.size(2)
batch_size, num_pos = target.size(0), target.size(1)
output = output.view(-1, self.tgt_vocab_size)
target = target.view(-1)
model_prob = self.one_hot.repeat(target.size(0), 1)
model_prob.scatter_(1, target.unsqueeze(1), self.confidence)
model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0)
return F.kl_div(output, model_prob, reduction='none').view(batch_size, num_pos, -1).sum(2)
logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
'unilm-base-cased': "https://conversationhub.blob.core.windows.net/beit-share-public/ckpt/unilm1-base-cased.bin?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D",
'unilm-large-cased': "https://conversationhub.blob.core.windows.net/beit-share-public/ckpt/unilm1-large-cased.bin?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D",
'unilm1-base-cased': "https://conversationhub.blob.core.windows.net/beit-share-public/ckpt/unilm1-base-cased.bin?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D",
'unilm1-large-cased': "https://conversationhub.blob.core.windows.net/beit-share-public/ckpt/unilm1-large-cased.bin?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D",
'unilm1.2-base-uncased': "https://conversationhub.blob.core.windows.net/beit-share-public/ckpt/unilm1.2-base-uncased.bin?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D"
}
CONFIG_NAME = 'config.json'
WEIGHTS_NAME = 'pytorch_model.bin'
def gelu(x):
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def swish(x):
return x * torch.sigmoid(x)
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
class BertConfig(object):
"""Configuration class to store the configuration of a `BertModel`.
"""
def __init__(self,
vocab_size_or_config_json_file,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
relax_projection=0,
new_pos_ids=False,
initializer_range=0.02,
task_idx=None,
fp32_embedding=False,
ffn_type=0,
label_smoothing=None,
num_qkv=0,
seg_emb=False,
source_type_id=0,
target_type_id=1,
no_segment_embedding=False, **kwargs):
"""Constructs BertConfig.
Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
hidden_dropout_prob: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`BertModel`.
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
"""
if isinstance(vocab_size_or_config_json_file, str):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.vocab_size = vocab_size_or_config_json_file
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.relax_projection = relax_projection
self.new_pos_ids = new_pos_ids
self.initializer_range = initializer_range
self.task_idx = task_idx
self.fp32_embedding = fp32_embedding
self.ffn_type = ffn_type
self.label_smoothing = label_smoothing
self.num_qkv = num_qkv
self.seg_emb = seg_emb
self.no_segment_embedding = no_segment_embedding
self.source_type_id = source_type_id
self.target_type_id = target_type_id
if type_vocab_size == 0:
self.no_segment_embedding = True
else:
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
@classmethod
def from_dict(cls, json_object):
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
config = BertConfig(vocab_size_or_config_json_file=-1)
for key, value in json_object.items():
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file, **kwargs):
"""Constructs a `BertConfig` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader:
text = reader.read()
json_info = json.loads(text)
for k, v in kwargs.items():
json_info[k] = v
return cls.from_dict(json_info)
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
try:
from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
except ImportError:
print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
class BertLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-5):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(BertLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
"""
def __init__(self, config):
super(BertEmbeddings, self).__init__()
self.word_embeddings = nn.Embedding(
config.vocab_size, config.hidden_size)
if config.no_segment_embedding:
self.token_type_embeddings = None
else:
self.token_type_embeddings = nn.Embedding(
config.type_vocab_size, config.hidden_size)
if hasattr(config, 'fp32_embedding'):
self.fp32_embedding = config.fp32_embedding
else:
self.fp32_embedding = False
if hasattr(config, 'new_pos_ids') and config.new_pos_ids:
self.num_pos_emb = 4
else:
self.num_pos_emb = 1
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.hidden_size * self.num_pos_emb)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids, token_type_ids=None, position_ids=None, task_idx=None):
seq_length = input_ids.size(1)
if position_ids is None:
position_ids = torch.arange(
seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
if self.num_pos_emb > 1:
num_batch = position_embeddings.size(0)
num_pos = position_embeddings.size(1)
position_embeddings = position_embeddings.view(
num_batch, num_pos, self.num_pos_emb, -1)[torch.arange(0, num_batch).long(), :, task_idx, :]
embeddings = words_embeddings + position_embeddings
if self.token_type_embeddings is not None:
embeddings = embeddings + self.token_type_embeddings(token_type_ids)
if self.fp32_embedding:
embeddings = embeddings.half()
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class LayoutlmEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
"""
def __init__(self, config):
super(LayoutlmEmbeddings, self).__init__()
# self.word_embeddings = nn.Embedding(
# config.vocab_size, config.hidden_size)
self.only_layout = config.layoutlm_only_layout_flag
if not self.only_layout:
self.word_embeddings = nn.Embedding(
config.vocab_size, config.hidden_size, padding_idx=0
)
else:
self.word_embeddings = None
self.x_position_embeddings = nn.Embedding(
config.max_2d_position_embeddings, config.hidden_size
)
self.y_position_embeddings = nn.Embedding(
config.max_2d_position_embeddings, config.hidden_size
)
self.h_position_embeddings = nn.Embedding(
config.max_2d_position_embeddings, config.hidden_size
)
self.w_position_embeddings = nn.Embedding(
config.max_2d_position_embeddings, config.hidden_size
)
if config.no_segment_embedding:
self.token_type_embeddings = None
else:
self.token_type_embeddings = nn.Embedding(
config.type_vocab_size, config.hidden_size)
if hasattr(config, 'fp32_embedding'):
self.fp32_embedding = config.fp32_embedding
else:
self.fp32_embedding = False
if hasattr(config, 'new_pos_ids') and config.new_pos_ids:
self.num_pos_emb = 4
else:
self.num_pos_emb = 1
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.hidden_size * self.num_pos_emb)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids, bbox, token_type_ids=None, position_ids=None, task_idx=None):
seq_length = input_ids.size(1)
if position_ids is None:
position_ids = torch.arange(
seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
position_embeddings = self.position_embeddings(position_ids)
if self.num_pos_emb > 1:
num_batch = position_embeddings.size(0)
num_pos = position_embeddings.size(1)
position_embeddings = position_embeddings.view(
num_batch, num_pos, self.num_pos_emb, -1)[torch.arange(0, num_batch).long(), :, task_idx, :]
left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])
upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1])
right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2])
lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3])
h_position_embeddings = self.h_position_embeddings(
bbox[:, :, 3] - bbox[:, :, 1]
)
w_position_embeddings = self.w_position_embeddings(
bbox[:, :, 2] - bbox[:, :, 0]
)
# token_type_embeddings = self.token_type_embeddings(token_type_ids)
# words_embeddings = self.word_embeddings(input_ids)
# position_embeddings = self.position_embeddings(position_ids)
embeddings = (
# words_embeddings
position_embeddings
+ left_position_embeddings
+ upper_position_embeddings
+ right_position_embeddings
+ lower_position_embeddings
+ h_position_embeddings
+ w_position_embeddings
)
if not self.only_layout:
words_embeddings = self.word_embeddings(input_ids)
embeddings = embeddings + words_embeddings
if self.token_type_embeddings is not None:
embeddings = embeddings + self.token_type_embeddings(token_type_ids)
if self.fp32_embedding:
embeddings = embeddings.half()
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class BertSelfAttention(nn.Module):
def __init__(self, config):
super(BertSelfAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(
config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
if hasattr(config, 'num_qkv') and (config.num_qkv > 1):
self.num_qkv = config.num_qkv
else:
self.num_qkv = 1
self.query = nn.Linear(
config.hidden_size, self.all_head_size * self.num_qkv)
self.key = nn.Linear(config.hidden_size,
self.all_head_size * self.num_qkv)
self.value = nn.Linear(
config.hidden_size, self.all_head_size * self.num_qkv)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.uni_debug_flag = True if os.getenv(
'UNI_DEBUG_FLAG', '') else False
if self.uni_debug_flag:
self.register_buffer('debug_attention_probs',
torch.zeros((512, 512)))
if hasattr(config, 'seg_emb') and config.seg_emb:
self.b_q_s = nn.Parameter(torch.zeros(
1, self.num_attention_heads, 1, self.attention_head_size))
self.seg_emb = nn.Embedding(
config.type_vocab_size, self.all_head_size)
else:
self.b_q_s = None
self.seg_emb = None
def transpose_for_scores(self, x, mask_qkv=None):
if self.num_qkv > 1:
sz = x.size()[:-1] + (self.num_qkv,
self.num_attention_heads, self.all_head_size)
# (batch, pos, num_qkv, head, head_hid)
x = x.view(*sz)
if mask_qkv is None:
x = x[:, :, 0, :, :]
elif isinstance(mask_qkv, int):
x = x[:, :, mask_qkv, :, :]
else:
# mask_qkv: (batch, pos)
if mask_qkv.size(1) > sz[1]:
mask_qkv = mask_qkv[:, :sz[1]]
# -> x: (batch, pos, head, head_hid)
x = x.gather(2, mask_qkv.view(sz[0], sz[1], 1, 1, 1).expand(
sz[0], sz[1], 1, sz[3], sz[4])).squeeze(2)
else:
sz = x.size()[:-1] + (self.num_attention_heads,
self.attention_head_size)
# (batch, pos, head, head_hid)
x = x.view(*sz)
# (batch, head, pos, head_hid)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, attention_mask, history_states=None,
mask_qkv=None, seg_ids=None, key_history=None, value_history=None,
key_cache=None, value_cache=None,
):
if history_states is None:
mixed_query_layer = self.query(hidden_states)
# possible issue: https://github.com/NVIDIA/apex/issues/131
mixed_key_layer = F.linear(hidden_states, self.key.weight)
mixed_value_layer = self.value(hidden_states)
else:
x_states = torch.cat((history_states, hidden_states), dim=1)
mixed_query_layer = self.query(hidden_states)
# possible issue: https://github.com/NVIDIA/apex/issues/131
mixed_key_layer = F.linear(x_states, self.key.weight)
mixed_value_layer = self.value(x_states)
if key_cache is not None and isinstance(key_cache, list):
key_cache.append(mixed_key_layer)
mixed_key_layer = torch.cat(key_cache, dim=1)
if value_cache is not None and isinstance(value_cache, list):
value_cache.append(mixed_value_layer)
mixed_value_layer = torch.cat(value_cache, dim=1)
query_layer = self.transpose_for_scores(mixed_query_layer, mask_qkv)
key_layer = self.transpose_for_scores(mixed_key_layer, mask_qkv)
value_layer = self.transpose_for_scores(mixed_value_layer, mask_qkv)
if key_history is not None and not isinstance(key_history, list):
key_layer = torch.cat((key_history, key_layer), dim=-2)
value_layer = torch.cat((value_history, value_layer), dim=-2)
# Take the dot product between "query" and "key" to get the raw attention scores.
# (batch, head, pos, pos)
attention_scores = torch.matmul(
query_layer / math.sqrt(self.attention_head_size), key_layer.transpose(-1, -2))
if self.seg_emb is not None:
seg_rep = self.seg_emb(seg_ids)
# (batch, pos, head, head_hid)
seg_rep = seg_rep.view(seg_rep.size(0), seg_rep.size(
1), self.num_attention_heads, self.attention_head_size)
qs = torch.einsum('bnih,bjnh->bnij',
query_layer + self.b_q_s, seg_rep)
attention_scores = attention_scores + qs
# attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
if self.uni_debug_flag:
_pos = attention_probs.size(-1)
self.debug_attention_probs[:_pos, :_pos].copy_(
attention_probs[0].mean(0).view(_pos, _pos))
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[
:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
if isinstance(key_history, list):
key_history.append(key_layer)
if isinstance(value_history, list):
value_history.append(value_layer)
return context_layer
class BertSelfOutput(nn.Module):
def __init__(self, config):
super(BertSelfOutput, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertAttention(nn.Module):
def __init__(self, config):
super(BertAttention, self).__init__()
self.self = BertSelfAttention(config)
self.output = BertSelfOutput(config)
def forward(self, input_tensor, attention_mask, history_states=None,
mask_qkv=None, seg_ids=None, key_history=None, value_history=None):
self_output = self.self(
input_tensor, attention_mask, history_states=history_states,
mask_qkv=mask_qkv, seg_ids=seg_ids, key_history=key_history, value_history=value_history)
attention_output = self.output(self_output, input_tensor)
return attention_output
class BertIntermediate(nn.Module):
def __init__(self, config):
super(BertIntermediate, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.intermediate_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, str) else config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class BertOutput(nn.Module):
def __init__(self, config):
super(BertOutput, self).__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class TransformerFFN(nn.Module):
def __init__(self, config):
super(TransformerFFN, self).__init__()
self.ffn_type = config.ffn_type
assert self.ffn_type in (1, 2)
if self.ffn_type in (1, 2):
self.wx0 = nn.Linear(config.hidden_size, config.hidden_size)
if self.ffn_type in (2,):
self.wx1 = nn.Linear(config.hidden_size, config.hidden_size)
if self.ffn_type in (1, 2):
self.output = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, x):
if self.ffn_type in (1, 2):
x0 = self.wx0(x)
if self.ffn_type == 1:
x1 = x
elif self.ffn_type == 2:
x1 = self.wx1(x)
out = self.output(x0 * x1)
out = self.dropout(out)
out = self.LayerNorm(out + x)
return out
class BertLayer(nn.Module):
def __init__(self, config):
super(BertLayer, self).__init__()
self.attention = BertAttention(config)
self.ffn_type = config.ffn_type
if self.ffn_type:
self.ffn = TransformerFFN(config)
else:
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(self, hidden_states, attention_mask, history_states=None,
mask_qkv=None, seg_ids=None, key_history=None, value_history=None):
attention_output = self.attention(
hidden_states, attention_mask, history_states=history_states,
mask_qkv=mask_qkv, seg_ids=seg_ids, key_history=key_history, value_history=value_history)
if self.ffn_type:
layer_output = self.ffn(attention_output)
else:
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class BertEncoder(nn.Module):
def __init__(self, config):
super(BertEncoder, self).__init__()
layer = BertLayer(config)
self.layer = nn.ModuleList([copy.deepcopy(layer)
for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, prev_embedding=None,
prev_encoded_layers=None, mask_qkv=None, seg_ids=None, key_history=None, value_history=None):
# history embedding and encoded layer must be simultanously given
assert (prev_embedding is None) == (prev_encoded_layers is None)
all_encoder_layers = []
if (prev_embedding is not None) and (prev_encoded_layers is not None):
history_states = prev_embedding
for i, layer_module in enumerate(self.layer):
hidden_states = layer_module(
hidden_states, attention_mask, history_states=history_states, mask_qkv=mask_qkv, seg_ids=seg_ids)
if output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
if prev_encoded_layers is not None:
history_states = prev_encoded_layers[i]
else:
for i, layer_module in enumerate(self.layer):
set_key = None
if isinstance(key_history, list):
set_key = key_history if len(key_history) < len(self.layer) else key_history[i]
set_value = None
if isinstance(value_history, list):
set_value = value_history if len(key_history) < len(self.layer) else value_history[i]
hidden_states = layer_module(
hidden_states, attention_mask, mask_qkv=mask_qkv, seg_ids=seg_ids,
key_history=set_key, value_history=set_value)
if output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
if not output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
return all_encoder_layers
class BertPooler(nn.Module):
def __init__(self, config):
super(BertPooler, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super(BertPredictionHeadTransform, self).__init__()
self.transform_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, str) else config.hidden_act
hid_size = config.hidden_size
if hasattr(config, 'relax_projection') and (config.relax_projection > 1):
hid_size *= config.relax_projection
self.dense = nn.Linear(config.hidden_size, hid_size)
self.LayerNorm = BertLayerNorm(hid_size, eps=1e-5)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class LayoutlmSPLMPredictionHead(nn.Module):
def __init__(self, config, src_len):
super(LayoutlmSPLMPredictionHead, self).__init__()
self.transform = BertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.bias = nn.Parameter(torch.zeros(src_len))
if hasattr(config, 'relax_projection') and (config.relax_projection > 1):
self.relax_projection = config.relax_projection
else:
self.relax_projection = 0
self.fp32_embedding = config.fp32_embedding
def convert_to_type(tensor):
if self.fp32_embedding:
return tensor.half()
else:
return tensor
self.type_converter = convert_to_type
self.converted = False
def forward(self, hidden_states, src_emb, task_idx=None):
if not self.converted:
self.converted = True
if self.fp32_embedding:
self.transform.half()
hidden_states = self.transform(self.type_converter(hidden_states))
if self.relax_projection > 1:
num_batch = hidden_states.size(0)
num_pos = hidden_states.size(1)
# (batch, num_pos, relax_projection*hid) -> (batch, num_pos, relax_projection, hid) -> (batch, num_pos, hid)
hidden_states = hidden_states.view(
num_batch, num_pos, self.relax_projection, -1)[torch.arange(0, num_batch).long(), :, task_idx, :]
if self.fp32_embedding:
hidden_states = torch.einsum('btf,bsf->bts',
self.type_converter(hidden_states), self.type_converter(src_emb)) + \
self.type_converter(self.bias)
# hidden_states = F.linear(self.type_converter(hidden_states), self.type_converter(
# self.decoder.weight), self.type_converter(self.bias))
else:
hidden_states = torch.einsum('btf,bsf->bts', hidden_states, src_emb) + self.bias
return hidden_states
class LayoutlmSPPreTrainingHeads(nn.Module):
def __init__(self, config, src_len, num_labels=2):
super(LayoutlmSPPreTrainingHeads, self).__init__()
self.predictions = LayoutlmSPLMPredictionHead(config, src_len)
self.seq_relationship = nn.Linear(config.hidden_size, num_labels)
def forward(self, sequence_output, pooled_output, src_emb, task_idx=None):
prediction_scores = self.predictions(sequence_output, src_emb, task_idx)
if pooled_output is None:
seq_relationship_score = None
else:
seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score
class PreTrainedBertModel(nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, config, *inputs, **kwargs):
super(PreTrainedBertModel, self).__init__()
if not isinstance(config, BertConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
self.config = config
def init_bert_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
# module.weight.data.copy_(torch.Tensor(
# truncnorm.rvs(-1, 1, size=list(module.weight.data.shape)) * self.config.initializer_range))
elif isinstance(module, BertLayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
@classmethod
def from_pretrained(cls, pretrained_model_name, config, state_dict=None, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `bert-base-uncased`
. `bert-large-uncased`
. `bert-base-cased`
. `bert-base-multilingual`
. `bert-base-chinese`
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
"""
logger.info("Model config {}".format(config))
# clean the arguments in kwargs
for arg_clean in ('config_path', 'type_vocab_size', 'relax_projection', 'new_pos_ids', 'task_idx',
'max_position_embeddings', 'fp32_embedding', 'ffn_type', 'label_smoothing',
'hidden_dropout_prob', 'attention_probs_dropout_prob', 'num_qkv', 'seg_emb',
'word_emb_map', 'num_labels', 'num_rel', 'num_sentlvl_labels'):
if arg_clean in kwargs:
del kwargs[arg_clean]
# Instantiate model.
model = cls(config, *inputs, **kwargs)
if state_dict is None:
weights_path = os.path.join(pretrained_model_name, WEIGHTS_NAME)
state_dict = torch.load(weights_path, map_location='cpu')
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if 'gamma' in key:
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(
prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(model, prefix='' if hasattr(model, 'bert') else 'bert.')
model.missing_keys = missing_keys
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys))
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys))
if len(error_msgs) > 0:
logger.info('\n'.join(error_msgs))
return model
class BertModel(PreTrainedBertModel):
"""BERT model ("Bidirectional Embedding Representations from a Transformer").
Params:
config: a BertConfig class instance with the configuration to build a new model
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
Outputs: Tuple of (encoded_layers, pooled_output)
`encoded_layers`: controled by `output_all_encoded_layers` argument:
- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
to the last attention block of shape [batch_size, sequence_length, hidden_size],
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
classifier pretrained on top of the hidden state associated to the first character of the
input (`CLF`) to train on the Next-Sentence task (see BERT's paper).
```
"""
def __init__(self, config):
super(BertModel, self).__init__(config)
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config)
self.apply(self.init_bert_weights)
def rescale_some_parameters(self):
for layer_id, layer in enumerate(self.encoder.layer):
layer.attention.output.dense.weight.data.div_(
math.sqrt(2.0 * (layer_id + 1)))
layer.output.dense.weight.data.div_(math.sqrt(2.0 * (layer_id + 1)))
def get_extended_attention_mask(self, input_ids, token_type_ids, attention_mask):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
if attention_mask.dim() == 2:
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
elif attention_mask.dim() == 3:
extended_attention_mask = attention_mask.unsqueeze(1)
else:
raise NotImplementedError
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(
dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True,
mask_qkv=None, task_idx=None, key_history=None, value_history=None, position_ids=None):
extended_attention_mask = self.get_extended_attention_mask(
input_ids, token_type_ids, attention_mask)
embedding_output = self.embeddings(
input_ids, token_type_ids, task_idx=task_idx, position_ids=position_ids)
encoded_layers = self.encoder(embedding_output, extended_attention_mask,
output_all_encoded_layers=output_all_encoded_layers,
mask_qkv=mask_qkv, seg_ids=token_type_ids,
key_history=key_history, value_history=value_history)
sequence_output = encoded_layers[-1]
pooled_output = self.pooler(sequence_output)
if not output_all_encoded_layers:
encoded_layers = encoded_layers[-1]
return encoded_layers, pooled_output
class LayoutlmModel(PreTrainedBertModel):
"""BERT model ("Bidirectional Embedding Representations from a Transformer").
Params:
config: a BertConfig class instance with the configuration to build a new model
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
Outputs: Tuple of (encoded_layers, pooled_output)
`encoded_layers`: controled by `output_all_encoded_layers` argument:
- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
to the last attention block of shape [batch_size, sequence_length, hidden_size],
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
classifier pretrained on top of the hidden state associated to the first character of the
input (`CLF`) to train on the Next-Sentence task (see BERT's paper).
```
"""
def __init__(self, config):
super(LayoutlmModel, self).__init__(config)
self.embeddings = LayoutlmEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config)
self.apply(self.init_bert_weights)
def rescale_some_parameters(self):
for layer_id, layer in enumerate(self.encoder.layer):
layer.attention.output.dense.weight.data.div_(
math.sqrt(2.0 * (layer_id + 1)))
layer.output.dense.weight.data.div_(math.sqrt(2.0 * (layer_id + 1)))
def get_extended_attention_mask(self, input_ids, token_type_ids, attention_mask):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
if attention_mask.dim() == 2:
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
elif attention_mask.dim() == 3:
extended_attention_mask = attention_mask.unsqueeze(1)
else:
raise NotImplementedError
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(
dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True,
mask_qkv=None, task_idx=None, key_history=None, value_history=None, position_ids=None):
extended_attention_mask = self.get_extended_attention_mask(
input_ids[:, :, 0], token_type_ids, attention_mask)
embedding_output = self.embeddings(
input_ids[:, :, 0], input_ids[:, :, 1:], token_type_ids, task_idx=task_idx, position_ids=position_ids)
encoded_layers = self.encoder(embedding_output, extended_attention_mask,
output_all_encoded_layers=output_all_encoded_layers,
mask_qkv=mask_qkv, seg_ids=token_type_ids,
key_history=key_history, value_history=value_history)
sequence_output = encoded_layers[-1]
pooled_output = self.pooler(sequence_output)
if not output_all_encoded_layers:
encoded_layers = encoded_layers[-1]
return encoded_layers, pooled_output
class BertModelIncr(BertModel):
def __init__(self, config):
super(BertModelIncr, self).__init__(config)
def forward(self, input_ids, token_type_ids, position_ids, attention_mask, output_all_encoded_layers=True,
prev_embedding=None, prev_encoded_layers=None, mask_qkv=None, task_idx=None):
extended_attention_mask = self.get_extended_attention_mask(
input_ids, token_type_ids, attention_mask)
embedding_output = self.embeddings(
input_ids, token_type_ids, position_ids, task_idx=task_idx)
encoded_layers = self.encoder(embedding_output,
extended_attention_mask,
output_all_encoded_layers=output_all_encoded_layers,
prev_embedding=prev_embedding,
prev_encoded_layers=prev_encoded_layers, mask_qkv=mask_qkv,
seg_ids=token_type_ids)
sequence_output = encoded_layers[-1]
pooled_output = self.pooler(sequence_output)
if not output_all_encoded_layers:
encoded_layers = encoded_layers[-1]
return embedding_output, encoded_layers, pooled_output
class LayoutlmModelIncr(LayoutlmModel):
def __init__(self, config):
super(LayoutlmModelIncr, self).__init__(config)
def forward(self, input_ids, token_type_ids, position_ids, attention_mask, output_all_encoded_layers=True,
prev_embedding=None, prev_encoded_layers=None, mask_qkv=None, task_idx=None):
extended_attention_mask = self.get_extended_attention_mask(
input_ids[:, :, 0], token_type_ids, attention_mask)
embedding_output = self.embeddings(
input_ids[:, :, 0], input_ids[:, :, 1:], token_type_ids, position_ids, task_idx=task_idx)
encoded_layers = self.encoder(embedding_output,
extended_attention_mask,
output_all_encoded_layers=output_all_encoded_layers,
prev_embedding=prev_embedding,
prev_encoded_layers=prev_encoded_layers, mask_qkv=mask_qkv,
seg_ids=token_type_ids)
sequence_output = encoded_layers[-1]
pooled_output = self.pooler(sequence_output)
if not output_all_encoded_layers:
encoded_layers = encoded_layers[-1]
return embedding_output, encoded_layers, pooled_output
class LayoutlmForSeq2SeqDecoder(PreTrainedBertModel):
"""refer to BertForPreTraining"""
def __init__(self, config, mask_word_id=0, num_labels=2, num_rel=0,
search_beam_size=1, length_penalty=1.0, eos_id=0, sos_id=0,
forbid_duplicate_ngrams=False, forbid_ignore_set=None, ngram_size=3, min_len=0, mode="s2s",
pos_shift=False):
super(LayoutlmForSeq2SeqDecoder, self).__init__(config)
self.layout_flag = config.base_model_type == 'layoutlm'
if config.base_model_type == 'layoutlm':
self.bert = LayoutlmModelIncr(config)
else:
self.bert = BertModelIncr(config)
# self.bert = BertModelIncr(config)
# note: the max source length is the max src seq length during fine tuning which includes the cls and sep
# NOTE: we don't remove anything. the 0 is for padding
self.cls = LayoutlmSPPreTrainingHeads(
config, src_len=config.max_source_length, num_labels=num_labels)
self.apply(self.init_bert_weights)
self.crit_mask_lm = nn.CrossEntropyLoss(reduction='none')
self.crit_next_sent = nn.CrossEntropyLoss(ignore_index=-1)
self.mask_word_id = mask_word_id
self.num_labels = num_labels
self.num_rel = num_rel
self.search_beam_size = search_beam_size
self.length_penalty = length_penalty
self.eos_id = eos_id
self.sos_id = sos_id
self.forbid_duplicate_ngrams = forbid_duplicate_ngrams
self.forbid_ignore_set = forbid_ignore_set
self.ngram_size = ngram_size
self.min_len = min_len
assert mode in ("s2s", "l2r")
self.mode = mode
self.pos_shift = pos_shift
def forward(self, input_ids, token_type_ids, position_ids, attention_mask, task_idx=None, mask_qkv=None):
if self.search_beam_size > 1:
return self.beam_search(input_ids, token_type_ids, position_ids, attention_mask, task_idx=task_idx,
mask_qkv=mask_qkv)
input_shape = list(input_ids.size())
batch_size = input_shape[0]
input_length = input_shape[1]
output_shape = list(token_type_ids.size())
output_length = output_shape[1]
output_ids = []
prev_embedding = None
prev_encoded_layers = None
curr_ids = input_ids
if not self.layout_flag:
mask_ids = input_ids.new(batch_size, 1).fill_(self.mask_word_id)
else:
mask_ids = input_ids.new_zeros(batch_size, 1, 5)
mask_ids[:, :, 0] = self.mask_word_id
next_pos = input_length
if self.pos_shift:
if not self.layout_flag:
sos_ids = input_ids.new(batch_size, 1).fill_(self.sos_id)
else:
sos_ids = input_ids.new_zeros(batch_size, 1, 5)
sos_ids[:, :, 0] = self.sos_id
src_embedding = None
while next_pos < output_length:
curr_length = list(curr_ids.size())[1]
if self.pos_shift:
if next_pos == input_length:
x_input_ids = torch.cat((curr_ids, sos_ids), dim=1)
start_pos = 0
else:
x_input_ids = curr_ids
start_pos = next_pos
else:
start_pos = next_pos - curr_length
# if self.layout_flag:
# mask_ids[:, -1, 1:] = curr_ids[:, , 1:]
x_input_ids = torch.cat((curr_ids, mask_ids), dim=1)
curr_token_type_ids = token_type_ids[:, start_pos:next_pos + 1]
curr_attention_mask = attention_mask[:,
start_pos:next_pos + 1, :next_pos + 1]
curr_position_ids = position_ids[:, start_pos:next_pos + 1]
new_embedding, new_encoded_layers, _ = \
self.bert(x_input_ids, curr_token_type_ids, curr_position_ids, curr_attention_mask,
output_all_encoded_layers=True, prev_embedding=prev_embedding,
prev_encoded_layers=prev_encoded_layers, mask_qkv=mask_qkv)
if src_embedding is None:
# note: cut three embedding: CLS (1st), ..., SEP (-2nd), next to pred (-1st)
# note: (NEW) the sep is kept for ignore index in loss func (for padding's index)
# NOTE: only remove the next to pred token
src_embedding = new_embedding[:, :-1, :]
last_hidden = new_encoded_layers[-1][:, -1:, :]
prediction_scores, _ = self.cls(last_hidden, None, src_embedding, task_idx=task_idx)
_, max_ids = torch.max(prediction_scores, dim=-1)
output_ids.append(max_ids)
if self.pos_shift:
if prev_embedding is None:
prev_embedding = new_embedding
else:
prev_embedding = torch.cat(
(prev_embedding, new_embedding), dim=1)
if prev_encoded_layers is None:
prev_encoded_layers = [x for x in new_encoded_layers]
else:
prev_encoded_layers = [torch.cat((x[0], x[1]), dim=1) for x in zip(
prev_encoded_layers, new_encoded_layers)]
else:
if prev_embedding is None:
prev_embedding = new_embedding[:, :-1, :]
else:
prev_embedding = torch.cat(
(prev_embedding, new_embedding[:, :-1, :]), dim=1)
if prev_encoded_layers is None:
prev_encoded_layers = [x[:, :-1, :]
for x in new_encoded_layers]
else:
prev_encoded_layers = [torch.cat((x[0], x[1][:, :-1, :]), dim=1)
for x in zip(prev_encoded_layers, new_encoded_layers)]
if not self.layout_flag:
index = max_ids
curr_ids = torch.gather(input_ids, 1, index)
else:
_, _, dim = input_ids.shape
index = max_ids.unsqueeze(-1)
index = index.expand(index.shape[0], index.shape[1], dim)
# index = index.repeat(1, 1, dim)
curr_ids = torch.gather(input_ids, 1, index)
# if len(input_ids.shape) == 2:
# real_input_ids = input_ids[:, 1:]
# index = max_ids
# curr_ids = torch.gather(real_input_ids, 1, index)
# else:
# real_input_ids = input_ids[:, 1:, :]
# _, _, dim = real_input_ids.shape
# index = max_ids.unsqueeze(-1)
# index = index.expand(index.shape[0], index.shape[1], dim)
# curr_ids = torch.gather(real_input_ids, 1, index)
# # note: real input ids only include the ids for real data (remove the cls and sep)
# real_input_ids = input_ids[:, 1: -1, :]
#
# _, _, dim = real_input_ids.shape
# index = max_ids.unsqueeze(-1)
# index = index.expand(index.shape[0], index.shape[1], dim)
#
# curr_ids = torch.gather(real_input_ids, 1, index)
# curr_ids = real_input_ids[:, max_ids, :]
# curr_ids = max_ids
next_pos += 1
return torch.cat(output_ids, dim=1)
# TODO: do the same with beam search as forward()
def beam_search(self, input_ids, token_type_ids, position_ids, attention_mask, task_idx=None, mask_qkv=None):
input_shape = list(input_ids.size())
batch_size = input_shape[0]
input_length = input_shape[1]
output_shape = list(token_type_ids.size())
output_length = output_shape[1]
output_ids = []
prev_embedding = None
prev_encoded_layers = None
curr_ids = input_ids
# mask_ids = input_ids.new(batch_size, 1).fill_(self.mask_word_id)
if not self.layout_flag:
mask_ids = input_ids.new(batch_size, 1).fill_(self.mask_word_id)
else:
mask_ids = input_ids.new_zeros(batch_size, 1, 5)
mask_ids[:, :, 0] = self.mask_word_id
next_pos = input_length
if self.pos_shift:
if not self.layout_flag:
sos_ids = input_ids.new(batch_size, 1).fill_(self.sos_id)
else:
sos_ids = input_ids.new_zeros(batch_size, 1, 5)
sos_ids[:, :, 0] = self.sos_id
K = self.search_beam_size
total_scores = []
beam_masks = []
step_ids = []
step_back_ptrs = []
partial_seqs = []
forbid_word_mask = None
buf_matrix = None
src_embedding = None
while next_pos < output_length:
curr_length = list(curr_ids.size())[1]
if self.pos_shift:
if next_pos == input_length:
x_input_ids = torch.cat((curr_ids, sos_ids), dim=1)
start_pos = 0
else:
x_input_ids = curr_ids
start_pos = next_pos
else:
start_pos = next_pos - curr_length
x_input_ids = torch.cat((curr_ids, mask_ids), dim=1)
curr_token_type_ids = token_type_ids[:, start_pos:next_pos + 1]
curr_attention_mask = attention_mask[:,
start_pos:next_pos + 1, :next_pos + 1]
curr_position_ids = position_ids[:, start_pos:next_pos + 1]
new_embedding, new_encoded_layers, _ = \
self.bert(x_input_ids, curr_token_type_ids, curr_position_ids, curr_attention_mask,
output_all_encoded_layers=True, prev_embedding=prev_embedding,
prev_encoded_layers=prev_encoded_layers, mask_qkv=mask_qkv)
def first_expand(x):
input_shape = list(x.size())
expanded_shape = input_shape[:1] + [1] + input_shape[1:]
x = torch.reshape(x, expanded_shape)
repeat_count = [1, K] + [1] * (len(input_shape) - 1)
x = x.repeat(*repeat_count)
x = torch.reshape(x, [input_shape[0] * K] + input_shape[1:])
return x
if src_embedding is None:
src_embedding = new_embedding[:, :-1, :]
if src_embedding.shape[0] != new_embedding.shape[0]:
src_embedding = first_expand(src_embedding)
last_hidden = new_encoded_layers[-1][:, -1:, :]
prediction_scores, _ = self.cls(last_hidden, None, src_embedding, task_idx=task_idx)
log_scores = torch.nn.functional.log_softmax(
prediction_scores, dim=-1)
# if forbid_word_mask is not None:
# log_scores += (forbid_word_mask * -10000.0)
# if self.min_len and (next_pos - input_length + 1 <= self.min_len):
# log_scores[:, :, self.eos_id].fill_(-10000.0)
kk_scores, kk_ids = torch.topk(log_scores, k=K)
if len(total_scores) == 0:
k_ids = torch.reshape(kk_ids, [batch_size, K])
back_ptrs = torch.zeros(batch_size, K, dtype=torch.long)
k_scores = torch.reshape(kk_scores, [batch_size, K])
else:
last_eos = torch.reshape(
beam_masks[-1], [batch_size * K, 1, 1])
last_seq_scores = torch.reshape(
total_scores[-1], [batch_size * K, 1, 1])
kk_scores += last_eos * (-10000.0) + last_seq_scores
kk_scores = torch.reshape(kk_scores, [batch_size, K * K])
k_scores, k_ids = torch.topk(kk_scores, k=K)
back_ptrs = torch.floor_divide(k_ids, K)
kk_ids = torch.reshape(kk_ids, [batch_size, K * K])
k_ids = torch.gather(kk_ids, 1, k_ids)
step_back_ptrs.append(back_ptrs)
step_ids.append(k_ids)
beam_masks.append(torch.eq(k_ids, self.eos_id).type_as(kk_scores))
total_scores.append(k_scores)
# def first_expand(x):
# input_shape = list(x.size())
# expanded_shape = input_shape[:1] + [1] + input_shape[1:]
# x = torch.reshape(x, expanded_shape)
# repeat_count = [1, K] + [1] * (len(input_shape) - 1)
# x = x.repeat(*repeat_count)
# x = torch.reshape(x, [input_shape[0] * K] + input_shape[1:])
# return x
def select_beam_items(x, ids):
id_shape = list(ids.size())
id_rank = len(id_shape)
assert len(id_shape) == 2
x_shape = list(x.size())
x = torch.reshape(x, [batch_size, K] + x_shape[1:])
x_rank = len(x_shape) + 1
assert x_rank >= 2
if id_rank < x_rank:
ids = torch.reshape(
ids, id_shape + [1] * (x_rank - id_rank))
ids = ids.expand(id_shape + x_shape[1:])
y = torch.gather(x, 1, ids)
y = torch.reshape(y, x_shape)
return y
is_first = (prev_embedding is None)
if self.pos_shift:
if prev_embedding is None:
prev_embedding = first_expand(new_embedding)
else:
prev_embedding = torch.cat(
(prev_embedding, new_embedding), dim=1)
prev_embedding = select_beam_items(
prev_embedding, back_ptrs)
if prev_encoded_layers is None:
prev_encoded_layers = [first_expand(
x) for x in new_encoded_layers]
else:
prev_encoded_layers = [torch.cat((x[0], x[1]), dim=1) for x in zip(
prev_encoded_layers, new_encoded_layers)]
prev_encoded_layers = [select_beam_items(
x, back_ptrs) for x in prev_encoded_layers]
else:
if prev_embedding is None:
prev_embedding = first_expand(new_embedding[:, :-1, :])
else:
prev_embedding = torch.cat(
(prev_embedding, new_embedding[:, :-1, :]), dim=1)
prev_embedding = select_beam_items(
prev_embedding, back_ptrs)
if prev_encoded_layers is None:
prev_encoded_layers = [first_expand(
x[:, :-1, :]) for x in new_encoded_layers]
else:
prev_encoded_layers = [torch.cat((x[0], x[1][:, :-1, :]), dim=1)
for x in zip(prev_encoded_layers, new_encoded_layers)]
prev_encoded_layers = [select_beam_items(
x, back_ptrs) for x in prev_encoded_layers]
max_ids = torch.reshape(k_ids, [batch_size * K, 1])
if len(input_ids.shape) == 2:
expand_input_ids = first_expand(input_ids)
index = max_ids
curr_ids = torch.gather(expand_input_ids, 1, index)
else:
expand_input_ids = first_expand(input_ids)
_, _, dim = expand_input_ids.shape
index = max_ids.unsqueeze(-1)
index = index.expand(index.shape[0], index.shape[1], dim)
curr_ids = torch.gather(expand_input_ids, 1, index)
if is_first:
token_type_ids = first_expand(token_type_ids)
position_ids = first_expand(position_ids)
attention_mask = first_expand(attention_mask)
mask_ids = first_expand(mask_ids)
if mask_qkv is not None:
mask_qkv = first_expand(mask_qkv)
if self.forbid_duplicate_ngrams:
wids = step_ids[-1].tolist()
ptrs = step_back_ptrs[-1].tolist()
if is_first:
partial_seqs = []
for b in range(batch_size):
for k in range(K):
partial_seqs.append([wids[b][k]])
else:
new_partial_seqs = []
for b in range(batch_size):
for k in range(K):
new_partial_seqs.append(
partial_seqs[ptrs[b][k] + b * K] + [wids[b][k]])
partial_seqs = new_partial_seqs
def get_dup_ngram_candidates(seq, n):
cands = set()
if len(seq) < n:
return []
tail = seq[-(n - 1):]
if self.forbid_ignore_set and any(tk in self.forbid_ignore_set for tk in tail):
return []
for i in range(len(seq) - (n - 1)):
mismatch = False
for j in range(n - 1):
if tail[j] != seq[i + j]:
mismatch = True
break
if (not mismatch) and not (
self.forbid_ignore_set and (seq[i + n - 1] in self.forbid_ignore_set)):
cands.add(seq[i + n - 1])
return list(sorted(cands))
if len(partial_seqs[0]) >= self.ngram_size:
dup_cands = []
for seq in partial_seqs:
dup_cands.append(
get_dup_ngram_candidates(seq, self.ngram_size))
if max(len(x) for x in dup_cands) > 0:
if buf_matrix is None:
vocab_size = list(log_scores.size())[-1]
buf_matrix = np.zeros(
(batch_size * K, vocab_size), dtype=float)
else:
buf_matrix.fill(0)
for bk, cands in enumerate(dup_cands):
for i, wid in enumerate(cands):
buf_matrix[bk, wid] = 1.0
forbid_word_mask = torch.tensor(
buf_matrix, dtype=log_scores.dtype)
forbid_word_mask = torch.reshape(
forbid_word_mask, [batch_size * K, 1, vocab_size]).to(input_ids.device)
else:
forbid_word_mask = None
next_pos += 1
# [(batch, beam)]
total_scores = [x.tolist() for x in total_scores]
step_ids = [x.tolist() for x in step_ids]
step_back_ptrs = [x.tolist() for x in step_back_ptrs]
# back tracking
traces = {'pred_seq': [], 'scores': [], 'wids': [], 'ptrs': []}
for b in range(batch_size):
# [(beam,)]
scores = [x[b] for x in total_scores]
wids_list = [x[b] for x in step_ids]
ptrs = [x[b] for x in step_back_ptrs]
traces['scores'].append(scores)
traces['wids'].append(wids_list)
traces['ptrs'].append(ptrs)
# first we need to find the eos frame where all symbols are eos
# any frames after the eos frame are invalid
last_frame_id = len(scores) - 1
for i, wids in enumerate(wids_list):
if all(wid == self.eos_id for wid in wids):
last_frame_id = i
break
max_score = -math.inf
frame_id = -1
pos_in_frame = -1
for fid in range(last_frame_id + 1):
for i, wid in enumerate(wids_list[fid]):
if wid == self.eos_id or fid == last_frame_id:
s = scores[fid][i]
if self.length_penalty > 0:
s /= math.pow((5 + fid + 1) / 6.0,
self.length_penalty)
if s > max_score:
max_score = s
frame_id = fid
pos_in_frame = i
if frame_id == -1:
traces['pred_seq'].append([0])
else:
seq = [wids_list[frame_id][pos_in_frame]]
for fid in range(frame_id, 0, -1):
pos_in_frame = ptrs[fid][pos_in_frame]
seq.append(wids_list[fid - 1][pos_in_frame])
seq.reverse()
traces['pred_seq'].append(seq)
def _pad_sequence(sequences, max_len, padding_value=0):
trailing_dims = sequences[0].size()[1:]
out_dims = (len(sequences), max_len) + trailing_dims
out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
for i, tensor in enumerate(sequences):
length = tensor.size(0)
# use index notation to prevent duplicate references to the tensor
out_tensor[i, :length, ...] = tensor
return out_tensor
# convert to tensors for DataParallel
for k in ('pred_seq', 'scores', 'wids', 'ptrs'):
ts_list = traces[k]
if not isinstance(ts_list[0], torch.Tensor):
dt = torch.float if k == 'scores' else torch.long
ts_list = [torch.tensor(it, dtype=dt) for it in ts_list]
traces[k] = _pad_sequence(
ts_list, output_length, padding_value=0).to(input_ids.device)
return traces