Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2021 The Deeplab2 Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""This script contains STEMs for neural networks. | |
The `STEM` is defined as the first few convolutions that process the input | |
image to a spatially smaller feature map (e.g., output stride = 2). | |
Reference code: | |
https://github.com/tensorflow/models/blob/master/research/deeplab/core/resnet_v1_beta.py | |
""" | |
import tensorflow as tf | |
from deeplab2.model.layers import convolutions | |
layers = tf.keras.layers | |
class InceptionSTEM(tf.keras.layers.Layer): | |
"""A InceptionSTEM layer. | |
This class builds an InceptionSTEM layer which can be used to as the first | |
few layers in a neural network. In particular, InceptionSTEM contains three | |
consecutive 3x3 colutions. | |
Reference: | |
- Christian Szegedy, Sergey Ioffe, Vincent Vanhoucke, and Alexander Alemi. | |
"Inception-v4, inception-resnet and the impact of residual connections on | |
learning." In AAAI, 2017. | |
""" | |
def __init__(self, | |
bn_layer=tf.keras.layers.BatchNormalization, | |
width_multiplier=1.0, | |
conv_kernel_weight_decay=0.0, | |
activation='relu'): | |
"""Creates the InceptionSTEM layer. | |
Args: | |
bn_layer: An optional tf.keras.layers.Layer that computes the | |
normalization (default: tf.keras.layers.BatchNormalization). | |
width_multiplier: A float multiplier, controlling the value of | |
convolution output channels. | |
conv_kernel_weight_decay: A float, the weight decay for convolution | |
kernels. | |
activation: A string specifying an activation function to be used in this | |
stem. | |
""" | |
super(InceptionSTEM, self).__init__(name='stem') | |
self._conv1_bn_act = convolutions.Conv2DSame( | |
output_channels=int(64 * width_multiplier), | |
kernel_size=3, | |
name='conv1_bn_act', | |
strides=2, | |
use_bias=False, | |
use_bn=True, | |
bn_layer=bn_layer, | |
activation=activation, | |
conv_kernel_weight_decay=conv_kernel_weight_decay) | |
self._conv2_bn_act = convolutions.Conv2DSame( | |
output_channels=int(64 * width_multiplier), | |
kernel_size=3, | |
name='conv2_bn_act', | |
strides=1, | |
use_bias=False, | |
use_bn=True, | |
bn_layer=bn_layer, | |
activation=activation, | |
conv_kernel_weight_decay=conv_kernel_weight_decay) | |
self._conv3_bn = convolutions.Conv2DSame( | |
output_channels=int(128 * width_multiplier), | |
kernel_size=3, | |
strides=1, | |
use_bias=False, | |
use_bn=True, | |
bn_layer=bn_layer, | |
activation='none', | |
name='conv3_bn', | |
conv_kernel_weight_decay=conv_kernel_weight_decay) | |
def call(self, input_tensor, training=False): | |
"""Performs a forward pass. | |
Args: | |
input_tensor: An input tensor of type tf.Tensor with shape [batch, height, | |
width, channels]. | |
training: A boolean flag indicating whether training behavior should be | |
used (default: False). | |
Returns: | |
Two output tensors. The first output tensor is not activated. The second | |
tensor is activated. | |
""" | |
x = self._conv1_bn_act(input_tensor, training=training) | |
x = self._conv2_bn_act(x, training=training) | |
x = self._conv3_bn(x, training=training) | |
return x | |