File size: 22,897 Bytes
0924f30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
# coding=utf-8
# Copyright 2021 The Deeplab2 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This file contains basic loss classes used in the DeepLab model."""

from typing import Text, Dict, Callable, Optional

import tensorflow as tf
from deeplab2.model import utils


def compute_average_top_k_loss(loss: tf.Tensor,
                               top_k_percentage: float) -> tf.Tensor:
  """Computes the avaerage top-k loss per sample.

  Args:
    loss: A tf.Tensor with 2 or more dimensions of shape [batch, ...].
    top_k_percentage: A float representing the % of pixel that should be used
      for calculating the loss.

  Returns:
    A tensor of shape [batch] containing the mean top-k loss per sample. Due to
    the use of different tf.strategy, we return the loss per sample and require
    explicit averaging by the user.
  """
  loss = tf.reshape(loss, shape=(tf.shape(loss)[0], -1))

  if top_k_percentage != 1.0:
    num_elements_per_sample = tf.shape(loss)[1]
    top_k_pixels = tf.cast(
        tf.math.round(top_k_percentage *
                      tf.cast(num_elements_per_sample, tf.float32)), tf.int32)

    def top_k_1d(inputs):
      return tf.math.top_k(inputs, top_k_pixels, sorted=False)[0]
    loss = tf.map_fn(fn=top_k_1d, elems=loss)

  # Compute mean loss over spatial dimension.
  num_non_zero = tf.reduce_sum(tf.cast(tf.not_equal(loss, 0.0), tf.float32), 1)
  loss_sum_per_sample = tf.reduce_sum(loss, 1)
  return tf.math.divide_no_nan(loss_sum_per_sample, num_non_zero)


def compute_mask_dice_loss(y_true: tf.Tensor,
                           y_pred: tf.Tensor,
                           prediction_activation='softmax') -> tf.Tensor:
  """Computes the Mask Dice loss between y_true and y_pred masks.

  Reference:
  [1] Milletari, F., Navab, N., Ahmadi, S.A.: V-net: Fully convolutional neural
      networks for volumetric medical image segmentation. In: 3DV (2016)
      https://arxiv.org/abs/1606.04797

  Args:
    y_true: A tf.Tensor of shape [batch, height, width, channels] (or [batch,
      length, channels]) containing the ground-truth. The channel dimension
      indicates the mask ID in MaX-DeepLab, instead of a "class" dimension in
      the V-net paper. In our case, for all batch, height, width, (or batch,
      length) the [batch, height, width, :] (or [batch, length, :]) should be
      one-hot encodings only, with valid pixels having one and only one 1.0, and
      with void pixels being all 0.0. The valid pixels of the masks do not and
      should not overlap because of the non-overlapping definition of panoptic
      segmentation. The output loss is computed and normalized by valid (not
      void) pixels.
    y_pred: A tf.Tensor of shape [batch, height, width, channels] (or [batch,
      length, channels]) containing the prediction.
    prediction_activation: A String indicating activation function of the
      prediction. It should be either 'sigmoid' or 'softmax'.

  Returns:
    A tf.Tensor of shape [batch, channels] with the computed dice loss value.

  Raises:
      ValueError: An error occurs when prediction_activation is not either
        'sigmoid' or 'softmax'.
  """
  tf.debugging.assert_rank_in(
      y_pred, [3, 4], message='Input tensors y_pred must have rank 3 or 4.')
  tf.debugging.assert_rank_in(
      y_true, [3, 4], message='Input tensors y_true must have rank 3 or 4.')

  shape_list = y_true.shape.as_list()
  batch, channels = shape_list[0], shape_list[-1]
  if prediction_activation == 'sigmoid':
    y_pred = tf.math.sigmoid(y_pred)
  elif prediction_activation == 'softmax':
    y_pred = tf.nn.softmax(y_pred, axis=-1)
  else:
    raise ValueError(
        "prediction_activation should be either 'sigmoid' or 'softmax'")

  y_true_flat = tf.reshape(y_true, [batch, -1, channels])
  # valid_flat indicates labeled pixels in the groudtruth. y_true is one-hot
  # encodings only, with valid pixels having one and only one 1.0, and with
  # invalid pixels having 0.0 values in all the channels. The valid pixels of
  # the masks do not overlap because of the non-overlapping definition of
  # panoptic segmentation.
  valid_flat = tf.reduce_sum(y_true_flat, axis=-1, keepdims=True)
  y_pred_flat = tf.reshape(
      y_pred, [batch, -1, channels]) * valid_flat
  # Use smooth = 1 to avoid division by zero when both y_pred and y_true are
  # zeros.
  smooth = 1.0
  intersection = 2 * tf.reduce_sum(y_pred_flat * y_true_flat, axis=1) + smooth
  denominator = (tf.reduce_sum(y_pred_flat, axis=1) +
                 tf.reduce_sum(y_true_flat, axis=1) + smooth)
  loss = 1. - tf.math.divide_no_nan(intersection, denominator)
  return loss


