File size: 13,543 Bytes
0924f30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
# coding=utf-8
# Copyright 2021 The Deeplab2 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Implements Axial-Blocks proposed in Axial-DeepLab [1].

Axial-Blocks are based on residual bottleneck blocks, but with the 3x3
convolution replaced with two axial-attention layers, one on the height-axis,
followed by the other on the width-axis.

[1] Axial-Deeplab: Stand-Alone Axial-Attention for Panoptic Segmentation,
    ECCV 2020 Spotlight.
      Huiyu Wang, Yukun Zhu, Bradley Green, Hartwig Adam, Alan Yuille,
      Liang-Chieh Chen.
"""
import tensorflow as tf

from deeplab2.model import utils
from deeplab2.model.layers import activations
from deeplab2.model.layers import axial_layers
from deeplab2.model.layers import convolutions
from deeplab2.model.layers import squeeze_and_excite


class AxialBlock(tf.keras.layers.Layer):
  """An AxialBlock as a building block for an Axial-ResNet model.

  We implement the Axial-Block proposed in [1] in a general way that also
  includes convolutional residual blocks, such as the basic block and the
  bottleneck block (w/ and w/o Switchable Atrous Convolution).

  A basic block consists of two 3x3 convolutions and a residual connection. It
  is the main building block for wide-resnet variants.

  A bottleneck block consists of consecutive 1x1, 3x3, 1x1 convolutions and a
  residual connection. It is the main building block for standard resnet
  variants.

  An axial block consists of a 1x1 input convolution, a self-attention layer
  (either axial-attention or global attention), a 1x1 output convolution, and a
  residual connection. It is the main building block for axial-resnet variants.

  Note: We apply the striding in the first spatial operation (i.e. 3x3
  convolution or self-attention layer).
  """

  def __init__(self,
               filters_list,
               kernel_size=3,
               strides=1,
               atrous_rate=1,
               use_squeeze_and_excite=False,
               use_sac=False,
               bn_layer=tf.keras.layers.BatchNormalization,
               activation='relu',
               name=None,
               conv_kernel_weight_decay=0.0,
               basic_block_second_conv_atrous_rate=None,
               attention_type=None,
               axial_layer_config=None):
    """Initializes an AxialBlock.

    Args:
      filters_list: A list of filter numbers in the residual block. We currently
        support filters_list with two or three elements. Two elements specify
        the filters for two consecutive 3x3 convolutions, while three elements
        specify the filters for three convolutions (1x1, 3x3, and 1x1).
      kernel_size: The size of the convolution kernels (default: 3).
      strides: The strides of the block (default: 1).
      atrous_rate: The atrous rate of the 3x3 convolutions (default: 1). If this
        residual block is a basic block, it is recommendeded to specify correct
        basic_block_second_conv_atrous_rate for the second 3x3 convolution.
        Otherwise, the second conv will also use atrous rate, which might cause
        atrous inconsistency with different output strides, as tested in
        axial_block_groups_test.test_atrous_consistency_basic_block.
      use_squeeze_and_excite: A boolean flag indicating whether
        squeeze-and-excite (SE) is used.
      use_sac: A boolean, using the Switchable Atrous Convolution (SAC) or not.
      bn_layer: A tf.keras.layers.Layer that computes the normalization
        (default: tf.keras.layers.BatchNormalization).
      activation: A string specifying the activation function to apply.
      name: An string specifying the name of the layer (default: None).
      conv_kernel_weight_decay: A float, the weight decay for convolution
        kernels.
      basic_block_second_conv_atrous_rate: An integer, the atrous rate for the
        second convolution of basic block. This is necessary to ensure atrous
        consistency with different output_strides. Defaults to atrous_rate.
      attention_type: A string, type of attention to apply. Support 'axial' and
        'global'.
      axial_layer_config: A dict, an argument dictionary for the axial layer.

    Raises:
      ValueError: If filters_list does not have two or three elements.
      ValueError: If attention_type is not supported.
      ValueError: If double_global_attention is True in axial_layer_config.
    """
    super(AxialBlock, self).__init__(name=name)

    self._filters_list = filters_list
    self._strides = strides
    self._use_squeeze_and_excite = use_squeeze_and_excite
    self._bn_layer = bn_layer
    self._activate_fn = activations.get_activation(activation)
    self._attention_type = attention_type

    if axial_layer_config is None:
      axial_layer_config = {}

    if basic_block_second_conv_atrous_rate is None:
      basic_block_second_conv_atrous_rate = atrous_rate

    if len(filters_list) == 3:
      # Three consecutive convolutions: 1x1, 3x3, and 1x1.
      self._conv1_bn_act = convolutions.Conv2DSame(
          filters_list[0], 1, 'conv1_bn_act',
          use_bias=False,
          use_bn=True,
          bn_layer=bn_layer,
          activation=activation,
          conv_kernel_weight_decay=conv_kernel_weight_decay)

      if attention_type is None or attention_type.lower() == 'none':
        self._conv2_bn_act = convolutions.Conv2DSame(
            filters_list[1], kernel_size, 'conv2_bn_act',
            strides=strides,
            atrous_rate=atrous_rate,
            use_bias=False,
            use_bn=True,
            bn_layer=bn_layer,
            activation=activation,
            use_switchable_atrous_conv=use_sac,
            # We default to use global context in SAC if use_sac is True. This
            # setting is experimentally found effective.
            use_global_context_in_sac=use_sac,
            conv_kernel_weight_decay=conv_kernel_weight_decay)
      elif attention_type == 'axial':
        if 'double_global_attention' in axial_layer_config:
          if axial_layer_config['double_global_attention']:
            raise ValueError('Double_global_attention takes no effect in '
                             'AxialAttention2D.')
          del axial_layer_config['double_global_attention']
        self._attention = axial_layers.AxialAttention2D(
            strides=strides,
            filters=filters_list[1],
            name='attention',
            bn_layer=bn_layer,
            conv_kernel_weight_decay=conv_kernel_weight_decay,
            **axial_layer_config)
      elif attention_type == 'global':
        self._attention = axial_layers.GlobalAttention2D(
            strides=strides,
            filters=filters_list[1],
            name='attention',
            bn_layer=bn_layer,
            conv_kernel_weight_decay=conv_kernel_weight_decay,
            **axial_layer_config)
      else:
        raise ValueError(attention_type + ' is not supported.')

      # Here we apply a batch norm with gamma initialized at zero. This ensures
      # that at random initialization of the model, the skip connections
      # dominate all residual blocks. In this way, all the skip connections
      # construct an identity mapping that passes the gradients (without any
      # distortion from the randomly initialized blocks) to all residual blocks.
      # This trick helps training at early epochs.
      # Reference: "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour".
      # https://arxiv.org/abs/1706.02677
      self._conv3_bn = convolutions.Conv2DSame(
          filters_list[2], 1, 'conv3_bn',
          use_bias=False,
          use_bn=True,
          bn_layer=bn_layer,
          bn_gamma_initializer='zeros',
          activation='none',
          conv_kernel_weight_decay=conv_kernel_weight_decay)
    elif len(filters_list) == 2:
      # Two consecutive convolutions: 3x3 and 3x3.
      self._conv1_bn_act = convolutions.Conv2DSame(
          filters_list[0], kernel_size, 'conv1_bn_act',
          strides=strides,
          atrous_rate=atrous_rate,
          use_bias=False,
          use_bn=True,
          bn_layer=bn_layer,
          activation=activation,
          use_switchable_atrous_conv=use_sac,
          use_global_context_in_sac=use_sac,
          conv_kernel_weight_decay=conv_kernel_weight_decay)
      # Here we apply a batch norm with gamma initialized at zero. This ensures
      # that at random initialization of the model, the skip connections
      # dominate all residual blocks. In this way, all the skip connections
      # construct an identity mapping that passes the gradients (without any
      # distortion from the randomly initialized blocks) to all residual blocks.
      # This trick helps training at early epochs.
      # Reference: "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour".
      # https://arxiv.org/abs/1706.02677
      self._conv2_bn = convolutions.Conv2DSame(
          filters_list[1], kernel_size, 'conv2_bn',
          strides=1,
          atrous_rate=basic_block_second_conv_atrous_rate,
          use_bias=False,
          use_bn=True,
          bn_layer=bn_layer,
          bn_gamma_initializer='zeros',
          activation='none',
          use_switchable_atrous_conv=use_sac,
          use_global_context_in_sac=use_sac,
          conv_kernel_weight_decay=conv_kernel_weight_decay)
    else:
      raise ValueError('Expect filters_list to have length 2 or 3; got %d' %
                       len(filters_list))

    if self._use_squeeze_and_excite:
      self._squeeze_and_excite = squeeze_and_excite.SimplifiedSqueezeAndExcite(
          filters_list[-1])
    self._conv_kernel_weight_decay = conv_kernel_weight_decay

  def build(self, input_shape_list):
    input_tensor_shape = input_shape_list[0]
    self._shortcut = None
    if input_tensor_shape[3] != self._filters_list[-1]:
      self._shortcut = convolutions.Conv2DSame(
          self._filters_list[-1], 1, 'shortcut',
          strides=self._strides,
          use_bias=False,
          use_bn=True,
          bn_layer=self._bn_layer,
          activation='none',
          conv_kernel_weight_decay=self._conv_kernel_weight_decay)

  def call(self, inputs):
    """Performs a forward pass.

    We have to define drop_path_random_mask outside the layer call and pass it
    into the layer, because recompute_grad (gradient checkpointing) does not
    allow any randomness within the function call. In addition, recompute_grad
    only supports float tensors as inputs. For this reason, the training flag
    should be also passed as a float tensor. For the same reason, we cannot
    support passing drop_path_random_mask as None. Instead, we ask the users to
    pass only the first two tensors when drop path is not used.

    Args:
      inputs: A tuple of 2 or 3 tensors, containing
        input_tensor should be an input tensor of type tf.Tensor with shape
          [batch, height, width, channels].
        float_tensor_training should be a float tensor of 0.0 or 1.0, whether
          the model is in training mode.
        (optional) drop_path_random_mask is a drop path random mask of type
          tf.Tensor with shape [batch, 1, 1, 1].

    Returns:
      outputs: two tensors. The first tensor does not use the last activation
        function. The second tensor uses the activation. We return non-activated
        output to support MaX-DeepLab which uses non-activated feature for the
        stacked decoders.

    Raises:
      ValueError: If the length of inputs is not 2 or 3.
    """
    if len(inputs) not in (2, 3):
      raise ValueError('The length of inputs should be either 2 or 3.')

    # Unpack the inputs.
    input_tensor, float_tensor_training, drop_path_random_mask = (
        utils.pad_sequence_with_none(inputs, target_length=3))

    # Recompute_grad takes only float tensors as inputs. It does not allow
    # bools or boolean tensors. For this reason, we cast training to a float
    # tensor outside this call, and now we cast it back to a boolean tensor.
    training = tf.cast(float_tensor_training, tf.bool)

    shortcut = input_tensor
    if self._shortcut is not None:
      shortcut = self._shortcut(shortcut, training=training)
    elif self._strides != 1:
      shortcut = shortcut[:, ::self._strides, ::self._strides, :]

    if len(self._filters_list) == 3:
      x = self._conv1_bn_act(input_tensor, training=training)
      if (self._attention_type is None or
          self._attention_type.lower() == 'none'):
        x = self._conv2_bn_act(x, training=training)
      else:
        x = self._attention(x, training=training)
        x = self._activate_fn(x)
      x = self._conv3_bn(x, training=training)
    if len(self._filters_list) == 2:
      x = self._conv1_bn_act(input_tensor, training=training)
      x = self._conv2_bn(x, training=training)

    if self._use_squeeze_and_excite:
      x = self._squeeze_and_excite(x)

    if drop_path_random_mask is not None:
      x = x * drop_path_random_mask
    x = x + shortcut
    return x, self._activate_fn(x)