Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2021 The Deeplab2 Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""This file contains loss builder classes used in the DeepLab model.""" | |
import collections | |
from typing import Any, Dict, Text, Tuple, Optional | |
import tensorflow as tf | |
from deeplab2 import common | |
from deeplab2 import config_pb2 | |
from deeplab2.model.loss import base_loss | |
from deeplab2.model.loss import max_deeplab_loss | |
def _create_loss_and_weight( | |
loss_options: config_pb2.LossOptions.SingleLossOptions, gt_key: Text, | |
pred_key: Text, weight_key: Text, **kwargs: Any) -> tf.keras.losses.Loss: | |
"""Creates a loss and its weight from loss options. | |
Args: | |
loss_options: Loss options as defined by | |
config_pb2.LossOptions.SingleLossOptions or None. | |
gt_key: A key to extract the ground-truth from a dictionary. | |
pred_key: A key to extract the prediction from a dictionary. | |
weight_key: A key to extract the per-pixel weights from a dictionary. | |
**kwargs: Additional parameters to initialize the loss. | |
Returns: | |
A tuple of an instance of tf.keras.losses.Loss and its corresponding weight | |
as an integer. | |
Raises: | |
ValueError: An error occurs when the loss name is not a valid loss. | |
""" | |
if loss_options is None: | |
return None, 0 | |
if loss_options.name == 'softmax_cross_entropy': | |
return base_loss.TopKCrossEntropyLoss( | |
gt_key, | |
pred_key, | |
weight_key, | |
top_k_percent_pixels=loss_options.top_k_percent, | |
**kwargs), loss_options.weight | |
elif loss_options.name == 'l1': | |
return base_loss.TopKGeneralLoss( | |
base_loss.mean_absolute_error, | |
gt_key, | |
pred_key, | |
weight_key, | |
top_k_percent_pixels=loss_options.top_k_percent), loss_options.weight | |
elif loss_options.name == 'mse': | |
return base_loss.TopKGeneralLoss( | |
base_loss.mean_squared_error, | |
gt_key, | |
pred_key, | |
weight_key, | |
top_k_percent_pixels=loss_options.top_k_percent), loss_options.weight | |
raise ValueError('Loss %s is not a valid loss.' % loss_options.name) | |
class DeepLabFamilyLoss(tf.keras.layers.Layer): | |
"""This class contains code to build and call losses for DeepLabFamilyLoss.""" | |
def __init__( | |
self, | |
loss_options: config_pb2.LossOptions, | |
num_classes: Optional[int], | |
ignore_label: Optional[int], | |
thing_class_ids: Tuple[int]): | |
"""Initializes the losses for Panoptic-DeepLab. | |
Args: | |
loss_options: Loss options as defined by config_pb2.LossOptions. | |
num_classes: An integer specifying the number of classes in the dataset. | |
ignore_label: An optional integer specifying the ignore label or None. | |
thing_class_ids: A tuple of length [N] containing N thing indices. | |
""" | |
super(DeepLabFamilyLoss, self).__init__(name='DeepLabFamilyLoss') | |
# Single-term losses are losses that have only one loss term and thus each | |
# loss function directly returns a single tensor as the loss value, as | |
# opposed to multi-term losses that involve multiple terms and return a | |
# dictionary of loss values. | |
self._single_term_loss_func_and_weight_dict = collections.OrderedDict() | |
self._extra_loss_names = [common.TOTAL_LOSS] | |
if loss_options.HasField(common.SEMANTIC_LOSS): | |
self._single_term_loss_func_and_weight_dict[ | |
common.SEMANTIC_LOSS] = _create_loss_and_weight( | |
loss_options.semantic_loss, | |
common.GT_SEMANTIC_KEY, | |
common.PRED_SEMANTIC_LOGITS_KEY, | |
common.SEMANTIC_LOSS_WEIGHT_KEY, | |
num_classes=num_classes, | |
ignore_label=ignore_label) | |
if loss_options.HasField(common.CENTER_LOSS): | |
self._single_term_loss_func_and_weight_dict[ | |
common.CENTER_LOSS] = _create_loss_and_weight( | |
loss_options.center_loss, common.GT_INSTANCE_CENTER_KEY, | |
common.PRED_CENTER_HEATMAP_KEY, common.CENTER_LOSS_WEIGHT_KEY) | |
if loss_options.HasField(common.REGRESSION_LOSS): | |
self._single_term_loss_func_and_weight_dict[ | |
common.REGRESSION_LOSS] = _create_loss_and_weight( | |
loss_options.regression_loss, common.GT_INSTANCE_REGRESSION_KEY, | |
common.PRED_OFFSET_MAP_KEY, common.REGRESSION_LOSS_WEIGHT_KEY) | |
# Currently, only used for Motion-DeepLab. | |
if loss_options.HasField(common.MOTION_LOSS): | |
self._single_term_loss_func_and_weight_dict[ | |
common.MOTION_LOSS] = _create_loss_and_weight( | |
loss_options.motion_loss, common.GT_FRAME_OFFSET_KEY, | |
common.PRED_FRAME_OFFSET_MAP_KEY, | |
common.FRAME_REGRESSION_LOSS_WEIGHT_KEY) | |
# Next-frame regression loss used in ViP-DeepLab. | |
if loss_options.HasField(common.NEXT_REGRESSION_LOSS): | |
self._single_term_loss_func_and_weight_dict[ | |
common.NEXT_REGRESSION_LOSS] = _create_loss_and_weight( | |
loss_options.next_regression_loss, | |
common.GT_NEXT_INSTANCE_REGRESSION_KEY, | |
common.PRED_NEXT_OFFSET_MAP_KEY, | |
common.NEXT_REGRESSION_LOSS_WEIGHT_KEY) | |
# Multi-term losses that return dictionaries of loss terms. | |
self._multi_term_losses = [] | |
# MaXDeepLabLoss optionally returns four loss terms in total: | |
# - common.PQ_STYLE_LOSS_CLASS_TERM | |
# - common.PQ_STYLE_LOSS_MASK_DICE_TERM | |
# - common.MASK_ID_CROSS_ENTROPY_LOSS | |
# - common.INSTANCE_DISCRIMINATION_LOSS | |
if any([loss_options.HasField('pq_style_loss'), | |
loss_options.HasField('mask_id_cross_entropy_loss'), | |
loss_options.HasField('instance_discrimination_loss')]): | |
self._multi_term_losses.append(max_deeplab_loss.MaXDeepLabLoss( | |
loss_options, ignore_label, thing_class_ids)) | |
for multi_term_loss in self._multi_term_losses: | |
self._extra_loss_names += multi_term_loss.loss_terms | |
def get_loss_names(self): | |
# Keep track of all the keys that will be returned in self.call(). | |
loss_names = list(self._single_term_loss_func_and_weight_dict.keys()) | |
return loss_names + self._extra_loss_names | |
def call(self, y_true: Dict[Text, tf.Tensor], | |
y_pred: Dict[Text, tf.Tensor]) -> Dict[Text, tf.Tensor]: | |
"""Performs the loss computations given ground-truth and predictions. | |
The loss is computed for each sample separately. Currently, smoothed | |
ground-truth labels are not supported. | |
Args: | |
y_true: A dictionary of tf.Tensor containing all ground-truth data to | |
compute the loss. Depending on the configuration, the dict has to | |
contain common.GT_SEMANTIC_KEY, and optionally | |
common.GT_INSTANCE_CENTER_KEY, common.GT_INSTANCE_REGRESSION_KEY, and | |
common.GT_FRAME_OFFSET_KEY. | |
y_pred: A dicitionary of tf.Tensor containing all predictions to compute | |
the loss. Depending on the configuration, the dict has to contain | |
common.PRED_SEMANTIC_LOGITS_KEY, and optionally | |
common.PRED_CENTER_HEATMAP_KEY, common.PRED_OFFSET_MAP_KEY, and | |
common.PRED_FRAME_OFFSET_MAP_KEY. | |
Returns: | |
The loss as a dict of tf.Tensor, optionally containing the following: | |
- common.SEMANTIC_LOSS: [batch]. | |
- common.CENTER_LOSS: [batch]. | |
- common.REGRESSION_LOSS: [batch]. | |
- common.MOTION_LOSS: [batch], the frame offset regression loss. | |
- common.NEXT_REGRESSION_LOSS: [batch], the next regression loss. | |
Raises: | |
AssertionError: If the keys of the resulting_dict do not match | |
self.get_loss_names(). | |
AssertionError: The keys of the resulting_dict overlap with the keys of | |
the loss_dict. | |
""" | |
resulting_dict = collections.OrderedDict() | |
# Single-term losses. | |
for loss_name, func_and_weight in ( | |
self._single_term_loss_func_and_weight_dict.items()): | |
loss_func, loss_weight = func_and_weight | |
loss_value = loss_func(y_true, y_pred) | |
resulting_dict[loss_name] = loss_value * loss_weight | |
# Multi-term losses predict a dictionary, so we handle them differently. | |
for multi_term_loss in self._multi_term_losses: | |
loss_dict = multi_term_loss((y_true, y_pred)) | |
if not set(loss_dict).isdisjoint(resulting_dict): | |
raise AssertionError('The keys of the resulting_dict overlap with the ' | |
'keys of the loss_dict.') | |
resulting_dict.update(loss_dict) | |
# Also include the total loss in the resulting_dict. | |
total_loss = tf.math.accumulate_n(list(resulting_dict.values())) | |
resulting_dict[common.TOTAL_LOSS] = total_loss | |
if sorted(resulting_dict.keys()) != sorted(self.get_loss_names()): | |
raise AssertionError( | |
'The keys of the resulting_dict should match self.get_loss_names().') | |
return resulting_dict | |