def mean_absolute_error(y_true: tf.Tensor,
                        y_pred: tf.Tensor,
                        force_keep_dims=False) -> tf.Tensor:
  """Computes the per-pixel mean absolute error for 3D and 4D tensors.

  Default reduction behavior: If a 3D tensor is used, no reduction is applied.
  In case of a 4D tensor, reduction is applied. This behavior can be overridden
  by force_keep_dims.
  Note: tf.keras.losses.mean_absolute_error always reduces the output by one
  dimension.

  Args:
    y_true: A tf.Tensor of shape [batch, height, width] or [batch, height,
      width, channels] containing the ground-truth.
    y_pred: A tf.Tensor of shape [batch, height, width] or [batch, height,
      width, channels] containing the prediction.
    force_keep_dims: A boolean flag specifying whether no reduction should be
      applied.

  Returns:
    A tf.Tensor with the mean absolute error.
  """
  tf.debugging.assert_rank_in(
      y_pred, [3, 4], message='Input tensors must have rank 3 or 4.')
  if len(y_pred.shape.as_list()) == 3 or force_keep_dims:
    return tf.abs(y_true - y_pred)
  else:
    return tf.reduce_mean(tf.abs(y_true - y_pred), axis=[3])


def mean_squared_error(y_true: tf.Tensor,
                       y_pred: tf.Tensor,
                       force_keep_dims=False) -> tf.Tensor:
  """Computes the per-pixel mean squared error for 3D and 4D tensors.

  Default reduction behavior: If a 3D tensor is used, no reduction is applied.
  In case of a 4D tensor, reduction is applied. This behavior can be overridden
  by force_keep_dims.
  Note: tf.keras.losses.mean_squared_error always reduces the output by one
  dimension.

  Args:
    y_true: A tf.Tensor of shape [batch, height, width] or [batch, height,
      width, channels] containing the ground-truth.
    y_pred: A tf.Tensor of shape [batch, height, width] or [batch, height,
      width, channels] containing the prediction.
    force_keep_dims: A boolean flag specifying whether no reduction should be
      applied.

  Returns:
    A tf.Tensor with the mean squared error.
  """
  tf.debugging.assert_rank_in(
      y_pred, [3, 4], message='Input tensors must have rank 3 or 4.')
  if len(y_pred.shape.as_list()) == 3 or force_keep_dims:
    return tf.square(y_true - y_pred)
  else:
    return tf.reduce_mean(tf.square(y_true - y_pred), axis=[3])


def encode_one_hot(gt: tf.Tensor,
                   num_classes: int,
                   weights: tf.Tensor,
                   ignore_label: Optional[int]):
  """Helper function for one-hot encoding of integer labels.

  Args:
    gt: A tf.Tensor providing ground-truth information. Integer type label.
    num_classes: An integer indicating the number of classes considered in the
      ground-truth. It is used as 'depth' in tf.one_hot().
    weights: A tf.Tensor containing weights information.
    ignore_label: An integer specifying the ignore label or None.

  Returns:
    gt: A tf.Tensor of one-hot encoded gt labels.
    weights: A tf.Tensor with ignore_label considered.
  """
  if ignore_label is not None:
    keep_mask = tf.cast(tf.not_equal(gt, ignore_label), dtype=tf.float32)
  else:
    keep_mask = tf.ones_like(gt, dtype=tf.float32)
  gt = tf.stop_gradient(tf.one_hot(gt, num_classes))
  weights = tf.multiply(weights, keep_mask)
  return gt, weights


def is_one_hot(gt: tf.Tensor, pred: tf.Tensor):
  """Helper function for checking if gt tensor is one-hot encoded or not.

  Args:
    gt: A tf.Tensor providing ground-truth information.
    pred: A tf.Tensor providing prediction information.

  Returns:
    A boolean indicating whether the gt is one-hot encoded (True) or
    in integer type (False).
  """
  gt_shape = gt.get_shape().as_list()
  pred_shape = pred.get_shape().as_list()
  # If the ground truth is one-hot encoded, the rank of the ground truth should
  # match that of the prediction. In addition, we check that the first
  # dimension, batch_size, and the last dimension, channels, should also match
  # the prediction. However, we still allow spatial dimensions, e.g., height and
  # width, to be different since we will downsample the ground truth if needed.
  return (len(gt_shape) == len(pred_shape) and
          gt_shape[0] == pred_shape[0] and gt_shape[-1] == pred_shape[-1])


