Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2021 The Deeplab2 Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Implements Axial-Blocks proposed in Axial-DeepLab [1]. | |
Axial-Blocks are based on residual bottleneck blocks, but with the 3x3 | |
convolution replaced with two axial-attention layers, one on the height-axis, | |
followed by the other on the width-axis. | |
[1] Axial-Deeplab: Stand-Alone Axial-Attention for Panoptic Segmentation, | |
ECCV 2020 Spotlight. | |
Huiyu Wang, Yukun Zhu, Bradley Green, Hartwig Adam, Alan Yuille, | |
Liang-Chieh Chen. | |
""" | |
import tensorflow as tf | |
from deeplab2.model import utils | |
from deeplab2.model.layers import activations | |
from deeplab2.model.layers import axial_layers | |
from deeplab2.model.layers import convolutions | |
from deeplab2.model.layers import squeeze_and_excite | |
class AxialBlock(tf.keras.layers.Layer): | |
"""An AxialBlock as a building block for an Axial-ResNet model. | |
We implement the Axial-Block proposed in [1] in a general way that also | |
includes convolutional residual blocks, such as the basic block and the | |
bottleneck block (w/ and w/o Switchable Atrous Convolution). | |
A basic block consists of two 3x3 convolutions and a residual connection. It | |
is the main building block for wide-resnet variants. | |
A bottleneck block consists of consecutive 1x1, 3x3, 1x1 convolutions and a | |
residual connection. It is the main building block for standard resnet | |
variants. | |
An axial block consists of a 1x1 input convolution, a self-attention layer | |
(either axial-attention or global attention), a 1x1 output convolution, and a | |
residual connection. It is the main building block for axial-resnet variants. | |
Note: We apply the striding in the first spatial operation (i.e. 3x3 | |
convolution or self-attention layer). | |
""" | |
def __init__(self, | |
filters_list, | |
kernel_size=3, | |
strides=1, | |
atrous_rate=1, | |
use_squeeze_and_excite=False, | |
use_sac=False, | |
bn_layer=tf.keras.layers.BatchNormalization, | |
activation='relu', | |
name=None, | |
conv_kernel_weight_decay=0.0, | |
basic_block_second_conv_atrous_rate=None, | |
attention_type=None, | |
axial_layer_config=None): | |
"""Initializes an AxialBlock. | |
Args: | |
filters_list: A list of filter numbers in the residual block. We currently | |
support filters_list with two or three elements. Two elements specify | |
the filters for two consecutive 3x3 convolutions, while three elements | |
specify the filters for three convolutions (1x1, 3x3, and 1x1). | |
kernel_size: The size of the convolution kernels (default: 3). | |
strides: The strides of the block (default: 1). | |
atrous_rate: The atrous rate of the 3x3 convolutions (default: 1). If this | |
residual block is a basic block, it is recommendeded to specify correct | |
basic_block_second_conv_atrous_rate for the second 3x3 convolution. | |
Otherwise, the second conv will also use atrous rate, which might cause | |
atrous inconsistency with different output strides, as tested in | |
axial_block_groups_test.test_atrous_consistency_basic_block. | |
use_squeeze_and_excite: A boolean flag indicating whether | |
squeeze-and-excite (SE) is used. | |
use_sac: A boolean, using the Switchable Atrous Convolution (SAC) or not. | |
bn_layer: A tf.keras.layers.Layer that computes the normalization | |
(default: tf.keras.layers.BatchNormalization). | |
activation: A string specifying the activation function to apply. | |
name: An string specifying the name of the layer (default: None). | |
conv_kernel_weight_decay: A float, the weight decay for convolution | |
kernels. | |
basic_block_second_conv_atrous_rate: An integer, the atrous rate for the | |
second convolution of basic block. This is necessary to ensure atrous | |
consistency with different output_strides. Defaults to atrous_rate. | |
attention_type: A string, type of attention to apply. Support 'axial' and | |
'global'. | |
axial_layer_config: A dict, an argument dictionary for the axial layer. | |
Raises: | |
ValueError: If filters_list does not have two or three elements. | |
ValueError: If attention_type is not supported. | |
ValueError: If double_global_attention is True in axial_layer_config. | |
""" | |
super(AxialBlock, self).__init__(name=name) | |
self._filters_list = filters_list | |
self._strides = strides | |
self._use_squeeze_and_excite = use_squeeze_and_excite | |
self._bn_layer = bn_layer | |
self._activate_fn = activations.get_activation(activation) | |
self._attention_type = attention_type | |
if axial_layer_config is None: | |
axial_layer_config = {} | |
if basic_block_second_conv_atrous_rate is None: | |
basic_block_second_conv_atrous_rate = atrous_rate | |
if len(filters_list) == 3: | |
# Three consecutive convolutions: 1x1, 3x3, and 1x1. | |
self._conv1_bn_act = convolutions.Conv2DSame( | |
filters_list[0], 1, 'conv1_bn_act', | |
use_bias=False, | |
use_bn=True, | |
bn_layer=bn_layer, | |
activation=activation, | |
conv_kernel_weight_decay=conv_kernel_weight_decay) | |
if attention_type is None or attention_type.lower() == 'none': | |
self._conv2_bn_act = convolutions.Conv2DSame( | |
filters_list[1], kernel_size, 'conv2_bn_act', | |
strides=strides, | |
atrous_rate=atrous_rate, | |
use_bias=False, | |
use_bn=True, | |
bn_layer=bn_layer, | |
activation=activation, | |
use_switchable_atrous_conv=use_sac, | |
# We default to use global context in SAC if use_sac is True. This | |
# setting is experimentally found effective. | |
use_global_context_in_sac=use_sac, | |
conv_kernel_weight_decay=conv_kernel_weight_decay) | |
elif attention_type == 'axial': | |
if 'double_global_attention' in axial_layer_config: | |
if axial_layer_config['double_global_attention']: | |
raise ValueError('Double_global_attention takes no effect in ' | |
'AxialAttention2D.') | |
del axial_layer_config['double_global_attention'] | |
self._attention = axial_layers.AxialAttention2D( | |
strides=strides, | |
filters=filters_list[1], | |
name='attention', | |
bn_layer=bn_layer, | |
conv_kernel_weight_decay=conv_kernel_weight_decay, | |
**axial_layer_config) | |
elif attention_type == 'global': | |
self._attention = axial_layers.GlobalAttention2D( | |
strides=strides, | |
filters=filters_list[1], | |
name='attention', | |
bn_layer=bn_layer, | |
conv_kernel_weight_decay=conv_kernel_weight_decay, | |
**axial_layer_config) | |
else: | |
raise ValueError(attention_type + ' is not supported.') | |
# Here we apply a batch norm with gamma initialized at zero. This ensures | |
# that at random initialization of the model, the skip connections | |
# dominate all residual blocks. In this way, all the skip connections | |
# construct an identity mapping that passes the gradients (without any | |
# distortion from the randomly initialized blocks) to all residual blocks. | |
# This trick helps training at early epochs. | |
# Reference: "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour". | |
# https://arxiv.org/abs/1706.02677 | |
self._conv3_bn = convolutions.Conv2DSame( | |
filters_list[2], 1, 'conv3_bn', | |
use_bias=False, | |
use_bn=True, | |
bn_layer=bn_layer, | |
bn_gamma_initializer='zeros', | |
activation='none', | |
conv_kernel_weight_decay=conv_kernel_weight_decay) | |
elif len(filters_list) == 2: | |
# Two consecutive convolutions: 3x3 and 3x3. | |
self._conv1_bn_act = convolutions.Conv2DSame( | |
filters_list[0], kernel_size, 'conv1_bn_act', | |
strides=strides, | |
atrous_rate=atrous_rate, | |
use_bias=False, | |
use_bn=True, | |
bn_layer=bn_layer, | |
activation=activation, | |
use_switchable_atrous_conv=use_sac, | |
use_global_context_in_sac=use_sac, | |
conv_kernel_weight_decay=conv_kernel_weight_decay) | |
# Here we apply a batch norm with gamma initialized at zero. This ensures | |
# that at random initialization of the model, the skip connections | |
# dominate all residual blocks. In this way, all the skip connections | |
# construct an identity mapping that passes the gradients (without any | |
# distortion from the randomly initialized blocks) to all residual blocks. | |
# This trick helps training at early epochs. | |
# Reference: "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour". | |
# https://arxiv.org/abs/1706.02677 | |
self._conv2_bn = convolutions.Conv2DSame( | |
filters_list[1], kernel_size, 'conv2_bn', | |
strides=1, | |
atrous_rate=basic_block_second_conv_atrous_rate, | |
use_bias=False, | |
use_bn=True, | |
bn_layer=bn_layer, | |
bn_gamma_initializer='zeros', | |
activation='none', | |
use_switchable_atrous_conv=use_sac, | |
use_global_context_in_sac=use_sac, | |
conv_kernel_weight_decay=conv_kernel_weight_decay) | |
else: | |
raise ValueError('Expect filters_list to have length 2 or 3; got %d' % | |
len(filters_list)) | |
if self._use_squeeze_and_excite: | |
self._squeeze_and_excite = squeeze_and_excite.SimplifiedSqueezeAndExcite( | |
filters_list[-1]) | |
self._conv_kernel_weight_decay = conv_kernel_weight_decay | |
def build(self, input_shape_list): | |
input_tensor_shape = input_shape_list[0] | |
self._shortcut = None | |
if input_tensor_shape[3] != self._filters_list[-1]: | |
self._shortcut = convolutions.Conv2DSame( | |
self._filters_list[-1], 1, 'shortcut', | |
strides=self._strides, | |
use_bias=False, | |
use_bn=True, | |
bn_layer=self._bn_layer, | |
activation='none', | |
conv_kernel_weight_decay=self._conv_kernel_weight_decay) | |
def call(self, inputs): | |
"""Performs a forward pass. | |
We have to define drop_path_random_mask outside the layer call and pass it | |
into the layer, because recompute_grad (gradient checkpointing) does not | |
allow any randomness within the function call. In addition, recompute_grad | |
only supports float tensors as inputs. For this reason, the training flag | |
should be also passed as a float tensor. For the same reason, we cannot | |
support passing drop_path_random_mask as None. Instead, we ask the users to | |
pass only the first two tensors when drop path is not used. | |
Args: | |
inputs: A tuple of 2 or 3 tensors, containing | |
input_tensor should be an input tensor of type tf.Tensor with shape | |
[batch, height, width, channels]. | |
float_tensor_training should be a float tensor of 0.0 or 1.0, whether | |
the model is in training mode. | |
(optional) drop_path_random_mask is a drop path random mask of type | |
tf.Tensor with shape [batch, 1, 1, 1]. | |
Returns: | |
outputs: two tensors. The first tensor does not use the last activation | |
function. The second tensor uses the activation. We return non-activated | |
output to support MaX-DeepLab which uses non-activated feature for the | |
stacked decoders. | |
Raises: | |
ValueError: If the length of inputs is not 2 or 3. | |
""" | |
if len(inputs) not in (2, 3): | |
raise ValueError('The length of inputs should be either 2 or 3.') | |
# Unpack the inputs. | |
input_tensor, float_tensor_training, drop_path_random_mask = ( | |
utils.pad_sequence_with_none(inputs, target_length=3)) | |
# Recompute_grad takes only float tensors as inputs. It does not allow | |
# bools or boolean tensors. For this reason, we cast training to a float | |
# tensor outside this call, and now we cast it back to a boolean tensor. | |
training = tf.cast(float_tensor_training, tf.bool) | |
shortcut = input_tensor | |
if self._shortcut is not None: | |
shortcut = self._shortcut(shortcut, training=training) | |
elif self._strides != 1: | |
shortcut = shortcut[:, ::self._strides, ::self._strides, :] | |
if len(self._filters_list) == 3: | |
x = self._conv1_bn_act(input_tensor, training=training) | |
if (self._attention_type is None or | |
self._attention_type.lower() == 'none'): | |
x = self._conv2_bn_act(x, training=training) | |
else: | |
x = self._attention(x, training=training) | |
x = self._activate_fn(x) | |
x = self._conv3_bn(x, training=training) | |
if len(self._filters_list) == 2: | |
x = self._conv1_bn_act(input_tensor, training=training) | |
x = self._conv2_bn(x, training=training) | |
if self._use_squeeze_and_excite: | |
x = self._squeeze_and_excite(x) | |
if drop_path_random_mask is not None: | |
x = x * drop_path_random_mask | |
x = x + shortcut | |
return x, self._activate_fn(x) | |