File size: 9,216 Bytes
d1843be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
# coding=utf-8
# Copyright 2021 The Deeplab2 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This file contains loss builder classes used in the DeepLab model."""

import collections
from typing import Any, Dict, Text, Tuple, Optional

import tensorflow as tf

from deeplab2 import common
from deeplab2 import config_pb2
from deeplab2.model.loss import base_loss
from deeplab2.model.loss import max_deeplab_loss


def _create_loss_and_weight(
    loss_options: config_pb2.LossOptions.SingleLossOptions, gt_key: Text,
    pred_key: Text, weight_key: Text, **kwargs: Any) -> tf.keras.losses.Loss:
  """Creates a loss and its weight from loss options.

  Args:
    loss_options: Loss options as defined by
      config_pb2.LossOptions.SingleLossOptions or None.
    gt_key: A key to extract the ground-truth from a dictionary.
    pred_key: A key to extract the prediction from a dictionary.
    weight_key: A key to extract the per-pixel weights from a dictionary.
    **kwargs: Additional parameters to initialize the loss.

  Returns:
    A tuple of an instance of tf.keras.losses.Loss and its corresponding weight
    as an integer.

  Raises:
    ValueError: An error occurs when the loss name is not a valid loss.
  """
  if loss_options is None:
    return None, 0
  if loss_options.name == 'softmax_cross_entropy':
    return base_loss.TopKCrossEntropyLoss(
        gt_key,
        pred_key,
        weight_key,
        top_k_percent_pixels=loss_options.top_k_percent,
        **kwargs), loss_options.weight
  elif loss_options.name == 'l1':
    return base_loss.TopKGeneralLoss(
        base_loss.mean_absolute_error,
        gt_key,
        pred_key,
        weight_key,
        top_k_percent_pixels=loss_options.top_k_percent), loss_options.weight
  elif loss_options.name == 'mse':
    return base_loss.TopKGeneralLoss(
        base_loss.mean_squared_error,
        gt_key,
        pred_key,
        weight_key,
        top_k_percent_pixels=loss_options.top_k_percent), loss_options.weight

  raise ValueError('Loss %s is not a valid loss.' % loss_options.name)


