Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2021 The Deeplab2 Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Implements dual path transformer layers proposed in MaX-DeepLab [1]. | |
Dual-path transformer introduces a global memory path in addition to a CNN path, | |
allowing bi-directional communication with any CNN layers. | |
[1] MaX-DeepLab: End-to-End Panoptic Segmentation with Mask Transformers, | |
CVPR 2021. | |
Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen. | |
""" | |
import tensorflow as tf | |
from deeplab2.model import utils | |
from deeplab2.model.layers import activations | |
from deeplab2.model.layers import convolutions | |
class AttentionOperation(tf.keras.layers.Layer): | |
"""Computes standard 1D multi-head attention with query, key, and value.""" | |
def __init__(self, | |
name, | |
activation, | |
transformer_activation, | |
bn_layer=tf.keras.layers.BatchNormalization): | |
"""Initializes an AttentionOperation layer. | |
Args: | |
name: A string, the name of this layer. | |
activation: A string, type of activation function to apply. | |
transformer_activation: A string, type of activation function for | |
self-attention. Support 'sigmoid' and 'softmax'. | |
bn_layer: An optional tf.keras.layers.Layer that computes the | |
normalization (default: tf.keras.layers.BatchNormalization). | |
""" | |
super(AttentionOperation, self).__init__(name=name) | |
# batch_norm_similarity has shape [batch, num_heads, num_query, num_key], | |
# where num_query and num_key usually equals to height or width or length, | |
# i.e., spatial dimensions, so batch norm is applied to axis=1 only. | |
self._batch_norm_similarity = bn_layer(axis=1, name='batch_norm_similarity') | |
# batch_norm_retrieved_value is done on shape [batch, num_heads, length, | |
# value_channels], which will be reshaped to the output shape [batch, | |
# length, value_channels * num_heads], so we apply batch norm on the | |
# effective channel dimension -- value_channels * num_heads. | |
self._batch_norm_retrieved_value = bn_layer( | |
axis=[1, 3], name='batch_norm_retrieved_value') | |
self._activation_fn = activations.get_activation(activation) | |
self._transformer_activation_fn = activations.get_activation( | |
transformer_activation) | |
def call(self, inputs, training=False): | |
"""Performs an AttentionOperation. | |
Args: | |
inputs: A tuple of (query, key, value), where query is [batch, num_head, | |
query_length, channels] tensor, key is a [batch, num_head, key_length, | |
channels] tensor, and value is a [batch, key_length, num_head, | |
value_channels] tensor. | |
training: A boolean, whether the model is in training mode. | |
Returns: | |
output: A [batch, query_length, num_head * value_channels] tensor, the | |
retrieved value. | |
""" | |
# Decode query, key, and value from inputs. | |
query, key, value = inputs | |
# Compute attention similarity. | |
similarity_logits = tf.einsum('bhld,bhmd->bhlm', query, key) | |
similarity_logits = self._batch_norm_similarity( | |
similarity_logits, training=training) | |
# Apply a transformer attention activation function, e.g. softmax. | |
attention_weights = self._transformer_activation_fn(similarity_logits) | |
# Retrieve the value content. | |
retrieved_value = tf.einsum( | |
'bhlm,bmhd->bhld', attention_weights, value) | |
retrieved_value = self._batch_norm_retrieved_value( | |
retrieved_value, training=training) | |
retrieved_value = self._activation_fn(retrieved_value) | |
# Reshape the output. | |
return utils.transpose_and_reshape_for_attention_operation( | |
retrieved_value) | |
class DualPathTransformerLayer(tf.keras.layers.Layer): | |
"""Applies a dual path transformer layer, as proposed in MaX-DeepLab [1]. | |
Dual-path transformer layer takes a pixel space input and a memory space | |
input, and performs memory2pixel attention, pixel2memory attention, and | |
memory2memory self-attention. Note that the pixel2pixel self-attention or | |
convolution in the pixel space is implemented in axial_layers.py and | |
axial_blocks.py. Thus, the pixel2pixel operation is not included in this | |
DualPathTransformerLayer implementation. Please use this class together with | |
a residual block with axial-attention, global-attention, or convolution in | |
order to construct the full dual path transformer in the paper. | |
[1] MaX-DeepLab: End-to-End Panoptic Segmentation with Mask Transformers, | |
CVPR 2021. | |
Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen. | |
""" | |
def __init__(self, | |
name='dual_path_transformer_layer', | |
activation='relu', | |
filters=128, | |
num_heads=8, | |
bottleneck_expansion=2, | |
key_expansion=1, | |
value_expansion=2, | |
feed_forward_network_channels=2048, | |
use_memory_self_attention=True, | |
use_pixel2memory_feedback_attention=True, | |
transformer_activation='softmax', | |
bn_layer=tf.keras.layers.BatchNormalization, | |
conv_kernel_weight_decay=0.0): | |
"""Initializes a DualPathTransformerLayer. | |
This function implements a dual path transformer layer between a pixel space | |
and a memory space, as described in the MaX-DeepLab paper. In this dual path | |
transformer, the memory2pixel cross attention and the memory self-attention | |
share a single activation, e.g. softmax. | |
Reference: | |
MaX-DeepLab: "End-to-End Panoptic Segmentation with Mask Transformers", | |
CVPR 2021. https://arxiv.org/abs/2012.00759 | |
Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen. | |
Args: | |
name: A string, the name of this dual path transformer layer. | |
activation: A string, type of activation function to apply. | |
filters: An integer, the base number of channels for the layer. | |
num_heads: An integer, the number of heads in multi-head attention. | |
bottleneck_expansion: A float, the channel expansion ratio for the | |
bottleneck. | |
key_expansion: A float, the channel expansion ratio for keys. | |
value_expansion: A float, the channel expansion ratio for values. | |
feed_forward_network_channels: An integer, the number of channels for the | |
feed_forward_network. Zero means no feed_forward_network will be | |
applied. | |
use_memory_self_attention: A boolean, whether to apply the memory space | |
self-attention. | |
use_pixel2memory_feedback_attention: A boolean, whether to apply the | |
pixel2memory feedback attention. | |
transformer_activation: A string, type of activation function for | |
self-attention. Support 'sigmoid' and 'softmax'. | |
bn_layer: A tf.keras.layers.Layer that computes the normalization | |
(default: tf.keras.layers.BatchNormalization). | |
conv_kernel_weight_decay: A float, the weight decay for convolution | |
kernels. | |
Raises: | |
ValueError: If filters * key_expansion is not divisible by num_heads. | |
ValueError: If filters * value_expansion is not divisible by num_heads. | |
""" | |
super(DualPathTransformerLayer, self).__init__(name=name) | |
bottleneck_channels = int(round(filters * bottleneck_expansion)) | |
total_key_depth = int(round(filters * key_expansion)) | |
total_value_depth = int(round(filters * value_expansion)) | |
if total_key_depth % num_heads: | |
raise ValueError('Total_key_depth should be divisible by num_heads.') | |
if total_value_depth % num_heads: | |
raise ValueError('Total_value_depth should be divisible by num_heads.') | |
# Compute query key value with one convolution and a batch norm layer. The | |
# initialization std is standard transformer initialization (without batch | |
# norm), as used in SASA and ViT. In our case, we use batch norm by default, | |
# so it does not require careful tuning. If one wants to remove all batch | |
# norms in axial attention, this standard initialization should still be | |
# good, but a more careful initialization is encouraged. | |
initialization_std = bottleneck_channels ** -0.5 | |
self._memory_conv1_bn_act = convolutions.Conv1D( | |
bottleneck_channels, 'memory_conv1_bn_act', | |
use_bias=False, | |
use_bn=True, | |
bn_layer=bn_layer, | |
activation=activation, | |
conv_kernel_weight_decay=conv_kernel_weight_decay) | |
self._pixel_conv1_bn_act = convolutions.Conv1D( | |
bottleneck_channels, 'pixel_conv1_bn_act', | |
use_bias=False, | |
use_bn=True, | |
bn_layer=bn_layer, | |
activation=activation, | |
conv_kernel_weight_decay=conv_kernel_weight_decay) | |
# We always compute the query for memory space, since it gathers information | |
# from the pixel space and thus cannot be removed. We compute the key and | |
# value for memory space only when they are necessary (i.e. either | |
# use_memory_self_attention or use_pixel2memory_feedback_attention). | |
if use_memory_self_attention or use_pixel2memory_feedback_attention: | |
self._memory_qkv_conv_bn = convolutions.Conv1D( | |
total_key_depth * 2 + total_value_depth, 'memory_qkv_conv_bn', | |
use_bias=False, | |
use_bn=True, | |
bn_layer=bn_layer, | |
activation='none', | |
conv_kernel_weight_decay=conv_kernel_weight_decay, | |
kernel_initializer=tf.keras.initializers.TruncatedNormal( | |
stddev=initialization_std)) | |
else: | |
# Compute memory query only if memory key and value are not used. | |
self._memory_query_conv_bn = convolutions.Conv1D( | |
total_key_depth, 'memory_query_conv_bn', | |
use_bias=False, | |
use_bn=True, | |
bn_layer=bn_layer, | |
activation='none', | |
conv_kernel_weight_decay=conv_kernel_weight_decay, | |
kernel_initializer=tf.keras.initializers.TruncatedNormal( | |
stddev=initialization_std)) | |
# For the pixel space, we always compute the key and value, since they | |
# provide information for the memory space and thus cannot be removed. We | |
# compute the query for pixel space only when it is necessary (i.e. | |
# use_pixel2memory_feedback_attention is True). | |
if use_pixel2memory_feedback_attention: | |
self._pixel_qkv_conv_bn = convolutions.Conv1D( | |
total_key_depth * 2 + total_value_depth, 'pixel_qkv_conv_bn', | |
use_bias=False, | |
use_bn=True, | |
bn_layer=bn_layer, | |
activation='none', | |
conv_kernel_weight_decay=conv_kernel_weight_decay, | |
kernel_initializer=tf.keras.initializers.TruncatedNormal( | |
stddev=initialization_std)) | |
else: | |
self._pixel_kv_conv_bn = convolutions.Conv1D( | |
total_key_depth + total_value_depth, 'pixel_kv_conv_bn', | |
use_bias=False, | |
use_bn=True, | |
bn_layer=bn_layer, | |
activation='none', | |
conv_kernel_weight_decay=conv_kernel_weight_decay, | |
kernel_initializer=tf.keras.initializers.TruncatedNormal( | |
stddev=initialization_std)) | |
self._memory_attention = AttentionOperation( | |
'memory_attention', activation, transformer_activation, | |
bn_layer=bn_layer) | |
if use_pixel2memory_feedback_attention: | |
self._pixel_attention = AttentionOperation( | |
'pixel_attention', activation, transformer_activation, | |
bn_layer=bn_layer) | |
self._use_memory_self_attention = use_memory_self_attention | |
self._use_pixel2memory_feedback_attention = ( | |
use_pixel2memory_feedback_attention) | |
self._total_key_depth = total_key_depth | |
self._total_value_depth = total_value_depth | |
self._num_heads = num_heads | |
self._bn_layer = bn_layer | |
self._conv_kernel_weight_decay = conv_kernel_weight_decay | |
self._activation = activation | |
self._activation_fn = activations.get_activation(activation) | |
self._feed_forward_network_channels = feed_forward_network_channels | |
def build(self, input_shape_list): | |
pixel_shape, memory_shape = input_shape_list[:2] | |
# Here we follow ResNet bottleneck blocks: we apply a batch norm with gamma | |
# initialized at zero, followed by drop path and an activation function. | |
# Initializing this gamma at zero ensures that at random initialization of | |
# the model, the skip connections dominate all residual blocks. In this way, | |
# all the skip connections construct an identity mapping that passes the | |
# gradients (without any distortion from the randomly initialized blocks) to | |
# all residual blocks. This helps training at early epochs. | |
# Reference: "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour". | |
# https://arxiv.org/abs/1706.02677 | |
self._memory_conv3_bn = convolutions.Conv1D( | |
memory_shape[-1], 'memory_conv3_bn', | |
use_bias=False, | |
use_bn=True, | |
bn_layer=self._bn_layer, | |
bn_gamma_initializer='zeros', | |
activation='none', | |
conv_kernel_weight_decay=self._conv_kernel_weight_decay) | |
if self._feed_forward_network_channels > 0: | |
self._memory_ffn_conv1_bn_act = convolutions.Conv1D( | |
self._feed_forward_network_channels, 'memory_ffn_conv1_bn_act', | |
use_bias=False, | |
use_bn=True, | |
bn_layer=self._bn_layer, | |
activation=self._activation, | |
conv_kernel_weight_decay=self._conv_kernel_weight_decay) | |
# Again, we follow ResNet bottleneck blocks: we apply a batch norm with | |
# gamma initialized at zero, followed by drop path and an activation | |
# function. | |
self._memory_ffn_conv2_bn = convolutions.Conv1D( | |
memory_shape[-1], 'memory_ffn_conv2_bn', | |
use_bias=False, | |
use_bn=True, | |
bn_layer=self._bn_layer, | |
bn_gamma_initializer='zeros', | |
activation='none', | |
conv_kernel_weight_decay=self._conv_kernel_weight_decay) | |
if self._use_pixel2memory_feedback_attention: | |
self._pixel_conv3_bn = convolutions.Conv1D( | |
pixel_shape[-1], 'pixel_conv3_bn', | |
use_bias=False, | |
use_bn=True, | |
bn_layer=self._bn_layer, | |
bn_gamma_initializer='zeros', | |
activation='none', | |
conv_kernel_weight_decay=self._conv_kernel_weight_decay) | |
def call(self, inputs): | |
"""Performs a forward pass. | |
We have to define drop_path_masks outside the layer call and pass it into | |
the layer call, because recompute_grad (gradient checkpointing) does not | |
allow any randomness within the function call. In addition, recompute_grad | |
only supports float tensors as inputs. For this reason, the training flag | |
should be also passed as a float tensor. For the same reason, we cannot | |
support passing drop_path_random_mask as None. Instead, we ask the users to | |
pass only the first two tensors when drop path is not used. | |
Args: | |
inputs: A tuple of 3 or 6 tensors, containing | |
pixel_space_input should be a [batch, num_pixel, pixel_space_channels] | |
tensor. | |
memory_space_input should be a [batch, num_memory, | |
memory_space_channels] tensor. | |
float_tensor_training should be a float tensor of 0.0 or 1.0, whether | |
the model is in training mode. | |
(optional) pixel_space_drop_path_mask is a drop path mask tensor of | |
shape [batch, 1, 1] for the pixel space. | |
(optional) memory_space_attention_drop_path_mask is a drop path mask | |
tensor of shape [batch, 1, 1] for the memory space. | |
(optional) memory_space_feed_forward_network_drop_path_mask is a drop | |
path mask tensor of shape [batch, 1, 1] for the memory space feed | |
forward network. | |
Returns: | |
pixel_space_output: A [batch, num_pixel, pixel_space_channels] tensor. | |
activated_pixel_space_output: A [batch, num_pixel, pixel_space_channels] | |
tensor, activated pixel_space_output. | |
memory_space_output: A [batch, num_memory, memory_space_channels] | |
tensor. | |
Raises: | |
ValueError: If the length of inputs is not 3 or 6. | |
""" | |
if len(inputs) not in (3, 6): | |
raise ValueError('The length of inputs should be either 3 or 6.') | |
# Unpack the inputs. | |
(pixel_space_input, memory_space_input, float_tensor_training, | |
pixel_space_drop_path_mask, memory_space_attention_drop_path_mask, | |
memory_space_feed_forward_network_drop_path_mask) = ( | |
utils.pad_sequence_with_none(inputs, target_length=6)) | |
# Recompute_grad takes only float tensors as inputs. It does not allow | |
# bools or boolean tensors. For this reason, we cast training to a float | |
# tensor outside this call, and now we cast it back to a boolean tensor. | |
training = tf.cast(float_tensor_training, tf.bool) | |
# Decode the inputs shapes. | |
pixel_shape = pixel_space_input.get_shape().as_list() | |
memory_shape = memory_space_input.get_shape().as_list() | |
# Similar to the ResNet bottleneck design, we do an input down projection | |
# in both the pixel space and the memory space. | |
memory_space = self._memory_conv1_bn_act(memory_space_input, | |
training=training) | |
# Pixel space input is not activated. | |
pixel_space = self._pixel_conv1_bn_act( | |
self._activation_fn(pixel_space_input), training=training) | |
if (self._use_memory_self_attention or | |
self._use_pixel2memory_feedback_attention): | |
memory_space_qkv = self._memory_qkv_conv_bn(memory_space, | |
training=training) | |
# Split, reshape, and transpose the query, key, and value. | |
memory_query, memory_key, memory_value = ( | |
tf.split(memory_space_qkv, [ | |
self._total_key_depth, self._total_key_depth, | |
self._total_value_depth], axis=-1)) | |
memory_key = utils.reshape_and_transpose_for_attention_operation( | |
memory_key, self._num_heads) | |
memory_value = tf.reshape(memory_value, [ | |
-1, memory_shape[1], self._num_heads, | |
self._total_value_depth // self._num_heads]) | |
else: | |
# Compute memory query only if memory key and value are not used. | |
memory_query = self._memory_query_conv_bn(memory_space, | |
training=training) | |
# Reshape and transpose the query. | |
memory_query = utils.reshape_and_transpose_for_attention_operation( | |
memory_query, self._num_heads) | |
if self._use_pixel2memory_feedback_attention: | |
pixel_space_qkv = self._pixel_qkv_conv_bn(pixel_space, | |
training=training) | |
# Split the query, key, and value. | |
pixel_query, pixel_key, pixel_value = tf.split( | |
pixel_space_qkv, [ | |
self._total_key_depth, self._total_key_depth, | |
self._total_value_depth], axis=-1) | |
pixel_query = utils.reshape_and_transpose_for_attention_operation( | |
pixel_query, self._num_heads) | |
else: | |
pixel_space_kv = self._pixel_kv_conv_bn(pixel_space, training=training) | |
# Split the key and the value. | |
pixel_key, pixel_value = tf.split(pixel_space_kv, [ | |
self._total_key_depth, self._total_value_depth], axis=-1) | |
# Reshape and transpose the key and the value. | |
pixel_key = utils.reshape_and_transpose_for_attention_operation( | |
pixel_key, self._num_heads) | |
pixel_value = tf.reshape(pixel_value, [ | |
-1, pixel_shape[1], self._num_heads, | |
self._total_value_depth // self._num_heads]) | |
# Compute memory space attention. | |
if not self._use_memory_self_attention: | |
# If memory self attention is not used, then only memory2pixel cross | |
# attention is used for the memory space. In this case, the key and the | |
# value are simply pixel_key and pixel_value. | |
memory_attention_key = pixel_key | |
memory_attention_value = pixel_value | |
else: | |
# If we also use memory self attention, the key and the value are the | |
# concatenation of keys and values in both the pixel space and the | |
# memory space. | |
memory_attention_key = tf.concat([pixel_key, memory_key], axis=2) | |
memory_attention_value = tf.concat([pixel_value, memory_value], axis=1) | |
memory_space = self._memory_attention( | |
(memory_query, memory_attention_key, memory_attention_value), | |
training=training) | |
memory_space = self._memory_conv3_bn(memory_space, training=training) | |
if memory_space_attention_drop_path_mask is not None: | |
memory_space = memory_space * memory_space_attention_drop_path_mask | |
memory_space_output = self._activation_fn( | |
memory_space_input + memory_space) | |
# Apply an optional feed-forward network to the memory space. | |
if self._feed_forward_network_channels > 0: | |
memory_space = self._memory_ffn_conv1_bn_act(memory_space_output, | |
training=training) | |
memory_space = self._memory_ffn_conv2_bn(memory_space, | |
training=training) | |
if memory_space_feed_forward_network_drop_path_mask is not None: | |
memory_space = (memory_space * | |
memory_space_feed_forward_network_drop_path_mask) | |
memory_space_output = self._activation_fn( | |
memory_space_output + memory_space) | |
# Compute pixel space attention and the output projection only when | |
# pixel2memory_feedback_attention is used. | |
if self._use_pixel2memory_feedback_attention: | |
pixel_space = self._pixel_attention( | |
(pixel_query, memory_key, memory_value), training=training) | |
pixel_space = self._pixel_conv3_bn(pixel_space, training=training) | |
if pixel_space_drop_path_mask is not None: | |
pixel_space = pixel_space * pixel_space_drop_path_mask | |
pixel_space_output = pixel_space_input + pixel_space | |
else: | |
# If pixel2memory_feedback_attention is not used, the pixel_space_input | |
# is not changed. | |
pixel_space_output = pixel_space_input | |
activated_pixel_space_output = self._activation_fn(pixel_space_output) | |
# Return the pixel space output and memory space output. Note that we | |
# return pixel sapce output with and without the activation function, | |
# because our decoder might use non-activated features. | |
return (pixel_space_output, | |
activated_pixel_space_output, | |
memory_space_output) | |