def _ensure_topk_value_is_percentage(top_k_percentage: float):
  """Checks if top_k_percentage is between 0.0 and 1.0.

  Args:
    top_k_percentage: The floating point value to check.
  """
  if top_k_percentage < 0.0 or top_k_percentage > 1.0:
    raise ValueError('The top-k percentage parameter must lie within 0.0 and '
                     '1.0, but %f was given' % top_k_percentage)


class TopKGeneralLoss(tf.keras.losses.Loss):
  """This class contains code to compute the top-k loss."""

  def __init__(self,
               loss_function: Callable[[tf.Tensor, tf.Tensor], tf.Tensor],
               gt_key: Text,
               pred_key: Text,
               weight_key: Text,
               top_k_percent_pixels: float = 1.0):
    """Initializes a top-k L1 loss.

    Args:
      loss_function: A callable loss function.
      gt_key: A key to extract the ground-truth tensor.
      pred_key: A key to extract the prediction tensor.
      weight_key: A key to extract the weight tensor.
      top_k_percent_pixels: An optional float specifying the percentage of
        pixels used to compute the loss. The value must lie within [0.0, 1.0].
    """
    # Implicit reduction might mess with tf.distribute.Strategy, hence we
    # explicitly reduce the loss.
    super(TopKGeneralLoss,
          self).__init__(reduction=tf.keras.losses.Reduction.NONE)

    _ensure_topk_value_is_percentage(top_k_percent_pixels)

    self._loss_function = loss_function
    self._top_k_percent_pixels = top_k_percent_pixels
    self._gt_key = gt_key
    self._pred_key = pred_key
    self._weight_key = weight_key

  def call(self, y_true: Dict[Text, tf.Tensor],
           y_pred: Dict[Text, tf.Tensor]) -> tf.Tensor:
    """Computes the top-k loss.

    Args:
      y_true: A dict of tensors providing ground-truth information.
      y_pred: A dict of tensors providing predictions.

    Returns:
      A tensor of shape [batch] containing the loss per sample.
    """
    gt = y_true[self._gt_key]
    pred = y_pred[self._pred_key]
    weights = y_true[self._weight_key]

    per_pixel_loss = self._loss_function(gt, pred)
    per_pixel_loss = tf.multiply(per_pixel_loss, weights)

    return compute_average_top_k_loss(per_pixel_loss,
                                      self._top_k_percent_pixels)


