File size: 22,441 Bytes
d1843be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
# coding=utf-8
# Copyright 2021 The Deeplab2 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Implements convolutional and attentional residual block groups."""

import math
import tensorflow as tf

from deeplab2.model import utils
from deeplab2.model.layers import activations
from deeplab2.model.layers import axial_blocks
from deeplab2.model.layers import drop_path
from deeplab2.model.layers import dual_path_transformer
from deeplab2.model.layers import positional_encodings
from deeplab2.model.layers import recompute_grad as recompute_grad_lib

# We will apply 10x larger learning rates on transformer layers. This global
# variable name will be accessed when we build the optimizers. This keyword is
# reserved and should not be a part of the variable names in a classification
# pretrained backbone.
TRANSFORMER = 'transformer'


def _get_current_names(index):
  current_name = '_block{}'.format(index + 1)
  transformer_current_name = '_block{}_{}'.format(index + 1, TRANSFORMER)
  return current_name, transformer_current_name


class BlockGroup(tf.keras.layers.Layer):
  """Applies a group of residual blocks with dual path transformer layers [1].

  An optional dual-path transformer layer is inserted after each residual block.
  The transformer layer performs memory2pixel attention, pixel2memory attention,
  and memory2memory self-attention, while the standard residual block applies
  the pixel2pixel axial-attention, global-attention, or spatial convolution.

  Reference:
  [1] MaX-DeepLab: End-to-End Panoptic Segmentation with Mask Transformers,
      CVPR 2021. https://arxiv.org/abs/2012.00759
        Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen.
  """

  def __init__(self,
               filters,
               num_blocks,
               name,
               original_resnet_stride,
               original_resnet_input_stride,
               output_stride=16,
               backbone_type='resnet_beta',
               positional_encoding_type=None,
               use_global_beyond_stride=0,
               use_axial_beyond_stride=16,
               use_transformer_beyond_stride=32,
               use_sac_beyond_stride=0,
               use_squeeze_and_excite=False,
               conv_use_recompute_grad=False,
               axial_use_recompute_grad=True,
               recompute_within_stride=0,
               transformer_use_recompute_grad=False,
               transformer_expansion=1,
               drop_path_keep_prob=0.8,
               drop_path_beyond_stride=16,
               drop_path_schedule='constant',
               activation='relu',
               attention_bottleneck_expansion=2,
               axial_layer_config=None,
               dual_path_transformer_layer_config=None,
               bn_layer=tf.keras.layers.BatchNormalization,
               conv_kernel_weight_decay=0.0):
    """Initializes a BlockGroup layer.

    Args:
      filters: An integer, the base number of channels for this block group.
      num_blocks: An integer, the number of blocks for this block group.
      name: A string, the name of the block group.
      original_resnet_stride: An integer, the original resnet stride for this
        block, usually 1 or 2. The stride will be applied if
        original_resnet_input_stride is smaller than the desired output_stride.
        Otherwise, the stride will not be applied, and atrous convolution will
        be used after the first block.
      original_resnet_input_stride: An integer, the total input stride in the
        original resnet. For example, the total input stride for the last stage
        of the original resnet is 16, and the total output stride is 32. This
        stride differs from the true stride of the feature in that we might use
        atrous convolution to change both the input and output stride to, e.g.
        8, but its original resnet input stride remains the same. In this case,
        we also use the original resnet input stride to compute the atrous rate.
      output_stride: An integer, the desired output_stride for the ResNet.
      backbone_type: A string, the type of the backbone. Supports 'resnet',
        'resnet_beta', and 'wider_resnet'. The 'resnet' refers to the original
        resnet with a 7x7 convolutional stem. The 'resnet_beta' means a resnet
        but with an inception stem. The 'wider_resnet' is a wider variant of
        resnet with extensively used 3x3 convolutions.
      positional_encoding_type: A string, type of the positional encoding.
        Support '2D', '1D', and None.
      use_global_beyond_stride: An integer, the stride beyond which we use
        global attention. Set to 0 if no global attention is desired. Defaults
        to 0, i.e. we do not use global attention.
      use_axial_beyond_stride: An integer, the stride beyond which we use axial
        attention. Note that use_global_beyond_stride has a higher priority,
        i.e. we use global attention if the stride is also beyond
        use_global_beyond_stride. Set to 0 if no axial attention is desired.
        Defaults to 16 as in MaX-DeepLab.
      use_transformer_beyond_stride: An integer, the stride beyond which we use
        a transformer layer. Set to 0 if no transformer is desired. Defaults to
        32 as in MaX-DeepLab-S.
      use_sac_beyond_stride: An integer. Use the Switchable Atrous Convolution
        (SAC) beyond the specified stride. For example, if
        `use_sac_beyond_stride` = 16, SAC will be applied to the network stage
        whose output stride >= 16 (i.e., 16 and 32). Set to 0 or -1 to disable
        it. Defaults to 0 as SAC is not used in MaX-DeepLab.
      use_squeeze_and_excite: A boolean, whether squeeze-and-excite (SE) is
        used. Defaults to False as SE is not used in MaX-DeepLab.
      conv_use_recompute_grad: A boolean, whether to use the gradient
        checkpointing trick for convolutional blocks. This trick reduces
        accelerator memory usage, but takes longer to compute gradients.
        Defaults to False since convolutional layers are memory efficient.
      axial_use_recompute_grad: A boolean, whether to use the gradient
        checkpointing trick for axial blocks. This trick reduces accelerator
        memory usage, but takes longer to compute gradients. Defaults to True
        since it saves memory for axial blocks.
      recompute_within_stride: An integer, the stride within which we use the
        gradient checkpointing trick. This trick reduces accelerator memory
        usage, but takes longer to compute gradients. Defaults to 0 (do not
        recompute any layer).
      transformer_use_recompute_grad: A boolean, whether to use the gradient
        checkpointing trick for dual-path transformer blocks. This trick reduces
        accelerator memory usage, but takes longer to compute gradients.
        Defaults to False.
      transformer_expansion: An integer, the expansion ratio for the transformer
        bottleneck.
      drop_path_keep_prob: A float, the keep probability for dropping path.
        Defaults to 0.8 as in MaX-DeepLab-S.
      drop_path_beyond_stride: An integer, the stride beyond which we apply drop
        path augmentation. Defaults to 16 as in MaX-DeepLab-S.
      drop_path_schedule: A string, the drop path schedule. Currently, we
        support 'constant': use the same drop path keep probability for all
        stages, and 'linear': linearly decrease the drop path keep probability
        from 1.0 at 0-th stage (or STEM) to `drop_path_keep_prob` at last stage.
      activation: A string, type of activation function to apply. Support
        'relu', 'swish' (or 'silu'), 'gelu', 'approximated_gelu', and 'elu'.
      attention_bottleneck_expansion: An integer, the expansion ratio for
        axial attention blocks.
      axial_layer_config: A dict, an argument dictionary for the axial layer.
      dual_path_transformer_layer_config: A dict, an argument dictionary for the
        transformer.
      bn_layer: An optional tf.keras.layers.Layer that computes the
        normalization (default: tf.keras.layers.BatchNormalization).
      conv_kernel_weight_decay: A float, the weight decay for convolution
        kernels.

    Raises:
      ValueError: If backbone_type is not one of 'resnet', 'resnet_beta', or
        'wider_resnet'.
      ValueError: original_resnet_input_stride is not power of 2.
      ValueError: output_stride is not power of 2.
    """
    if original_resnet_input_stride & (original_resnet_input_stride - 1):
      raise ValueError('original_resnet_input_stride is not power of 2.')
    if output_stride & (output_stride - 1):
      raise ValueError('output_stride is not power of 2.')

    super(BlockGroup, self).__init__(name=name)
    self._add_absolute_positional_encoding = None
    self._activation_fn = activations.get_activation(activation)
    self._num_blocks = num_blocks
    self._drop_path_keep_prob = []
    self._recompute_grad = []
    self._transformer_use_recompute_grad = transformer_use_recompute_grad
    if dual_path_transformer_layer_config is None:
      dual_path_transformer_layer_config = {}
    original_resnet_current_stride = original_resnet_input_stride

    use_sac = (original_resnet_input_stride * original_resnet_stride >=
               use_sac_beyond_stride > 0)

    recompute_grad = (original_resnet_input_stride * original_resnet_stride <=
                      recompute_within_stride)

    for index in range(num_blocks):
      current_name, transformer_current_name = _get_current_names(index)

      # Compute the current strides. If there is a stride for this block group,
      # we do it in the first residual block.
      if index == 0 and original_resnet_input_stride < output_stride:
        current_strides = original_resnet_stride
      else:
        current_strides = 1

      # Compute the current atrous rate.
      if original_resnet_current_stride > output_stride:
        atrous_rate = original_resnet_current_stride // output_stride
      else:
        atrous_rate = 1

      # Compute the atrous rate for the second conv in the first basic block.
      if (index == 0 and original_resnet_input_stride * original_resnet_stride >
          output_stride):
        basic_block_second_conv_atrous_rate = (
            original_resnet_input_stride * original_resnet_stride //
            output_stride)
      else:
        basic_block_second_conv_atrous_rate = atrous_rate

      # Compute the current drop_path_keep_prob.
      current_stage = math.log2(original_resnet_current_stride) - 1
      if original_resnet_current_stride >= drop_path_beyond_stride:
        current_drop_path_keep_prob = drop_path.get_drop_path_keep_prob(
            drop_path_keep_prob, drop_path_schedule,
            current_stage=int(round(current_stage)),
            num_stages=4)
      else:
        current_drop_path_keep_prob = 1.0

      # Compute which block_fn to use for this residual block.
      if original_resnet_current_stride >= use_global_beyond_stride > 0:
        attention_type = 'global'
        recompute_grad = axial_use_recompute_grad or recompute_grad
        filters_list = [filters * attention_bottleneck_expansion,
                        filters,
                        filters * 4]
      elif original_resnet_current_stride >= use_axial_beyond_stride > 0:
        attention_type = 'axial'
        recompute_grad = axial_use_recompute_grad or recompute_grad
        filters_list = [filters * attention_bottleneck_expansion,
                        filters,
                        filters * 4]
      elif backbone_type == 'resnet' or backbone_type == 'resnet_beta':
        attention_type = None
        recompute_grad = conv_use_recompute_grad or recompute_grad
        filters_list = [filters,
                        filters,
                        filters * 4]
      elif backbone_type == 'wider_resnet':
        if original_resnet_input_stride * original_resnet_stride < 32:
          # Wider-ResNet uses conv basic blocks except the last stage.
          attention_type = None
          recompute_grad = conv_use_recompute_grad or recompute_grad
          filters_list = [filters * 4,
                          filters * 4]
        else:
          # Wider-ResNet uses an expanded bottleneck block in the last stage.
          attention_type = None
          recompute_grad = conv_use_recompute_grad or recompute_grad
          filters_list = [filters,
                          filters * 2,
                          filters * 4]
      else:
        raise ValueError(backbone_type + ' is not supported.')

      self._drop_path_keep_prob.append(current_drop_path_keep_prob)
      # Apply the residual block.
      # The inputs to block_fn should be activated features.
      block_fn = axial_blocks.AxialBlock(
          filters_list,
          kernel_size=3,
          strides=current_strides,
          atrous_rate=atrous_rate,
          use_squeeze_and_excite=use_squeeze_and_excite,
          use_sac=use_sac,
          bn_layer=bn_layer,
          activation=activation,
          name=current_name[1:],
          conv_kernel_weight_decay=conv_kernel_weight_decay,
          basic_block_second_conv_atrous_rate=(
              basic_block_second_conv_atrous_rate),
          attention_type=attention_type,
          axial_layer_config=axial_layer_config)
      self._recompute_grad.append(recompute_grad)
      utils.safe_setattr(self, current_name, block_fn)

      # Modify the original_resnet_stride according to the strides.
      if index == 0 and original_resnet_stride > 1:
        original_resnet_current_stride *= original_resnet_stride
        # Add absolute positional encoding if we will apply global attention
        # beyond this stride.
        if original_resnet_current_stride == use_global_beyond_stride > 0:
          self._add_absolute_positional_encoding = (
              positional_encodings.AddAbsolutePositionalEncoding(
                  'add_absolute_positional_encoding',
                  positional_encoding_type, bn_layer, conv_kernel_weight_decay))
      if original_resnet_current_stride >= use_transformer_beyond_stride > 0:
        # Apply a dual-path transformer.
        transformer_block_fn = dual_path_transformer.DualPathTransformerLayer(
            name=transformer_current_name[1:],
            filters=int(128 * transformer_expansion),
            activation=activation,
            bn_layer=bn_layer,
            conv_kernel_weight_decay=conv_kernel_weight_decay,
            **dual_path_transformer_layer_config)
        utils.safe_setattr(self, transformer_current_name, transformer_block_fn)
      else:
        utils.safe_setattr(self, transformer_current_name, None)
    # Avoid using recompute_grad for the first call that builds the sub-layers.
    # Otherwise, recompute_grad will not track newly built model parameters.
    self._first_building_call = True

  def call(self, inputs, training=False):
    """Performs a forward pass.

    Args:
      inputs: two tensors. The first tensor is a pixel_space_input with shape
        [batch, height, width, pixel_channels]. The second tensor is
        memory_space_input with shape [batch, length, memory_channels]. This
        input will be used only if a transformer is used. Otherwise, the input
        is returned unmodified.
      training: A boolean flag indicating whether training behavior should be
        used (default: False).

    Returns:
      output: An output [batch, height, width, filters * 4] tensor.
      activated_output: An activated output [batch, height, width, filters * 4]
        tensor.
      memory_space_output: A memory space output [batch, length,
        memory_channels] tensor.
    """
    # The pixel space inputs are activated features.
    activated_features, memory_space_output = inputs

    # Recompute_grad takes only float tensors as inputs. It does not allow
    # bools or boolean tensors. For this reason, we cast training to a float
    # tensor and cast it back after we go through the recompute_grad wrap.
    float_tensor_training = tf.cast(training, tf.float32)

    for index in range(self._num_blocks):
      current_name, transformer_current_name = _get_current_names(index)
      block_fn_no_recompute = getattr(
          self, current_name)
      transformer_block_fn_no_recompute = getattr(
          self, transformer_current_name)
      current_drop_path_keep_prob = self._drop_path_keep_prob[index]

      # Wrap the layer if we want to recompute it in the backward pass.
      if (self._recompute_grad[index] and training):
        # The seed is not actually used since we do not have any random
        # operation in the recomputed function. The purpose of the provided seed
        # is to prevent recompute_grad from generating a new seed variable which
        # is not compatible with model exporting.
        block_fn = recompute_grad_lib.recompute_grad(
            block_fn_no_recompute, seed=tf.constant(0, tf.int32))
      else:
        block_fn = block_fn_no_recompute

      # The inputs to block_fn should be activated features.
      block_fn_inputs = [activated_features, float_tensor_training]
      # We have to define drop_path_masks outside the layer call and pass it
      # into the layer, because tf.recompute_grad (gradient checkpointing) does
      # not allow any randomness within the function call. In addition,
      # recompute_grad functions can only take Tensors as inputs, so we do not
      # pass the drop_path_random_mask (when it is None) into block_fn.
      if current_drop_path_keep_prob < 1.0 and training:
        drop_path_random_mask = drop_path.generate_drop_path_random_mask(
            activated_features, current_drop_path_keep_prob)

        block_fn_inputs.append(drop_path_random_mask)

      # Build the sub-layers when the block_fn is called for the first time.
      # Otherwise, recompute_grad will not track newly built model parameters.
      if self._first_building_call:
        _ = block_fn_no_recompute(tuple(block_fn_inputs))
      # Apply the residual block.
      features, activated_features = block_fn(tuple(block_fn_inputs))

      if index == 0 and self._add_absolute_positional_encoding is not None:
        features = self._add_absolute_positional_encoding(features,
                                                          training=training)
        activated_features = self._activation_fn(features)

      if transformer_block_fn_no_recompute is not None:
        # Reshape pixel space features from 4D to 3D.
        _, height, width, channels = features.get_shape().as_list()
        features = tf.reshape(
            features, [-1, height * width, channels])

        # Wrap the layer if we want to recompute it in the backward pass.
        if (self._transformer_use_recompute_grad and training):
          # The seed is not actually used since we do not have any random
          # operation in the recomputed function. The purpose of the provided
          # seed is to prevent recompute_grad from generating a new seed
          # variable which is not compatible with model exporting.
          transformer_block_fn = recompute_grad_lib.recompute_grad(
              transformer_block_fn_no_recompute, seed=tf.constant(0, tf.int32))
        else:
          transformer_block_fn = transformer_block_fn_no_recompute

        transformer_block_fn_input_list = [
            features, memory_space_output, float_tensor_training]
        # We have to define drop_path_masks outside the layer call and pass it
        # into the layer, because recompute_grad (gradient checkpointing) does
        # not allow any randomness within the function call. In addition,
        # recompute_grad functions can only take Tensors as inputs, so we do not
        # pass the drop_path_masks (when they are None) into
        # transformer_block_fn.
        if current_drop_path_keep_prob < 1.0 and training:
          # Drop path random mask for pixel space attention.
          pixel_space_drop_path_mask = drop_path.generate_drop_path_random_mask(
              memory_space_output, current_drop_path_keep_prob)
          # Drop path random mask for memory space attention.
          memory_space_attention_drop_path_mask = (
              drop_path.generate_drop_path_random_mask(
                  memory_space_output, current_drop_path_keep_prob))
          # Drop path random mask for memory space feed-forward network.
          memory_space_feed_forward_network_drop_path_mask = (
              drop_path.generate_drop_path_random_mask(
                  memory_space_output, current_drop_path_keep_prob))
          transformer_block_fn_input_list += [
              pixel_space_drop_path_mask,
              memory_space_attention_drop_path_mask,
              memory_space_feed_forward_network_drop_path_mask]

        # Build the sub-layers when the transformer_block_fn is called for the
        # first time. Otherwise, recompute_grad will not track newly built model
        # parameters.
        if self._first_building_call:
          _ = transformer_block_fn_no_recompute(
              tuple(transformer_block_fn_input_list))
        # Apply a dual-path transformer.
        features, activated_features, memory_space_output = (
            transformer_block_fn(tuple(transformer_block_fn_input_list)))

        # Reshape pixel space features back to 4D.
        features = tf.reshape(features, [-1, height, width, channels])
        activated_features = tf.reshape(activated_features,
                                        [-1, height, width, channels])
    # Now the first call has finished and the sub-layers have been built.
    self._first_building_call = False
    # We also return the non-activated output so that the function is compatible
    # with a decoder that takes a non-activated tensor as input.
    return features, activated_features, memory_space_output