File size: 36,382 Bytes
d1843be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
# coding=utf-8
# Copyright 2021 The Deeplab2 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This file contains the loss functions for MaX-DeepLab models.

Reference:
  MaX-DeepLab: "End-to-End Panoptic Segmentation with Mask Transformers",
    CVPR 2021. https://arxiv.org/abs/2012.00759
      Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen.
"""
from typing import Text, Dict, Tuple, List

import tensorflow as tf
from deeplab2 import common
from deeplab2 import config_pb2
from deeplab2.model import utils
from deeplab2.model.loss import base_loss
from deeplab2.model.loss import matchers_ops

# Positive and negative constants that are used to pad or mask hungarian
# matching weights.
_MATCHING_NEGATIVE_CONSTANT = -999.0
_MATCHING_POSITIVE_CONSTANT = 999.0
# A large negative constant applied before softmax. This will make the softmax
# ignore the masked logits.
_SOFTMAX_MASKING_CONSTANT = -99999.0

_GT_KEY = 'gt_key'
_PRED_KEY = 'pred_key'
_WEIGHT_KEY = 'weight_key'


def _generate_mask_slot_semantic_one_hot(
    matched_mask_slot_indices: tf.Tensor,
    mask_gt_semantic_map: tf.Tensor,
    num_mask_slots: int,
    thing_stuff_class_ids: List[int]):
  """Generates the ground truth for transformer_class_logits.

  This function generates a pseudo ground truth that we will use to train the
  transformer class head logits. The input tensors, matched_mask_slot_indices
  and mask_gt_semantic_map, are obtained by (hungarian) matching the ground
  truth masks with the predicted masks. Note that this function generates the
  positive one hot encodings only, i.e., the void class is not included in the
  output tensor but will be generated outside the function.

  Args:
    matched_mask_slot_indices: An int32 tf.Tensor of shape [batch_size,
      num_ground_truth_masks] that encodes the matched mask slot id for each
      ground truth mask.
    mask_gt_semantic_map: An int32 tf.Tensor of shape [batch_size,
      num_ground_truth_masks] that encodes the semantic label for each ground
      truth mask. A padded mask (or void, or no object) will have the label -1.
    num_mask_slots: An integer, the number of mask slots for the MaX-DeepLab
      model.
    thing_stuff_class_ids: A list of integers of length [num_thing_classes +
      num_stuff_classes] that encodes the class IDs for all thing and stuff
      classes. It is a concatenation of the thing_class_ids list and the
      stuff_class_ids list.

  Returns:
    mask_slot_semantic_one_hot: An output tf.Tensor with shape [batch_size,
      num_mask_slots, num_thing_classes + num_stuff_classes].
  """
  semantic_map_shape = mask_gt_semantic_map.get_shape().as_list()
  batch_size = semantic_map_shape[0]
  num_ground_truth_masks = semantic_map_shape[-1]

  # Concatenate the indices in each dimension of the ground truth one hot
  # output.
  batch_indices = tf.expand_dims(tf.range(batch_size), axis=-1)
  batch_indices = tf.tile(batch_indices, [1, num_ground_truth_masks])
  batch_indices = tf.reshape(batch_indices, [-1, 1])
  matched_mask_slot_indices = tf.reshape(matched_mask_slot_indices, [-1, 1])
  # We shift the semantic map by one so that void labels (-1) will be a valid
  # index too. Otherwise, tf.scatter_nd raises error if it runs on CPU.
  semantic_indices = tf.reshape(mask_gt_semantic_map, [-1, 1]) + 1
  indices = tf.concat([batch_indices,
                       matched_mask_slot_indices,
                       semantic_indices], axis=-1)

  # Generate mask_slot_semantic_one_hot by scattering constant ones onto a
  # constant zero tensor.
  updates = tf.ones([batch_size * num_ground_truth_masks], dtype=tf.float32)
  mask_slot_semantic_one_hot = tf.scatter_nd(
      indices, updates,
      shape=[batch_size, num_mask_slots, max(thing_stuff_class_ids) + 2])

  # Gather the wanted classes in the desired (thing + stuff) order.
  thing_stuff_tensor = tf.cast(thing_stuff_class_ids, tf.int32)
  # We also shift the thing_stuff_tensor index by one in order to revert the
  # semantic map shifting above.
  mask_slot_semantic_one_hot = tf.gather(mask_slot_semantic_one_hot,
                                         thing_stuff_tensor + 1, axis=2)
  return mask_slot_semantic_one_hot


def nonsquare_hungarian_matching(
    weights: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
  """Hungarian matching with arbitrary shape.

  The matchers_ops.hungarian_matching supports only squared weight matrices.
  This function generalizes the hungarian matching to nonsquare cases by padding
  the weights to a square and running the square version matching. The property
  of hungarian matching ensures that the solutions are equivalent for the padded
  square problem and the original nonsquare problem.

  Args:
    weights: A [batch, shape1, shape2] float32 tf.Tensor.

  Returns:
    square_permutation: A [batch, max(shape1, shape2), max(shape1, shape2)]
      float32 tf.Tensor that is the permutation matrix that achieves the minimum
      total weight. Note that a permutation matrix contains only value 0.0 and
      1.0, with each row and each column sums to 1.0.
    nonsquare_permutation: A [batch, shape1, shape2] float32 tf.Tensor. The
      nonsquare part of the permutation matrix.
  """
  _, height, width = weights.get_shape().as_list()
  max_height_width = max(height, width)
  # Padding a constant on one axis does not affect matching results.
  weights = tf.pad(weights,
                   [[0, 0],  # Do not pad the batch dimension.
                    [0, max_height_width - height],
                    [0, max_height_width - width]],
                   constant_values=_MATCHING_NEGATIVE_CONSTANT)
  square_permutation = matchers_ops.hungarian_matching(weights)

  square_permutation = tf.cast(square_permutation, tf.float32)
  return square_permutation, square_permutation[:, :height, :width]


def _mask_similarity(gt_mask: tf.Tensor, pred_mask: tf.Tensor,
                     metric: str = 'dice') -> tf.Tensor:
  """Computes mask similarity between gt_masks and pred_masks.

  Args:
    gt_mask: A [batch, height * width, num_gt_masks] float32 tf.Tensor, that
      contains only value 0.0 and 1.0. Each 1.0 indicates that the pixel belongs
      to the ground truth mask. Note that panoptic segmentation enforces that
      ground truth masks do not overlap.
    pred_mask: A [batch, height * width, num_pred_masks] float32 tf.Tensor, that
      is positive. For each batch_id and pixel_id, the [num_pred_masks] vector
      encodes whether each pixel belongs to each mask. The sum of each vector is
      less than or equal to one.
    metric: A string, the mask similarity metric that we will compute. Supports
      'dice' (default), 'iou', 'intersection_over_ground_truth', and
      'intersection_over_prediction'.

  Returns:
    mask_similarity: A float32 [batch, num_gt_masks, num_pred_masks] tf.Tensor
      that contains the mask similarity between all ground truth masks and all
      predicted masks.

  Raises:
    ValueError: If the mask similarity metric is not one of 'dice', 'iou',
    'intersection_over_ground_truth', or 'intersection_over_prediction'.
  """
  denominator_epsilon = 1e-5
  intersection = tf.einsum('bpi,bpj->bij', gt_mask, pred_mask)
  if metric.lower() == 'dice':
    denominator = (tf.expand_dims(tf.reduce_sum(gt_mask, axis=1), axis=2) +
                   tf.reduce_sum(pred_mask, axis=1, keepdims=True)) / 2
  elif metric.lower() == 'iou':
    denominator = (tf.expand_dims(tf.reduce_sum(gt_mask, axis=1), axis=2) +
                   tf.reduce_sum(pred_mask, axis=1, keepdims=True) -
                   intersection)
  elif metric.lower() == 'intersection_over_ground_truth':
    denominator = tf.expand_dims(tf.reduce_sum(gt_mask, axis=1), axis=2)
  elif metric.lower() == 'intersection_over_prediction':
    denominator = tf.reduce_sum(pred_mask, axis=1, keepdims=True)
  else:
    raise ValueError('The mask similarity metric is not supported.')
  return intersection / (denominator + denominator_epsilon)


class MaXDeepLabLoss(tf.keras.layers.Layer):
  """This class contains code for MaX-DeepLab losses."""

  def __init__(self,
               loss_options: config_pb2.LossOptions,
               ignore_label: int,
               thing_class_ids: Tuple[int],
               focal_loss_alpha: float = 0.75,
               instance_discrimination_temperature: float = 0.3):
    """Initializes a MaX-DeepLab loss.

    This class supports PQ-style loss, mask id cross entropy loss, and instance
    discrimination loss, proposed in MaX-DeepLab. The PQ-style loss can be
    further decomposed in to a classification term and a mask dice term.

    Reference:
      MaX-DeepLab: "End-to-End Panoptic Segmentation with Mask Transformers",
      CVPR 2021. https://arxiv.org/abs/2012.00759
        Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen.

    Args:
      loss_options: Loss options as defined by config_pb2.LossOptions.
      ignore_label: An integer specifying the ignore label.
      thing_class_ids: A tuple of length [N] containing N thing indices.
      focal_loss_alpha: An optional float specifying the coefficient that
        weights between positive (matched) and negative (unmatched) masks in
        focal loss. The positives are weighted by alpha, while the negatives
        are weighted by (1. - alpha). Note that we do not use a focal loss
        gamma here, i.e., the gamma is set to zero which is equivalent to the
        normal cross-entropy loss, except for the alpha weighting. Default to
        0.75.
      instance_discrimination_temperature: An optional float specifying the
        temperature for the instance discrimination loss.
    """
    super(MaXDeepLabLoss, self).__init__(name='MaXDeepLabLoss')
    # The loss_terms will optionally include
    #  - common.PQ_STYLE_LOSS_CLASS_TERM
    #  - common.PQ_STYLE_LOSS_MASK_DICE_TERM
    #  - common.MASK_ID_CROSS_ENTROPY_LOSS
    #  - common.INSTANCE_DISCRIMINATION_LOSS
    # These loss terms will be accessed by loss_builder.py and will be used to
    # build loss metrics.
    self.loss_terms = []

    # The PQ-style loss includes two terms.
    self._pq_style_loss_weight = 0.0
    if loss_options.HasField(common.PQ_STYLE_LOSS):
      self._pq_style_loss_weight = loss_options.pq_style_loss.weight
      self.loss_terms.append(common.PQ_STYLE_LOSS_CLASS_TERM)
      self.loss_terms.append(common.PQ_STYLE_LOSS_MASK_DICE_TERM)

    # Mask-ID cross entropy loss.
    self._mask_id_cross_entropy_loss_weight = 0.0
    if loss_options.HasField(common.MASK_ID_CROSS_ENTROPY_LOSS):
      self._mask_id_cross_entropy_loss_weight = (
          loss_options.mask_id_cross_entropy_loss.weight)
      self.loss_terms.append(common.MASK_ID_CROSS_ENTROPY_LOSS)

    # Instance discrimination loss.
    self._instance_discrimination_loss_weight = 0.0
    if loss_options.HasField(common.INSTANCE_DISCRIMINATION_LOSS):
      self._instance_discrimination_loss_weight = (
          loss_options.instance_discrimination_loss.weight)
      self.loss_terms.append(common.INSTANCE_DISCRIMINATION_LOSS)

    self._ignore_label = ignore_label
    self._thing_class_ids = list(thing_class_ids)
    self._focal_loss_alpha = focal_loss_alpha
    self._instance_discrimination_temperature = (
        instance_discrimination_temperature)

    # Build the base loss functions.
    self._pq_style_loss_class_term = base_loss.FocalCrossEntropyLoss(
        gt_key=_GT_KEY, pred_key=_PRED_KEY, weight_key=_WEIGHT_KEY,
        # Num_classes and ignore_label are not necessary since the inputs will
        # be one hot encoded already.
        num_classes=None, ignore_label=None,
        focal_loss_alpha=focal_loss_alpha,
        focal_loss_gamma=0.0, background_channel_index=-1,
        dynamic_weight=True)
    self._pq_style_loss_mask_dice_term = base_loss.MaskDiceLoss(
        gt_key=_GT_KEY, pred_key=_PRED_KEY, weight_key=_WEIGHT_KEY,
        prediction_activation='softmax')
    self._mask_id_cross_entropy_loss = base_loss.TopKCrossEntropyLoss(
        gt_key=_GT_KEY, pred_key=_PRED_KEY, weight_key=_WEIGHT_KEY,
        # Num_classes and ignore_label are not necessary since the inputs will
        # be one hot encoded already.
        num_classes=None, ignore_label=None,
        top_k_percent_pixels=1.0, dynamic_weight=True)
    self._instance_discrimination_loss = base_loss.TopKCrossEntropyLoss(
        gt_key=_GT_KEY, pred_key=_PRED_KEY, weight_key=_WEIGHT_KEY,
        # Num_classes and ignore_label are not necessary since the inputs will
        # be one hot encoded already.
        num_classes=None, ignore_label=None,
        top_k_percent_pixels=1.0, dynamic_weight=True)

  def build(self,
            input_shapes: Tuple[Dict[Text, tf.Tensor], Dict[Text, tf.Tensor]]):
    """Extracts useful constants that depend on the input shapes."""
    y_true_shapes = input_shapes[0]
    self._max_thing_id = int(y_true_shapes[common.GT_THING_ID_CLASS_KEY][-1])
    y_pred_shapes = input_shapes[1]
    transformer_class_logits_shape = y_pred_shapes[
        common.PRED_TRANSFORMER_CLASS_LOGITS_KEY]
    self._num_mask_slots = int(transformer_class_logits_shape[1])
    # The transformer_class_logits contain thing classes, stuff classes, and the
    # void class, so num_thing_stuff_classes should be the total number of
    # classes minus one.
    self._num_thing_stuff_classes = int(transformer_class_logits_shape[2]) - 1
    # Since we implement the PQ-style loss with the class term plus the mask
    # dice term (Equation 10 of the paper), we need to balance the two terms to
    # have the same weight and normalizating constants. The focal loss alpha is
    # a weight on the positive class term, so we apply it to the mask dice term
    # too. The class loss is also normalized by the number of mask slots, so we
    # do the same normalization for the mask dice term.
    self._mask_dice_term_modifier = (
        self._focal_loss_alpha / self._num_mask_slots)

    self._stuff_class_ids = utils.get_stuff_class_ids(
        self._num_thing_stuff_classes,
        self._thing_class_ids,
        self._ignore_label)
    self._num_stuff_classes = len(self._stuff_class_ids)
    self._thing_stuff_class_ids = self._thing_class_ids + self._stuff_class_ids
    self._pixel_gt_num_mask_id = self._max_thing_id + self._num_stuff_classes

  def _pre_process_ground_truth(
      self, y_true: Dict[Text, tf.Tensor], output_height: int, output_width: int
  ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor,
             tf.Tensor]:
    """Pre-processes the ground truth before we compute the losses.

    This function generates tensors that do not depend on the prediction of the
    model, but are useful to the calculation of the losses. The function mainly
    downsamples the pixel space ground truth to the model output resolution, and
    combines (or concatenates) the thing masks and the stuff masks. The output
    shape pixel_gt_num_mask_id = max_thing_id + num_stuff_classes, which means
    the output masks contain both thing masks and stuff masks.

    Args:
      y_true: A dict of tensors providing ground-truth information, containing
       - common.GT_SEMANTIC_KEY: A [batch, height, width] int32 tf.Tensor, the
         semantic label map.
       - common.GT_THING_ID_MASK_KEY: A [batch, height, width] int32 tf.Tensor.
         It assigns each non-crowd thing instance a unique mask-ID label,
         starting from 0. Unassigned pixels are set to -1.
       - common.GT_THING_ID_CLASS_KEY: A [batch, max_thing_id] int32 tf.Tensor.
         It contains semantic ID of each instance assigned to thing_id_mask. The
         remaining (max_thing_id - num_things) elements are set to -1.
      output_height: An integer, the height of the model output.
      output_width: An integer, the width of the model output.

    Returns:
      pixel_gt_thing_mask: A [batch, output_height * output_width] float32
        tensor, with value 0.0 and 1.0 only, indicating whether a pixel belongs
        to a 'thing' class.
      pixel_gt_non_void_mask: A [batch, output_height * output_width] float32
        tensor, with value 0.0 and 1.0 only, indicating if a pixel does not
        belong to the void class.
      pixel_gt_mask_id_one_hot: A [batch, output_height * output_width,
        pixel_gt_num_mask_id] float32 tensor, with value 0.0 and 1.0 only,
        indicating the mask id each pixel belongs to.
      mask_gt_semantic_map: A [batch, pixel_gt_num_mask_id] int32 tensor, the
        semantic class of each ground truth mask.
      mask_gt_non_void_mask: A [batch, pixel_gt_num_mask_id] int32 tensor, with
        value 0.0 and 1.0 only, indicating if the ground truth mask is a valid
        mask, not a padded mask. The masks are padded because TPU does not
        support dynamic shapes except in the batch axis. We pad all ground truth
        thing masks to a large enough constant max_thing_id. Similarly, stuff
        classes that do not present in the current image will be set to a void
        mask too.
      mask_gt_semantic_one_hot: A [batch, pixel_gt_num_mask_id,
        num_thing_stuff_classes] float32 tensor, with value 0.0 and 1.0 only,
        containing the one hot encodings of the ground truth mask classes. The
        last dimension contains concatenated thing classes and stuff classes,
        which is different from the dataset class IDs in mask_gt_semantic_map.
      mask_gt_area: A [batch, pixel_gt_num_mask_id] float32 tensor, the area of
        each ground truth mask. Padded masks have an area of 0.0.
    """
    # The depth of one hot encoding should be the largest id plus one. For
    # example, if we want to one-hot encode a class ID of 133 (the largest ID
    # for the COCO dataset), we will need a one-hot encoding of length 134.
    one_hot_depth = max(self._thing_stuff_class_ids) + 1
    batch_size = y_true[common.GT_SEMANTIC_KEY].get_shape().as_list()[0]

    # Compute pixel_gt_semantic_map (downsampling and reshaping to the 1D
    # representation that will be mainly used in this loss function).
    pixel_gt_semantic_map = utils.strided_downsample(
        y_true[common.GT_SEMANTIC_KEY],
        target_size=[output_height, output_width])
    pixel_gt_semantic_map = tf.reshape(
        pixel_gt_semantic_map,
        [batch_size, output_height * output_width])

    # Compute pixel_gt_non_void_mask.
    pixel_gt_non_void_mask = tf.cast(
        tf.not_equal(pixel_gt_semantic_map, self._ignore_label), tf.float32)
    pixel_gt_non_void_mask = tf.ensure_shape(
        pixel_gt_non_void_mask,
        [batch_size, output_height * output_width])

    # Compute pixel_gt_semantic_one_hot from pixel_gt_semantic_map in order to
    # gather pixel_gt_stuff_id_one_hot from pixel_gt_semantic_one_hot.
    pixel_gt_semantic_one_hot = tf.one_hot(pixel_gt_semantic_map, one_hot_depth)
    # Convert the one hot encoding from the dataset id order to (thing, stuff)
    # order used in MaX-DeepLab.
    pixel_gt_stuff_id_one_hot = tf.gather(pixel_gt_semantic_one_hot,
                                          self._stuff_class_ids, axis=-1)
    pixel_gt_stuff_id_one_hot = tf.ensure_shape(
        pixel_gt_stuff_id_one_hot,
        [batch_size, output_height * output_width, self._num_stuff_classes])

    # Compute pixel_gt_thing_id_one_hot for thing masks.
    pixel_gt_thing_id_map = utils.strided_downsample(
        y_true[common.GT_THING_ID_MASK_KEY],
        target_size=[output_height, output_width])
    pixel_gt_thing_id_map = tf.reshape(
        pixel_gt_thing_id_map, shape=[batch_size, output_height * output_width])
    # Note that common.GT_THING_ID_MASK_KEY uses -1 for void masks. And 0 to
    # (num_mask_slots - 1) are used for num_mask_slots mask slots.
    pixel_gt_thing_mask = tf.cast(
        tf.not_equal(pixel_gt_thing_id_map, -1), tf.float32)
    pixel_gt_thing_id_one_hot = tf.one_hot(pixel_gt_thing_id_map,
                                           self._max_thing_id)
    # Compute pixel_gt_mask_id_one_hot by concatenating thing masks with stuff
    # masks.
    pixel_gt_mask_id_one_hot = tf.concat([pixel_gt_thing_id_one_hot,
                                          pixel_gt_stuff_id_one_hot], axis=-1)
    pixel_gt_mask_id_one_hot = tf.ensure_shape(
        pixel_gt_mask_id_one_hot,
        [batch_size, output_height * output_width, self._pixel_gt_num_mask_id])

    # Compute mask_gt_area by summing the one hot encodings spatially.
    mask_gt_area = tf.expand_dims(
        tf.reduce_sum(pixel_gt_mask_id_one_hot, axis=1), axis=-1)
    # Generate a binary mask for ground truth masks, indicating whether each
    # ground truth mask exists in the pixel space with a non-zero area. Note
    # that a mask that exists in the original input resolution will be removed
    # if its area is zero in the output resolution, due to downsampling.
    mask_gt_area_mask = tf.reshape(mask_gt_area > 0.5,
                                   [batch_size, self._pixel_gt_num_mask_id])

    # Compute mask_gt_semantic_map and mask_gt_semantic_one_hot.
    thing_id_gt_semantic_map = tf.reshape(
        tf.cast(y_true[common.GT_THING_ID_CLASS_KEY], tf.int32),
        [batch_size, self._max_thing_id])
    # The stuff ground truth semantic map is just the stuff class IDs.
    stuff_id_gt_semantic_map = tf.tile(
        tf.reshape(
            tf.cast(self._stuff_class_ids, tf.int32),
            [1, self._num_stuff_classes]), [batch_size, 1])
    mask_gt_semantic_map = tf.concat(
        [thing_id_gt_semantic_map, stuff_id_gt_semantic_map], axis=-1)
    # Set masks with zero area to void (-1), which is consistent with the void
    # label used in common.GT_THING_ID_CLASS_KEY but is different from the
    # ignore_labels of the datasets.
    mask_gt_semantic_map = (
        (mask_gt_semantic_map + 1) * tf.cast(mask_gt_area_mask, tf.int32) - 1)
    # Void (-1) classes will automatically be ignored by tf.one_hot.
    mask_gt_semantic_one_hot = tf.one_hot(mask_gt_semantic_map, one_hot_depth)
    mask_gt_semantic_one_hot = tf.gather(
        mask_gt_semantic_one_hot, self._thing_stuff_class_ids, axis=-1)

    # Compute mask_gt_non_void_mask. Again, a mask that exists in the original
    # input resolution is set to void if its area is zero in the output
    # resolution, due to downsampling.
    mask_gt_non_void_mask = tf.cast(mask_gt_semantic_map > -1, tf.float32)
    mask_gt_non_void_mask = tf.ensure_shape(
        mask_gt_non_void_mask, [batch_size, self._pixel_gt_num_mask_id])

    return (pixel_gt_thing_mask, pixel_gt_non_void_mask,
            pixel_gt_mask_id_one_hot, mask_gt_semantic_map,
            mask_gt_non_void_mask, mask_gt_semantic_one_hot, mask_gt_area)

  def call(
      self, inputs: Tuple[Dict[Text, tf.Tensor], Dict[Text, tf.Tensor]]
  ) -> Dict[Text, tf.Tensor]:
    """Computes the MaX-DeepLab losses.

    Args:
      inputs: A tuple of two dicts (y_true, y_pred):
      - y_true: A dict of tensors providing ground-truth information, containing
         - common.GT_SEMANTIC_KEY: A [batch, height, width] int32 tf.Tensor, the
           semantic label map.
         - common.GT_THING_ID_MASK_KEY: A [batch, height, width] int32
           tf.Tensor. It assigns each non-crowd thing instance a unique mask-ID
           label, starting from 0. Unassigned pixels are set to -1.
         - common.GT_THING_ID_CLASS_KEY: A [batch, max_thing_id] int32
           tf.Tensor. It contains semantic ID of each instance assigned to
           thing_id_mask. The remaining (max_thing_id - num_things) elements are
           set to -1.
      - y_pred: A dict of tensors providing predictions.
         - common.PRED_PIXEL_SPACE_NORMALIZED_FEATURE_KEY: A [batch_size,
           output_height, output_width, channels] float32 tensor.
         - common.PRED_PIXEL_SPACE_MASK_LOGITS_KEY: A [batch_size,
           output_height, output_width, num_mask_slots] float32 tensor, the
           logits that a pixel belongs to a mask slot.
         - common.PRED_TRANSFORMER_CLASS_LOGITS_KEY: A [batch_size,
           num_mask_slots, num_thing_stuff_classes + 1] float32 tensor, the
           logits that a mask belongs to a semantic class (including thing,
           stuff, and void)

    Returns:
      The loss as a dict of tf.Tensor, optionally containing the following:
      - common.PQ_STYLE_LOSS_CLASS_TERM: [batch].
      - common.PQ_STYLE_LOSS_MASK_DICE_TERM: [batch].
      - common.MASK_ID_CROSS_ENTROPY_LOSS: [batch].
      - common.INSTANCE_DISCRIMINATION_LOSS: [batch].
    """
    y_true, y_pred = inputs
    resulting_dict = {}

    pixel_feature = y_pred[common.PRED_PIXEL_SPACE_NORMALIZED_FEATURE_KEY]
    batch_size, output_height, output_width, _ = (
        pixel_feature.get_shape().as_list())

    # Pre-process the ground truth.
    (pixel_gt_thing_mask, pixel_gt_non_void_mask, pixel_gt_mask_id_one_hot,
     mask_gt_semantic_map, mask_gt_non_void_mask, mask_gt_semantic_one_hot,
     mask_gt_area) = self._pre_process_ground_truth(y_true,
                                                    output_height, output_width)
    pixel_gt_non_void_mask_expanded = tf.expand_dims(
        pixel_gt_non_void_mask, axis=-1)

    # Compute mask_average_feature by averaging the feature of each mask.
    pixel_feature = tf.reshape(
        pixel_feature, [batch_size, output_height * output_width, -1])
    mask_average_feature = tf.einsum(
        'bpd,bpi->bid',
        pixel_feature,
        pixel_gt_mask_id_one_hot) / tf.maximum(mask_gt_area, 1.0)
    # Normalize the mask feature as the pixel space output feature is usually
    # normalized too.
    mask_average_feature = tf.math.l2_normalize(mask_average_feature, axis=-1)

    # Compute instance_discrimination_similarity, scaled by a constant
    # temperature.
    instance_discrimination_similarity = tf.einsum(
        'bpd,bid->bpi', pixel_feature, mask_average_feature)
    instance_discrimination_similarity /= (
        self._instance_discrimination_temperature)
    mask_gt_non_void_mask_expanded_1 = tf.expand_dims(
        mask_gt_non_void_mask, axis=1)
    # Mask void masks by setting them to a large negative value, so that they
    # will be ignored by the softmax in the loss.
    instance_discrimination_similarity = (
        mask_gt_non_void_mask_expanded_1 * instance_discrimination_similarity +
        (1.0 - mask_gt_non_void_mask_expanded_1) * _SOFTMAX_MASKING_CONSTANT)

    # Auxiliary instance_discrimination_loss.
    if self._instance_discrimination_loss_weight > 0.0:
      resulting_dict[common.INSTANCE_DISCRIMINATION_LOSS] = (
          self._instance_discrimination_loss(
              {_GT_KEY: pixel_gt_mask_id_one_hot},
              {_PRED_KEY: instance_discrimination_similarity,
               _WEIGHT_KEY: pixel_gt_thing_mask}) *
          self._instance_discrimination_loss_weight)

    # Extract pixel_space_mask_logits and pixel_space_mask_probs.
    pixel_space_mask_logits = y_pred[common.PRED_PIXEL_SPACE_MASK_LOGITS_KEY]
    pixel_space_mask_logits = tf.reshape(
        pixel_space_mask_logits,
        [batch_size, output_height * output_width, self._num_mask_slots])
    pixel_space_mask_probs = tf.nn.softmax(pixel_space_mask_logits, axis=-1)

    # Compute the mask similarity between all ground truth masks and all
    # predicted masks.
    mask_similarity = _mask_similarity(
        pixel_gt_mask_id_one_hot,
        pixel_space_mask_probs * pixel_gt_non_void_mask_expanded,
        metric='dice')

    # Compute the class similarity by multiplying the ground truth one hot
    # encoding with the predicted probability distribution. This is done between
    # all ground truth masks and all predicted masks.
    transformer_class_logits = y_pred[common.PRED_TRANSFORMER_CLASS_LOGITS_KEY]
    transformer_class_probs = tf.nn.softmax(
        transformer_class_logits, axis=-1)[:, :, :-1]
    class_similarity = tf.einsum(
        'bij,bkj->bik', mask_gt_semantic_one_hot, transformer_class_probs)

    # Compute hungarian matching weights. We take the negative here since the
    # hungarian matching algorithm looks for the matching with the least total
    # weight.
    hungarian_weights = - mask_similarity * class_similarity
    mask_gt_non_void_mask_expanded_2 = tf.expand_dims(
        mask_gt_non_void_mask, axis=2)

    # Mask the void ground truth masks (in the rows) so that they do not affect
    # the result of the hungarian matching.
    if self._num_mask_slots >= self._pixel_gt_num_mask_id:
      # If the number of mask slots (number of columns) is larger than the
      # constant number of ground truth masks (number of rows), the
      # nonsquare_hungarian_matching will pad the rows with
      # _MATCHING_NEGATIVE_CONSTANT. In this case, we can fill in the void mask
      # rows with _MATCHING_NEGATIVE_CONSTANT too, then the void mask rows will
      # be ignored too, according to the hungarian matching property.
      hungarian_weights = (
          hungarian_weights * mask_gt_non_void_mask_expanded_2 +
          (1 - mask_gt_non_void_mask_expanded_2) * _MATCHING_NEGATIVE_CONSTANT)
    else:
      # If the number of mask slots (number of columns) is smaller than the
      # constant number of ground truth masks (number of rows), the
      # nonsquare_hungarian_matching will pad the columns with
      # _MATCHING_NEGATIVE_CONSTANT. In this case, we should fill in the void
      # mask rows with _MATCHING_POSITIVE_CONSTANT here, then the void mask rows
      # will have a huge cost compared with existing non-void mask rows, so that
      # the predicted masks will prefer matching with existing non-void masks
      # rather than the padded void masks, according to the hungarian matching
      # property.
      hungarian_weights = (
          hungarian_weights * mask_gt_non_void_mask_expanded_2 +
          (1 - mask_gt_non_void_mask_expanded_2) * _MATCHING_POSITIVE_CONSTANT)

    # Perform the hungarian matching algorithm.
    full_permutation, nonsquare_permutation = (
        nonsquare_hungarian_matching(hungarian_weights))

    # Extract the permutation (matching) between all existing non-void ground
    # truth masks and the matched predicted masks.
    matched_permutation = (
        nonsquare_permutation * mask_gt_non_void_mask_expanded_2)
    # The matched mask dice scores for each mask slot. The scores will be used
    # as a loss weight for the PQ-style loss class term after the stop_gradient.
    matched_mask_dice = tf.reduce_max(
        mask_similarity * matched_permutation, axis=-2)
    matched_mask_dice = tf.stop_gradient(matched_mask_dice)

    # The matched class probabilities for each ground truth mask. The
    # probabilities will be used as a loss weight for the PQ-style loss mask
    # dice term after the stop_gradient.
    matched_class_prob = tf.reduce_max(
        class_similarity * matched_permutation, axis=-1)
    matched_class_prob = tf.stop_gradient(matched_class_prob)

    # Extract the index of the matched mask slot for each ground truth mask.
    matched_mask_slot_indices = tf.math.argmax(
        nonsquare_permutation, axis=-1, output_type=tf.dtypes.int32)

    full_num_mask_slots = full_permutation.get_shape().as_list()[-1]
    # Pad the pixel_space_mask_logits so that it is compatible with the
    # permutation matrix.
    full_pixel_space_mask_logits = tf.pad(
        pixel_space_mask_logits,
        [[0, 0], [0, 0], [0, full_num_mask_slots - self._num_mask_slots]],
        constant_values=_SOFTMAX_MASKING_CONSTANT)

    # Permute the pixel space mask logits with the permutation matrix, which
    # converts the mask slot indices to the ground truth indices.
    permuted_full_pixel_space_mask_logits = tf.einsum(
        'bpi,bji->bpj', full_pixel_space_mask_logits, full_permutation)

    # Pad the class probabilities too.
    full_matched_class_prob = tf.pad(
        matched_class_prob,
        [[0, 0], [0, full_num_mask_slots - self._pixel_gt_num_mask_id]])
    # We only compute dice loss term on non-void ground truth masks.
    mask_dice_term_loss_weight = tf.pad(
        mask_gt_non_void_mask,
        [[0, 0], [0, full_num_mask_slots - self._pixel_gt_num_mask_id]])
    # Use the class probabilities as the loss weight for the mask dice term. In
    # addition, we set a lower bound, 1e-5, for the mask dice term loss weight.
    # Otherwise, if a loss weight is accidentally zero, the dice loss will treat
    # it as void and use an incorrect denominator or normalizing constant for
    # the loss.
    mask_dice_term_loss_weight *= tf.maximum(full_matched_class_prob, 1e-5)

    # Pad the one hot encoding too.
    full_pixel_gt_mask_id_one_hot = tf.pad(
        pixel_gt_mask_id_one_hot,
        [[0, 0], [0, 0], [0, full_num_mask_slots - self._pixel_gt_num_mask_id]])

    if self._pq_style_loss_weight > 0.0:
      # Mask_dice_term_modifier balances the mask_dice_term and the class_term
      # of the PQ-style loss to have the same weight and normalizating constant.
      resulting_dict[common.PQ_STYLE_LOSS_MASK_DICE_TERM] = (
          self._pq_style_loss_mask_dice_term(
              {_GT_KEY: full_pixel_gt_mask_id_one_hot},
              {_PRED_KEY: permuted_full_pixel_space_mask_logits,
               _WEIGHT_KEY: mask_dice_term_loss_weight}) *
          (self._pq_style_loss_weight * self._mask_dice_term_modifier))

    # Mask-ID cross entropy loss shares the same ground truth and logits as the
    # dice loss term, but with different weights.
    if self._mask_id_cross_entropy_loss_weight > 0.0:
      resulting_dict[common.MASK_ID_CROSS_ENTROPY_LOSS] = (
          self._mask_id_cross_entropy_loss(
              {_GT_KEY: full_pixel_gt_mask_id_one_hot},
              {_PRED_KEY: permuted_full_pixel_space_mask_logits,
               _WEIGHT_KEY: pixel_gt_non_void_mask}) *
          self._mask_id_cross_entropy_loss_weight)

    # Generate a pseudo ground truth for transformer_class_logits.
    mask_slot_semantic_one_hot = _generate_mask_slot_semantic_one_hot(
        matched_mask_slot_indices, mask_gt_semantic_map,
        self._num_mask_slots, self._thing_stuff_class_ids)

    # Compute the positive mask and the negative mask.
    mask_slot_positive_mask = tf.cast(tf.equal(tf.reduce_max(
        mask_slot_semantic_one_hot, axis=-1), 1.0), tf.float32)
    mask_slot_negative_mask = 1.0 - mask_slot_positive_mask

    # Compute the overlap ratio between all predicted masks and the void region.
    # This void ratio will be used as a weight for the negative class term.
    mask_void_ratio = tf.stop_gradient(_mask_similarity(
        1.0 - pixel_gt_non_void_mask_expanded,
        pixel_space_mask_probs,
        'intersection_over_prediction'))
    mask_void_ratio = tf.squeeze(mask_void_ratio, axis=1)

    # Use the matched mask dice scores as the weights for the positive class
    # terms. For the negative class terms, we reduce the penalty for a mask slot
    # class term if the mask prediction overlaps a lot with void regions.
    transformer_class_loss_weight = (
        mask_slot_positive_mask * tf.maximum(matched_mask_dice, 1e-5) +
        mask_slot_negative_mask * tf.maximum(mask_void_ratio, 1e-5))

    # Concatenate the void mask in the last channel, constructing the final
    # ground truth one hot label with (thing + stuff + void) channels.
    transformer_class_one_hot = tf.concat(
        [mask_slot_semantic_one_hot,
         tf.expand_dims(mask_slot_negative_mask, axis=-1)], axis=-1)

    # Apply the PQ-style loss class term.
    if self._pq_style_loss_weight > 0.0:
      resulting_dict[common.PQ_STYLE_LOSS_CLASS_TERM] = (
          self._pq_style_loss_class_term(
              {_GT_KEY: transformer_class_one_hot},
              {_PRED_KEY: transformer_class_logits,
               _WEIGHT_KEY: transformer_class_loss_weight}) *
          self._pq_style_loss_weight)

    return resulting_dict