class TopKCrossEntropyLoss(tf.keras.losses.Loss):
  """This class contains code for top-k cross-entropy."""

  def __init__(self,
               gt_key: Text,
               pred_key: Text,
               weight_key: Text,
               num_classes: Optional[int],
               ignore_label: Optional[int],
               top_k_percent_pixels: float = 1.0,
               dynamic_weight: bool = False):
    """Initializes a top-k cross entropy loss.

    Args:
      gt_key: A key to extract the ground-truth tensor.
      pred_key: A key to extract the prediction tensor.
      weight_key: A key to extract the weight tensor.
      num_classes: An integer specifying the number of classes in the dataset.
      ignore_label: An optional integer specifying the ignore label or None.
      top_k_percent_pixels: An optional float specifying the percentage of
        pixels used to compute the loss. The value must lie within [0.0, 1.0].
      dynamic_weight: A boolean indicating whether the weights are determined
        dynamically w.r.t. the class confidence of each predicted mask.

    Raises:
      ValueError: An error occurs when top_k_percent_pixels is not between 0.0
        and 1.0.
    """
    # Implicit reduction might mess with tf.distribute.Strategy, hence we
    # explicitly reduce the loss.
    super(TopKCrossEntropyLoss,
          self).__init__(reduction=tf.keras.losses.Reduction.NONE)

    _ensure_topk_value_is_percentage(top_k_percent_pixels)

    self._num_classes = num_classes
    self._ignore_label = ignore_label
    self._top_k_percent_pixels = top_k_percent_pixels
    self._gt_key = gt_key
    self._pred_key = pred_key
    self._weight_key = weight_key
    self._dynamic_weight = dynamic_weight

  def call(self, y_true: Dict[Text, tf.Tensor],
           y_pred: Dict[Text, tf.Tensor]) -> tf.Tensor:
    """Computes the top-k cross-entropy loss.

    Args:
      y_true: A dict of tensors providing ground-truth information. The tensors
        can be either integer type or one-hot encoded. When is integer type, the
        shape can be either [batch, num_elements] or [batch, height, width].
        When one-hot encoded, the shape can be [batch, num_elements, channels]
        or [batch, height, width, channels].
      y_pred: A dict of tensors providing predictions. The tensors are of shape
        [batch, num_elements, channels] or [batch, height, width, channels]. If
        the prediction is 2D (with height and width), we allow the spatial
        dimension to be strided_height and strided_width. In this case, we
        downsample the ground truth accordingly.

    Returns:
      A tensor of shape [batch] containing the loss per image.

    Raises:
      ValueError: If the prediction is 1D (with the length dimension) but its
        length does not match that of the ground truth.
    """
    gt = y_true[self._gt_key]
    pred = y_pred[self._pred_key]
    gt_shape = gt.get_shape().as_list()
    pred_shape = pred.get_shape().as_list()
    if self._dynamic_weight:
      weights = y_pred[self._weight_key]
    else:
      weights = y_true[self._weight_key]

    # Downsample the ground truth for 2D prediction cases.
    if len(pred_shape) == 4 and gt_shape[1:3] != pred_shape[1:3]:
      gt = utils.strided_downsample(gt, pred_shape[1:3])
      weights = utils.strided_downsample(weights, pred_shape[1:3])
    elif len(pred_shape) == 3 and gt_shape[1] != pred_shape[1]:
      # We don't support downsampling for 1D predictions.
      raise ValueError('The shape of gt does not match the shape of pred.')

    if is_one_hot(gt, pred):
      gt = tf.cast(gt, tf.float32)
    else:
      gt = tf.cast(gt, tf.int32)
      gt, weights = encode_one_hot(gt, self._num_classes, weights,
                                   self._ignore_label)
    pixel_losses = tf.keras.backend.categorical_crossentropy(
        gt, pred, from_logits=True)
    weighted_pixel_losses = tf.multiply(pixel_losses, weights)

    return compute_average_top_k_loss(weighted_pixel_losses,
                                      self._top_k_percent_pixels)


