maxdeeplab / model /loss /max_deeplab_loss.py
karolmajek's picture
from https://huggingface.co/spaces/akhaliq/deeplab2
0924f30
# coding=utf-8
# Copyright 2021 The Deeplab2 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This file contains the loss functions for MaX-DeepLab models.
Reference:
MaX-DeepLab: "End-to-End Panoptic Segmentation with Mask Transformers",
CVPR 2021. https://arxiv.org/abs/2012.00759
Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen.
"""
from typing import Text, Dict, Tuple, List
import tensorflow as tf
from deeplab2 import common
from deeplab2 import config_pb2
from deeplab2.model import utils
from deeplab2.model.loss import base_loss
from deeplab2.model.loss import matchers_ops
# Positive and negative constants that are used to pad or mask hungarian
# matching weights.
_MATCHING_NEGATIVE_CONSTANT = -999.0
_MATCHING_POSITIVE_CONSTANT = 999.0
# A large negative constant applied before softmax. This will make the softmax
# ignore the masked logits.
_SOFTMAX_MASKING_CONSTANT = -99999.0
_GT_KEY = 'gt_key'
_PRED_KEY = 'pred_key'
_WEIGHT_KEY = 'weight_key'
def _generate_mask_slot_semantic_one_hot(
matched_mask_slot_indices: tf.Tensor,
mask_gt_semantic_map: tf.Tensor,
num_mask_slots: int,
thing_stuff_class_ids: List[int]):
"""Generates the ground truth for transformer_class_logits.
This function generates a pseudo ground truth that we will use to train the
transformer class head logits. The input tensors, matched_mask_slot_indices
and mask_gt_semantic_map, are obtained by (hungarian) matching the ground
truth masks with the predicted masks. Note that this function generates the
positive one hot encodings only, i.e., the void class is not included in the
output tensor but will be generated outside the function.
Args:
matched_mask_slot_indices: An int32 tf.Tensor of shape [batch_size,
num_ground_truth_masks] that encodes the matched mask slot id for each
ground truth mask.
mask_gt_semantic_map: An int32 tf.Tensor of shape [batch_size,
num_ground_truth_masks] that encodes the semantic label for each ground
truth mask. A padded mask (or void, or no object) will have the label -1.
num_mask_slots: An integer, the number of mask slots for the MaX-DeepLab
model.
thing_stuff_class_ids: A list of integers of length [num_thing_classes +
num_stuff_classes] that encodes the class IDs for all thing and stuff
classes. It is a concatenation of the thing_class_ids list and the
stuff_class_ids list.
Returns:
mask_slot_semantic_one_hot: An output tf.Tensor with shape [batch_size,
num_mask_slots, num_thing_classes + num_stuff_classes].
"""
semantic_map_shape = mask_gt_semantic_map.get_shape().as_list()
batch_size = semantic_map_shape[0]
num_ground_truth_masks = semantic_map_shape[-1]
# Concatenate the indices in each dimension of the ground truth one hot
# output.
batch_indices = tf.expand_dims(tf.range(batch_size), axis=-1)
batch_indices = tf.tile(batch_indices, [1, num_ground_truth_masks])
batch_indices = tf.reshape(batch_indices, [-1, 1])
matched_mask_slot_indices = tf.reshape(matched_mask_slot_indices, [-1, 1])
# We shift the semantic map by one so that void labels (-1) will be a valid
# index too. Otherwise, tf.scatter_nd raises error if it runs on CPU.
semantic_indices = tf.reshape(mask_gt_semantic_map, [-1, 1]) + 1
indices = tf.concat([batch_indices,
matched_mask_slot_indices,
semantic_indices], axis=-1)
# Generate mask_slot_semantic_one_hot by scattering constant ones onto a
# constant zero tensor.
updates = tf.ones([batch_size * num_ground_truth_masks], dtype=tf.float32)
mask_slot_semantic_one_hot = tf.scatter_nd(
indices, updates,
shape=[batch_size, num_mask_slots, max(thing_stuff_class_ids) + 2])
# Gather the wanted classes in the desired (thing + stuff) order.
thing_stuff_tensor = tf.cast(thing_stuff_class_ids, tf.int32)
# We also shift the thing_stuff_tensor index by one in order to revert the
# semantic map shifting above.
mask_slot_semantic_one_hot = tf.gather(mask_slot_semantic_one_hot,
thing_stuff_tensor + 1, axis=2)
return mask_slot_semantic_one_hot
def nonsquare_hungarian_matching(
weights: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
"""Hungarian matching with arbitrary shape.
The matchers_ops.hungarian_matching supports only squared weight matrices.
This function generalizes the hungarian matching to nonsquare cases by padding
the weights to a square and running the square version matching. The property
of hungarian matching ensures that the solutions are equivalent for the padded
square problem and the original nonsquare problem.
Args:
weights: A [batch, shape1, shape2] float32 tf.Tensor.
Returns:
square_permutation: A [batch, max(shape1, shape2), max(shape1, shape2)]
float32 tf.Tensor that is the permutation matrix that achieves the minimum
total weight. Note that a permutation matrix contains only value 0.0 and
1.0, with each row and each column sums to 1.0.
nonsquare_permutation: A [batch, shape1, shape2] float32 tf.Tensor. The
nonsquare part of the permutation matrix.
"""
_, height, width = weights.get_shape().as_list()
max_height_width = max(height, width)
# Padding a constant on one axis does not affect matching results.
weights = tf.pad(weights,
[[0, 0], # Do not pad the batch dimension.
[0, max_height_width - height],
[0, max_height_width - width]],
constant_values=_MATCHING_NEGATIVE_CONSTANT)
square_permutation = matchers_ops.hungarian_matching(weights)
square_permutation = tf.cast(square_permutation, tf.float32)
return square_permutation, square_permutation[:, :height, :width]
def _mask_similarity(gt_mask: tf.Tensor, pred_mask: tf.Tensor,
metric: str = 'dice') -> tf.Tensor:
"""Computes mask similarity between gt_masks and pred_masks.
Args:
gt_mask: A [batch, height * width, num_gt_masks] float32 tf.Tensor, that
contains only value 0.0 and 1.0. Each 1.0 indicates that the pixel belongs
to the ground truth mask. Note that panoptic segmentation enforces that
ground truth masks do not overlap.
pred_mask: A [batch, height * width, num_pred_masks] float32 tf.Tensor, that
is positive. For each batch_id and pixel_id, the [num_pred_masks] vector
encodes whether each pixel belongs to each mask. The sum of each vector is
less than or equal to one.
metric: A string, the mask similarity metric that we will compute. Supports
'dice' (default), 'iou', 'intersection_over_ground_truth', and
'intersection_over_prediction'.
Returns:
mask_similarity: A float32 [batch, num_gt_masks, num_pred_masks] tf.Tensor
that contains the mask similarity between all ground truth masks and all
predicted masks.
Raises:
ValueError: If the mask similarity metric is not one of 'dice', 'iou',
'intersection_over_ground_truth', or 'intersection_over_prediction'.
"""
denominator_epsilon = 1e-5
intersection = tf.einsum('bpi,bpj->bij', gt_mask, pred_mask)
if metric.lower() == 'dice':
denominator = (tf.expand_dims(tf.reduce_sum(gt_mask, axis=1), axis=2) +
tf.reduce_sum(pred_mask, axis=1, keepdims=True)) / 2
elif metric.lower() == 'iou':
denominator = (tf.expand_dims(tf.reduce_sum(gt_mask, axis=1), axis=2) +
tf.reduce_sum(pred_mask, axis=1, keepdims=True) -
intersection)
elif metric.lower() == 'intersection_over_ground_truth':
denominator = tf.expand_dims(tf.reduce_sum(gt_mask, axis=1), axis=2)
elif metric.lower() == 'intersection_over_prediction':
denominator = tf.reduce_sum(pred_mask, axis=1, keepdims=True)
else:
raise ValueError('The mask similarity metric is not supported.')
return intersection / (denominator + denominator_epsilon)
class MaXDeepLabLoss(tf.keras.layers.Layer):
"""This class contains code for MaX-DeepLab losses."""
def __init__(self,
loss_options: config_pb2.LossOptions,
ignore_label: int,
thing_class_ids: Tuple[int],
focal_loss_alpha: float = 0.75,
instance_discrimination_temperature: float = 0.3):
"""Initializes a MaX-DeepLab loss.
This class supports PQ-style loss, mask id cross entropy loss, and instance
discrimination loss, proposed in MaX-DeepLab. The PQ-style loss can be
further decomposed in to a classification term and a mask dice term.
Reference:
MaX-DeepLab: "End-to-End Panoptic Segmentation with Mask Transformers",
CVPR 2021. https://arxiv.org/abs/2012.00759
Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen.
Args:
loss_options: Loss options as defined by config_pb2.LossOptions.
ignore_label: An integer specifying the ignore label.
thing_class_ids: A tuple of length [N] containing N thing indices.
focal_loss_alpha: An optional float specifying the coefficient that
weights between positive (matched) and negative (unmatched) masks in
focal loss. The positives are weighted by alpha, while the negatives
are weighted by (1. - alpha). Note that we do not use a focal loss
gamma here, i.e., the gamma is set to zero which is equivalent to the
normal cross-entropy loss, except for the alpha weighting. Default to
0.75.
instance_discrimination_temperature: An optional float specifying the
temperature for the instance discrimination loss.
"""
super(MaXDeepLabLoss, self).__init__(name='MaXDeepLabLoss')
# The loss_terms will optionally include
# - common.PQ_STYLE_LOSS_CLASS_TERM
# - common.PQ_STYLE_LOSS_MASK_DICE_TERM
# - common.MASK_ID_CROSS_ENTROPY_LOSS
# - common.INSTANCE_DISCRIMINATION_LOSS
# These loss terms will be accessed by loss_builder.py and will be used to
# build loss metrics.
self.loss_terms = []
# The PQ-style loss includes two terms.
self._pq_style_loss_weight = 0.0
if loss_options.HasField(common.PQ_STYLE_LOSS):
self._pq_style_loss_weight = loss_options.pq_style_loss.weight
self.loss_terms.append(common.PQ_STYLE_LOSS_CLASS_TERM)
self.loss_terms.append(common.PQ_STYLE_LOSS_MASK_DICE_TERM)
# Mask-ID cross entropy loss.
self._mask_id_cross_entropy_loss_weight = 0.0
if loss_options.HasField(common.MASK_ID_CROSS_ENTROPY_LOSS):
self._mask_id_cross_entropy_loss_weight = (
loss_options.mask_id_cross_entropy_loss.weight)
self.loss_terms.append(common.MASK_ID_CROSS_ENTROPY_LOSS)
# Instance discrimination loss.
self._instance_discrimination_loss_weight = 0.0
if loss_options.HasField(common.INSTANCE_DISCRIMINATION_LOSS):
self._instance_discrimination_loss_weight = (
loss_options.instance_discrimination_loss.weight)
self.loss_terms.append(common.INSTANCE_DISCRIMINATION_LOSS)
self._ignore_label = ignore_label
self._thing_class_ids = list(thing_class_ids)
self._focal_loss_alpha = focal_loss_alpha
self._instance_discrimination_temperature = (
instance_discrimination_temperature)
# Build the base loss functions.
self._pq_style_loss_class_term = base_loss.FocalCrossEntropyLoss(
gt_key=_GT_KEY, pred_key=_PRED_KEY, weight_key=_WEIGHT_KEY,
# Num_classes and ignore_label are not necessary since the inputs will
# be one hot encoded already.
num_classes=None, ignore_label=None,
focal_loss_alpha=focal_loss_alpha,
focal_loss_gamma=0.0, background_channel_index=-1,
dynamic_weight=True)
self._pq_style_loss_mask_dice_term = base_loss.MaskDiceLoss(
gt_key=_GT_KEY, pred_key=_PRED_KEY, weight_key=_WEIGHT_KEY,
prediction_activation='softmax')
self._mask_id_cross_entropy_loss = base_loss.TopKCrossEntropyLoss(
gt_key=_GT_KEY, pred_key=_PRED_KEY, weight_key=_WEIGHT_KEY,
# Num_classes and ignore_label are not necessary since the inputs will
# be one hot encoded already.
num_classes=None, ignore_label=None,
top_k_percent_pixels=1.0, dynamic_weight=True)
self._instance_discrimination_loss = base_loss.TopKCrossEntropyLoss(
gt_key=_GT_KEY, pred_key=_PRED_KEY, weight_key=_WEIGHT_KEY,
# Num_classes and ignore_label are not necessary since the inputs will
# be one hot encoded already.
num_classes=None, ignore_label=None,
top_k_percent_pixels=1.0, dynamic_weight=True)
def build(self,
input_shapes: Tuple[Dict[Text, tf.Tensor], Dict[Text, tf.Tensor]]):
"""Extracts useful constants that depend on the input shapes."""
y_true_shapes = input_shapes[0]
self._max_thing_id = int(y_true_shapes[common.GT_THING_ID_CLASS_KEY][-1])
y_pred_shapes = input_shapes[1]
transformer_class_logits_shape = y_pred_shapes[
common.PRED_TRANSFORMER_CLASS_LOGITS_KEY]
self._num_mask_slots = int(transformer_class_logits_shape[1])
# The transformer_class_logits contain thing classes, stuff classes, and the
# void class, so num_thing_stuff_classes should be the total number of
# classes minus one.
self._num_thing_stuff_classes = int(transformer_class_logits_shape[2]) - 1
# Since we implement the PQ-style loss with the class term plus the mask
# dice term (Equation 10 of the paper), we need to balance the two terms to
# have the same weight and normalizating constants. The focal loss alpha is
# a weight on the positive class term, so we apply it to the mask dice term
# too. The class loss is also normalized by the number of mask slots, so we
# do the same normalization for the mask dice term.
self._mask_dice_term_modifier = (
self._focal_loss_alpha / self._num_mask_slots)
self._stuff_class_ids = utils.get_stuff_class_ids(
self._num_thing_stuff_classes,
self._thing_class_ids,
self._ignore_label)
self._num_stuff_classes = len(self._stuff_class_ids)
self._thing_stuff_class_ids = self._thing_class_ids + self._stuff_class_ids
self._pixel_gt_num_mask_id = self._max_thing_id + self._num_stuff_classes
def _pre_process_ground_truth(
self, y_true: Dict[Text, tf.Tensor], output_height: int, output_width: int
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor,
tf.Tensor]:
"""Pre-processes the ground truth before we compute the losses.
This function generates tensors that do not depend on the prediction of the
model, but are useful to the calculation of the losses. The function mainly
downsamples the pixel space ground truth to the model output resolution, and
combines (or concatenates) the thing masks and the stuff masks. The output
shape pixel_gt_num_mask_id = max_thing_id + num_stuff_classes, which means
the output masks contain both thing masks and stuff masks.
Args:
y_true: A dict of tensors providing ground-truth information, containing
- common.GT_SEMANTIC_KEY: A [batch, height, width] int32 tf.Tensor, the
semantic label map.
- common.GT_THING_ID_MASK_KEY: A [batch, height, width] int32 tf.Tensor.
It assigns each non-crowd thing instance a unique mask-ID label,
starting from 0. Unassigned pixels are set to -1.
- common.GT_THING_ID_CLASS_KEY: A [batch, max_thing_id] int32 tf.Tensor.
It contains semantic ID of each instance assigned to thing_id_mask. The
remaining (max_thing_id - num_things) elements are set to -1.
output_height: An integer, the height of the model output.
output_width: An integer, the width of the model output.
Returns:
pixel_gt_thing_mask: A [batch, output_height * output_width] float32
tensor, with value 0.0 and 1.0 only, indicating whether a pixel belongs
to a 'thing' class.
pixel_gt_non_void_mask: A [batch, output_height * output_width] float32
tensor, with value 0.0 and 1.0 only, indicating if a pixel does not
belong to the void class.
pixel_gt_mask_id_one_hot: A [batch, output_height * output_width,
pixel_gt_num_mask_id] float32 tensor, with value 0.0 and 1.0 only,
indicating the mask id each pixel belongs to.
mask_gt_semantic_map: A [batch, pixel_gt_num_mask_id] int32 tensor, the
semantic class of each ground truth mask.
mask_gt_non_void_mask: A [batch, pixel_gt_num_mask_id] int32 tensor, with
value 0.0 and 1.0 only, indicating if the ground truth mask is a valid
mask, not a padded mask. The masks are padded because TPU does not
support dynamic shapes except in the batch axis. We pad all ground truth
thing masks to a large enough constant max_thing_id. Similarly, stuff
classes that do not present in the current image will be set to a void
mask too.
mask_gt_semantic_one_hot: A [batch, pixel_gt_num_mask_id,
num_thing_stuff_classes] float32 tensor, with value 0.0 and 1.0 only,
containing the one hot encodings of the ground truth mask classes. The
last dimension contains concatenated thing classes and stuff classes,
which is different from the dataset class IDs in mask_gt_semantic_map.
mask_gt_area: A [batch, pixel_gt_num_mask_id] float32 tensor, the area of
each ground truth mask. Padded masks have an area of 0.0.
"""
# The depth of one hot encoding should be the largest id plus one. For
# example, if we want to one-hot encode a class ID of 133 (the largest ID
# for the COCO dataset), we will need a one-hot encoding of length 134.
one_hot_depth = max(self._thing_stuff_class_ids) + 1
batch_size = y_true[common.GT_SEMANTIC_KEY].get_shape().as_list()[0]
# Compute pixel_gt_semantic_map (downsampling and reshaping to the 1D
# representation that will be mainly used in this loss function).
pixel_gt_semantic_map = utils.strided_downsample(
y_true[common.GT_SEMANTIC_KEY],
target_size=[output_height, output_width])
pixel_gt_semantic_map = tf.reshape(
pixel_gt_semantic_map,
[batch_size, output_height * output_width])
# Compute pixel_gt_non_void_mask.
pixel_gt_non_void_mask = tf.cast(
tf.not_equal(pixel_gt_semantic_map, self._ignore_label), tf.float32)
pixel_gt_non_void_mask = tf.ensure_shape(
pixel_gt_non_void_mask,
[batch_size, output_height * output_width])
# Compute pixel_gt_semantic_one_hot from pixel_gt_semantic_map in order to
# gather pixel_gt_stuff_id_one_hot from pixel_gt_semantic_one_hot.
pixel_gt_semantic_one_hot = tf.one_hot(pixel_gt_semantic_map, one_hot_depth)
# Convert the one hot encoding from the dataset id order to (thing, stuff)
# order used in MaX-DeepLab.
pixel_gt_stuff_id_one_hot = tf.gather(pixel_gt_semantic_one_hot,
self._stuff_class_ids, axis=-1)
pixel_gt_stuff_id_one_hot = tf.ensure_shape(
pixel_gt_stuff_id_one_hot,
[batch_size, output_height * output_width, self._num_stuff_classes])
# Compute pixel_gt_thing_id_one_hot for thing masks.
pixel_gt_thing_id_map = utils.strided_downsample(
y_true[common.GT_THING_ID_MASK_KEY],
target_size=[output_height, output_width])
pixel_gt_thing_id_map = tf.reshape(
pixel_gt_thing_id_map, shape=[batch_size, output_height * output_width])
# Note that common.GT_THING_ID_MASK_KEY uses -1 for void masks. And 0 to
# (num_mask_slots - 1) are used for num_mask_slots mask slots.
pixel_gt_thing_mask = tf.cast(
tf.not_equal(pixel_gt_thing_id_map, -1), tf.float32)
pixel_gt_thing_id_one_hot = tf.one_hot(pixel_gt_thing_id_map,
self._max_thing_id)
# Compute pixel_gt_mask_id_one_hot by concatenating thing masks with stuff
# masks.
pixel_gt_mask_id_one_hot = tf.concat([pixel_gt_thing_id_one_hot,
pixel_gt_stuff_id_one_hot], axis=-1)
pixel_gt_mask_id_one_hot = tf.ensure_shape(
pixel_gt_mask_id_one_hot,
[batch_size, output_height * output_width, self._pixel_gt_num_mask_id])
# Compute mask_gt_area by summing the one hot encodings spatially.
mask_gt_area = tf.expand_dims(
tf.reduce_sum(pixel_gt_mask_id_one_hot, axis=1), axis=-1)
# Generate a binary mask for ground truth masks, indicating whether each
# ground truth mask exists in the pixel space with a non-zero area. Note
# that a mask that exists in the original input resolution will be removed
# if its area is zero in the output resolution, due to downsampling.
mask_gt_area_mask = tf.reshape(mask_gt_area > 0.5,
[batch_size, self._pixel_gt_num_mask_id])
# Compute mask_gt_semantic_map and mask_gt_semantic_one_hot.
thing_id_gt_semantic_map = tf.reshape(
tf.cast(y_true[common.GT_THING_ID_CLASS_KEY], tf.int32),
[batch_size, self._max_thing_id])
# The stuff ground truth semantic map is just the stuff class IDs.
stuff_id_gt_semantic_map = tf.tile(
tf.reshape(
tf.cast(self._stuff_class_ids, tf.int32),
[1, self._num_stuff_classes]), [batch_size, 1])
mask_gt_semantic_map = tf.concat(
[thing_id_gt_semantic_map, stuff_id_gt_semantic_map], axis=-1)
# Set masks with zero area to void (-1), which is consistent with the void
# label used in common.GT_THING_ID_CLASS_KEY but is different from the
# ignore_labels of the datasets.
mask_gt_semantic_map = (
(mask_gt_semantic_map + 1) * tf.cast(mask_gt_area_mask, tf.int32) - 1)
# Void (-1) classes will automatically be ignored by tf.one_hot.
mask_gt_semantic_one_hot = tf.one_hot(mask_gt_semantic_map, one_hot_depth)
mask_gt_semantic_one_hot = tf.gather(
mask_gt_semantic_one_hot, self._thing_stuff_class_ids, axis=-1)
# Compute mask_gt_non_void_mask. Again, a mask that exists in the original
# input resolution is set to void if its area is zero in the output
# resolution, due to downsampling.
mask_gt_non_void_mask = tf.cast(mask_gt_semantic_map > -1, tf.float32)
mask_gt_non_void_mask = tf.ensure_shape(
mask_gt_non_void_mask, [batch_size, self._pixel_gt_num_mask_id])
return (pixel_gt_thing_mask, pixel_gt_non_void_mask,
pixel_gt_mask_id_one_hot, mask_gt_semantic_map,
mask_gt_non_void_mask, mask_gt_semantic_one_hot, mask_gt_area)
def call(
self, inputs: Tuple[Dict[Text, tf.Tensor], Dict[Text, tf.Tensor]]
) -> Dict[Text, tf.Tensor]:
"""Computes the MaX-DeepLab losses.
Args:
inputs: A tuple of two dicts (y_true, y_pred):
- y_true: A dict of tensors providing ground-truth information, containing
- common.GT_SEMANTIC_KEY: A [batch, height, width] int32 tf.Tensor, the
semantic label map.
- common.GT_THING_ID_MASK_KEY: A [batch, height, width] int32
tf.Tensor. It assigns each non-crowd thing instance a unique mask-ID
label, starting from 0. Unassigned pixels are set to -1.
- common.GT_THING_ID_CLASS_KEY: A [batch, max_thing_id] int32
tf.Tensor. It contains semantic ID of each instance assigned to
thing_id_mask. The remaining (max_thing_id - num_things) elements are
set to -1.
- y_pred: A dict of tensors providing predictions.
- common.PRED_PIXEL_SPACE_NORMALIZED_FEATURE_KEY: A [batch_size,
output_height, output_width, channels] float32 tensor.
- common.PRED_PIXEL_SPACE_MASK_LOGITS_KEY: A [batch_size,
output_height, output_width, num_mask_slots] float32 tensor, the
logits that a pixel belongs to a mask slot.
- common.PRED_TRANSFORMER_CLASS_LOGITS_KEY: A [batch_size,
num_mask_slots, num_thing_stuff_classes + 1] float32 tensor, the
logits that a mask belongs to a semantic class (including thing,
stuff, and void)
Returns:
The loss as a dict of tf.Tensor, optionally containing the following:
- common.PQ_STYLE_LOSS_CLASS_TERM: [batch].
- common.PQ_STYLE_LOSS_MASK_DICE_TERM: [batch].
- common.MASK_ID_CROSS_ENTROPY_LOSS: [batch].
- common.INSTANCE_DISCRIMINATION_LOSS: [batch].
"""
y_true, y_pred = inputs
resulting_dict = {}
pixel_feature = y_pred[common.PRED_PIXEL_SPACE_NORMALIZED_FEATURE_KEY]
batch_size, output_height, output_width, _ = (
pixel_feature.get_shape().as_list())
# Pre-process the ground truth.
(pixel_gt_thing_mask, pixel_gt_non_void_mask, pixel_gt_mask_id_one_hot,
mask_gt_semantic_map, mask_gt_non_void_mask, mask_gt_semantic_one_hot,
mask_gt_area) = self._pre_process_ground_truth(y_true,
output_height, output_width)
pixel_gt_non_void_mask_expanded = tf.expand_dims(
pixel_gt_non_void_mask, axis=-1)
# Compute mask_average_feature by averaging the feature of each mask.
pixel_feature = tf.reshape(
pixel_feature, [batch_size, output_height * output_width, -1])
mask_average_feature = tf.einsum(
'bpd,bpi->bid',
pixel_feature,
pixel_gt_mask_id_one_hot) / tf.maximum(mask_gt_area, 1.0)
# Normalize the mask feature as the pixel space output feature is usually
# normalized too.
mask_average_feature = tf.math.l2_normalize(mask_average_feature, axis=-1)
# Compute instance_discrimination_similarity, scaled by a constant
# temperature.
instance_discrimination_similarity = tf.einsum(
'bpd,bid->bpi', pixel_feature, mask_average_feature)
instance_discrimination_similarity /= (
self._instance_discrimination_temperature)
mask_gt_non_void_mask_expanded_1 = tf.expand_dims(
mask_gt_non_void_mask, axis=1)
# Mask void masks by setting them to a large negative value, so that they
# will be ignored by the softmax in the loss.
instance_discrimination_similarity = (
mask_gt_non_void_mask_expanded_1 * instance_discrimination_similarity +
(1.0 - mask_gt_non_void_mask_expanded_1) * _SOFTMAX_MASKING_CONSTANT)
# Auxiliary instance_discrimination_loss.
if self._instance_discrimination_loss_weight > 0.0:
resulting_dict[common.INSTANCE_DISCRIMINATION_LOSS] = (
self._instance_discrimination_loss(
{_GT_KEY: pixel_gt_mask_id_one_hot},
{_PRED_KEY: instance_discrimination_similarity,
_WEIGHT_KEY: pixel_gt_thing_mask}) *
self._instance_discrimination_loss_weight)
# Extract pixel_space_mask_logits and pixel_space_mask_probs.
pixel_space_mask_logits = y_pred[common.PRED_PIXEL_SPACE_MASK_LOGITS_KEY]
pixel_space_mask_logits = tf.reshape(
pixel_space_mask_logits,
[batch_size, output_height * output_width, self._num_mask_slots])
pixel_space_mask_probs = tf.nn.softmax(pixel_space_mask_logits, axis=-1)
# Compute the mask similarity between all ground truth masks and all
# predicted masks.
mask_similarity = _mask_similarity(
pixel_gt_mask_id_one_hot,
pixel_space_mask_probs * pixel_gt_non_void_mask_expanded,
metric='dice')
# Compute the class similarity by multiplying the ground truth one hot
# encoding with the predicted probability distribution. This is done between
# all ground truth masks and all predicted masks.
transformer_class_logits = y_pred[common.PRED_TRANSFORMER_CLASS_LOGITS_KEY]
transformer_class_probs = tf.nn.softmax(
transformer_class_logits, axis=-1)[:, :, :-1]
class_similarity = tf.einsum(
'bij,bkj->bik', mask_gt_semantic_one_hot, transformer_class_probs)
# Compute hungarian matching weights. We take the negative here since the
# hungarian matching algorithm looks for the matching with the least total
# weight.
hungarian_weights = - mask_similarity * class_similarity
mask_gt_non_void_mask_expanded_2 = tf.expand_dims(
mask_gt_non_void_mask, axis=2)
# Mask the void ground truth masks (in the rows) so that they do not affect
# the result of the hungarian matching.
if self._num_mask_slots >= self._pixel_gt_num_mask_id:
# If the number of mask slots (number of columns) is larger than the
# constant number of ground truth masks (number of rows), the
# nonsquare_hungarian_matching will pad the rows with
# _MATCHING_NEGATIVE_CONSTANT. In this case, we can fill in the void mask
# rows with _MATCHING_NEGATIVE_CONSTANT too, then the void mask rows will
# be ignored too, according to the hungarian matching property.
hungarian_weights = (
hungarian_weights * mask_gt_non_void_mask_expanded_2 +
(1 - mask_gt_non_void_mask_expanded_2) * _MATCHING_NEGATIVE_CONSTANT)
else:
# If the number of mask slots (number of columns) is smaller than the
# constant number of ground truth masks (number of rows), the
# nonsquare_hungarian_matching will pad the columns with
# _MATCHING_NEGATIVE_CONSTANT. In this case, we should fill in the void
# mask rows with _MATCHING_POSITIVE_CONSTANT here, then the void mask rows
# will have a huge cost compared with existing non-void mask rows, so that
# the predicted masks will prefer matching with existing non-void masks
# rather than the padded void masks, according to the hungarian matching
# property.
hungarian_weights = (
hungarian_weights * mask_gt_non_void_mask_expanded_2 +
(1 - mask_gt_non_void_mask_expanded_2) * _MATCHING_POSITIVE_CONSTANT)
# Perform the hungarian matching algorithm.
full_permutation, nonsquare_permutation = (
nonsquare_hungarian_matching(hungarian_weights))
# Extract the permutation (matching) between all existing non-void ground
# truth masks and the matched predicted masks.
matched_permutation = (
nonsquare_permutation * mask_gt_non_void_mask_expanded_2)
# The matched mask dice scores for each mask slot. The scores will be used
# as a loss weight for the PQ-style loss class term after the stop_gradient.
matched_mask_dice = tf.reduce_max(
mask_similarity * matched_permutation, axis=-2)
matched_mask_dice = tf.stop_gradient(matched_mask_dice)
# The matched class probabilities for each ground truth mask. The
# probabilities will be used as a loss weight for the PQ-style loss mask
# dice term after the stop_gradient.
matched_class_prob = tf.reduce_max(
class_similarity * matched_permutation, axis=-1)
matched_class_prob = tf.stop_gradient(matched_class_prob)
# Extract the index of the matched mask slot for each ground truth mask.
matched_mask_slot_indices = tf.math.argmax(
nonsquare_permutation, axis=-1, output_type=tf.dtypes.int32)
full_num_mask_slots = full_permutation.get_shape().as_list()[-1]
# Pad the pixel_space_mask_logits so that it is compatible with the
# permutation matrix.
full_pixel_space_mask_logits = tf.pad(
pixel_space_mask_logits,
[[0, 0], [0, 0], [0, full_num_mask_slots - self._num_mask_slots]],
constant_values=_SOFTMAX_MASKING_CONSTANT)
# Permute the pixel space mask logits with the permutation matrix, which
# converts the mask slot indices to the ground truth indices.
permuted_full_pixel_space_mask_logits = tf.einsum(
'bpi,bji->bpj', full_pixel_space_mask_logits, full_permutation)
# Pad the class probabilities too.
full_matched_class_prob = tf.pad(
matched_class_prob,
[[0, 0], [0, full_num_mask_slots - self._pixel_gt_num_mask_id]])
# We only compute dice loss term on non-void ground truth masks.
mask_dice_term_loss_weight = tf.pad(
mask_gt_non_void_mask,
[[0, 0], [0, full_num_mask_slots - self._pixel_gt_num_mask_id]])
# Use the class probabilities as the loss weight for the mask dice term. In
# addition, we set a lower bound, 1e-5, for the mask dice term loss weight.
# Otherwise, if a loss weight is accidentally zero, the dice loss will treat
# it as void and use an incorrect denominator or normalizing constant for
# the loss.
mask_dice_term_loss_weight *= tf.maximum(full_matched_class_prob, 1e-5)
# Pad the one hot encoding too.
full_pixel_gt_mask_id_one_hot = tf.pad(
pixel_gt_mask_id_one_hot,
[[0, 0], [0, 0], [0, full_num_mask_slots - self._pixel_gt_num_mask_id]])
if self._pq_style_loss_weight > 0.0:
# Mask_dice_term_modifier balances the mask_dice_term and the class_term
# of the PQ-style loss to have the same weight and normalizating constant.
resulting_dict[common.PQ_STYLE_LOSS_MASK_DICE_TERM] = (
self._pq_style_loss_mask_dice_term(
{_GT_KEY: full_pixel_gt_mask_id_one_hot},
{_PRED_KEY: permuted_full_pixel_space_mask_logits,
_WEIGHT_KEY: mask_dice_term_loss_weight}) *
(self._pq_style_loss_weight * self._mask_dice_term_modifier))
# Mask-ID cross entropy loss shares the same ground truth and logits as the
# dice loss term, but with different weights.
if self._mask_id_cross_entropy_loss_weight > 0.0:
resulting_dict[common.MASK_ID_CROSS_ENTROPY_LOSS] = (
self._mask_id_cross_entropy_loss(
{_GT_KEY: full_pixel_gt_mask_id_one_hot},
{_PRED_KEY: permuted_full_pixel_space_mask_logits,
_WEIGHT_KEY: pixel_gt_non_void_mask}) *
self._mask_id_cross_entropy_loss_weight)
# Generate a pseudo ground truth for transformer_class_logits.
mask_slot_semantic_one_hot = _generate_mask_slot_semantic_one_hot(
matched_mask_slot_indices, mask_gt_semantic_map,
self._num_mask_slots, self._thing_stuff_class_ids)
# Compute the positive mask and the negative mask.
mask_slot_positive_mask = tf.cast(tf.equal(tf.reduce_max(
mask_slot_semantic_one_hot, axis=-1), 1.0), tf.float32)
mask_slot_negative_mask = 1.0 - mask_slot_positive_mask
# Compute the overlap ratio between all predicted masks and the void region.
# This void ratio will be used as a weight for the negative class term.
mask_void_ratio = tf.stop_gradient(_mask_similarity(
1.0 - pixel_gt_non_void_mask_expanded,
pixel_space_mask_probs,
'intersection_over_prediction'))
mask_void_ratio = tf.squeeze(mask_void_ratio, axis=1)
# Use the matched mask dice scores as the weights for the positive class
# terms. For the negative class terms, we reduce the penalty for a mask slot
# class term if the mask prediction overlaps a lot with void regions.
transformer_class_loss_weight = (
mask_slot_positive_mask * tf.maximum(matched_mask_dice, 1e-5) +
mask_slot_negative_mask * tf.maximum(mask_void_ratio, 1e-5))
# Concatenate the void mask in the last channel, constructing the final
# ground truth one hot label with (thing + stuff + void) channels.
transformer_class_one_hot = tf.concat(
[mask_slot_semantic_one_hot,
tf.expand_dims(mask_slot_negative_mask, axis=-1)], axis=-1)
# Apply the PQ-style loss class term.
if self._pq_style_loss_weight > 0.0:
resulting_dict[common.PQ_STYLE_LOSS_CLASS_TERM] = (
self._pq_style_loss_class_term(
{_GT_KEY: transformer_class_one_hot},
{_PRED_KEY: transformer_class_logits,
_WEIGHT_KEY: transformer_class_loss_weight}) *
self._pq_style_loss_weight)
return resulting_dict