karolmajek's picture
from https://huggingface.co/spaces/akhaliq/deeplab2
d1843be
# coding=utf-8
# Copyright 2021 The Deeplab2 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This file contains loss builder classes used in the DeepLab model."""
import collections
from typing import Any, Dict, Text, Tuple, Optional
import tensorflow as tf
from deeplab2 import common
from deeplab2 import config_pb2
from deeplab2.model.loss import base_loss
from deeplab2.model.loss import max_deeplab_loss
def _create_loss_and_weight(
loss_options: config_pb2.LossOptions.SingleLossOptions, gt_key: Text,
pred_key: Text, weight_key: Text, **kwargs: Any) -> tf.keras.losses.Loss:
"""Creates a loss and its weight from loss options.
Args:
loss_options: Loss options as defined by
config_pb2.LossOptions.SingleLossOptions or None.
gt_key: A key to extract the ground-truth from a dictionary.
pred_key: A key to extract the prediction from a dictionary.
weight_key: A key to extract the per-pixel weights from a dictionary.
**kwargs: Additional parameters to initialize the loss.
Returns:
A tuple of an instance of tf.keras.losses.Loss and its corresponding weight
as an integer.
Raises:
ValueError: An error occurs when the loss name is not a valid loss.
"""
if loss_options is None:
return None, 0
if loss_options.name == 'softmax_cross_entropy':
return base_loss.TopKCrossEntropyLoss(
gt_key,
pred_key,
weight_key,
top_k_percent_pixels=loss_options.top_k_percent,
**kwargs), loss_options.weight
elif loss_options.name == 'l1':
return base_loss.TopKGeneralLoss(
base_loss.mean_absolute_error,
gt_key,
pred_key,
weight_key,
top_k_percent_pixels=loss_options.top_k_percent), loss_options.weight
elif loss_options.name == 'mse':
return base_loss.TopKGeneralLoss(
base_loss.mean_squared_error,
gt_key,
pred_key,
weight_key,
top_k_percent_pixels=loss_options.top_k_percent), loss_options.weight
raise ValueError('Loss %s is not a valid loss.' % loss_options.name)
class DeepLabFamilyLoss(tf.keras.layers.Layer):
"""This class contains code to build and call losses for DeepLabFamilyLoss."""
def __init__(
self,
loss_options: config_pb2.LossOptions,
num_classes: Optional[int],
ignore_label: Optional[int],
thing_class_ids: Tuple[int]):
"""Initializes the losses for Panoptic-DeepLab.
Args:
loss_options: Loss options as defined by config_pb2.LossOptions.
num_classes: An integer specifying the number of classes in the dataset.
ignore_label: An optional integer specifying the ignore label or None.
thing_class_ids: A tuple of length [N] containing N thing indices.
"""
super(DeepLabFamilyLoss, self).__init__(name='DeepLabFamilyLoss')
# Single-term losses are losses that have only one loss term and thus each
# loss function directly returns a single tensor as the loss value, as
# opposed to multi-term losses that involve multiple terms and return a
# dictionary of loss values.
self._single_term_loss_func_and_weight_dict = collections.OrderedDict()
self._extra_loss_names = [common.TOTAL_LOSS]
if loss_options.HasField(common.SEMANTIC_LOSS):
self._single_term_loss_func_and_weight_dict[
common.SEMANTIC_LOSS] = _create_loss_and_weight(
loss_options.semantic_loss,
common.GT_SEMANTIC_KEY,
common.PRED_SEMANTIC_LOGITS_KEY,
common.SEMANTIC_LOSS_WEIGHT_KEY,
num_classes=num_classes,
ignore_label=ignore_label)
if loss_options.HasField(common.CENTER_LOSS):
self._single_term_loss_func_and_weight_dict[
common.CENTER_LOSS] = _create_loss_and_weight(
loss_options.center_loss, common.GT_INSTANCE_CENTER_KEY,
common.PRED_CENTER_HEATMAP_KEY, common.CENTER_LOSS_WEIGHT_KEY)
if loss_options.HasField(common.REGRESSION_LOSS):
self._single_term_loss_func_and_weight_dict[
common.REGRESSION_LOSS] = _create_loss_and_weight(
loss_options.regression_loss, common.GT_INSTANCE_REGRESSION_KEY,
common.PRED_OFFSET_MAP_KEY, common.REGRESSION_LOSS_WEIGHT_KEY)
# Currently, only used for Motion-DeepLab.
if loss_options.HasField(common.MOTION_LOSS):
self._single_term_loss_func_and_weight_dict[
common.MOTION_LOSS] = _create_loss_and_weight(
loss_options.motion_loss, common.GT_FRAME_OFFSET_KEY,
common.PRED_FRAME_OFFSET_MAP_KEY,
common.FRAME_REGRESSION_LOSS_WEIGHT_KEY)
# Next-frame regression loss used in ViP-DeepLab.
if loss_options.HasField(common.NEXT_REGRESSION_LOSS):
self._single_term_loss_func_and_weight_dict[
common.NEXT_REGRESSION_LOSS] = _create_loss_and_weight(
loss_options.next_regression_loss,
common.GT_NEXT_INSTANCE_REGRESSION_KEY,
common.PRED_NEXT_OFFSET_MAP_KEY,
common.NEXT_REGRESSION_LOSS_WEIGHT_KEY)
# Multi-term losses that return dictionaries of loss terms.
self._multi_term_losses = []
# MaXDeepLabLoss optionally returns four loss terms in total:
# - common.PQ_STYLE_LOSS_CLASS_TERM
# - common.PQ_STYLE_LOSS_MASK_DICE_TERM
# - common.MASK_ID_CROSS_ENTROPY_LOSS
# - common.INSTANCE_DISCRIMINATION_LOSS
if any([loss_options.HasField('pq_style_loss'),
loss_options.HasField('mask_id_cross_entropy_loss'),
loss_options.HasField('instance_discrimination_loss')]):
self._multi_term_losses.append(max_deeplab_loss.MaXDeepLabLoss(
loss_options, ignore_label, thing_class_ids))
for multi_term_loss in self._multi_term_losses:
self._extra_loss_names += multi_term_loss.loss_terms
def get_loss_names(self):
# Keep track of all the keys that will be returned in self.call().
loss_names = list(self._single_term_loss_func_and_weight_dict.keys())
return loss_names + self._extra_loss_names
def call(self, y_true: Dict[Text, tf.Tensor],
y_pred: Dict[Text, tf.Tensor]) -> Dict[Text, tf.Tensor]:
"""Performs the loss computations given ground-truth and predictions.
The loss is computed for each sample separately. Currently, smoothed
ground-truth labels are not supported.
Args:
y_true: A dictionary of tf.Tensor containing all ground-truth data to
compute the loss. Depending on the configuration, the dict has to
contain common.GT_SEMANTIC_KEY, and optionally
common.GT_INSTANCE_CENTER_KEY, common.GT_INSTANCE_REGRESSION_KEY, and
common.GT_FRAME_OFFSET_KEY.
y_pred: A dicitionary of tf.Tensor containing all predictions to compute
the loss. Depending on the configuration, the dict has to contain
common.PRED_SEMANTIC_LOGITS_KEY, and optionally
common.PRED_CENTER_HEATMAP_KEY, common.PRED_OFFSET_MAP_KEY, and
common.PRED_FRAME_OFFSET_MAP_KEY.
Returns:
The loss as a dict of tf.Tensor, optionally containing the following:
- common.SEMANTIC_LOSS: [batch].
- common.CENTER_LOSS: [batch].
- common.REGRESSION_LOSS: [batch].
- common.MOTION_LOSS: [batch], the frame offset regression loss.
- common.NEXT_REGRESSION_LOSS: [batch], the next regression loss.
Raises:
AssertionError: If the keys of the resulting_dict do not match
self.get_loss_names().
AssertionError: The keys of the resulting_dict overlap with the keys of
the loss_dict.
"""
resulting_dict = collections.OrderedDict()
# Single-term losses.
for loss_name, func_and_weight in (
self._single_term_loss_func_and_weight_dict.items()):
loss_func, loss_weight = func_and_weight
loss_value = loss_func(y_true, y_pred)
resulting_dict[loss_name] = loss_value * loss_weight
# Multi-term losses predict a dictionary, so we handle them differently.
for multi_term_loss in self._multi_term_losses:
loss_dict = multi_term_loss((y_true, y_pred))
if not set(loss_dict).isdisjoint(resulting_dict):
raise AssertionError('The keys of the resulting_dict overlap with the '
'keys of the loss_dict.')
resulting_dict.update(loss_dict)
# Also include the total loss in the resulting_dict.
total_loss = tf.math.accumulate_n(list(resulting_dict.values()))
resulting_dict[common.TOTAL_LOSS] = total_loss
if sorted(resulting_dict.keys()) != sorted(self.get_loss_names()):
raise AssertionError(
'The keys of the resulting_dict should match self.get_loss_names().')
return resulting_dict