class FocalCrossEntropyLoss(tf.keras.losses.Loss):
  """This class contains code for focal cross-entropy."""

  def __init__(self,
               gt_key: Text,
               pred_key: Text,
               weight_key: Text,
               num_classes: Optional[int],
               ignore_label: Optional[int],
               focal_loss_alpha: float = 0.75,
               focal_loss_gamma: float = 0.0,
               background_channel_index: int = -1,
               dynamic_weight: bool = True):
    """Initializes a focal cross entropy loss.

    FocalCrossEntropyLoss supports focal-loss mode with integer
    or one-hot ground-truth labels.
    Reference:
    [1] Lin, T. Y., Goyal, P., Girshick, R., He, K., & Dollár, P. Focal loss for
        dense object detection. In Proceedings of the IEEE International
        Conference on Computer Vision (ICCV). (2017)
        https://arxiv.org/abs/1708.02002

    Args:
      gt_key: A key to extract the ground-truth tensor.
      pred_key: A key to extract the prediction tensor.
      weight_key: A key to extract the weight tensor.
      num_classes: An integer specifying the number of classes in the dataset.
      ignore_label: An optional integer specifying the ignore label or None.
        Only effective when ground truth labels are in integer mode.
      focal_loss_alpha: An optional float specifying the coefficient that
        weights between positive (matched) and negative (unmatched) masks in
        focal loss. The positives are weighted by alpha, while the negatives
        are weighted by (1. - alpha). Default to 0.75.
      focal_loss_gamma: An optional float specifying the coefficient that
        weights probability (pt) term in focal loss. Focal loss = - ((1 - pt) ^
        gamma) * log(pt). Default to 0.0.
      background_channel_index: The index for background channel. When alpha
        is used, we assume the last channel is background and others are
        foreground. Default to -1.
      dynamic_weight: A boolean indicating whether the weights are determined
        dynamically w.r.t. the class confidence of each predicted mask.
    """
    # Implicit reduction might mess with tf.distribute.Strategy, hence we
    # explicitly reduce the loss.
    super(FocalCrossEntropyLoss,
          self).__init__(reduction=tf.keras.losses.Reduction.NONE)

    self._num_classes = num_classes
    self._ignore_label = ignore_label
    self._focal_loss_alpha = focal_loss_alpha
    self._focal_loss_gamma = focal_loss_gamma
    self._background_channel_index = background_channel_index
    self._gt_key = gt_key
    self._pred_key = pred_key
    self._weight_key = weight_key
    self._dynamic_weight = dynamic_weight

  def call(self, y_true: Dict[Text, tf.Tensor],
           y_pred: Dict[Text, tf.Tensor]) -> tf.Tensor:
    """Computes the focal cross-entropy loss.

    Args:
      y_true: A dict of tensors providing ground-truth information. The tensors
        can be either integer type or one-hot encoded. When is integer type, the
        shape can be either [batch, num_elements] or [batch, height, width].
        When one-hot encoded, the shape can be [batch, num_elements, channels]
        or [batch, height, width, channels].
      y_pred: A dict of tensors providing predictions. The tensors are of shape
        [batch, num_elements, channels] or [batch, height, width, channels].

    Returns:
      A tensor of shape [batch] containing the loss per image.

    """

    gt = y_true[self._gt_key]
    pred = y_pred[self._pred_key]
    if self._dynamic_weight:
      # Dynamic weights w.r.t. the class confidence of each predicted mask.
      weights = y_pred[self._weight_key]
    else:
      weights = y_true[self._weight_key]

    if is_one_hot(gt, pred):
      gt = tf.cast(gt, tf.float32)
    else:
      gt = tf.cast(gt, tf.int32)
      gt, weights = encode_one_hot(gt, self._num_classes, weights,
                                   self._ignore_label)
    pixel_losses = tf.nn.softmax_cross_entropy_with_logits(gt, pred)
    # Focal loss
    if self._focal_loss_gamma == 0.0:
      pixel_focal_losses = pixel_losses
    else:
      predictions = tf.nn.softmax(pred, axis=-1)
      pt = tf.reduce_sum(predictions * gt, axis=-1)
      pixel_focal_losses = tf.multiply(
          tf.pow(1.0 - pt, self._focal_loss_gamma), pixel_losses)

    if self._focal_loss_alpha >= 0:
      # alpha_weights = alpha * positive masks + (1 - alpha) * negative masks.
      alpha = self._focal_loss_alpha
      alpha_weights = (
          alpha * (1.0 - gt[..., self._background_channel_index])
          + (1 - alpha) * gt[..., self._background_channel_index])
      pixel_focal_losses = alpha_weights * pixel_focal_losses
    weighted_pixel_losses = tf.multiply(pixel_focal_losses, weights)
    weighted_pixel_losses = tf.reshape(
        weighted_pixel_losses, shape=(tf.shape(weighted_pixel_losses)[0], -1))
    # Compute mean loss over spatial dimension.
    num_non_zero = tf.reduce_sum(
        tf.cast(tf.not_equal(weighted_pixel_losses, 0.0), tf.float32), 1)
    loss_sum_per_sample = tf.reduce_sum(weighted_pixel_losses, 1)
    return tf.math.divide_no_nan(loss_sum_per_sample, num_non_zero)


class MaskDiceLoss(tf.keras.losses.Loss):
  """This class contains code to compute Mask Dice loss.

  The channel dimension in Mask Dice loss indicates the mask ID in MaX-DeepLab,
  instead of a "class" dimension in the original Dice loss.
  """

  def __init__(self,
               gt_key: Text,
               pred_key: Text,
               weight_key: Text,
               prediction_activation='softmax'):
    """Initializes a Mask Dice loss.

    Args:
      gt_key: A key to extract the ground-truth tensor.
      pred_key: A key to extract the pred tensor.
      weight_key: A key to extract the weight tensor.
      prediction_activation: A String indicating activation function of the
        prediction. It should be either 'sigmoid' or 'softmax'.
    """
    # Implicit reduction might mess with tf.distribute.Strategy, hence we
    # explicitly reduce the loss.
    super(MaskDiceLoss, self).__init__(reduction=tf.keras.losses.Reduction.NONE)

    self._gt_key = gt_key
    self._pred_key = pred_key
    self._weight_key = weight_key
    self._prediction_activation = prediction_activation

  def call(self, y_true: Dict[Text, tf.Tensor],
           y_pred: Dict[Text, tf.Tensor]) -> tf.Tensor:
    """Computes the Mask Dice loss.

    Args:
      y_true: A dict of tensors providing ground-truth information.
      y_pred: A dict of tensors providing predictions.

    Returns:
      A tensor of shape [batch] containing the loss per sample.
    """
    gt = y_true[self._gt_key]
    pred = y_pred[self._pred_key]
    # Dynamic weights w.r.t. the class confidence of each predicted mask.
    weights = y_pred[self._weight_key]
    weighted_dice_losses = tf.multiply(
        compute_mask_dice_loss(gt, pred, self._prediction_activation),
        weights)
    # Reduce_sum over the channels (i.e., number of masks).
    return tf.reduce_sum(weighted_dice_losses, axis=-1)