|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Implements convolutional and attentional residual block groups.""" |
|
|
|
import math |
|
import tensorflow as tf |
|
|
|
from deeplab2.model import utils |
|
from deeplab2.model.layers import activations |
|
from deeplab2.model.layers import axial_blocks |
|
from deeplab2.model.layers import drop_path |
|
from deeplab2.model.layers import dual_path_transformer |
|
from deeplab2.model.layers import positional_encodings |
|
from deeplab2.model.layers import recompute_grad as recompute_grad_lib |
|
|
|
|
|
|
|
|
|
|
|
TRANSFORMER = 'transformer' |
|
|
|
|
|
def _get_current_names(index): |
|
current_name = '_block{}'.format(index + 1) |
|
transformer_current_name = '_block{}_{}'.format(index + 1, TRANSFORMER) |
|
return current_name, transformer_current_name |
|
|
|
|
|
class BlockGroup(tf.keras.layers.Layer): |
|
"""Applies a group of residual blocks with dual path transformer layers [1]. |
|
|
|
An optional dual-path transformer layer is inserted after each residual block. |
|
The transformer layer performs memory2pixel attention, pixel2memory attention, |
|
and memory2memory self-attention, while the standard residual block applies |
|
the pixel2pixel axial-attention, global-attention, or spatial convolution. |
|
|
|
Reference: |
|
[1] MaX-DeepLab: End-to-End Panoptic Segmentation with Mask Transformers, |
|
CVPR 2021. https://arxiv.org/abs/2012.00759 |
|
Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen. |
|
""" |
|
|
|
def __init__(self, |
|
filters, |
|
num_blocks, |
|
name, |
|
original_resnet_stride, |
|
original_resnet_input_stride, |
|
output_stride=16, |
|
backbone_type='resnet_beta', |
|
positional_encoding_type=None, |
|
use_global_beyond_stride=0, |
|
use_axial_beyond_stride=16, |
|
use_transformer_beyond_stride=32, |
|
use_sac_beyond_stride=0, |
|
use_squeeze_and_excite=False, |
|
conv_use_recompute_grad=False, |
|
axial_use_recompute_grad=True, |
|
recompute_within_stride=0, |
|
transformer_use_recompute_grad=False, |
|
transformer_expansion=1, |
|
drop_path_keep_prob=0.8, |
|
drop_path_beyond_stride=16, |
|
drop_path_schedule='constant', |
|
activation='relu', |
|
attention_bottleneck_expansion=2, |
|
axial_layer_config=None, |
|
dual_path_transformer_layer_config=None, |
|
bn_layer=tf.keras.layers.BatchNormalization, |
|
conv_kernel_weight_decay=0.0): |
|
"""Initializes a BlockGroup layer. |
|
|
|
Args: |
|
filters: An integer, the base number of channels for this block group. |
|
num_blocks: An integer, the number of blocks for this block group. |
|
name: A string, the name of the block group. |
|
original_resnet_stride: An integer, the original resnet stride for this |
|
block, usually 1 or 2. The stride will be applied if |
|
original_resnet_input_stride is smaller than the desired output_stride. |
|
Otherwise, the stride will not be applied, and atrous convolution will |
|
be used after the first block. |
|
original_resnet_input_stride: An integer, the total input stride in the |
|
original resnet. For example, the total input stride for the last stage |
|
of the original resnet is 16, and the total output stride is 32. This |
|
stride differs from the true stride of the feature in that we might use |
|
atrous convolution to change both the input and output stride to, e.g. |
|
8, but its original resnet input stride remains the same. In this case, |
|
we also use the original resnet input stride to compute the atrous rate. |
|
output_stride: An integer, the desired output_stride for the ResNet. |
|
backbone_type: A string, the type of the backbone. Supports 'resnet', |
|
'resnet_beta', and 'wider_resnet'. The 'resnet' refers to the original |
|
resnet with a 7x7 convolutional stem. The 'resnet_beta' means a resnet |
|
but with an inception stem. The 'wider_resnet' is a wider variant of |
|
resnet with extensively used 3x3 convolutions. |
|
positional_encoding_type: A string, type of the positional encoding. |
|
Support '2D', '1D', and None. |
|
use_global_beyond_stride: An integer, the stride beyond which we use |
|
global attention. Set to 0 if no global attention is desired. Defaults |
|
to 0, i.e. we do not use global attention. |
|
use_axial_beyond_stride: An integer, the stride beyond which we use axial |
|
attention. Note that use_global_beyond_stride has a higher priority, |
|
i.e. we use global attention if the stride is also beyond |
|
use_global_beyond_stride. Set to 0 if no axial attention is desired. |
|
Defaults to 16 as in MaX-DeepLab. |
|
use_transformer_beyond_stride: An integer, the stride beyond which we use |
|
a transformer layer. Set to 0 if no transformer is desired. Defaults to |
|
32 as in MaX-DeepLab-S. |
|
use_sac_beyond_stride: An integer. Use the Switchable Atrous Convolution |
|
(SAC) beyond the specified stride. For example, if |
|
`use_sac_beyond_stride` = 16, SAC will be applied to the network stage |
|
whose output stride >= 16 (i.e., 16 and 32). Set to 0 or -1 to disable |
|
it. Defaults to 0 as SAC is not used in MaX-DeepLab. |
|
use_squeeze_and_excite: A boolean, whether squeeze-and-excite (SE) is |
|
used. Defaults to False as SE is not used in MaX-DeepLab. |
|
conv_use_recompute_grad: A boolean, whether to use the gradient |
|
checkpointing trick for convolutional blocks. This trick reduces |
|
accelerator memory usage, but takes longer to compute gradients. |
|
Defaults to False since convolutional layers are memory efficient. |
|
axial_use_recompute_grad: A boolean, whether to use the gradient |
|
checkpointing trick for axial blocks. This trick reduces accelerator |
|
memory usage, but takes longer to compute gradients. Defaults to True |
|
since it saves memory for axial blocks. |
|
recompute_within_stride: An integer, the stride within which we use the |
|
gradient checkpointing trick. This trick reduces accelerator memory |
|
usage, but takes longer to compute gradients. Defaults to 0 (do not |
|
recompute any layer). |
|
transformer_use_recompute_grad: A boolean, whether to use the gradient |
|
checkpointing trick for dual-path transformer blocks. This trick reduces |
|
accelerator memory usage, but takes longer to compute gradients. |
|
Defaults to False. |
|
transformer_expansion: An integer, the expansion ratio for the transformer |
|
bottleneck. |
|
drop_path_keep_prob: A float, the keep probability for dropping path. |
|
Defaults to 0.8 as in MaX-DeepLab-S. |
|
drop_path_beyond_stride: An integer, the stride beyond which we apply drop |
|
path augmentation. Defaults to 16 as in MaX-DeepLab-S. |
|
drop_path_schedule: A string, the drop path schedule. Currently, we |
|
support 'constant': use the same drop path keep probability for all |
|
stages, and 'linear': linearly decrease the drop path keep probability |
|
from 1.0 at 0-th stage (or STEM) to `drop_path_keep_prob` at last stage. |
|
activation: A string, type of activation function to apply. Support |
|
'relu', 'swish' (or 'silu'), 'gelu', 'approximated_gelu', and 'elu'. |
|
attention_bottleneck_expansion: An integer, the expansion ratio for |
|
axial attention blocks. |
|
axial_layer_config: A dict, an argument dictionary for the axial layer. |
|
dual_path_transformer_layer_config: A dict, an argument dictionary for the |
|
transformer. |
|
bn_layer: An optional tf.keras.layers.Layer that computes the |
|
normalization (default: tf.keras.layers.BatchNormalization). |
|
conv_kernel_weight_decay: A float, the weight decay for convolution |
|
kernels. |
|
|
|
Raises: |
|
ValueError: If backbone_type is not one of 'resnet', 'resnet_beta', or |
|
'wider_resnet'. |
|
ValueError: original_resnet_input_stride is not power of 2. |
|
ValueError: output_stride is not power of 2. |
|
""" |
|
if original_resnet_input_stride & (original_resnet_input_stride - 1): |
|
raise ValueError('original_resnet_input_stride is not power of 2.') |
|
if output_stride & (output_stride - 1): |
|
raise ValueError('output_stride is not power of 2.') |
|
|
|
super(BlockGroup, self).__init__(name=name) |
|
self._add_absolute_positional_encoding = None |
|
self._activation_fn = activations.get_activation(activation) |
|
self._num_blocks = num_blocks |
|
self._drop_path_keep_prob = [] |
|
self._recompute_grad = [] |
|
self._transformer_use_recompute_grad = transformer_use_recompute_grad |
|
if dual_path_transformer_layer_config is None: |
|
dual_path_transformer_layer_config = {} |
|
original_resnet_current_stride = original_resnet_input_stride |
|
|
|
use_sac = (original_resnet_input_stride * original_resnet_stride >= |
|
use_sac_beyond_stride > 0) |
|
|
|
recompute_grad = (original_resnet_input_stride * original_resnet_stride <= |
|
recompute_within_stride) |
|
|
|
for index in range(num_blocks): |
|
current_name, transformer_current_name = _get_current_names(index) |
|
|
|
|
|
|
|
if index == 0 and original_resnet_input_stride < output_stride: |
|
current_strides = original_resnet_stride |
|
else: |
|
current_strides = 1 |
|
|
|
|
|
if original_resnet_current_stride > output_stride: |
|
atrous_rate = original_resnet_current_stride // output_stride |
|
else: |
|
atrous_rate = 1 |
|
|
|
|
|
if (index == 0 and original_resnet_input_stride * original_resnet_stride > |
|
output_stride): |
|
basic_block_second_conv_atrous_rate = ( |
|
original_resnet_input_stride * original_resnet_stride // |
|
output_stride) |
|
else: |
|
basic_block_second_conv_atrous_rate = atrous_rate |
|
|
|
|
|
current_stage = math.log2(original_resnet_current_stride) - 1 |
|
if original_resnet_current_stride >= drop_path_beyond_stride: |
|
current_drop_path_keep_prob = drop_path.get_drop_path_keep_prob( |
|
drop_path_keep_prob, drop_path_schedule, |
|
current_stage=int(round(current_stage)), |
|
num_stages=4) |
|
else: |
|
current_drop_path_keep_prob = 1.0 |
|
|
|
|
|
if original_resnet_current_stride >= use_global_beyond_stride > 0: |
|
attention_type = 'global' |
|
recompute_grad = axial_use_recompute_grad or recompute_grad |
|
filters_list = [filters * attention_bottleneck_expansion, |
|
filters, |
|
filters * 4] |
|
elif original_resnet_current_stride >= use_axial_beyond_stride > 0: |
|
attention_type = 'axial' |
|
recompute_grad = axial_use_recompute_grad or recompute_grad |
|
filters_list = [filters * attention_bottleneck_expansion, |
|
filters, |
|
filters * 4] |
|
elif backbone_type == 'resnet' or backbone_type == 'resnet_beta': |
|
attention_type = None |
|
recompute_grad = conv_use_recompute_grad or recompute_grad |
|
filters_list = [filters, |
|
filters, |
|
filters * 4] |
|
elif backbone_type == 'wider_resnet': |
|
if original_resnet_input_stride * original_resnet_stride < 32: |
|
|
|
attention_type = None |
|
recompute_grad = conv_use_recompute_grad or recompute_grad |
|
filters_list = [filters * 4, |
|
filters * 4] |
|
else: |
|
|
|
attention_type = None |
|
recompute_grad = conv_use_recompute_grad or recompute_grad |
|
filters_list = [filters, |
|
filters * 2, |
|
filters * 4] |
|
else: |
|
raise ValueError(backbone_type + ' is not supported.') |
|
|
|
self._drop_path_keep_prob.append(current_drop_path_keep_prob) |
|
|
|
|
|
block_fn = axial_blocks.AxialBlock( |
|
filters_list, |
|
kernel_size=3, |
|
strides=current_strides, |
|
atrous_rate=atrous_rate, |
|
use_squeeze_and_excite=use_squeeze_and_excite, |
|
use_sac=use_sac, |
|
bn_layer=bn_layer, |
|
activation=activation, |
|
name=current_name[1:], |
|
conv_kernel_weight_decay=conv_kernel_weight_decay, |
|
basic_block_second_conv_atrous_rate=( |
|
basic_block_second_conv_atrous_rate), |
|
attention_type=attention_type, |
|
axial_layer_config=axial_layer_config) |
|
self._recompute_grad.append(recompute_grad) |
|
utils.safe_setattr(self, current_name, block_fn) |
|
|
|
|
|
if index == 0 and original_resnet_stride > 1: |
|
original_resnet_current_stride *= original_resnet_stride |
|
|
|
|
|
if original_resnet_current_stride == use_global_beyond_stride > 0: |
|
self._add_absolute_positional_encoding = ( |
|
positional_encodings.AddAbsolutePositionalEncoding( |
|
'add_absolute_positional_encoding', |
|
positional_encoding_type, bn_layer, conv_kernel_weight_decay)) |
|
if original_resnet_current_stride >= use_transformer_beyond_stride > 0: |
|
|
|
transformer_block_fn = dual_path_transformer.DualPathTransformerLayer( |
|
name=transformer_current_name[1:], |
|
filters=int(128 * transformer_expansion), |
|
activation=activation, |
|
bn_layer=bn_layer, |
|
conv_kernel_weight_decay=conv_kernel_weight_decay, |
|
**dual_path_transformer_layer_config) |
|
utils.safe_setattr(self, transformer_current_name, transformer_block_fn) |
|
else: |
|
utils.safe_setattr(self, transformer_current_name, None) |
|
|
|
|
|
self._first_building_call = True |
|
|
|
def call(self, inputs, training=False): |
|
"""Performs a forward pass. |
|
|
|
Args: |
|
inputs: two tensors. The first tensor is a pixel_space_input with shape |
|
[batch, height, width, pixel_channels]. The second tensor is |
|
memory_space_input with shape [batch, length, memory_channels]. This |
|
input will be used only if a transformer is used. Otherwise, the input |
|
is returned unmodified. |
|
training: A boolean flag indicating whether training behavior should be |
|
used (default: False). |
|
|
|
Returns: |
|
output: An output [batch, height, width, filters * 4] tensor. |
|
activated_output: An activated output [batch, height, width, filters * 4] |
|
tensor. |
|
memory_space_output: A memory space output [batch, length, |
|
memory_channels] tensor. |
|
""" |
|
|
|
activated_features, memory_space_output = inputs |
|
|
|
|
|
|
|
|
|
float_tensor_training = tf.cast(training, tf.float32) |
|
|
|
for index in range(self._num_blocks): |
|
current_name, transformer_current_name = _get_current_names(index) |
|
block_fn_no_recompute = getattr( |
|
self, current_name) |
|
transformer_block_fn_no_recompute = getattr( |
|
self, transformer_current_name) |
|
current_drop_path_keep_prob = self._drop_path_keep_prob[index] |
|
|
|
|
|
if (self._recompute_grad[index] and training): |
|
|
|
|
|
|
|
|
|
block_fn = recompute_grad_lib.recompute_grad( |
|
block_fn_no_recompute, seed=tf.constant(0, tf.int32)) |
|
else: |
|
block_fn = block_fn_no_recompute |
|
|
|
|
|
block_fn_inputs = [activated_features, float_tensor_training] |
|
|
|
|
|
|
|
|
|
|
|
if current_drop_path_keep_prob < 1.0 and training: |
|
drop_path_random_mask = drop_path.generate_drop_path_random_mask( |
|
activated_features, current_drop_path_keep_prob) |
|
|
|
block_fn_inputs.append(drop_path_random_mask) |
|
|
|
|
|
|
|
if self._first_building_call: |
|
_ = block_fn_no_recompute(tuple(block_fn_inputs)) |
|
|
|
features, activated_features = block_fn(tuple(block_fn_inputs)) |
|
|
|
if index == 0 and self._add_absolute_positional_encoding is not None: |
|
features = self._add_absolute_positional_encoding(features, |
|
training=training) |
|
activated_features = self._activation_fn(features) |
|
|
|
if transformer_block_fn_no_recompute is not None: |
|
|
|
_, height, width, channels = features.get_shape().as_list() |
|
features = tf.reshape( |
|
features, [-1, height * width, channels]) |
|
|
|
|
|
if (self._transformer_use_recompute_grad and training): |
|
|
|
|
|
|
|
|
|
transformer_block_fn = recompute_grad_lib.recompute_grad( |
|
transformer_block_fn_no_recompute, seed=tf.constant(0, tf.int32)) |
|
else: |
|
transformer_block_fn = transformer_block_fn_no_recompute |
|
|
|
transformer_block_fn_input_list = [ |
|
features, memory_space_output, float_tensor_training] |
|
|
|
|
|
|
|
|
|
|
|
|
|
if current_drop_path_keep_prob < 1.0 and training: |
|
|
|
pixel_space_drop_path_mask = drop_path.generate_drop_path_random_mask( |
|
memory_space_output, current_drop_path_keep_prob) |
|
|
|
memory_space_attention_drop_path_mask = ( |
|
drop_path.generate_drop_path_random_mask( |
|
memory_space_output, current_drop_path_keep_prob)) |
|
|
|
memory_space_feed_forward_network_drop_path_mask = ( |
|
drop_path.generate_drop_path_random_mask( |
|
memory_space_output, current_drop_path_keep_prob)) |
|
transformer_block_fn_input_list += [ |
|
pixel_space_drop_path_mask, |
|
memory_space_attention_drop_path_mask, |
|
memory_space_feed_forward_network_drop_path_mask] |
|
|
|
|
|
|
|
|
|
if self._first_building_call: |
|
_ = transformer_block_fn_no_recompute( |
|
tuple(transformer_block_fn_input_list)) |
|
|
|
features, activated_features, memory_space_output = ( |
|
transformer_block_fn(tuple(transformer_block_fn_input_list))) |
|
|
|
|
|
features = tf.reshape(features, [-1, height, width, channels]) |
|
activated_features = tf.reshape(activated_features, |
|
[-1, height, width, channels]) |
|
|
|
self._first_building_call = False |
|
|
|
|
|
return features, activated_features, memory_space_output |
|
|