File size: 3,810 Bytes
0924f30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# coding=utf-8
# Copyright 2021 The Deeplab2 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This script contains STEMs for neural networks.

The `STEM` is defined as the first few convolutions that process the input
image to a spatially smaller feature map (e.g., output stride = 2).


Reference code:
https://github.com/tensorflow/models/blob/master/research/deeplab/core/resnet_v1_beta.py
"""
import tensorflow as tf

from deeplab2.model.layers import convolutions

layers = tf.keras.layers


class InceptionSTEM(tf.keras.layers.Layer):
  """A InceptionSTEM layer.

  This class builds an InceptionSTEM layer which can be used to as the first
  few layers in a neural network. In particular, InceptionSTEM contains three
  consecutive 3x3 colutions.

  Reference:
  - Christian Szegedy, Sergey Ioffe, Vincent Vanhoucke, and Alexander Alemi.
    "Inception-v4, inception-resnet and the impact of residual connections on
    learning." In AAAI, 2017.
  """

  def __init__(self,
               bn_layer=tf.keras.layers.BatchNormalization,
               width_multiplier=1.0,
               conv_kernel_weight_decay=0.0,
               activation='relu'):
    """Creates the InceptionSTEM layer.

    Args:
      bn_layer: An optional tf.keras.layers.Layer that computes the
        normalization (default: tf.keras.layers.BatchNormalization).
      width_multiplier: A float multiplier, controlling the value of
        convolution output channels.
      conv_kernel_weight_decay: A float, the weight decay for convolution
        kernels.
      activation: A string specifying an activation function to be used in this
        stem.
    """
    super(InceptionSTEM, self).__init__(name='stem')

    self._conv1_bn_act = convolutions.Conv2DSame(
        output_channels=int(64 * width_multiplier),
        kernel_size=3,
        name='conv1_bn_act',
        strides=2,
        use_bias=False,
        use_bn=True,
        bn_layer=bn_layer,
        activation=activation,
        conv_kernel_weight_decay=conv_kernel_weight_decay)

    self._conv2_bn_act = convolutions.Conv2DSame(
        output_channels=int(64 * width_multiplier),
        kernel_size=3,
        name='conv2_bn_act',
        strides=1,
        use_bias=False,
        use_bn=True,
        bn_layer=bn_layer,
        activation=activation,
        conv_kernel_weight_decay=conv_kernel_weight_decay)

    self._conv3_bn = convolutions.Conv2DSame(
        output_channels=int(128 * width_multiplier),
        kernel_size=3,
        strides=1,
        use_bias=False,
        use_bn=True,
        bn_layer=bn_layer,
        activation='none',
        name='conv3_bn',
        conv_kernel_weight_decay=conv_kernel_weight_decay)

  def call(self, input_tensor, training=False):
    """Performs a forward pass.

    Args:
      input_tensor: An input tensor of type tf.Tensor with shape [batch, height,
        width, channels].
      training: A boolean flag indicating whether training behavior should be
        used (default: False).

    Returns:
      Two output tensors. The first output tensor is not activated. The second
        tensor is activated.
    """
    x = self._conv1_bn_act(input_tensor, training=training)
    x = self._conv2_bn_act(x, training=training)
    x = self._conv3_bn(x, training=training)
    return x