Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2021 The Deeplab2 Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""This file contains basic loss classes used in the DeepLab model.""" | |
from typing import Text, Dict, Callable, Optional | |
import tensorflow as tf | |
from deeplab2.model import utils | |
def compute_average_top_k_loss(loss: tf.Tensor, | |
top_k_percentage: float) -> tf.Tensor: | |
"""Computes the avaerage top-k loss per sample. | |
Args: | |
loss: A tf.Tensor with 2 or more dimensions of shape [batch, ...]. | |
top_k_percentage: A float representing the % of pixel that should be used | |
for calculating the loss. | |
Returns: | |
A tensor of shape [batch] containing the mean top-k loss per sample. Due to | |
the use of different tf.strategy, we return the loss per sample and require | |
explicit averaging by the user. | |
""" | |
loss = tf.reshape(loss, shape=(tf.shape(loss)[0], -1)) | |
if top_k_percentage != 1.0: | |
num_elements_per_sample = tf.shape(loss)[1] | |
top_k_pixels = tf.cast( | |
tf.math.round(top_k_percentage * | |
tf.cast(num_elements_per_sample, tf.float32)), tf.int32) | |
def top_k_1d(inputs): | |
return tf.math.top_k(inputs, top_k_pixels, sorted=False)[0] | |
loss = tf.map_fn(fn=top_k_1d, elems=loss) | |
# Compute mean loss over spatial dimension. | |
num_non_zero = tf.reduce_sum(tf.cast(tf.not_equal(loss, 0.0), tf.float32), 1) | |
loss_sum_per_sample = tf.reduce_sum(loss, 1) | |
return tf.math.divide_no_nan(loss_sum_per_sample, num_non_zero) | |
def compute_mask_dice_loss(y_true: tf.Tensor, | |
y_pred: tf.Tensor, | |
prediction_activation='softmax') -> tf.Tensor: | |
"""Computes the Mask Dice loss between y_true and y_pred masks. | |
Reference: | |
[1] Milletari, F., Navab, N., Ahmadi, S.A.: V-net: Fully convolutional neural | |
networks for volumetric medical image segmentation. In: 3DV (2016) | |
https://arxiv.org/abs/1606.04797 | |
Args: | |
y_true: A tf.Tensor of shape [batch, height, width, channels] (or [batch, | |
length, channels]) containing the ground-truth. The channel dimension | |
indicates the mask ID in MaX-DeepLab, instead of a "class" dimension in | |
the V-net paper. In our case, for all batch, height, width, (or batch, | |
length) the [batch, height, width, :] (or [batch, length, :]) should be | |
one-hot encodings only, with valid pixels having one and only one 1.0, and | |
with void pixels being all 0.0. The valid pixels of the masks do not and | |
should not overlap because of the non-overlapping definition of panoptic | |
segmentation. The output loss is computed and normalized by valid (not | |
void) pixels. | |
y_pred: A tf.Tensor of shape [batch, height, width, channels] (or [batch, | |
length, channels]) containing the prediction. | |
prediction_activation: A String indicating activation function of the | |
prediction. It should be either 'sigmoid' or 'softmax'. | |
Returns: | |
A tf.Tensor of shape [batch, channels] with the computed dice loss value. | |
Raises: | |
ValueError: An error occurs when prediction_activation is not either | |
'sigmoid' or 'softmax'. | |
""" | |
tf.debugging.assert_rank_in( | |
y_pred, [3, 4], message='Input tensors y_pred must have rank 3 or 4.') | |
tf.debugging.assert_rank_in( | |
y_true, [3, 4], message='Input tensors y_true must have rank 3 or 4.') | |
shape_list = y_true.shape.as_list() | |
batch, channels = shape_list[0], shape_list[-1] | |
if prediction_activation == 'sigmoid': | |
y_pred = tf.math.sigmoid(y_pred) | |
elif prediction_activation == 'softmax': | |
y_pred = tf.nn.softmax(y_pred, axis=-1) | |
else: | |
raise ValueError( | |
"prediction_activation should be either 'sigmoid' or 'softmax'") | |
y_true_flat = tf.reshape(y_true, [batch, -1, channels]) | |
# valid_flat indicates labeled pixels in the groudtruth. y_true is one-hot | |
# encodings only, with valid pixels having one and only one 1.0, and with | |
# invalid pixels having 0.0 values in all the channels. The valid pixels of | |
# the masks do not overlap because of the non-overlapping definition of | |
# panoptic segmentation. | |
valid_flat = tf.reduce_sum(y_true_flat, axis=-1, keepdims=True) | |
y_pred_flat = tf.reshape( | |
y_pred, [batch, -1, channels]) * valid_flat | |
# Use smooth = 1 to avoid division by zero when both y_pred and y_true are | |
# zeros. | |
smooth = 1.0 | |
intersection = 2 * tf.reduce_sum(y_pred_flat * y_true_flat, axis=1) + smooth | |
denominator = (tf.reduce_sum(y_pred_flat, axis=1) + | |
tf.reduce_sum(y_true_flat, axis=1) + smooth) | |
loss = 1. - tf.math.divide_no_nan(intersection, denominator) | |
return loss | |
def mean_absolute_error(y_true: tf.Tensor, | |
y_pred: tf.Tensor, | |
force_keep_dims=False) -> tf.Tensor: | |
"""Computes the per-pixel mean absolute error for 3D and 4D tensors. | |
Default reduction behavior: If a 3D tensor is used, no reduction is applied. | |
In case of a 4D tensor, reduction is applied. This behavior can be overridden | |
by force_keep_dims. | |
Note: tf.keras.losses.mean_absolute_error always reduces the output by one | |
dimension. | |
Args: | |
y_true: A tf.Tensor of shape [batch, height, width] or [batch, height, | |
width, channels] containing the ground-truth. | |
y_pred: A tf.Tensor of shape [batch, height, width] or [batch, height, | |
width, channels] containing the prediction. | |
force_keep_dims: A boolean flag specifying whether no reduction should be | |
applied. | |
Returns: | |
A tf.Tensor with the mean absolute error. | |
""" | |
tf.debugging.assert_rank_in( | |
y_pred, [3, 4], message='Input tensors must have rank 3 or 4.') | |
if len(y_pred.shape.as_list()) == 3 or force_keep_dims: | |
return tf.abs(y_true - y_pred) | |
else: | |
return tf.reduce_mean(tf.abs(y_true - y_pred), axis=[3]) | |
def mean_squared_error(y_true: tf.Tensor, | |
y_pred: tf.Tensor, | |
force_keep_dims=False) -> tf.Tensor: | |
"""Computes the per-pixel mean squared error for 3D and 4D tensors. | |
Default reduction behavior: If a 3D tensor is used, no reduction is applied. | |
In case of a 4D tensor, reduction is applied. This behavior can be overridden | |
by force_keep_dims. | |
Note: tf.keras.losses.mean_squared_error always reduces the output by one | |
dimension. | |
Args: | |
y_true: A tf.Tensor of shape [batch, height, width] or [batch, height, | |
width, channels] containing the ground-truth. | |
y_pred: A tf.Tensor of shape [batch, height, width] or [batch, height, | |
width, channels] containing the prediction. | |
force_keep_dims: A boolean flag specifying whether no reduction should be | |
applied. | |
Returns: | |
A tf.Tensor with the mean squared error. | |
""" | |
tf.debugging.assert_rank_in( | |
y_pred, [3, 4], message='Input tensors must have rank 3 or 4.') | |
if len(y_pred.shape.as_list()) == 3 or force_keep_dims: | |
return tf.square(y_true - y_pred) | |
else: | |
return tf.reduce_mean(tf.square(y_true - y_pred), axis=[3]) | |
def encode_one_hot(gt: tf.Tensor, | |
num_classes: int, | |
weights: tf.Tensor, | |
ignore_label: Optional[int]): | |
"""Helper function for one-hot encoding of integer labels. | |
Args: | |
gt: A tf.Tensor providing ground-truth information. Integer type label. | |
num_classes: An integer indicating the number of classes considered in the | |
ground-truth. It is used as 'depth' in tf.one_hot(). | |
weights: A tf.Tensor containing weights information. | |
ignore_label: An integer specifying the ignore label or None. | |
Returns: | |
gt: A tf.Tensor of one-hot encoded gt labels. | |
weights: A tf.Tensor with ignore_label considered. | |
""" | |
if ignore_label is not None: | |
keep_mask = tf.cast(tf.not_equal(gt, ignore_label), dtype=tf.float32) | |
else: | |
keep_mask = tf.ones_like(gt, dtype=tf.float32) | |
gt = tf.stop_gradient(tf.one_hot(gt, num_classes)) | |
weights = tf.multiply(weights, keep_mask) | |
return gt, weights | |
def is_one_hot(gt: tf.Tensor, pred: tf.Tensor): | |
"""Helper function for checking if gt tensor is one-hot encoded or not. | |
Args: | |
gt: A tf.Tensor providing ground-truth information. | |
pred: A tf.Tensor providing prediction information. | |
Returns: | |
A boolean indicating whether the gt is one-hot encoded (True) or | |
in integer type (False). | |
""" | |
gt_shape = gt.get_shape().as_list() | |
pred_shape = pred.get_shape().as_list() | |
# If the ground truth is one-hot encoded, the rank of the ground truth should | |
# match that of the prediction. In addition, we check that the first | |
# dimension, batch_size, and the last dimension, channels, should also match | |
# the prediction. However, we still allow spatial dimensions, e.g., height and | |
# width, to be different since we will downsample the ground truth if needed. | |
return (len(gt_shape) == len(pred_shape) and | |
gt_shape[0] == pred_shape[0] and gt_shape[-1] == pred_shape[-1]) | |
def _ensure_topk_value_is_percentage(top_k_percentage: float): | |
"""Checks if top_k_percentage is between 0.0 and 1.0. | |
Args: | |
top_k_percentage: The floating point value to check. | |
""" | |
if top_k_percentage < 0.0 or top_k_percentage > 1.0: | |
raise ValueError('The top-k percentage parameter must lie within 0.0 and ' | |
'1.0, but %f was given' % top_k_percentage) | |
class TopKGeneralLoss(tf.keras.losses.Loss): | |
"""This class contains code to compute the top-k loss.""" | |
def __init__(self, | |
loss_function: Callable[[tf.Tensor, tf.Tensor], tf.Tensor], | |
gt_key: Text, | |
pred_key: Text, | |
weight_key: Text, | |
top_k_percent_pixels: float = 1.0): | |
"""Initializes a top-k L1 loss. | |
Args: | |
loss_function: A callable loss function. | |
gt_key: A key to extract the ground-truth tensor. | |
pred_key: A key to extract the prediction tensor. | |
weight_key: A key to extract the weight tensor. | |
top_k_percent_pixels: An optional float specifying the percentage of | |
pixels used to compute the loss. The value must lie within [0.0, 1.0]. | |
""" | |
# Implicit reduction might mess with tf.distribute.Strategy, hence we | |
# explicitly reduce the loss. | |
super(TopKGeneralLoss, | |
self).__init__(reduction=tf.keras.losses.Reduction.NONE) | |
_ensure_topk_value_is_percentage(top_k_percent_pixels) | |
self._loss_function = loss_function | |
self._top_k_percent_pixels = top_k_percent_pixels | |
self._gt_key = gt_key | |
self._pred_key = pred_key | |
self._weight_key = weight_key | |
def call(self, y_true: Dict[Text, tf.Tensor], | |
y_pred: Dict[Text, tf.Tensor]) -> tf.Tensor: | |
"""Computes the top-k loss. | |
Args: | |
y_true: A dict of tensors providing ground-truth information. | |
y_pred: A dict of tensors providing predictions. | |
Returns: | |
A tensor of shape [batch] containing the loss per sample. | |
""" | |
gt = y_true[self._gt_key] | |
pred = y_pred[self._pred_key] | |
weights = y_true[self._weight_key] | |
per_pixel_loss = self._loss_function(gt, pred) | |
per_pixel_loss = tf.multiply(per_pixel_loss, weights) | |
return compute_average_top_k_loss(per_pixel_loss, | |
self._top_k_percent_pixels) | |
class TopKCrossEntropyLoss(tf.keras.losses.Loss): | |
"""This class contains code for top-k cross-entropy.""" | |
def __init__(self, | |
gt_key: Text, | |
pred_key: Text, | |
weight_key: Text, | |
num_classes: Optional[int], | |
ignore_label: Optional[int], | |
top_k_percent_pixels: float = 1.0, | |
dynamic_weight: bool = False): | |
"""Initializes a top-k cross entropy loss. | |
Args: | |
gt_key: A key to extract the ground-truth tensor. | |
pred_key: A key to extract the prediction tensor. | |
weight_key: A key to extract the weight tensor. | |
num_classes: An integer specifying the number of classes in the dataset. | |
ignore_label: An optional integer specifying the ignore label or None. | |
top_k_percent_pixels: An optional float specifying the percentage of | |
pixels used to compute the loss. The value must lie within [0.0, 1.0]. | |
dynamic_weight: A boolean indicating whether the weights are determined | |
dynamically w.r.t. the class confidence of each predicted mask. | |
Raises: | |
ValueError: An error occurs when top_k_percent_pixels is not between 0.0 | |
and 1.0. | |
""" | |
# Implicit reduction might mess with tf.distribute.Strategy, hence we | |
# explicitly reduce the loss. | |
super(TopKCrossEntropyLoss, | |
self).__init__(reduction=tf.keras.losses.Reduction.NONE) | |
_ensure_topk_value_is_percentage(top_k_percent_pixels) | |
self._num_classes = num_classes | |
self._ignore_label = ignore_label | |
self._top_k_percent_pixels = top_k_percent_pixels | |
self._gt_key = gt_key | |
self._pred_key = pred_key | |
self._weight_key = weight_key | |
self._dynamic_weight = dynamic_weight | |
def call(self, y_true: Dict[Text, tf.Tensor], | |
y_pred: Dict[Text, tf.Tensor]) -> tf.Tensor: | |
"""Computes the top-k cross-entropy loss. | |
Args: | |
y_true: A dict of tensors providing ground-truth information. The tensors | |
can be either integer type or one-hot encoded. When is integer type, the | |
shape can be either [batch, num_elements] or [batch, height, width]. | |
When one-hot encoded, the shape can be [batch, num_elements, channels] | |
or [batch, height, width, channels]. | |
y_pred: A dict of tensors providing predictions. The tensors are of shape | |
[batch, num_elements, channels] or [batch, height, width, channels]. If | |
the prediction is 2D (with height and width), we allow the spatial | |
dimension to be strided_height and strided_width. In this case, we | |
downsample the ground truth accordingly. | |
Returns: | |
A tensor of shape [batch] containing the loss per image. | |
Raises: | |
ValueError: If the prediction is 1D (with the length dimension) but its | |
length does not match that of the ground truth. | |
""" | |
gt = y_true[self._gt_key] | |
pred = y_pred[self._pred_key] | |
gt_shape = gt.get_shape().as_list() | |
pred_shape = pred.get_shape().as_list() | |
if self._dynamic_weight: | |
weights = y_pred[self._weight_key] | |
else: | |
weights = y_true[self._weight_key] | |
# Downsample the ground truth for 2D prediction cases. | |
if len(pred_shape) == 4 and gt_shape[1:3] != pred_shape[1:3]: | |
gt = utils.strided_downsample(gt, pred_shape[1:3]) | |
weights = utils.strided_downsample(weights, pred_shape[1:3]) | |
elif len(pred_shape) == 3 and gt_shape[1] != pred_shape[1]: | |
# We don't support downsampling for 1D predictions. | |
raise ValueError('The shape of gt does not match the shape of pred.') | |
if is_one_hot(gt, pred): | |
gt = tf.cast(gt, tf.float32) | |
else: | |
gt = tf.cast(gt, tf.int32) | |
gt, weights = encode_one_hot(gt, self._num_classes, weights, | |
self._ignore_label) | |
pixel_losses = tf.keras.backend.categorical_crossentropy( | |
gt, pred, from_logits=True) | |
weighted_pixel_losses = tf.multiply(pixel_losses, weights) | |
return compute_average_top_k_loss(weighted_pixel_losses, | |
self._top_k_percent_pixels) | |
class FocalCrossEntropyLoss(tf.keras.losses.Loss): | |
"""This class contains code for focal cross-entropy.""" | |
def __init__(self, | |
gt_key: Text, | |
pred_key: Text, | |
weight_key: Text, | |
num_classes: Optional[int], | |
ignore_label: Optional[int], | |
focal_loss_alpha: float = 0.75, | |
focal_loss_gamma: float = 0.0, | |
background_channel_index: int = -1, | |
dynamic_weight: bool = True): | |
"""Initializes a focal cross entropy loss. | |
FocalCrossEntropyLoss supports focal-loss mode with integer | |
or one-hot ground-truth labels. | |
Reference: | |
[1] Lin, T. Y., Goyal, P., Girshick, R., He, K., & Dollár, P. Focal loss for | |
dense object detection. In Proceedings of the IEEE International | |
Conference on Computer Vision (ICCV). (2017) | |
https://arxiv.org/abs/1708.02002 | |
Args: | |
gt_key: A key to extract the ground-truth tensor. | |
pred_key: A key to extract the prediction tensor. | |
weight_key: A key to extract the weight tensor. | |
num_classes: An integer specifying the number of classes in the dataset. | |
ignore_label: An optional integer specifying the ignore label or None. | |
Only effective when ground truth labels are in integer mode. | |
focal_loss_alpha: An optional float specifying the coefficient that | |
weights between positive (matched) and negative (unmatched) masks in | |
focal loss. The positives are weighted by alpha, while the negatives | |
are weighted by (1. - alpha). Default to 0.75. | |
focal_loss_gamma: An optional float specifying the coefficient that | |
weights probability (pt) term in focal loss. Focal loss = - ((1 - pt) ^ | |
gamma) * log(pt). Default to 0.0. | |
background_channel_index: The index for background channel. When alpha | |
is used, we assume the last channel is background and others are | |
foreground. Default to -1. | |
dynamic_weight: A boolean indicating whether the weights are determined | |
dynamically w.r.t. the class confidence of each predicted mask. | |
""" | |
# Implicit reduction might mess with tf.distribute.Strategy, hence we | |
# explicitly reduce the loss. | |
super(FocalCrossEntropyLoss, | |
self).__init__(reduction=tf.keras.losses.Reduction.NONE) | |
self._num_classes = num_classes | |
self._ignore_label = ignore_label | |
self._focal_loss_alpha = focal_loss_alpha | |
self._focal_loss_gamma = focal_loss_gamma | |
self._background_channel_index = background_channel_index | |
self._gt_key = gt_key | |
self._pred_key = pred_key | |
self._weight_key = weight_key | |
self._dynamic_weight = dynamic_weight | |
def call(self, y_true: Dict[Text, tf.Tensor], | |
y_pred: Dict[Text, tf.Tensor]) -> tf.Tensor: | |
"""Computes the focal cross-entropy loss. | |
Args: | |
y_true: A dict of tensors providing ground-truth information. The tensors | |
can be either integer type or one-hot encoded. When is integer type, the | |
shape can be either [batch, num_elements] or [batch, height, width]. | |
When one-hot encoded, the shape can be [batch, num_elements, channels] | |
or [batch, height, width, channels]. | |
y_pred: A dict of tensors providing predictions. The tensors are of shape | |
[batch, num_elements, channels] or [batch, height, width, channels]. | |
Returns: | |
A tensor of shape [batch] containing the loss per image. | |
""" | |
gt = y_true[self._gt_key] | |
pred = y_pred[self._pred_key] | |
if self._dynamic_weight: | |
# Dynamic weights w.r.t. the class confidence of each predicted mask. | |
weights = y_pred[self._weight_key] | |
else: | |
weights = y_true[self._weight_key] | |
if is_one_hot(gt, pred): | |
gt = tf.cast(gt, tf.float32) | |
else: | |
gt = tf.cast(gt, tf.int32) | |
gt, weights = encode_one_hot(gt, self._num_classes, weights, | |
self._ignore_label) | |
pixel_losses = tf.nn.softmax_cross_entropy_with_logits(gt, pred) | |
# Focal loss | |
if self._focal_loss_gamma == 0.0: | |
pixel_focal_losses = pixel_losses | |
else: | |
predictions = tf.nn.softmax(pred, axis=-1) | |
pt = tf.reduce_sum(predictions * gt, axis=-1) | |
pixel_focal_losses = tf.multiply( | |
tf.pow(1.0 - pt, self._focal_loss_gamma), pixel_losses) | |
if self._focal_loss_alpha >= 0: | |
# alpha_weights = alpha * positive masks + (1 - alpha) * negative masks. | |
alpha = self._focal_loss_alpha | |
alpha_weights = ( | |
alpha * (1.0 - gt[..., self._background_channel_index]) | |
+ (1 - alpha) * gt[..., self._background_channel_index]) | |
pixel_focal_losses = alpha_weights * pixel_focal_losses | |
weighted_pixel_losses = tf.multiply(pixel_focal_losses, weights) | |
weighted_pixel_losses = tf.reshape( | |
weighted_pixel_losses, shape=(tf.shape(weighted_pixel_losses)[0], -1)) | |
# Compute mean loss over spatial dimension. | |
num_non_zero = tf.reduce_sum( | |
tf.cast(tf.not_equal(weighted_pixel_losses, 0.0), tf.float32), 1) | |
loss_sum_per_sample = tf.reduce_sum(weighted_pixel_losses, 1) | |
return tf.math.divide_no_nan(loss_sum_per_sample, num_non_zero) | |
class MaskDiceLoss(tf.keras.losses.Loss): | |
"""This class contains code to compute Mask Dice loss. | |
The channel dimension in Mask Dice loss indicates the mask ID in MaX-DeepLab, | |
instead of a "class" dimension in the original Dice loss. | |
""" | |
def __init__(self, | |
gt_key: Text, | |
pred_key: Text, | |
weight_key: Text, | |
prediction_activation='softmax'): | |
"""Initializes a Mask Dice loss. | |
Args: | |
gt_key: A key to extract the ground-truth tensor. | |
pred_key: A key to extract the pred tensor. | |
weight_key: A key to extract the weight tensor. | |
prediction_activation: A String indicating activation function of the | |
prediction. It should be either 'sigmoid' or 'softmax'. | |
""" | |
# Implicit reduction might mess with tf.distribute.Strategy, hence we | |
# explicitly reduce the loss. | |
super(MaskDiceLoss, self).__init__(reduction=tf.keras.losses.Reduction.NONE) | |
self._gt_key = gt_key | |
self._pred_key = pred_key | |
self._weight_key = weight_key | |
self._prediction_activation = prediction_activation | |
def call(self, y_true: Dict[Text, tf.Tensor], | |
y_pred: Dict[Text, tf.Tensor]) -> tf.Tensor: | |
"""Computes the Mask Dice loss. | |
Args: | |
y_true: A dict of tensors providing ground-truth information. | |
y_pred: A dict of tensors providing predictions. | |
Returns: | |
A tensor of shape [batch] containing the loss per sample. | |
""" | |
gt = y_true[self._gt_key] | |
pred = y_pred[self._pred_key] | |
# Dynamic weights w.r.t. the class confidence of each predicted mask. | |
weights = y_pred[self._weight_key] | |
weighted_dice_losses = tf.multiply( | |
compute_mask_dice_loss(gt, pred, self._prediction_activation), | |
weights) | |
# Reduce_sum over the channels (i.e., number of masks). | |
return tf.reduce_sum(weighted_dice_losses, axis=-1) | |