Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2021 The Deeplab2 Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Implements Axial-ResNets proposed in Axial-DeepLab [1]. | |
[1] Axial-Deeplab: Stand-Alone Axial-Attention for Panoptic Segmentation, | |
ECCV 2020 Spotlight. | |
Huiyu Wang, Yukun Zhu, Bradley Green, Hartwig Adam, Alan Yuille, | |
Liang-Chieh Chen. | |
""" | |
import tensorflow as tf | |
from deeplab2.model import utils | |
from deeplab2.model.layers import activations | |
from deeplab2.model.layers import axial_block_groups | |
from deeplab2.model.layers import convolutions | |
from deeplab2.model.layers import resized_fuse | |
from deeplab2.model.layers import stems | |
# Add a suffix in layer names that indicate if the current layer is a part of | |
# the backbone or an extra layer, i.e. if the current layer will be pretrained | |
# or not. This name will be used when we apply 10x larger learning rates for | |
# extra parameters that have not been pretrained, in panoptic segmentation. | |
# This keyword is reserved and should not be a part of the variable names in a | |
# classification pretrained backbone. | |
EXTRA = 'extra' | |
# Similarly, we will apply 10x larger learning rates on the memory feature. | |
# This global variable name will be accessed when we build the optimizers. This | |
# keyword is reserved and should not be a part of the variable names in a | |
# classification pretrained backbone. | |
MEMORY_FEATURE = 'memory_feature' | |
class AxialResNet(tf.keras.Model): | |
"""An Axial-ResNet model as proposed in Axial-DeepLab [1] and MaX-DeepLab [2]. | |
An Axial-ResNet [1] replaces 3x3 convolutions in a Resnet by axial-attention | |
layers. A dual-path transformer [2] and a stacked decoder [2] can be used | |
optionally. In addition, this class supports scaling models with SWideRNet [3] | |
and augmenting convolutions with Switchable Atrous Convolution [4]. | |
Reference: | |
[1] Axial-Deeplab: Stand-Alone Axial-Attention for Panoptic Segmentation, | |
ECCV 2020 Spotlight. https://arxiv.org/abs/2003.07853 | |
Huiyu Wang, Yukun Zhu, Bradley Green, Hartwig Adam, Alan Yuille, | |
Liang-Chieh Chen. | |
[2] MaX-DeepLab: "End-to-End Panoptic Segmentation with Mask Transformers", | |
CVPR 2021. https://arxiv.org/abs/2012.00759 | |
Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen. | |
[3] Scaling Wide Residual Networks for Panoptic Segmentation, | |
https://arxiv.org/abs/2011.11675 | |
Liang-Chieh Chen, Huiyu Wang, Siyuan Qiao. | |
[4] DetectoRS: Detecting Objects with Recursive Feature Pyramid and Switchable | |
Atrous Convolution, CVPR 2021. https://arxiv.org/abs/2006.02334 | |
Siyuan Qiao, Liang-Chieh Chen, Alan Yuille. | |
""" | |
def __init__(self, | |
name, | |
num_blocks=(3, 4, 6, 3), | |
backbone_layer_multiplier=1.0, | |
width_multiplier=1.0, | |
stem_width_multiplier=1.0, | |
output_stride=16, | |
classification_mode=False, | |
backbone_type='resnet_beta', | |
use_axial_beyond_stride=16, | |
backbone_use_transformer_beyond_stride=32, | |
extra_decoder_use_transformer_beyond_stride=32, | |
backbone_decoder_num_stacks=0, | |
backbone_decoder_blocks_per_stage=1, | |
extra_decoder_num_stacks=0, | |
extra_decoder_blocks_per_stage=1, | |
max_num_mask_slots=128, | |
num_mask_slots=128, | |
memory_channels=256, | |
base_transformer_expansion=1.0, | |
global_feed_forward_network_channels=256, | |
high_resolution_output_stride=4, | |
activation='relu', | |
block_group_config=None, | |
bn_layer=tf.keras.layers.BatchNormalization, | |
conv_kernel_weight_decay=0.0): | |
"""Initializes an AxialResNet model. | |
Args: | |
name: A string, the name of the model. | |
num_blocks: A list of 4 integers. It denotes the number of blocks to | |
include in the last 4 stages or block groups. Each group consists of | |
blocks that output features of the same resolution. Defaults to (3, 4, | |
6, 3) as in MaX-DeepLab-S. | |
backbone_layer_multiplier: A float, layer_multiplier for the backbone, | |
excluding the STEM. This flag controls the number of layers. Defaults to | |
1.0 as in MaX-DeepLab-S. | |
width_multiplier: A float, the channel multiplier for the block groups. | |
Defaults to 1.0 as in MaX-DeepLab-S. | |
stem_width_multiplier: A float, the channel multiplier for stem | |
convolutions. Defaults to 1.0 as in MaX-DeepLab-S. | |
output_stride: An integer, the maximum ratio of input to output spatial | |
resolution. Defaults to 16 as in MaX-DeepLab-S. | |
classification_mode: A boolean, whether to perform in a classification | |
mode. If it is True, this function directly returns backbone feature | |
endpoints. Note that these feature endpoints can also be used directly | |
for Panoptic-DeepLab or Motion-DeepLab. If it is False, this function | |
builds MaX-DeepLab extra decoder layers and extra transformer layers. | |
Defaults to False as in MaX-DeepLab. | |
backbone_type: A string, the type of backbone. Supports 'resnet', | |
'resnet_beta', and 'wider_resnet'. It controls both the stem type and | |
the residual block type. Defaults to 'resnet_beta' as in MaX-DeepLab-S. | |
use_axial_beyond_stride: An integer, the stride beyond which we use axial | |
attention. Set to 0 if no axial attention is desired. Defaults to 16 as | |
in MaX-DeepLab. | |
backbone_use_transformer_beyond_stride: An integer, the stride beyond | |
which we use a memory path transformer block on top of a regular pixel | |
path block, in the backbone. Set to 0 if no transformer block is desired | |
in the backbone. Defaults to 32 as in MaX-DeepLab-S. | |
extra_decoder_use_transformer_beyond_stride: An integer, the stride beyond | |
which we use a memory path transformer block on top of a regular pixel | |
path block, in the extra decoder stages. Set to 0 if no transformer | |
block is desired in the extra decoder stages. Defaults to 32 as in | |
MaX-DeepLab-S. | |
backbone_decoder_num_stacks: An integer, the number of decoder stacks | |
(introduced in MaX-DeepLab) that we use in the backbone. The stacked | |
decoders are applied in a stacked hour-glass style. Defaults to 0 as in | |
MaX-DeepLab-S. | |
backbone_decoder_blocks_per_stage: An integer, the number of consecutive | |
residual blocks to apply for each decoder stage, in the backbone. | |
Defaults to 1 as in MaX-DeepLab-S. | |
extra_decoder_num_stacks: An integer, the number of decoder stacks | |
(introduced in MaX-DeepLab) that we use in the extra decoder layers. It | |
is different from backbone_decoder_blocks_per_stage in that the extra | |
decoder stacks will be trained from scratch on segmentation tasks, | |
instead of pretrained on ImageNet classification. Defaults to 0 as in | |
MaX-DeepLab-S. | |
extra_decoder_blocks_per_stage: An integer, the number of consecutive | |
residual blocks to apply for each decoder stage, in the extra decoder | |
stages. Defaults to 1 as in MaX-DeepLab-S. | |
max_num_mask_slots: An integer, the maximum possible number of mask slots | |
that will be used. This will be used in a pretraining-finetuning use | |
case with different num_mask_slots: We can set max_num_mask_slots to the | |
maximum possible num_mask_slots, and then the saved checkpoint can be | |
loaded for finetuning with a different num_mask_slots. Defaults to 128 | |
as in MaX-DeepLab. | |
num_mask_slots: An integer, the number of mask slots that will be used. | |
Defaults to 128 as in MaX-DeepLab-S. | |
memory_channels: An integer, the number of channels for the whole memory | |
path. Defaults to 256 as in MaX-DeepLab-S. | |
base_transformer_expansion: A float, the base width expansion rate for | |
transformer layers. Defaults to 1.0 as in MaX-DeepLab-S. | |
global_feed_forward_network_channels: An integer, the number of channels | |
in the final global feed forward network, i.e. the mask feature head and | |
the mask class head. Defaults to 256 as in MaX-DeepLab-S. | |
high_resolution_output_stride: An integer, the final decoding output | |
stride. Defaults to 4 as in MaX-DeepLab-S. | |
activation: A string, type of activation function to apply. Support | |
'relu', 'swish' (or 'silu'), 'gelu', 'approximated_gelu', and 'elu'. | |
block_group_config: An argument dictionary that will be passed to | |
block_group. | |
bn_layer: An optional tf.keras.layers.Layer that computes the | |
normalization (default: tf.keras.layers.BatchNormalization). | |
conv_kernel_weight_decay: A float, the weight decay for convolution | |
kernels. | |
Raises: | |
ValueError: If backbone_type is not one of 'resnet', 'resnet_beta', or | |
'wider_resnet'. | |
ValueError: If extra_decoder_blocks_per_stage is not greater than zero. | |
""" | |
super(AxialResNet, self).__init__(name=name) | |
if extra_decoder_blocks_per_stage <= 0: | |
raise ValueError( | |
'Extra_decoder_blocks_per_stage should be great than zero.') | |
if block_group_config is None: | |
block_group_config = {} | |
# Compute parameter lists for block_groups. We consider five stages so that | |
# it is general enough to cover fully axial resnets and wider resnets. | |
total_strides_list = [1, 2, 4, 8, 16] | |
# Append 3 blocks for the first stage of fully axial resnets and wider | |
# resnets. | |
num_blocks_list = [3] + utils.scale_int_list(list(num_blocks), | |
backbone_layer_multiplier) | |
strides_list = [2] * 5 | |
# Expand the transformer and the block filters with the stride. | |
transformer_expansions_list = [] | |
filters_list = [] | |
for index, stride in enumerate(total_strides_list): | |
# Reduce the number of channels when we apply transformer to low level | |
# features (stride = 2, 4, or 8). The base_transformer_expansion is used | |
# for stride = 16, i.e. the standard output_stride for MaX-DeepLab-S. | |
transformer_expansions_list.append(base_transformer_expansion * stride / | |
16.0) | |
# Compute the base number of filters in each stage. For example, the last | |
# stage of ResNet50 has an input stride of 16, then we compute the base | |
# number of filters for a bottleneck block as 16 * 32 = 512, which is the | |
# number of filters for the 3x3 convolution in those blocks. | |
if backbone_type == 'wider_resnet' and index == 0: | |
# SWideRNet variants use stem_width_multiplier for the first block. | |
filters_list.append(int(round(stride * 32 * stem_width_multiplier))) | |
else: | |
filters_list.append(int(round(stride * 32 * width_multiplier))) | |
self._num_mask_slots = None | |
# Initialize memory_feature only when a transformer block is used. | |
self._use_memory_feature = (backbone_use_transformer_beyond_stride or | |
(extra_decoder_use_transformer_beyond_stride and | |
(not classification_mode))) | |
if self._use_memory_feature: | |
self._memory_feature_shape = (1, max_num_mask_slots, memory_channels) | |
self._memory_feature_initializer = ( | |
tf.keras.initializers.TruncatedNormal(stddev=1.0)) | |
self._memory_feature_regularizer = tf.keras.regularizers.l2( | |
conv_kernel_weight_decay) | |
if num_mask_slots: | |
self._num_mask_slots = num_mask_slots | |
# Use a convolutional stem except fully axial cases. | |
stem_channels = int(round(64 * stem_width_multiplier)) | |
self._activation_fn = activations.get_activation(activation) | |
if use_axial_beyond_stride == 1: | |
self._stem = tf.identity | |
first_block_index = 0 | |
elif backbone_type.lower() == 'wider_resnet': | |
self._stem = convolutions.Conv2DSame( | |
output_channels=stem_channels, | |
kernel_size=3, | |
name='stem', | |
strides=2, | |
use_bias=False, | |
use_bn=True, | |
bn_layer=bn_layer, | |
activation='none', | |
conv_kernel_weight_decay=conv_kernel_weight_decay) | |
# Wider ResNet has five residual block stages, so we start from index 0. | |
first_block_index = 0 | |
# Since we have applied the first strided convolution here, we do not use | |
# a stride for the first stage (which will operate on stride 2). | |
strides_list[0] = 1 | |
total_strides_list[0] = 2 | |
elif backbone_type.lower() == 'resnet_beta': | |
self._stem = stems.InceptionSTEM( | |
bn_layer=bn_layer, | |
width_multiplier=stem_width_multiplier, | |
conv_kernel_weight_decay=conv_kernel_weight_decay, | |
activation=activation) | |
first_block_index = 1 | |
elif backbone_type.lower() == 'resnet': | |
self._stem = convolutions.Conv2DSame( | |
output_channels=stem_channels, | |
kernel_size=7, | |
name='stem', | |
strides=2, | |
use_bias=False, | |
use_bn=True, | |
bn_layer=bn_layer, | |
activation='none', | |
conv_kernel_weight_decay=conv_kernel_weight_decay) | |
first_block_index = 1 | |
else: | |
raise ValueError(backbone_type + ' is not supported.') | |
self._first_block_index = first_block_index | |
# Apply standard ResNet block groups. We use first_block_index to | |
# distinguish models with 4 stages and those with 5 stages. | |
for index in range(first_block_index, 5): | |
current_name = '_stage{}'.format(index + 1) | |
utils.safe_setattr(self, current_name, axial_block_groups.BlockGroup( | |
filters=filters_list[index], | |
num_blocks=num_blocks_list[index], | |
name=utils.get_layer_name(current_name), | |
original_resnet_stride=strides_list[index], | |
original_resnet_input_stride=total_strides_list[index], | |
output_stride=output_stride, | |
backbone_type=backbone_type, | |
use_axial_beyond_stride=use_axial_beyond_stride, | |
use_transformer_beyond_stride=( | |
backbone_use_transformer_beyond_stride), | |
transformer_expansion=transformer_expansions_list[index], | |
activation=activation, | |
bn_layer=bn_layer, | |
conv_kernel_weight_decay=conv_kernel_weight_decay, | |
**block_group_config)) | |
self._backbone_decoder_num_stacks = backbone_decoder_num_stacks | |
self._classification_mode = classification_mode | |
self._extra_decoder_num_stacks = extra_decoder_num_stacks | |
self._output_stride = output_stride | |
self._high_resolution_output_stride = high_resolution_output_stride | |
self._width_multiplier = width_multiplier | |
self._activation = activation | |
self._bn_layer = bn_layer | |
self._conv_kernel_weight_decay = conv_kernel_weight_decay | |
self._backbone_use_transformer_beyond_stride = ( | |
backbone_use_transformer_beyond_stride) | |
self._extra_decoder_use_transformer_beyond_stride = ( | |
extra_decoder_use_transformer_beyond_stride) | |
# Keep track of the current stack so that we know when to stop. | |
current_stack = 0 | |
# Track whether we are building the backbone. This will affect the backbone | |
# related arguments, local learning rate, and so on. | |
current_is_backbone = True | |
if backbone_decoder_num_stacks == 0: | |
# No stacked decoder is used in the backbone, so we have finished building | |
# the backbone. We either return the classification endpoints, or continue | |
# building a non-backbone decoder for panoptic segmentation. | |
if self._classification_mode: | |
return | |
else: | |
current_is_backbone = False | |
if not current_is_backbone: | |
# Now that we have finished building the backbone and no stacked decoder | |
# is used in the backbone, so we start to build extra (i.e., non-backbone) | |
# layers for panoptic segmentation. | |
current_name = '_stage5_' + EXTRA | |
utils.safe_setattr( | |
self, current_name, axial_block_groups.BlockGroup( | |
filters=filters_list[-1], | |
num_blocks=extra_decoder_blocks_per_stage, | |
name=utils.get_layer_name(current_name), | |
original_resnet_stride=1, | |
original_resnet_input_stride=32, | |
output_stride=output_stride, | |
backbone_type=backbone_type, | |
use_axial_beyond_stride=use_axial_beyond_stride, | |
use_transformer_beyond_stride=( | |
extra_decoder_use_transformer_beyond_stride), | |
transformer_expansion=base_transformer_expansion, | |
activation=activation, | |
bn_layer=bn_layer, | |
conv_kernel_weight_decay=conv_kernel_weight_decay, | |
**block_group_config)) | |
# Compute parameter lists for stacked decoder. | |
total_decoder_num_stacks = ( | |
backbone_decoder_num_stacks + extra_decoder_num_stacks) | |
# Use a function to compute the next stride. | |
next_stride_fn = lambda x: x // 2 | |
current_decoder_stride = output_stride | |
decoder_stage = 0 | |
# Exit if we have enough stacks and reach the decoding output stride. | |
while (current_stack < total_decoder_num_stacks or | |
current_decoder_stride > high_resolution_output_stride): | |
decoder_stage += 1 | |
current_decoder_stride = next_stride_fn(current_decoder_stride) | |
if current_decoder_stride == output_stride: | |
current_stack += 1 | |
# Always use blocks from the last resnet stage if the current stride is | |
# output stride (the largest stride). | |
original_resnet_input_stride = 32 | |
# Switch the decoder direction if we reach the largest stride. | |
next_stride_fn = lambda x: x // 2 | |
else: | |
original_resnet_input_stride = current_decoder_stride | |
# Scale channels according to the strides. | |
decoder_channels = original_resnet_input_stride * 64 * width_multiplier | |
current_transformer_expansion = ( | |
base_transformer_expansion * current_decoder_stride / 16.0) | |
# Apply a decoder block group for building the backbone. | |
if current_is_backbone: | |
current_name = '_decoder_stage{}'.format(decoder_stage) | |
utils.safe_setattr( | |
self, current_name, axial_block_groups.BlockGroup( | |
filters=decoder_channels // 4, | |
num_blocks=backbone_decoder_blocks_per_stage, | |
name=utils.get_layer_name(current_name), | |
original_resnet_stride=1, | |
original_resnet_input_stride=original_resnet_input_stride, | |
output_stride=output_stride, | |
backbone_type=backbone_type, | |
use_axial_beyond_stride=use_axial_beyond_stride, | |
use_transformer_beyond_stride=( | |
backbone_use_transformer_beyond_stride), | |
transformer_expansion=current_transformer_expansion, | |
activation=activation, | |
bn_layer=bn_layer, | |
conv_kernel_weight_decay=conv_kernel_weight_decay, | |
**block_group_config)) | |
if (current_decoder_stride == output_stride and | |
current_stack == backbone_decoder_num_stacks): | |
# Now that we have finished building the backbone, we either return the | |
# classification endpoints, or continue building a non-backbone decoder | |
# for panoptic segmentation. | |
if classification_mode: | |
return | |
else: | |
current_is_backbone = False | |
# Apply a decoder block group for building the extra layers. | |
if not current_is_backbone: | |
# Continue building an extra (i.e., non-backbone) decoder for panoptic | |
# segmentation. | |
current_name = '_decoder_stage{}_{}'.format(decoder_stage, EXTRA) | |
utils.safe_setattr( | |
self, current_name, axial_block_groups.BlockGroup( | |
filters=decoder_channels // 4, | |
num_blocks=extra_decoder_blocks_per_stage, | |
name=utils.get_layer_name(current_name), | |
original_resnet_stride=1, | |
original_resnet_input_stride=original_resnet_input_stride, | |
output_stride=output_stride, | |
backbone_type=backbone_type, | |
use_axial_beyond_stride=use_axial_beyond_stride, | |
use_transformer_beyond_stride=( | |
extra_decoder_use_transformer_beyond_stride), | |
transformer_expansion=current_transformer_expansion, | |
activation=activation, | |
bn_layer=bn_layer, | |
conv_kernel_weight_decay=conv_kernel_weight_decay, | |
**block_group_config)) | |
if current_decoder_stride == high_resolution_output_stride: | |
next_stride_fn = lambda x: x * 2 | |
# Assert that we have already returned if we are building a classifier. | |
assert not classification_mode | |
if (backbone_use_transformer_beyond_stride or | |
extra_decoder_use_transformer_beyond_stride): | |
# Build extra memory path feed forward networks for the class feature and | |
# the mask feature. | |
current_name = '_class_feature_' + EXTRA | |
utils.safe_setattr( | |
self, current_name, convolutions.Conv1D( | |
global_feed_forward_network_channels, | |
utils.get_layer_name(current_name), | |
use_bias=False, | |
use_bn=True, | |
bn_layer=bn_layer, | |
activation=activation, | |
conv_kernel_weight_decay=conv_kernel_weight_decay)) | |
current_name = '_mask_feature_' + EXTRA | |
utils.safe_setattr( | |
self, current_name, convolutions.Conv1D( | |
global_feed_forward_network_channels, | |
utils.get_layer_name(current_name), | |
use_bias=False, | |
use_bn=True, | |
bn_layer=bn_layer, | |
activation=activation, | |
conv_kernel_weight_decay=conv_kernel_weight_decay)) | |
def build(self, input_shape): | |
"""Builds model weights and input shape dependent sub-layers.""" | |
if self._use_memory_feature: | |
self._memory_feature = self.add_weight( | |
name=MEMORY_FEATURE, | |
shape=self._memory_feature_shape, | |
initializer=self._memory_feature_initializer, | |
regularizer=self._memory_feature_regularizer) | |
else: | |
self._memory_feature = None | |
# Go through the loop to build the ResizedFuse layers. | |
current_stack = 0 | |
# Track whether we are building the backbone. This will affect the backbone | |
# related arguments, local learning rate, and so on. | |
current_is_backbone = self._backbone_decoder_num_stacks != 0 | |
total_decoder_num_stacks = ( | |
self._backbone_decoder_num_stacks + self._extra_decoder_num_stacks) | |
next_stride_fn = lambda x: x // 2 | |
current_decoder_stride = self._output_stride | |
decoder_stage = 0 | |
while (current_stack < total_decoder_num_stacks or | |
current_decoder_stride > self._high_resolution_output_stride): | |
decoder_stage += 1 | |
current_decoder_stride = next_stride_fn(current_decoder_stride) | |
if current_decoder_stride == self._output_stride: | |
current_stack += 1 | |
original_resnet_input_stride = 32 | |
next_stride_fn = lambda x: x // 2 | |
else: | |
original_resnet_input_stride = current_decoder_stride | |
# Compute the decoder_channels according to original_resnet_input_stride. | |
# For example, at stride 4 with width multiplier = 1, we use 4 * 64 = 256 | |
# channels, which is the same as a standard ResNet. | |
decoder_channels = int(round( | |
original_resnet_input_stride * 64 * self._width_multiplier)) | |
decoder_height, decoder_width = utils.scale_mutable_sequence( | |
input_shape[1:3], 1.0 / current_decoder_stride) | |
if current_is_backbone: | |
current_name = '_decoder_stage{}_resized_fuse'.format(decoder_stage) | |
else: | |
current_name = '_decoder_stage{}_{}_resized_fuse'.format( | |
decoder_stage, EXTRA) | |
utils.safe_setattr( | |
self, current_name, resized_fuse.ResizedFuse( | |
name=utils.get_layer_name(current_name), | |
height=decoder_height, | |
width=decoder_width, | |
num_channels=decoder_channels, | |
activation=self._activation, | |
bn_layer=self._bn_layer, | |
conv_kernel_weight_decay=self._conv_kernel_weight_decay)) | |
if (current_decoder_stride == self._output_stride and | |
current_stack == self._backbone_decoder_num_stacks): | |
# Now that we have finished building the backbone, we either return the | |
# classification endpoints, or continue building a non-backbone decoder | |
# for panoptic segmentation. | |
if self._classification_mode: | |
return | |
current_is_backbone = False | |
if current_decoder_stride == self._high_resolution_output_stride: | |
next_stride_fn = lambda x: x * 2 | |
def call_encoder_before_stacked_decoder(self, inputs, training=False): | |
"""Performs a forward pass of the encoder before stacking decoders. | |
Args: | |
inputs: An input [batch, height, width, channel] tensor. | |
training: A boolean, whether the model is in training mode. | |
Returns: | |
current_output: An output tensor with shape [batch, new_height, new_width, | |
new_channel]. | |
activated_output: An activated output tensor with shape [batch, | |
new_height, new_width, new_channel]. | |
memory_feature: None if no transformer is used. A [batch, num_memory, | |
memory_channel] tensor if transformer is used. | |
endpoints: A dict, the network endpoints that might be used by DeepLab. | |
""" | |
memory_feature = self._memory_feature | |
if self._use_memory_feature: | |
if self._num_mask_slots: | |
memory_feature = self._memory_feature[:, :self._num_mask_slots, :] | |
memory_feature = tf.tile(memory_feature, | |
[tf.shape(inputs)[0], 1, 1]) | |
endpoints = {} | |
output = self._stem(inputs) | |
activated_output = self._activation_fn(output) | |
endpoints['stage1'] = output | |
endpoints['res1'] = activated_output | |
# Apply standard ResNet block groups. We use first_block_index to | |
# distinguish models with 4 stages and those with 5 stages. | |
for index in range(self._first_block_index, 5): | |
current_name = '_stage{}'.format(index + 1) | |
current_output, activated_output, memory_feature = ( | |
getattr(self, current_name)( | |
(activated_output, memory_feature), training=training)) | |
endpoints[utils.get_layer_name(current_name)] = current_output | |
activated_output_name = 'res{}'.format(index + 1) | |
endpoints[activated_output_name] = activated_output | |
return current_output, activated_output, memory_feature, endpoints | |
def call_stacked_decoder(self, | |
current_output, | |
activated_output, | |
memory_feature, | |
endpoints, | |
training=False): | |
"""Performs a forward pass of the stacked decoders. | |
Args: | |
current_output: An output tensor with shape [batch, new_height, new_width, | |
new_channel]. | |
activated_output: An activated output tensor with shape [batch, | |
new_height, new_width, new_channel]. | |
memory_feature: None if no transformer is used. A [batch, num_memory, | |
memory_channel] tensor if transformer is used. | |
endpoints: A dict, the network endpoints that might be used by DeepLab. | |
training: A boolean, whether the model is in training mode. | |
Returns: | |
memory_feature: None if no transformer is used. A [batch, num_memory, | |
memory_channel] tensor if transformer is used. | |
high_resolution_outputs: A list of decoded tensors with | |
high_resolution_output_stride. | |
backbone_output: An output tensor of the backbone, with output_stride. | |
endpoints: A dict, the network endpoints that might be used by DeepLab. | |
""" | |
# Keep track of the current stack so that we know when to stop. | |
current_stack = 0 | |
# Track whether we are building the backbone. This will affect the backbone | |
# related arguments, local learning rate, and so on. | |
current_is_backbone = True | |
high_resolution_outputs = [] | |
if self._backbone_decoder_num_stacks == 0: | |
# Keep track of the backbone output, since it might be used as the | |
# semantic feature output. | |
backbone_output = activated_output | |
# Now that we have finished building the backbone, we either return the | |
# classification logits, or continue building a non-backbone decoder for | |
# panoptic segmentation. | |
if self._classification_mode: | |
endpoints['backbone_output'] = backbone_output | |
return None, None, None, endpoints | |
else: | |
current_is_backbone = False | |
if not current_is_backbone: | |
# Build extra layers if we have finished building the backbone. | |
current_name = '_stage5_' + EXTRA | |
current_output, activated_output, memory_feature = ( | |
getattr(self, current_name)( | |
(activated_output, memory_feature), training=training)) | |
# Compute parameter lists for stacked decoder. | |
total_decoder_num_stacks = ( | |
self._backbone_decoder_num_stacks + self._extra_decoder_num_stacks) | |
# Keep track of all endpoints that will be used in the stacked decoder. | |
stride_to_features = {} | |
stride_to_features[min(2, self._output_stride)] = [endpoints['stage1']] | |
stride_to_features[min(4, self._output_stride)] = [endpoints['stage2']] | |
stride_to_features[min(8, self._output_stride)] = [endpoints['stage3']] | |
stride_to_features[min(16, self._output_stride)] = [endpoints['stage4']] | |
# Only keep the last endpoint from the backbone with the same resolution, | |
# i.e., if the output stride is 16, the current output will override | |
# the stride 16 endpoint, endpoints['res4']. | |
stride_to_features[min(32, self._output_stride)] = [current_output] | |
# Use a function to compute the next stride. | |
next_stride_fn = lambda x: x // 2 | |
current_decoder_stride = self._output_stride | |
decoder_stage = 0 | |
# Exit if we have enough stacks and reach the decoding output stride. | |
while (current_stack < total_decoder_num_stacks or | |
current_decoder_stride > self._high_resolution_output_stride): | |
decoder_stage += 1 | |
current_decoder_stride = next_stride_fn(current_decoder_stride) | |
if current_decoder_stride == self._output_stride: | |
current_stack += 1 | |
# Switch the decoder direction if we reach the largest stride. | |
next_stride_fn = lambda x: x // 2 | |
# Include the current feature and two previous features from the target | |
# resolution in the decoder. We select two because it contains one upward | |
# feature and one downward feature, but better choices are possible. | |
decoder_features_list = ( | |
[current_output] + | |
stride_to_features[current_decoder_stride][-2:]) | |
# Fuse and resize features with striding, resizing and 1x1 convolutions. | |
if current_is_backbone: | |
current_name = '_decoder_stage{}_resized_fuse'.format(decoder_stage) | |
else: | |
current_name = '_decoder_stage{}_{}_resized_fuse'.format( | |
decoder_stage, EXTRA) | |
activated_output = getattr(self, current_name)( | |
decoder_features_list, training=training) | |
# Apply a decoder block group for building the backbone. | |
if current_is_backbone: | |
current_name = '_decoder_stage{}'.format(decoder_stage) | |
current_output, activated_output, memory_feature = ( | |
getattr(self, current_name)( | |
(activated_output, memory_feature), training=training)) | |
if (current_decoder_stride == self._output_stride and | |
current_stack == self._backbone_decoder_num_stacks): | |
# Keep track of the backbone output, since it might be used as the | |
# semantic feature output. | |
backbone_output = activated_output | |
# Now that we have finished building the backbone, we either return the | |
# classification logits, or continue building a non-backbone decoder for | |
# panoptic segmentation. | |
if self._classification_mode: | |
endpoints['backbone_output'] = backbone_output | |
return None, None, None, endpoints | |
else: | |
current_is_backbone = False | |
# Apply a decoder block group for building the extra layers. | |
if not current_is_backbone: | |
current_name = '_decoder_stage{}_{}'.format(decoder_stage, EXTRA) | |
current_output, activated_output, memory_feature = ( | |
getattr(self, current_name)( | |
(activated_output, memory_feature), training=training)) | |
# Append the current feature into the feature dict for possible later | |
# usage. | |
stride_to_features[current_decoder_stride].append(current_output) | |
if current_decoder_stride == self._high_resolution_output_stride: | |
high_resolution_outputs.append(activated_output) | |
next_stride_fn = lambda x: x * 2 | |
return memory_feature, high_resolution_outputs, backbone_output, endpoints | |
def call_extra_endpoints(self, | |
memory_feature, | |
high_resolution_outputs, | |
backbone_output, | |
endpoints, | |
training=False): | |
"""Performs a forward pass to generate extra endpoints. | |
Args: | |
memory_feature: None if no transformer is used. A [batch, num_memory, | |
memory_channel] tensor if transformer is used. | |
high_resolution_outputs: A list of decoded tensors with | |
high_resolution_output_stride. | |
backbone_output: An output tensor of the backbone, with output_stride. | |
endpoints: A dict, the network endpoints that might be used by DeepLab. | |
training: A boolean, whether the model is in training mode. | |
Returns: | |
endpoints: A dict, the network endpoints that might be used by DeepLab. | |
""" | |
# Assert that we have already returned if we are building a classifier. | |
assert not self._classification_mode | |
if (self._backbone_use_transformer_beyond_stride or | |
self._extra_decoder_use_transformer_beyond_stride): | |
# Build extra memory path feed forward networks for the class feature and | |
# the mask feature. | |
class_feature = getattr(self, '_class_feature_' + EXTRA)( | |
memory_feature, training=training) | |
mask_feature = getattr(self, '_mask_feature_' + EXTRA)( | |
memory_feature, training=training) | |
endpoints['transformer_class_feature'] = class_feature | |
endpoints['transformer_mask_feature'] = mask_feature | |
# Output the last high resolution feature as panoptic feature. | |
endpoints['feature_panoptic'] = high_resolution_outputs[-1] | |
# Avoid sharing our panoptic feature with the semantic auxiliary loss. So we | |
# use the backbone feature or the decoded backbone feature for the semantic | |
# segmentation head (i.e. the auxiliary loss). | |
if self._extra_decoder_num_stacks: | |
endpoints['feature_semantic'] = ( | |
high_resolution_outputs[self._backbone_decoder_num_stacks]) | |
else: | |
endpoints['feature_semantic'] = backbone_output | |
endpoints['backbone_output'] = backbone_output | |
return endpoints | |
def call(self, inputs, training=False): | |
"""Performs a forward pass. | |
Args: | |
inputs: An input [batch, height, width, channel] tensor. | |
training: A boolean, whether the model is in training mode. | |
Returns: | |
endpoints: A dict, the network endpoints that might be used by DeepLab. | |
""" | |
current_output, activated_output, memory_feature, endpoints = ( | |
self.call_encoder_before_stacked_decoder(inputs, training=training)) | |
memory_feature, high_resolution_outputs, backbone_output, endpoints = ( | |
self.call_stacked_decoder(current_output, | |
activated_output, | |
memory_feature, | |
endpoints, | |
training=training)) | |
if self._classification_mode: | |
return endpoints | |
endpoints = self.call_extra_endpoints(memory_feature, | |
high_resolution_outputs, | |
backbone_output, | |
endpoints, | |
training=training) | |
return endpoints | |