class DeepLabFamilyLoss(tf.keras.layers.Layer):
  """This class contains code to build and call losses for DeepLabFamilyLoss."""

  def __init__(
      self,
      loss_options: config_pb2.LossOptions,
      num_classes: Optional[int],
      ignore_label: Optional[int],
      thing_class_ids: Tuple[int]):
    """Initializes the losses for Panoptic-DeepLab.

    Args:
      loss_options: Loss options as defined by config_pb2.LossOptions.
      num_classes: An integer specifying the number of classes in the dataset.
      ignore_label: An optional integer specifying the ignore label or None.
      thing_class_ids: A tuple of length [N] containing N thing indices.
    """
    super(DeepLabFamilyLoss, self).__init__(name='DeepLabFamilyLoss')

    # Single-term losses are losses that have only one loss term and thus each
    # loss function directly returns a single tensor as the loss value, as
    # opposed to multi-term losses that involve multiple terms and return a
    # dictionary of loss values.
    self._single_term_loss_func_and_weight_dict = collections.OrderedDict()
    self._extra_loss_names = [common.TOTAL_LOSS]

    if loss_options.HasField(common.SEMANTIC_LOSS):
      self._single_term_loss_func_and_weight_dict[
          common.SEMANTIC_LOSS] = _create_loss_and_weight(
              loss_options.semantic_loss,
              common.GT_SEMANTIC_KEY,
              common.PRED_SEMANTIC_LOGITS_KEY,
              common.SEMANTIC_LOSS_WEIGHT_KEY,
              num_classes=num_classes,
              ignore_label=ignore_label)

    if loss_options.HasField(common.CENTER_LOSS):
      self._single_term_loss_func_and_weight_dict[
          common.CENTER_LOSS] = _create_loss_and_weight(
              loss_options.center_loss, common.GT_INSTANCE_CENTER_KEY,
              common.PRED_CENTER_HEATMAP_KEY, common.CENTER_LOSS_WEIGHT_KEY)

    if loss_options.HasField(common.REGRESSION_LOSS):
      self._single_term_loss_func_and_weight_dict[
          common.REGRESSION_LOSS] = _create_loss_and_weight(
              loss_options.regression_loss, common.GT_INSTANCE_REGRESSION_KEY,
              common.PRED_OFFSET_MAP_KEY, common.REGRESSION_LOSS_WEIGHT_KEY)

    # Currently, only used for Motion-DeepLab.
    if loss_options.HasField(common.MOTION_LOSS):
      self._single_term_loss_func_and_weight_dict[
          common.MOTION_LOSS] = _create_loss_and_weight(
              loss_options.motion_loss, common.GT_FRAME_OFFSET_KEY,
              common.PRED_FRAME_OFFSET_MAP_KEY,
              common.FRAME_REGRESSION_LOSS_WEIGHT_KEY)

    # Next-frame regression loss used in ViP-DeepLab.
    if loss_options.HasField(common.NEXT_REGRESSION_LOSS):
      self._single_term_loss_func_and_weight_dict[
          common.NEXT_REGRESSION_LOSS] = _create_loss_and_weight(
              loss_options.next_regression_loss,
              common.GT_NEXT_INSTANCE_REGRESSION_KEY,
              common.PRED_NEXT_OFFSET_MAP_KEY,
              common.NEXT_REGRESSION_LOSS_WEIGHT_KEY)

    # Multi-term losses that return dictionaries of loss terms.
    self._multi_term_losses = []

    # MaXDeepLabLoss optionally returns four loss terms in total:
    # - common.PQ_STYLE_LOSS_CLASS_TERM
    # - common.PQ_STYLE_LOSS_MASK_DICE_TERM
    # - common.MASK_ID_CROSS_ENTROPY_LOSS
    # - common.INSTANCE_DISCRIMINATION_LOSS
    if any([loss_options.HasField('pq_style_loss'),
            loss_options.HasField('mask_id_cross_entropy_loss'),
            loss_options.HasField('instance_discrimination_loss')]):
      self._multi_term_losses.append(max_deeplab_loss.MaXDeepLabLoss(
          loss_options, ignore_label, thing_class_ids))

    for multi_term_loss in self._multi_term_losses:
      self._extra_loss_names += multi_term_loss.loss_terms

  def get_loss_names(self):
    # Keep track of all the keys that will be returned in self.call().
    loss_names = list(self._single_term_loss_func_and_weight_dict.keys())
    return loss_names + self._extra_loss_names

  def call(self, y_true: Dict[Text, tf.Tensor],
           y_pred: Dict[Text, tf.Tensor]) -> Dict[Text, tf.Tensor]:
    """Performs the loss computations given ground-truth and predictions.

    The loss is computed for each sample separately. Currently, smoothed
    ground-truth labels are not supported.

    Args:
      y_true: A dictionary of tf.Tensor containing all ground-truth data to
        compute the loss. Depending on the configuration, the dict has to
        contain common.GT_SEMANTIC_KEY, and optionally
        common.GT_INSTANCE_CENTER_KEY, common.GT_INSTANCE_REGRESSION_KEY, and
        common.GT_FRAME_OFFSET_KEY.
      y_pred: A dicitionary of tf.Tensor containing all predictions to compute
        the loss. Depending on the configuration, the dict has to contain
        common.PRED_SEMANTIC_LOGITS_KEY, and optionally
        common.PRED_CENTER_HEATMAP_KEY, common.PRED_OFFSET_MAP_KEY, and
        common.PRED_FRAME_OFFSET_MAP_KEY.

    Returns:
      The loss as a dict of tf.Tensor, optionally containing the following:
      - common.SEMANTIC_LOSS: [batch].
      - common.CENTER_LOSS: [batch].
      - common.REGRESSION_LOSS: [batch].
      - common.MOTION_LOSS: [batch], the frame offset regression loss.
      - common.NEXT_REGRESSION_LOSS: [batch], the next regression loss.

    Raises:
      AssertionError: If the keys of the resulting_dict do not match
        self.get_loss_names().
      AssertionError: The keys of the resulting_dict overlap with the keys of
        the loss_dict.
    """
    resulting_dict = collections.OrderedDict()

    # Single-term losses.
    for loss_name, func_and_weight in (
        self._single_term_loss_func_and_weight_dict.items()):
      loss_func, loss_weight = func_and_weight
      loss_value = loss_func(y_true, y_pred)
      resulting_dict[loss_name] = loss_value * loss_weight

    # Multi-term losses predict a dictionary, so we handle them differently.
    for multi_term_loss in self._multi_term_losses:
      loss_dict = multi_term_loss((y_true, y_pred))
      if not set(loss_dict).isdisjoint(resulting_dict):
        raise AssertionError('The keys of the resulting_dict overlap with the '
                             'keys of the loss_dict.')
      resulting_dict.update(loss_dict)

    # Also include the total loss in the resulting_dict.
    total_loss = tf.math.accumulate_n(list(resulting_dict.values()))
    resulting_dict[common.TOTAL_LOSS] = total_loss

    if sorted(resulting_dict.keys()) != sorted(self.get_loss_names()):
      raise AssertionError(
          'The keys of the resulting_dict should match self.get_loss_names().')
    return resulting_dict