Spaces:
Runtime error
Runtime error
File size: 9,216 Bytes
d1843be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
# coding=utf-8
# Copyright 2021 The Deeplab2 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This file contains loss builder classes used in the DeepLab model."""
import collections
from typing import Any, Dict, Text, Tuple, Optional
import tensorflow as tf
from deeplab2 import common
from deeplab2 import config_pb2
from deeplab2.model.loss import base_loss
from deeplab2.model.loss import max_deeplab_loss
def _create_loss_and_weight(
loss_options: config_pb2.LossOptions.SingleLossOptions, gt_key: Text,
pred_key: Text, weight_key: Text, **kwargs: Any) -> tf.keras.losses.Loss:
"""Creates a loss and its weight from loss options.
Args:
loss_options: Loss options as defined by
config_pb2.LossOptions.SingleLossOptions or None.
gt_key: A key to extract the ground-truth from a dictionary.
pred_key: A key to extract the prediction from a dictionary.
weight_key: A key to extract the per-pixel weights from a dictionary.
**kwargs: Additional parameters to initialize the loss.
Returns:
A tuple of an instance of tf.keras.losses.Loss and its corresponding weight
as an integer.
Raises:
ValueError: An error occurs when the loss name is not a valid loss.
"""
if loss_options is None:
return None, 0
if loss_options.name == 'softmax_cross_entropy':
return base_loss.TopKCrossEntropyLoss(
gt_key,
pred_key,
weight_key,
top_k_percent_pixels=loss_options.top_k_percent,
**kwargs), loss_options.weight
elif loss_options.name == 'l1':
return base_loss.TopKGeneralLoss(
base_loss.mean_absolute_error,
gt_key,
pred_key,
weight_key,
top_k_percent_pixels=loss_options.top_k_percent), loss_options.weight
elif loss_options.name == 'mse':
return base_loss.TopKGeneralLoss(
base_loss.mean_squared_error,
gt_key,
pred_key,
weight_key,
top_k_percent_pixels=loss_options.top_k_percent), loss_options.weight
raise ValueError('Loss %s is not a valid loss.' % loss_options.name)
class DeepLabFamilyLoss(tf.keras.layers.Layer):
"""This class contains code to build and call losses for DeepLabFamilyLoss."""
def __init__(
self,
loss_options: config_pb2.LossOptions,
num_classes: Optional[int],
ignore_label: Optional[int],
thing_class_ids: Tuple[int]):
"""Initializes the losses for Panoptic-DeepLab.
Args:
loss_options: Loss options as defined by config_pb2.LossOptions.
num_classes: An integer specifying the number of classes in the dataset.
ignore_label: An optional integer specifying the ignore label or None.
thing_class_ids: A tuple of length [N] containing N thing indices.
"""
super(DeepLabFamilyLoss, self).__init__(name='DeepLabFamilyLoss')
# Single-term losses are losses that have only one loss term and thus each
# loss function directly returns a single tensor as the loss value, as
# opposed to multi-term losses that involve multiple terms and return a
# dictionary of loss values.
self._single_term_loss_func_and_weight_dict = collections.OrderedDict()
self._extra_loss_names = [common.TOTAL_LOSS]
if loss_options.HasField(common.SEMANTIC_LOSS):
self._single_term_loss_func_and_weight_dict[
common.SEMANTIC_LOSS] = _create_loss_and_weight(
loss_options.semantic_loss,
common.GT_SEMANTIC_KEY,
common.PRED_SEMANTIC_LOGITS_KEY,
common.SEMANTIC_LOSS_WEIGHT_KEY,
num_classes=num_classes,
ignore_label=ignore_label)
if loss_options.HasField(common.CENTER_LOSS):
self._single_term_loss_func_and_weight_dict[
common.CENTER_LOSS] = _create_loss_and_weight(
loss_options.center_loss, common.GT_INSTANCE_CENTER_KEY,
common.PRED_CENTER_HEATMAP_KEY, common.CENTER_LOSS_WEIGHT_KEY)
if loss_options.HasField(common.REGRESSION_LOSS):
self._single_term_loss_func_and_weight_dict[
common.REGRESSION_LOSS] = _create_loss_and_weight(
loss_options.regression_loss, common.GT_INSTANCE_REGRESSION_KEY,
common.PRED_OFFSET_MAP_KEY, common.REGRESSION_LOSS_WEIGHT_KEY)
# Currently, only used for Motion-DeepLab.
if loss_options.HasField(common.MOTION_LOSS):
self._single_term_loss_func_and_weight_dict[
common.MOTION_LOSS] = _create_loss_and_weight(
loss_options.motion_loss, common.GT_FRAME_OFFSET_KEY,
common.PRED_FRAME_OFFSET_MAP_KEY,
common.FRAME_REGRESSION_LOSS_WEIGHT_KEY)
# Next-frame regression loss used in ViP-DeepLab.
if loss_options.HasField(common.NEXT_REGRESSION_LOSS):
self._single_term_loss_func_and_weight_dict[
common.NEXT_REGRESSION_LOSS] = _create_loss_and_weight(
loss_options.next_regression_loss,
common.GT_NEXT_INSTANCE_REGRESSION_KEY,
common.PRED_NEXT_OFFSET_MAP_KEY,
common.NEXT_REGRESSION_LOSS_WEIGHT_KEY)
# Multi-term losses that return dictionaries of loss terms.
self._multi_term_losses = []
# MaXDeepLabLoss optionally returns four loss terms in total:
# - common.PQ_STYLE_LOSS_CLASS_TERM
# - common.PQ_STYLE_LOSS_MASK_DICE_TERM
# - common.MASK_ID_CROSS_ENTROPY_LOSS
# - common.INSTANCE_DISCRIMINATION_LOSS
if any([loss_options.HasField('pq_style_loss'),
loss_options.HasField('mask_id_cross_entropy_loss'),
loss_options.HasField('instance_discrimination_loss')]):
self._multi_term_losses.append(max_deeplab_loss.MaXDeepLabLoss(
loss_options, ignore_label, thing_class_ids))
for multi_term_loss in self._multi_term_losses:
self._extra_loss_names += multi_term_loss.loss_terms
def get_loss_names(self):
# Keep track of all the keys that will be returned in self.call().
loss_names = list(self._single_term_loss_func_and_weight_dict.keys())
return loss_names + self._extra_loss_names
def call(self, y_true: Dict[Text, tf.Tensor],
y_pred: Dict[Text, tf.Tensor]) -> Dict[Text, tf.Tensor]:
"""Performs the loss computations given ground-truth and predictions.
The loss is computed for each sample separately. Currently, smoothed
ground-truth labels are not supported.
Args:
y_true: A dictionary of tf.Tensor containing all ground-truth data to
compute the loss. Depending on the configuration, the dict has to
contain common.GT_SEMANTIC_KEY, and optionally
common.GT_INSTANCE_CENTER_KEY, common.GT_INSTANCE_REGRESSION_KEY, and
common.GT_FRAME_OFFSET_KEY.
y_pred: A dicitionary of tf.Tensor containing all predictions to compute
the loss. Depending on the configuration, the dict has to contain
common.PRED_SEMANTIC_LOGITS_KEY, and optionally
common.PRED_CENTER_HEATMAP_KEY, common.PRED_OFFSET_MAP_KEY, and
common.PRED_FRAME_OFFSET_MAP_KEY.
Returns:
The loss as a dict of tf.Tensor, optionally containing the following:
- common.SEMANTIC_LOSS: [batch].
- common.CENTER_LOSS: [batch].
- common.REGRESSION_LOSS: [batch].
- common.MOTION_LOSS: [batch], the frame offset regression loss.
- common.NEXT_REGRESSION_LOSS: [batch], the next regression loss.
Raises:
AssertionError: If the keys of the resulting_dict do not match
self.get_loss_names().
AssertionError: The keys of the resulting_dict overlap with the keys of
the loss_dict.
"""
resulting_dict = collections.OrderedDict()
# Single-term losses.
for loss_name, func_and_weight in (
self._single_term_loss_func_and_weight_dict.items()):
loss_func, loss_weight = func_and_weight
loss_value = loss_func(y_true, y_pred)
resulting_dict[loss_name] = loss_value * loss_weight
# Multi-term losses predict a dictionary, so we handle them differently.
for multi_term_loss in self._multi_term_losses:
loss_dict = multi_term_loss((y_true, y_pred))
if not set(loss_dict).isdisjoint(resulting_dict):
raise AssertionError('The keys of the resulting_dict overlap with the '
'keys of the loss_dict.')
resulting_dict.update(loss_dict)
# Also include the total loss in the resulting_dict.
total_loss = tf.math.accumulate_n(list(resulting_dict.values()))
resulting_dict[common.TOTAL_LOSS] = total_loss
if sorted(resulting_dict.keys()) != sorted(self.get_loss_names()):
raise AssertionError(
'The keys of the resulting_dict should match self.get_loss_names().')
return resulting_dict
|