Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2021 The Deeplab2 Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""This file contains the loss functions for MaX-DeepLab models. | |
Reference: | |
MaX-DeepLab: "End-to-End Panoptic Segmentation with Mask Transformers", | |
CVPR 2021. https://arxiv.org/abs/2012.00759 | |
Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen. | |
""" | |
from typing import Text, Dict, Tuple, List | |
import tensorflow as tf | |
from deeplab2 import common | |
from deeplab2 import config_pb2 | |
from deeplab2.model import utils | |
from deeplab2.model.loss import base_loss | |
from deeplab2.model.loss import matchers_ops | |
# Positive and negative constants that are used to pad or mask hungarian | |
# matching weights. | |
_MATCHING_NEGATIVE_CONSTANT = -999.0 | |
_MATCHING_POSITIVE_CONSTANT = 999.0 | |
# A large negative constant applied before softmax. This will make the softmax | |
# ignore the masked logits. | |
_SOFTMAX_MASKING_CONSTANT = -99999.0 | |
_GT_KEY = 'gt_key' | |
_PRED_KEY = 'pred_key' | |
_WEIGHT_KEY = 'weight_key' | |
def _generate_mask_slot_semantic_one_hot( | |
matched_mask_slot_indices: tf.Tensor, | |
mask_gt_semantic_map: tf.Tensor, | |
num_mask_slots: int, | |
thing_stuff_class_ids: List[int]): | |
"""Generates the ground truth for transformer_class_logits. | |
This function generates a pseudo ground truth that we will use to train the | |
transformer class head logits. The input tensors, matched_mask_slot_indices | |
and mask_gt_semantic_map, are obtained by (hungarian) matching the ground | |
truth masks with the predicted masks. Note that this function generates the | |
positive one hot encodings only, i.e., the void class is not included in the | |
output tensor but will be generated outside the function. | |
Args: | |
matched_mask_slot_indices: An int32 tf.Tensor of shape [batch_size, | |
num_ground_truth_masks] that encodes the matched mask slot id for each | |
ground truth mask. | |
mask_gt_semantic_map: An int32 tf.Tensor of shape [batch_size, | |
num_ground_truth_masks] that encodes the semantic label for each ground | |
truth mask. A padded mask (or void, or no object) will have the label -1. | |
num_mask_slots: An integer, the number of mask slots for the MaX-DeepLab | |
model. | |
thing_stuff_class_ids: A list of integers of length [num_thing_classes + | |
num_stuff_classes] that encodes the class IDs for all thing and stuff | |
classes. It is a concatenation of the thing_class_ids list and the | |
stuff_class_ids list. | |
Returns: | |
mask_slot_semantic_one_hot: An output tf.Tensor with shape [batch_size, | |
num_mask_slots, num_thing_classes + num_stuff_classes]. | |
""" | |
semantic_map_shape = mask_gt_semantic_map.get_shape().as_list() | |
batch_size = semantic_map_shape[0] | |
num_ground_truth_masks = semantic_map_shape[-1] | |
# Concatenate the indices in each dimension of the ground truth one hot | |
# output. | |
batch_indices = tf.expand_dims(tf.range(batch_size), axis=-1) | |
batch_indices = tf.tile(batch_indices, [1, num_ground_truth_masks]) | |
batch_indices = tf.reshape(batch_indices, [-1, 1]) | |
matched_mask_slot_indices = tf.reshape(matched_mask_slot_indices, [-1, 1]) | |
# We shift the semantic map by one so that void labels (-1) will be a valid | |
# index too. Otherwise, tf.scatter_nd raises error if it runs on CPU. | |
semantic_indices = tf.reshape(mask_gt_semantic_map, [-1, 1]) + 1 | |
indices = tf.concat([batch_indices, | |
matched_mask_slot_indices, | |
semantic_indices], axis=-1) | |
# Generate mask_slot_semantic_one_hot by scattering constant ones onto a | |
# constant zero tensor. | |
updates = tf.ones([batch_size * num_ground_truth_masks], dtype=tf.float32) | |
mask_slot_semantic_one_hot = tf.scatter_nd( | |
indices, updates, | |
shape=[batch_size, num_mask_slots, max(thing_stuff_class_ids) + 2]) | |
# Gather the wanted classes in the desired (thing + stuff) order. | |
thing_stuff_tensor = tf.cast(thing_stuff_class_ids, tf.int32) | |
# We also shift the thing_stuff_tensor index by one in order to revert the | |
# semantic map shifting above. | |
mask_slot_semantic_one_hot = tf.gather(mask_slot_semantic_one_hot, | |
thing_stuff_tensor + 1, axis=2) | |
return mask_slot_semantic_one_hot | |
def nonsquare_hungarian_matching( | |
weights: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: | |
"""Hungarian matching with arbitrary shape. | |
The matchers_ops.hungarian_matching supports only squared weight matrices. | |
This function generalizes the hungarian matching to nonsquare cases by padding | |
the weights to a square and running the square version matching. The property | |
of hungarian matching ensures that the solutions are equivalent for the padded | |
square problem and the original nonsquare problem. | |
Args: | |
weights: A [batch, shape1, shape2] float32 tf.Tensor. | |
Returns: | |
square_permutation: A [batch, max(shape1, shape2), max(shape1, shape2)] | |
float32 tf.Tensor that is the permutation matrix that achieves the minimum | |
total weight. Note that a permutation matrix contains only value 0.0 and | |
1.0, with each row and each column sums to 1.0. | |
nonsquare_permutation: A [batch, shape1, shape2] float32 tf.Tensor. The | |
nonsquare part of the permutation matrix. | |
""" | |
_, height, width = weights.get_shape().as_list() | |
max_height_width = max(height, width) | |
# Padding a constant on one axis does not affect matching results. | |
weights = tf.pad(weights, | |
[[0, 0], # Do not pad the batch dimension. | |
[0, max_height_width - height], | |
[0, max_height_width - width]], | |
constant_values=_MATCHING_NEGATIVE_CONSTANT) | |
square_permutation = matchers_ops.hungarian_matching(weights) | |
square_permutation = tf.cast(square_permutation, tf.float32) | |
return square_permutation, square_permutation[:, :height, :width] | |
def _mask_similarity(gt_mask: tf.Tensor, pred_mask: tf.Tensor, | |
metric: str = 'dice') -> tf.Tensor: | |
"""Computes mask similarity between gt_masks and pred_masks. | |
Args: | |
gt_mask: A [batch, height * width, num_gt_masks] float32 tf.Tensor, that | |
contains only value 0.0 and 1.0. Each 1.0 indicates that the pixel belongs | |
to the ground truth mask. Note that panoptic segmentation enforces that | |
ground truth masks do not overlap. | |
pred_mask: A [batch, height * width, num_pred_masks] float32 tf.Tensor, that | |
is positive. For each batch_id and pixel_id, the [num_pred_masks] vector | |
encodes whether each pixel belongs to each mask. The sum of each vector is | |
less than or equal to one. | |
metric: A string, the mask similarity metric that we will compute. Supports | |
'dice' (default), 'iou', 'intersection_over_ground_truth', and | |
'intersection_over_prediction'. | |
Returns: | |
mask_similarity: A float32 [batch, num_gt_masks, num_pred_masks] tf.Tensor | |
that contains the mask similarity between all ground truth masks and all | |
predicted masks. | |
Raises: | |
ValueError: If the mask similarity metric is not one of 'dice', 'iou', | |
'intersection_over_ground_truth', or 'intersection_over_prediction'. | |
""" | |
denominator_epsilon = 1e-5 | |
intersection = tf.einsum('bpi,bpj->bij', gt_mask, pred_mask) | |
if metric.lower() == 'dice': | |
denominator = (tf.expand_dims(tf.reduce_sum(gt_mask, axis=1), axis=2) + | |
tf.reduce_sum(pred_mask, axis=1, keepdims=True)) / 2 | |
elif metric.lower() == 'iou': | |
denominator = (tf.expand_dims(tf.reduce_sum(gt_mask, axis=1), axis=2) + | |
tf.reduce_sum(pred_mask, axis=1, keepdims=True) - | |
intersection) | |
elif metric.lower() == 'intersection_over_ground_truth': | |
denominator = tf.expand_dims(tf.reduce_sum(gt_mask, axis=1), axis=2) | |
elif metric.lower() == 'intersection_over_prediction': | |
denominator = tf.reduce_sum(pred_mask, axis=1, keepdims=True) | |
else: | |
raise ValueError('The mask similarity metric is not supported.') | |
return intersection / (denominator + denominator_epsilon) | |
class MaXDeepLabLoss(tf.keras.layers.Layer): | |
"""This class contains code for MaX-DeepLab losses.""" | |
def __init__(self, | |
loss_options: config_pb2.LossOptions, | |
ignore_label: int, | |
thing_class_ids: Tuple[int], | |
focal_loss_alpha: float = 0.75, | |
instance_discrimination_temperature: float = 0.3): | |
"""Initializes a MaX-DeepLab loss. | |
This class supports PQ-style loss, mask id cross entropy loss, and instance | |
discrimination loss, proposed in MaX-DeepLab. The PQ-style loss can be | |
further decomposed in to a classification term and a mask dice term. | |
Reference: | |
MaX-DeepLab: "End-to-End Panoptic Segmentation with Mask Transformers", | |
CVPR 2021. https://arxiv.org/abs/2012.00759 | |
Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen. | |
Args: | |
loss_options: Loss options as defined by config_pb2.LossOptions. | |
ignore_label: An integer specifying the ignore label. | |
thing_class_ids: A tuple of length [N] containing N thing indices. | |
focal_loss_alpha: An optional float specifying the coefficient that | |
weights between positive (matched) and negative (unmatched) masks in | |
focal loss. The positives are weighted by alpha, while the negatives | |
are weighted by (1. - alpha). Note that we do not use a focal loss | |
gamma here, i.e., the gamma is set to zero which is equivalent to the | |
normal cross-entropy loss, except for the alpha weighting. Default to | |
0.75. | |
instance_discrimination_temperature: An optional float specifying the | |
temperature for the instance discrimination loss. | |
""" | |
super(MaXDeepLabLoss, self).__init__(name='MaXDeepLabLoss') | |
# The loss_terms will optionally include | |
# - common.PQ_STYLE_LOSS_CLASS_TERM | |
# - common.PQ_STYLE_LOSS_MASK_DICE_TERM | |
# - common.MASK_ID_CROSS_ENTROPY_LOSS | |
# - common.INSTANCE_DISCRIMINATION_LOSS | |
# These loss terms will be accessed by loss_builder.py and will be used to | |
# build loss metrics. | |
self.loss_terms = [] | |
# The PQ-style loss includes two terms. | |
self._pq_style_loss_weight = 0.0 | |
if loss_options.HasField(common.PQ_STYLE_LOSS): | |
self._pq_style_loss_weight = loss_options.pq_style_loss.weight | |
self.loss_terms.append(common.PQ_STYLE_LOSS_CLASS_TERM) | |
self.loss_terms.append(common.PQ_STYLE_LOSS_MASK_DICE_TERM) | |
# Mask-ID cross entropy loss. | |
self._mask_id_cross_entropy_loss_weight = 0.0 | |
if loss_options.HasField(common.MASK_ID_CROSS_ENTROPY_LOSS): | |
self._mask_id_cross_entropy_loss_weight = ( | |
loss_options.mask_id_cross_entropy_loss.weight) | |
self.loss_terms.append(common.MASK_ID_CROSS_ENTROPY_LOSS) | |
# Instance discrimination loss. | |
self._instance_discrimination_loss_weight = 0.0 | |
if loss_options.HasField(common.INSTANCE_DISCRIMINATION_LOSS): | |
self._instance_discrimination_loss_weight = ( | |
loss_options.instance_discrimination_loss.weight) | |
self.loss_terms.append(common.INSTANCE_DISCRIMINATION_LOSS) | |
self._ignore_label = ignore_label | |
self._thing_class_ids = list(thing_class_ids) | |
self._focal_loss_alpha = focal_loss_alpha | |
self._instance_discrimination_temperature = ( | |
instance_discrimination_temperature) | |
# Build the base loss functions. | |
self._pq_style_loss_class_term = base_loss.FocalCrossEntropyLoss( | |
gt_key=_GT_KEY, pred_key=_PRED_KEY, weight_key=_WEIGHT_KEY, | |
# Num_classes and ignore_label are not necessary since the inputs will | |
# be one hot encoded already. | |
num_classes=None, ignore_label=None, | |
focal_loss_alpha=focal_loss_alpha, | |
focal_loss_gamma=0.0, background_channel_index=-1, | |
dynamic_weight=True) | |
self._pq_style_loss_mask_dice_term = base_loss.MaskDiceLoss( | |
gt_key=_GT_KEY, pred_key=_PRED_KEY, weight_key=_WEIGHT_KEY, | |
prediction_activation='softmax') | |
self._mask_id_cross_entropy_loss = base_loss.TopKCrossEntropyLoss( | |
gt_key=_GT_KEY, pred_key=_PRED_KEY, weight_key=_WEIGHT_KEY, | |
# Num_classes and ignore_label are not necessary since the inputs will | |
# be one hot encoded already. | |
num_classes=None, ignore_label=None, | |
top_k_percent_pixels=1.0, dynamic_weight=True) | |
self._instance_discrimination_loss = base_loss.TopKCrossEntropyLoss( | |
gt_key=_GT_KEY, pred_key=_PRED_KEY, weight_key=_WEIGHT_KEY, | |
# Num_classes and ignore_label are not necessary since the inputs will | |
# be one hot encoded already. | |
num_classes=None, ignore_label=None, | |
top_k_percent_pixels=1.0, dynamic_weight=True) | |
def build(self, | |
input_shapes: Tuple[Dict[Text, tf.Tensor], Dict[Text, tf.Tensor]]): | |
"""Extracts useful constants that depend on the input shapes.""" | |
y_true_shapes = input_shapes[0] | |
self._max_thing_id = int(y_true_shapes[common.GT_THING_ID_CLASS_KEY][-1]) | |
y_pred_shapes = input_shapes[1] | |
transformer_class_logits_shape = y_pred_shapes[ | |
common.PRED_TRANSFORMER_CLASS_LOGITS_KEY] | |
self._num_mask_slots = int(transformer_class_logits_shape[1]) | |
# The transformer_class_logits contain thing classes, stuff classes, and the | |
# void class, so num_thing_stuff_classes should be the total number of | |
# classes minus one. | |
self._num_thing_stuff_classes = int(transformer_class_logits_shape[2]) - 1 | |
# Since we implement the PQ-style loss with the class term plus the mask | |
# dice term (Equation 10 of the paper), we need to balance the two terms to | |
# have the same weight and normalizating constants. The focal loss alpha is | |
# a weight on the positive class term, so we apply it to the mask dice term | |
# too. The class loss is also normalized by the number of mask slots, so we | |
# do the same normalization for the mask dice term. | |
self._mask_dice_term_modifier = ( | |
self._focal_loss_alpha / self._num_mask_slots) | |
self._stuff_class_ids = utils.get_stuff_class_ids( | |
self._num_thing_stuff_classes, | |
self._thing_class_ids, | |
self._ignore_label) | |
self._num_stuff_classes = len(self._stuff_class_ids) | |
self._thing_stuff_class_ids = self._thing_class_ids + self._stuff_class_ids | |
self._pixel_gt_num_mask_id = self._max_thing_id + self._num_stuff_classes | |
def _pre_process_ground_truth( | |
self, y_true: Dict[Text, tf.Tensor], output_height: int, output_width: int | |
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, | |
tf.Tensor]: | |
"""Pre-processes the ground truth before we compute the losses. | |
This function generates tensors that do not depend on the prediction of the | |
model, but are useful to the calculation of the losses. The function mainly | |
downsamples the pixel space ground truth to the model output resolution, and | |
combines (or concatenates) the thing masks and the stuff masks. The output | |
shape pixel_gt_num_mask_id = max_thing_id + num_stuff_classes, which means | |
the output masks contain both thing masks and stuff masks. | |
Args: | |
y_true: A dict of tensors providing ground-truth information, containing | |
- common.GT_SEMANTIC_KEY: A [batch, height, width] int32 tf.Tensor, the | |
semantic label map. | |
- common.GT_THING_ID_MASK_KEY: A [batch, height, width] int32 tf.Tensor. | |
It assigns each non-crowd thing instance a unique mask-ID label, | |
starting from 0. Unassigned pixels are set to -1. | |
- common.GT_THING_ID_CLASS_KEY: A [batch, max_thing_id] int32 tf.Tensor. | |
It contains semantic ID of each instance assigned to thing_id_mask. The | |
remaining (max_thing_id - num_things) elements are set to -1. | |
output_height: An integer, the height of the model output. | |
output_width: An integer, the width of the model output. | |
Returns: | |
pixel_gt_thing_mask: A [batch, output_height * output_width] float32 | |
tensor, with value 0.0 and 1.0 only, indicating whether a pixel belongs | |
to a 'thing' class. | |
pixel_gt_non_void_mask: A [batch, output_height * output_width] float32 | |
tensor, with value 0.0 and 1.0 only, indicating if a pixel does not | |
belong to the void class. | |
pixel_gt_mask_id_one_hot: A [batch, output_height * output_width, | |
pixel_gt_num_mask_id] float32 tensor, with value 0.0 and 1.0 only, | |
indicating the mask id each pixel belongs to. | |
mask_gt_semantic_map: A [batch, pixel_gt_num_mask_id] int32 tensor, the | |
semantic class of each ground truth mask. | |
mask_gt_non_void_mask: A [batch, pixel_gt_num_mask_id] int32 tensor, with | |
value 0.0 and 1.0 only, indicating if the ground truth mask is a valid | |
mask, not a padded mask. The masks are padded because TPU does not | |
support dynamic shapes except in the batch axis. We pad all ground truth | |
thing masks to a large enough constant max_thing_id. Similarly, stuff | |
classes that do not present in the current image will be set to a void | |
mask too. | |
mask_gt_semantic_one_hot: A [batch, pixel_gt_num_mask_id, | |
num_thing_stuff_classes] float32 tensor, with value 0.0 and 1.0 only, | |
containing the one hot encodings of the ground truth mask classes. The | |
last dimension contains concatenated thing classes and stuff classes, | |
which is different from the dataset class IDs in mask_gt_semantic_map. | |
mask_gt_area: A [batch, pixel_gt_num_mask_id] float32 tensor, the area of | |
each ground truth mask. Padded masks have an area of 0.0. | |
""" | |
# The depth of one hot encoding should be the largest id plus one. For | |
# example, if we want to one-hot encode a class ID of 133 (the largest ID | |
# for the COCO dataset), we will need a one-hot encoding of length 134. | |
one_hot_depth = max(self._thing_stuff_class_ids) + 1 | |
batch_size = y_true[common.GT_SEMANTIC_KEY].get_shape().as_list()[0] | |
# Compute pixel_gt_semantic_map (downsampling and reshaping to the 1D | |
# representation that will be mainly used in this loss function). | |
pixel_gt_semantic_map = utils.strided_downsample( | |
y_true[common.GT_SEMANTIC_KEY], | |
target_size=[output_height, output_width]) | |
pixel_gt_semantic_map = tf.reshape( | |
pixel_gt_semantic_map, | |
[batch_size, output_height * output_width]) | |
# Compute pixel_gt_non_void_mask. | |
pixel_gt_non_void_mask = tf.cast( | |
tf.not_equal(pixel_gt_semantic_map, self._ignore_label), tf.float32) | |
pixel_gt_non_void_mask = tf.ensure_shape( | |
pixel_gt_non_void_mask, | |
[batch_size, output_height * output_width]) | |
# Compute pixel_gt_semantic_one_hot from pixel_gt_semantic_map in order to | |
# gather pixel_gt_stuff_id_one_hot from pixel_gt_semantic_one_hot. | |
pixel_gt_semantic_one_hot = tf.one_hot(pixel_gt_semantic_map, one_hot_depth) | |
# Convert the one hot encoding from the dataset id order to (thing, stuff) | |
# order used in MaX-DeepLab. | |
pixel_gt_stuff_id_one_hot = tf.gather(pixel_gt_semantic_one_hot, | |
self._stuff_class_ids, axis=-1) | |
pixel_gt_stuff_id_one_hot = tf.ensure_shape( | |
pixel_gt_stuff_id_one_hot, | |
[batch_size, output_height * output_width, self._num_stuff_classes]) | |
# Compute pixel_gt_thing_id_one_hot for thing masks. | |
pixel_gt_thing_id_map = utils.strided_downsample( | |
y_true[common.GT_THING_ID_MASK_KEY], | |
target_size=[output_height, output_width]) | |
pixel_gt_thing_id_map = tf.reshape( | |
pixel_gt_thing_id_map, shape=[batch_size, output_height * output_width]) | |
# Note that common.GT_THING_ID_MASK_KEY uses -1 for void masks. And 0 to | |
# (num_mask_slots - 1) are used for num_mask_slots mask slots. | |
pixel_gt_thing_mask = tf.cast( | |
tf.not_equal(pixel_gt_thing_id_map, -1), tf.float32) | |
pixel_gt_thing_id_one_hot = tf.one_hot(pixel_gt_thing_id_map, | |
self._max_thing_id) | |
# Compute pixel_gt_mask_id_one_hot by concatenating thing masks with stuff | |
# masks. | |
pixel_gt_mask_id_one_hot = tf.concat([pixel_gt_thing_id_one_hot, | |
pixel_gt_stuff_id_one_hot], axis=-1) | |
pixel_gt_mask_id_one_hot = tf.ensure_shape( | |
pixel_gt_mask_id_one_hot, | |
[batch_size, output_height * output_width, self._pixel_gt_num_mask_id]) | |
# Compute mask_gt_area by summing the one hot encodings spatially. | |
mask_gt_area = tf.expand_dims( | |
tf.reduce_sum(pixel_gt_mask_id_one_hot, axis=1), axis=-1) | |
# Generate a binary mask for ground truth masks, indicating whether each | |
# ground truth mask exists in the pixel space with a non-zero area. Note | |
# that a mask that exists in the original input resolution will be removed | |
# if its area is zero in the output resolution, due to downsampling. | |
mask_gt_area_mask = tf.reshape(mask_gt_area > 0.5, | |
[batch_size, self._pixel_gt_num_mask_id]) | |
# Compute mask_gt_semantic_map and mask_gt_semantic_one_hot. | |
thing_id_gt_semantic_map = tf.reshape( | |
tf.cast(y_true[common.GT_THING_ID_CLASS_KEY], tf.int32), | |
[batch_size, self._max_thing_id]) | |
# The stuff ground truth semantic map is just the stuff class IDs. | |
stuff_id_gt_semantic_map = tf.tile( | |
tf.reshape( | |
tf.cast(self._stuff_class_ids, tf.int32), | |
[1, self._num_stuff_classes]), [batch_size, 1]) | |
mask_gt_semantic_map = tf.concat( | |
[thing_id_gt_semantic_map, stuff_id_gt_semantic_map], axis=-1) | |
# Set masks with zero area to void (-1), which is consistent with the void | |
# label used in common.GT_THING_ID_CLASS_KEY but is different from the | |
# ignore_labels of the datasets. | |
mask_gt_semantic_map = ( | |
(mask_gt_semantic_map + 1) * tf.cast(mask_gt_area_mask, tf.int32) - 1) | |
# Void (-1) classes will automatically be ignored by tf.one_hot. | |
mask_gt_semantic_one_hot = tf.one_hot(mask_gt_semantic_map, one_hot_depth) | |
mask_gt_semantic_one_hot = tf.gather( | |
mask_gt_semantic_one_hot, self._thing_stuff_class_ids, axis=-1) | |
# Compute mask_gt_non_void_mask. Again, a mask that exists in the original | |
# input resolution is set to void if its area is zero in the output | |
# resolution, due to downsampling. | |
mask_gt_non_void_mask = tf.cast(mask_gt_semantic_map > -1, tf.float32) | |
mask_gt_non_void_mask = tf.ensure_shape( | |
mask_gt_non_void_mask, [batch_size, self._pixel_gt_num_mask_id]) | |
return (pixel_gt_thing_mask, pixel_gt_non_void_mask, | |
pixel_gt_mask_id_one_hot, mask_gt_semantic_map, | |
mask_gt_non_void_mask, mask_gt_semantic_one_hot, mask_gt_area) | |
def call( | |
self, inputs: Tuple[Dict[Text, tf.Tensor], Dict[Text, tf.Tensor]] | |
) -> Dict[Text, tf.Tensor]: | |
"""Computes the MaX-DeepLab losses. | |
Args: | |
inputs: A tuple of two dicts (y_true, y_pred): | |
- y_true: A dict of tensors providing ground-truth information, containing | |
- common.GT_SEMANTIC_KEY: A [batch, height, width] int32 tf.Tensor, the | |
semantic label map. | |
- common.GT_THING_ID_MASK_KEY: A [batch, height, width] int32 | |
tf.Tensor. It assigns each non-crowd thing instance a unique mask-ID | |
label, starting from 0. Unassigned pixels are set to -1. | |
- common.GT_THING_ID_CLASS_KEY: A [batch, max_thing_id] int32 | |
tf.Tensor. It contains semantic ID of each instance assigned to | |
thing_id_mask. The remaining (max_thing_id - num_things) elements are | |
set to -1. | |
- y_pred: A dict of tensors providing predictions. | |
- common.PRED_PIXEL_SPACE_NORMALIZED_FEATURE_KEY: A [batch_size, | |
output_height, output_width, channels] float32 tensor. | |
- common.PRED_PIXEL_SPACE_MASK_LOGITS_KEY: A [batch_size, | |
output_height, output_width, num_mask_slots] float32 tensor, the | |
logits that a pixel belongs to a mask slot. | |
- common.PRED_TRANSFORMER_CLASS_LOGITS_KEY: A [batch_size, | |
num_mask_slots, num_thing_stuff_classes + 1] float32 tensor, the | |
logits that a mask belongs to a semantic class (including thing, | |
stuff, and void) | |
Returns: | |
The loss as a dict of tf.Tensor, optionally containing the following: | |
- common.PQ_STYLE_LOSS_CLASS_TERM: [batch]. | |
- common.PQ_STYLE_LOSS_MASK_DICE_TERM: [batch]. | |
- common.MASK_ID_CROSS_ENTROPY_LOSS: [batch]. | |
- common.INSTANCE_DISCRIMINATION_LOSS: [batch]. | |
""" | |
y_true, y_pred = inputs | |
resulting_dict = {} | |
pixel_feature = y_pred[common.PRED_PIXEL_SPACE_NORMALIZED_FEATURE_KEY] | |
batch_size, output_height, output_width, _ = ( | |
pixel_feature.get_shape().as_list()) | |
# Pre-process the ground truth. | |
(pixel_gt_thing_mask, pixel_gt_non_void_mask, pixel_gt_mask_id_one_hot, | |
mask_gt_semantic_map, mask_gt_non_void_mask, mask_gt_semantic_one_hot, | |
mask_gt_area) = self._pre_process_ground_truth(y_true, | |
output_height, output_width) | |
pixel_gt_non_void_mask_expanded = tf.expand_dims( | |
pixel_gt_non_void_mask, axis=-1) | |
# Compute mask_average_feature by averaging the feature of each mask. | |
pixel_feature = tf.reshape( | |
pixel_feature, [batch_size, output_height * output_width, -1]) | |
mask_average_feature = tf.einsum( | |
'bpd,bpi->bid', | |
pixel_feature, | |
pixel_gt_mask_id_one_hot) / tf.maximum(mask_gt_area, 1.0) | |
# Normalize the mask feature as the pixel space output feature is usually | |
# normalized too. | |
mask_average_feature = tf.math.l2_normalize(mask_average_feature, axis=-1) | |
# Compute instance_discrimination_similarity, scaled by a constant | |
# temperature. | |
instance_discrimination_similarity = tf.einsum( | |
'bpd,bid->bpi', pixel_feature, mask_average_feature) | |
instance_discrimination_similarity /= ( | |
self._instance_discrimination_temperature) | |
mask_gt_non_void_mask_expanded_1 = tf.expand_dims( | |
mask_gt_non_void_mask, axis=1) | |
# Mask void masks by setting them to a large negative value, so that they | |
# will be ignored by the softmax in the loss. | |
instance_discrimination_similarity = ( | |
mask_gt_non_void_mask_expanded_1 * instance_discrimination_similarity + | |
(1.0 - mask_gt_non_void_mask_expanded_1) * _SOFTMAX_MASKING_CONSTANT) | |
# Auxiliary instance_discrimination_loss. | |
if self._instance_discrimination_loss_weight > 0.0: | |
resulting_dict[common.INSTANCE_DISCRIMINATION_LOSS] = ( | |
self._instance_discrimination_loss( | |
{_GT_KEY: pixel_gt_mask_id_one_hot}, | |
{_PRED_KEY: instance_discrimination_similarity, | |
_WEIGHT_KEY: pixel_gt_thing_mask}) * | |
self._instance_discrimination_loss_weight) | |
# Extract pixel_space_mask_logits and pixel_space_mask_probs. | |
pixel_space_mask_logits = y_pred[common.PRED_PIXEL_SPACE_MASK_LOGITS_KEY] | |
pixel_space_mask_logits = tf.reshape( | |
pixel_space_mask_logits, | |
[batch_size, output_height * output_width, self._num_mask_slots]) | |
pixel_space_mask_probs = tf.nn.softmax(pixel_space_mask_logits, axis=-1) | |
# Compute the mask similarity between all ground truth masks and all | |
# predicted masks. | |
mask_similarity = _mask_similarity( | |
pixel_gt_mask_id_one_hot, | |
pixel_space_mask_probs * pixel_gt_non_void_mask_expanded, | |
metric='dice') | |
# Compute the class similarity by multiplying the ground truth one hot | |
# encoding with the predicted probability distribution. This is done between | |
# all ground truth masks and all predicted masks. | |
transformer_class_logits = y_pred[common.PRED_TRANSFORMER_CLASS_LOGITS_KEY] | |
transformer_class_probs = tf.nn.softmax( | |
transformer_class_logits, axis=-1)[:, :, :-1] | |
class_similarity = tf.einsum( | |
'bij,bkj->bik', mask_gt_semantic_one_hot, transformer_class_probs) | |
# Compute hungarian matching weights. We take the negative here since the | |
# hungarian matching algorithm looks for the matching with the least total | |
# weight. | |
hungarian_weights = - mask_similarity * class_similarity | |
mask_gt_non_void_mask_expanded_2 = tf.expand_dims( | |
mask_gt_non_void_mask, axis=2) | |
# Mask the void ground truth masks (in the rows) so that they do not affect | |
# the result of the hungarian matching. | |
if self._num_mask_slots >= self._pixel_gt_num_mask_id: | |
# If the number of mask slots (number of columns) is larger than the | |
# constant number of ground truth masks (number of rows), the | |
# nonsquare_hungarian_matching will pad the rows with | |
# _MATCHING_NEGATIVE_CONSTANT. In this case, we can fill in the void mask | |
# rows with _MATCHING_NEGATIVE_CONSTANT too, then the void mask rows will | |
# be ignored too, according to the hungarian matching property. | |
hungarian_weights = ( | |
hungarian_weights * mask_gt_non_void_mask_expanded_2 + | |
(1 - mask_gt_non_void_mask_expanded_2) * _MATCHING_NEGATIVE_CONSTANT) | |
else: | |
# If the number of mask slots (number of columns) is smaller than the | |
# constant number of ground truth masks (number of rows), the | |
# nonsquare_hungarian_matching will pad the columns with | |
# _MATCHING_NEGATIVE_CONSTANT. In this case, we should fill in the void | |
# mask rows with _MATCHING_POSITIVE_CONSTANT here, then the void mask rows | |
# will have a huge cost compared with existing non-void mask rows, so that | |
# the predicted masks will prefer matching with existing non-void masks | |
# rather than the padded void masks, according to the hungarian matching | |
# property. | |
hungarian_weights = ( | |
hungarian_weights * mask_gt_non_void_mask_expanded_2 + | |
(1 - mask_gt_non_void_mask_expanded_2) * _MATCHING_POSITIVE_CONSTANT) | |
# Perform the hungarian matching algorithm. | |
full_permutation, nonsquare_permutation = ( | |
nonsquare_hungarian_matching(hungarian_weights)) | |
# Extract the permutation (matching) between all existing non-void ground | |
# truth masks and the matched predicted masks. | |
matched_permutation = ( | |
nonsquare_permutation * mask_gt_non_void_mask_expanded_2) | |
# The matched mask dice scores for each mask slot. The scores will be used | |
# as a loss weight for the PQ-style loss class term after the stop_gradient. | |
matched_mask_dice = tf.reduce_max( | |
mask_similarity * matched_permutation, axis=-2) | |
matched_mask_dice = tf.stop_gradient(matched_mask_dice) | |
# The matched class probabilities for each ground truth mask. The | |
# probabilities will be used as a loss weight for the PQ-style loss mask | |
# dice term after the stop_gradient. | |
matched_class_prob = tf.reduce_max( | |
class_similarity * matched_permutation, axis=-1) | |
matched_class_prob = tf.stop_gradient(matched_class_prob) | |
# Extract the index of the matched mask slot for each ground truth mask. | |
matched_mask_slot_indices = tf.math.argmax( | |
nonsquare_permutation, axis=-1, output_type=tf.dtypes.int32) | |
full_num_mask_slots = full_permutation.get_shape().as_list()[-1] | |
# Pad the pixel_space_mask_logits so that it is compatible with the | |
# permutation matrix. | |
full_pixel_space_mask_logits = tf.pad( | |
pixel_space_mask_logits, | |
[[0, 0], [0, 0], [0, full_num_mask_slots - self._num_mask_slots]], | |
constant_values=_SOFTMAX_MASKING_CONSTANT) | |
# Permute the pixel space mask logits with the permutation matrix, which | |
# converts the mask slot indices to the ground truth indices. | |
permuted_full_pixel_space_mask_logits = tf.einsum( | |
'bpi,bji->bpj', full_pixel_space_mask_logits, full_permutation) | |
# Pad the class probabilities too. | |
full_matched_class_prob = tf.pad( | |
matched_class_prob, | |
[[0, 0], [0, full_num_mask_slots - self._pixel_gt_num_mask_id]]) | |
# We only compute dice loss term on non-void ground truth masks. | |
mask_dice_term_loss_weight = tf.pad( | |
mask_gt_non_void_mask, | |
[[0, 0], [0, full_num_mask_slots - self._pixel_gt_num_mask_id]]) | |
# Use the class probabilities as the loss weight for the mask dice term. In | |
# addition, we set a lower bound, 1e-5, for the mask dice term loss weight. | |
# Otherwise, if a loss weight is accidentally zero, the dice loss will treat | |
# it as void and use an incorrect denominator or normalizing constant for | |
# the loss. | |
mask_dice_term_loss_weight *= tf.maximum(full_matched_class_prob, 1e-5) | |
# Pad the one hot encoding too. | |
full_pixel_gt_mask_id_one_hot = tf.pad( | |
pixel_gt_mask_id_one_hot, | |
[[0, 0], [0, 0], [0, full_num_mask_slots - self._pixel_gt_num_mask_id]]) | |
if self._pq_style_loss_weight > 0.0: | |
# Mask_dice_term_modifier balances the mask_dice_term and the class_term | |
# of the PQ-style loss to have the same weight and normalizating constant. | |
resulting_dict[common.PQ_STYLE_LOSS_MASK_DICE_TERM] = ( | |
self._pq_style_loss_mask_dice_term( | |
{_GT_KEY: full_pixel_gt_mask_id_one_hot}, | |
{_PRED_KEY: permuted_full_pixel_space_mask_logits, | |
_WEIGHT_KEY: mask_dice_term_loss_weight}) * | |
(self._pq_style_loss_weight * self._mask_dice_term_modifier)) | |
# Mask-ID cross entropy loss shares the same ground truth and logits as the | |
# dice loss term, but with different weights. | |
if self._mask_id_cross_entropy_loss_weight > 0.0: | |
resulting_dict[common.MASK_ID_CROSS_ENTROPY_LOSS] = ( | |
self._mask_id_cross_entropy_loss( | |
{_GT_KEY: full_pixel_gt_mask_id_one_hot}, | |
{_PRED_KEY: permuted_full_pixel_space_mask_logits, | |
_WEIGHT_KEY: pixel_gt_non_void_mask}) * | |
self._mask_id_cross_entropy_loss_weight) | |
# Generate a pseudo ground truth for transformer_class_logits. | |
mask_slot_semantic_one_hot = _generate_mask_slot_semantic_one_hot( | |
matched_mask_slot_indices, mask_gt_semantic_map, | |
self._num_mask_slots, self._thing_stuff_class_ids) | |
# Compute the positive mask and the negative mask. | |
mask_slot_positive_mask = tf.cast(tf.equal(tf.reduce_max( | |
mask_slot_semantic_one_hot, axis=-1), 1.0), tf.float32) | |
mask_slot_negative_mask = 1.0 - mask_slot_positive_mask | |
# Compute the overlap ratio between all predicted masks and the void region. | |
# This void ratio will be used as a weight for the negative class term. | |
mask_void_ratio = tf.stop_gradient(_mask_similarity( | |
1.0 - pixel_gt_non_void_mask_expanded, | |
pixel_space_mask_probs, | |
'intersection_over_prediction')) | |
mask_void_ratio = tf.squeeze(mask_void_ratio, axis=1) | |
# Use the matched mask dice scores as the weights for the positive class | |
# terms. For the negative class terms, we reduce the penalty for a mask slot | |
# class term if the mask prediction overlaps a lot with void regions. | |
transformer_class_loss_weight = ( | |
mask_slot_positive_mask * tf.maximum(matched_mask_dice, 1e-5) + | |
mask_slot_negative_mask * tf.maximum(mask_void_ratio, 1e-5)) | |
# Concatenate the void mask in the last channel, constructing the final | |
# ground truth one hot label with (thing + stuff + void) channels. | |
transformer_class_one_hot = tf.concat( | |
[mask_slot_semantic_one_hot, | |
tf.expand_dims(mask_slot_negative_mask, axis=-1)], axis=-1) | |
# Apply the PQ-style loss class term. | |
if self._pq_style_loss_weight > 0.0: | |
resulting_dict[common.PQ_STYLE_LOSS_CLASS_TERM] = ( | |
self._pq_style_loss_class_term( | |
{_GT_KEY: transformer_class_one_hot}, | |
{_PRED_KEY: transformer_class_logits, | |
_WEIGHT_KEY: transformer_class_loss_weight}) * | |
self._pq_style_loss_weight) | |
return resulting_dict | |