Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2021 The Deeplab2 Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Implements relative [1, 2, 3] and global [3, 4] positional encodings. | |
Our Axial-Deeplab [1] proposes position-sensitive self-attention which uses | |
relative positional encodings for query, key, and value. | |
[1] Axial-Deeplab: Stand-Alone Axial-Attention for Panoptic Segmentation, | |
ECCV 2020 Spotlight. | |
Huiyu Wang, Yukun Zhu, Bradley Green, Hartwig Adam, Alan Yuille, | |
Liang-Chieh Chen. | |
[2] Self-Attention with Relative Position Representations, NAACL 2018. | |
Peter Shaw, Jakob Uszkoreit, Ashish Vaswani. | |
[3] Tensor2Tensor for Neural Machine Translation, arXiv 2018, | |
http://arxiv.org/abs/1803.07416. | |
Ashish Vaswani, Samy Bengio, Eugene Brevdo, Francois Chollet, | |
Aidan N. Gomez, Stephan Gouws, Llion Jones, Łukasz Kaiser, | |
Nal Kalchbrenner, Niki Parmar, Ryan Sepassi, Noam Shazeer, | |
Jakob Uszkoreit. | |
[4] Attention Is All You Need, NeurIPS 2017. | |
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, | |
Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin. | |
[5] An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale, | |
ICLR 2021. | |
Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, | |
Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, | |
Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. | |
""" | |
import tensorflow as tf | |
# MAX_SPAN defines the maximum shape of positional encoding. It is set as a | |
# large constant so that we can easily load and use models with global or | |
# different local spans, but it should not be too large so that it takes a | |
# reasonable amount of memory. The value 255 is larger than almost all span | |
# choices (e.g. 65 for local attention, 129, 193, etc.) so 255 is large enough. | |
# 257 will be a good choice for gpu, but 255 is more efficient on TPU which pads | |
# tensors to 128x. | |
MAX_SPAN = 255 | |
def _compute_relative_distance_matrix(query_length, key_length): | |
"""Computes a relative distance matrix between queries and keys. | |
We assume that the queries and the keys are centered, i.e., | |
key_length = memory_flange + query_length + memory_flange. | |
The function is based on the _generate_relative_positions_matrix function in | |
common_attention.py of tensor2tensor codebase: | |
https://github.com/tensorflow/tensor2tensor/blob/5623deb79cfcd28f8f8c5463b58b5bd76a81fd0d/tensor2tensor/layers/common_attention.py#L1670 | |
Args: | |
query_length: An integer, the length of queries. | |
key_length: An integer, the length of keys. | |
Returns: | |
distance_matrix: A [query_length, key_length] tensor. | |
Raises: | |
ValueError: If (key_length - query_length) is odd, i.e., the assumption does | |
not hold. | |
""" | |
if (key_length - query_length) % 2: | |
raise ValueError('Key_length should be query_length + 2 * memory_flange.') | |
key_index = tf.range(key_length) | |
query_index = tf.range(query_length) + (key_length - query_length) // 2 | |
distance_matrix = key_index[None, :] - query_index[:, None] | |
# Shift the distance_matrix so that it is >= 0. Each entry of the | |
# distance_matrix distance will index a relative positional embedding. | |
distance_matrix = distance_matrix + MAX_SPAN - 1 | |
if query_length + (key_length - query_length) // 2 > MAX_SPAN: | |
tf.logging.warn('Axial attention span is larger than MAX_SPAN. In this ' | |
'case, we use a single shared embedding for all positions ' | |
'beyond this relative distance. Please make sure, this ' | |
'behavior is intended.') | |
distance_matrix = tf.clip_by_value(distance_matrix, 0, MAX_SPAN * 2 - 2) | |
return distance_matrix | |
class RelativePositionalEncoding(tf.keras.layers.Layer): | |
"""Generates relative positional encoding. | |
The function is based on the _generate_relative_positions_embeddings function | |
in common_attention.py of tensor2tensor codebase: | |
https://github.com/tensorflow/tensor2tensor/blob/5623deb79cfcd28f8f8c5463b58b5bd76a81fd0d/tensor2tensor/layers/common_attention.py#L1691 | |
""" | |
def __init__(self, query_length, key_length, depth, num_heads, name, | |
initialization_std=1.0, conv_kernel_weight_decay=0.0): | |
"""Initializes a relative position encoding layer. | |
Args: | |
query_length: An integer, the length of queries. | |
key_length: An integer, the length of keys. | |
depth: An integer, the number of embedding channels per head. | |
num_heads: An integer, the number of heads in multi-head attention. | |
name: A string, the name of the embedding. | |
initialization_std: A float, the initialization std for the embedding. | |
conv_kernel_weight_decay: A float, the weight decay for convolution | |
kernels. | |
Returns: | |
output: A [num_heads, query, key, depth] tensor, the relative positional | |
encodings for each head and each query-key-pair. | |
""" | |
super(RelativePositionalEncoding, self).__init__(name=name) | |
self._initializer = tf.keras.initializers.TruncatedNormal( | |
stddev=initialization_std) | |
self._regularizer = tf.keras.regularizers.l2(conv_kernel_weight_decay) | |
self._relative_distance_matrix = _compute_relative_distance_matrix( | |
query_length, key_length) | |
self._num_heads = num_heads | |
self._embedding_shape = (MAX_SPAN * 2 - 1, depth) | |
def build(self, input_shape): | |
"""Builds the embedding weight.""" | |
del input_shape | |
self._embeddings = self.add_weight( | |
shape=self._embedding_shape, | |
initializer=self._initializer, trainable=True, | |
name='embeddings', | |
regularizer=self._regularizer) | |
def call(self, inputs): | |
"""A forward pass that gathers the relative positional encoding.""" | |
del inputs | |
# Gather the embeddings according to the relative distances. | |
embeddings = tf.gather(self._embeddings, self._relative_distance_matrix) | |
return tf.tile(tf.expand_dims(embeddings, axis=0), | |
[self._num_heads, 1, 1, 1]) | |
class AddAbsolutePositionalEncoding(tf.keras.layers.Layer): | |
"""Adds a learnable absolute positional encoding to the input feature. | |
Supports both 1D and 2D versions of the positional encoding: (1) 1D positional | |
encoding represents each row index with an embedding, and represents each | |
column index with another embedding. This results in a total of (height + | |
width) learnable embedding vectors. (2) 2D positional encoding adds | |
independent embeddings to each input grid position. This choice uses a total | |
of (height * width) learnable embedding vectors. | |
""" | |
def __init__(self, name, positional_encoding_type=None, | |
bn_layer=tf.keras.layers.BatchNormalization, | |
conv_kernel_weight_decay=0.0): | |
"""Initializes an AddAbsolutePositionEmbedding layer. | |
Args: | |
name: A string specifying the name of the layer. | |
positional_encoding_type: A string, type of the positional encoding. | |
Support '2D', '1D', 'none', and None. The feature is returned as is if | |
positional_encoding_type is 'none' or None. | |
bn_layer: An optional tf.keras.layers.Layer that computes the | |
normalization (default: tf.keras.layers.BatchNormalization). | |
conv_kernel_weight_decay: A float, the weight decay for convolution | |
kernels. | |
Raises: | |
ValueError: If positional_encoding_type is not one of '1D', '2D', 'none', | |
and None. | |
""" | |
super(AddAbsolutePositionalEncoding, self).__init__(name=name) | |
if not any([positional_encoding_type is None, | |
positional_encoding_type.lower() == 'none', | |
positional_encoding_type.lower() == '2d', | |
positional_encoding_type.lower() == '1d']): | |
raise ValueError(positional_encoding_type + ' is not supported.') | |
self._positional_encoding_type = positional_encoding_type | |
# This initialization std is tuned for global attention, but it does not | |
# seem to be a sensitive hyper-parameter, since we use batch norm on the | |
# positional encodings. | |
self._initializer = tf.keras.initializers.TruncatedNormal(stddev=0.2) | |
self._kernel_regularizer = tf.keras.regularizers.l2( | |
conv_kernel_weight_decay) | |
self._bn_layer = bn_layer | |
def build(self, input_shape): | |
"""Builds the layer weights whose shape depends on the 4D input shape.""" | |
_, height, width, channel = input_shape | |
if self._positional_encoding_type.lower() == '2d': | |
self._embeddings = self.add_weight( | |
shape=(1, height, width, channel), | |
initializer=self._initializer, trainable=True, | |
name='embeddings', | |
regularizer=self._kernel_regularizer) | |
self._batch_norm = self._bn_layer(axis=-1, name='batch_norm') | |
elif self._positional_encoding_type.lower() == '1d': | |
# Generate separable positional encodings for the height axis and the | |
# width axis. | |
self._height_axis_embeddings = self.add_weight( | |
shape=(1, height, 1, channel), | |
initializer=self._initializer, trainable=True, | |
name='height_axis_embeddings', | |
regularizer=self._kernel_regularizer) | |
self._height_axis_batch_norm = self._bn_layer( | |
axis=-1, name='height_axis_batch_norm') | |
self._width_axis_embeddings = self.add_weight( | |
shape=(1, height, 1, channel), | |
initializer=self._initializer, trainable=True, | |
name='width_axis_embeddings', | |
regularizer=self._kernel_regularizer) | |
self._width_axis_batch_norm = self._bn_layer( | |
axis=-1, name='width_axis_batch_norm') | |
def call(self, features, training=False): | |
"""Performs a forward pass. | |
Args: | |
features: An input [batch, height, width, channels] tensor. | |
training: A boolean, whether the model is in training mode. | |
Returns: | |
output: The sum of the input feature and learnable positional encodings. | |
""" | |
if (self._positional_encoding_type is None or | |
self._positional_encoding_type.lower() == 'none'): | |
return features | |
elif self._positional_encoding_type.lower() == '2d': | |
positional_encoding = self._batch_norm(self._embeddings, | |
training=training) | |
elif self._positional_encoding_type.lower() == '1d': | |
height_axis_positional_encoding = self._height_axis_batch_norm( | |
self._height_axis_embeddings, training=training) | |
width_axis_positional_encoding = self._width_axis_batch_norm( | |
self._width_axis_embeddings, training=training) | |
positional_encoding = (height_axis_positional_encoding + | |
width_axis_positional_encoding) | |
return features + positional_encoding | |