|
|
|
import warnings |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from mmdet.registry import MODELS |
|
from .utils import weight_reduce_loss |
|
|
|
|
|
def cross_entropy(pred, |
|
label, |
|
weight=None, |
|
reduction='mean', |
|
avg_factor=None, |
|
class_weight=None, |
|
ignore_index=-100, |
|
avg_non_ignore=False): |
|
"""Calculate the CrossEntropy loss. |
|
|
|
Args: |
|
pred (torch.Tensor): The prediction with shape (N, C), C is the number |
|
of classes. |
|
label (torch.Tensor): The learning label of the prediction. |
|
weight (torch.Tensor, optional): Sample-wise loss weight. |
|
reduction (str, optional): The method used to reduce the loss. |
|
avg_factor (int, optional): Average factor that is used to average |
|
the loss. Defaults to None. |
|
class_weight (list[float], optional): The weight for each class. |
|
ignore_index (int | None): The label index to be ignored. |
|
If None, it will be set to default value. Default: -100. |
|
avg_non_ignore (bool): The flag decides to whether the loss is |
|
only averaged over non-ignored targets. Default: False. |
|
|
|
Returns: |
|
torch.Tensor: The calculated loss |
|
""" |
|
|
|
ignore_index = -100 if ignore_index is None else ignore_index |
|
|
|
loss = F.cross_entropy( |
|
pred, |
|
label, |
|
weight=class_weight, |
|
reduction='none', |
|
ignore_index=ignore_index) |
|
|
|
|
|
|
|
|
|
if (avg_factor is None) and avg_non_ignore and reduction == 'mean': |
|
avg_factor = label.numel() - (label == ignore_index).sum().item() |
|
|
|
|
|
if weight is not None: |
|
weight = weight.float() |
|
loss = weight_reduce_loss( |
|
loss, weight=weight, reduction=reduction, avg_factor=avg_factor) |
|
|
|
return loss |
|
|
|
|
|
def _expand_onehot_labels(labels, label_weights, label_channels, ignore_index): |
|
"""Expand onehot labels to match the size of prediction.""" |
|
bin_labels = labels.new_full((labels.size(0), label_channels), 0) |
|
valid_mask = (labels >= 0) & (labels != ignore_index) |
|
inds = torch.nonzero( |
|
valid_mask & (labels < label_channels), as_tuple=False) |
|
|
|
if inds.numel() > 0: |
|
bin_labels[inds, labels[inds]] = 1 |
|
|
|
valid_mask = valid_mask.view(-1, 1).expand(labels.size(0), |
|
label_channels).float() |
|
if label_weights is None: |
|
bin_label_weights = valid_mask |
|
else: |
|
bin_label_weights = label_weights.view(-1, 1).repeat(1, label_channels) |
|
bin_label_weights *= valid_mask |
|
|
|
return bin_labels, bin_label_weights, valid_mask |
|
|
|
|
|
def binary_cross_entropy(pred, |
|
label, |
|
weight=None, |
|
reduction='mean', |
|
avg_factor=None, |
|
class_weight=None, |
|
ignore_index=-100, |
|
avg_non_ignore=False): |
|
"""Calculate the binary CrossEntropy loss. |
|
|
|
Args: |
|
pred (torch.Tensor): The prediction with shape (N, 1) or (N, ). |
|
When the shape of pred is (N, 1), label will be expanded to |
|
one-hot format, and when the shape of pred is (N, ), label |
|
will not be expanded to one-hot format. |
|
label (torch.Tensor): The learning label of the prediction, |
|
with shape (N, ). |
|
weight (torch.Tensor, optional): Sample-wise loss weight. |
|
reduction (str, optional): The method used to reduce the loss. |
|
Options are "none", "mean" and "sum". |
|
avg_factor (int, optional): Average factor that is used to average |
|
the loss. Defaults to None. |
|
class_weight (list[float], optional): The weight for each class. |
|
ignore_index (int | None): The label index to be ignored. |
|
If None, it will be set to default value. Default: -100. |
|
avg_non_ignore (bool): The flag decides to whether the loss is |
|
only averaged over non-ignored targets. Default: False. |
|
|
|
Returns: |
|
torch.Tensor: The calculated loss. |
|
""" |
|
|
|
ignore_index = -100 if ignore_index is None else ignore_index |
|
|
|
if pred.dim() != label.dim(): |
|
label, weight, valid_mask = _expand_onehot_labels( |
|
label, weight, pred.size(-1), ignore_index) |
|
else: |
|
|
|
valid_mask = ((label >= 0) & (label != ignore_index)).float() |
|
if weight is not None: |
|
|
|
|
|
|
|
weight = weight * valid_mask |
|
else: |
|
weight = valid_mask |
|
|
|
|
|
if (avg_factor is None) and avg_non_ignore and reduction == 'mean': |
|
avg_factor = valid_mask.sum().item() |
|
|
|
|
|
weight = weight.float() |
|
loss = F.binary_cross_entropy_with_logits( |
|
pred, label.float(), pos_weight=class_weight, reduction='none') |
|
|
|
loss = weight_reduce_loss( |
|
loss, weight, reduction=reduction, avg_factor=avg_factor) |
|
|
|
return loss |
|
|
|
|
|
def mask_cross_entropy(pred, |
|
target, |
|
label, |
|
reduction='mean', |
|
avg_factor=None, |
|
class_weight=None, |
|
ignore_index=None, |
|
**kwargs): |
|
"""Calculate the CrossEntropy loss for masks. |
|
|
|
Args: |
|
pred (torch.Tensor): The prediction with shape (N, C, *), C is the |
|
number of classes. The trailing * indicates arbitrary shape. |
|
target (torch.Tensor): The learning label of the prediction. |
|
label (torch.Tensor): ``label`` indicates the class label of the mask |
|
corresponding object. This will be used to select the mask in the |
|
of the class which the object belongs to when the mask prediction |
|
if not class-agnostic. |
|
reduction (str, optional): The method used to reduce the loss. |
|
Options are "none", "mean" and "sum". |
|
avg_factor (int, optional): Average factor that is used to average |
|
the loss. Defaults to None. |
|
class_weight (list[float], optional): The weight for each class. |
|
ignore_index (None): Placeholder, to be consistent with other loss. |
|
Default: None. |
|
|
|
Returns: |
|
torch.Tensor: The calculated loss |
|
|
|
Example: |
|
>>> N, C = 3, 11 |
|
>>> H, W = 2, 2 |
|
>>> pred = torch.randn(N, C, H, W) * 1000 |
|
>>> target = torch.rand(N, H, W) |
|
>>> label = torch.randint(0, C, size=(N,)) |
|
>>> reduction = 'mean' |
|
>>> avg_factor = None |
|
>>> class_weights = None |
|
>>> loss = mask_cross_entropy(pred, target, label, reduction, |
|
>>> avg_factor, class_weights) |
|
>>> assert loss.shape == (1,) |
|
""" |
|
assert ignore_index is None, 'BCE loss does not support ignore_index' |
|
|
|
assert reduction == 'mean' and avg_factor is None |
|
num_rois = pred.size()[0] |
|
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device) |
|
pred_slice = pred[inds, label].squeeze(1) |
|
return F.binary_cross_entropy_with_logits( |
|
pred_slice, target, weight=class_weight, reduction='mean')[None] |
|
|
|
|
|
@MODELS.register_module() |
|
class CrossEntropyLoss(nn.Module): |
|
|
|
def __init__(self, |
|
use_sigmoid=False, |
|
use_mask=False, |
|
reduction='mean', |
|
class_weight=None, |
|
ignore_index=None, |
|
loss_weight=1.0, |
|
avg_non_ignore=False): |
|
"""CrossEntropyLoss. |
|
|
|
Args: |
|
use_sigmoid (bool, optional): Whether the prediction uses sigmoid |
|
of softmax. Defaults to False. |
|
use_mask (bool, optional): Whether to use mask cross entropy loss. |
|
Defaults to False. |
|
reduction (str, optional): . Defaults to 'mean'. |
|
Options are "none", "mean" and "sum". |
|
class_weight (list[float], optional): Weight of each class. |
|
Defaults to None. |
|
ignore_index (int | None): The label index to be ignored. |
|
Defaults to None. |
|
loss_weight (float, optional): Weight of the loss. Defaults to 1.0. |
|
avg_non_ignore (bool): The flag decides to whether the loss is |
|
only averaged over non-ignored targets. Default: False. |
|
""" |
|
super(CrossEntropyLoss, self).__init__() |
|
assert (use_sigmoid is False) or (use_mask is False) |
|
self.use_sigmoid = use_sigmoid |
|
self.use_mask = use_mask |
|
self.reduction = reduction |
|
self.loss_weight = loss_weight |
|
self.class_weight = class_weight |
|
self.ignore_index = ignore_index |
|
self.avg_non_ignore = avg_non_ignore |
|
if ((ignore_index is not None) and not self.avg_non_ignore |
|
and self.reduction == 'mean'): |
|
warnings.warn( |
|
'Default ``avg_non_ignore`` is False, if you would like to ' |
|
'ignore the certain label and average loss over non-ignore ' |
|
'labels, which is the same with PyTorch official ' |
|
'cross_entropy, set ``avg_non_ignore=True``.') |
|
|
|
if self.use_sigmoid: |
|
self.cls_criterion = binary_cross_entropy |
|
elif self.use_mask: |
|
self.cls_criterion = mask_cross_entropy |
|
else: |
|
self.cls_criterion = cross_entropy |
|
|
|
def extra_repr(self): |
|
"""Extra repr.""" |
|
s = f'avg_non_ignore={self.avg_non_ignore}' |
|
return s |
|
|
|
def forward(self, |
|
cls_score, |
|
label, |
|
weight=None, |
|
avg_factor=None, |
|
reduction_override=None, |
|
ignore_index=None, |
|
**kwargs): |
|
"""Forward function. |
|
|
|
Args: |
|
cls_score (torch.Tensor): The prediction. |
|
label (torch.Tensor): The learning label of the prediction. |
|
weight (torch.Tensor, optional): Sample-wise loss weight. |
|
avg_factor (int, optional): Average factor that is used to average |
|
the loss. Defaults to None. |
|
reduction_override (str, optional): The method used to reduce the |
|
loss. Options are "none", "mean" and "sum". |
|
ignore_index (int | None): The label index to be ignored. |
|
If not None, it will override the default value. Default: None. |
|
Returns: |
|
torch.Tensor: The calculated loss. |
|
""" |
|
assert reduction_override in (None, 'none', 'mean', 'sum') |
|
reduction = ( |
|
reduction_override if reduction_override else self.reduction) |
|
if ignore_index is None: |
|
ignore_index = self.ignore_index |
|
|
|
if self.class_weight is not None: |
|
class_weight = cls_score.new_tensor( |
|
self.class_weight, device=cls_score.device) |
|
else: |
|
class_weight = None |
|
loss_cls = self.loss_weight * self.cls_criterion( |
|
cls_score, |
|
label, |
|
weight, |
|
class_weight=class_weight, |
|
reduction=reduction, |
|
avg_factor=avg_factor, |
|
ignore_index=ignore_index, |
|
avg_non_ignore=self.avg_non_ignore, |
|
**kwargs) |
|
return loss_cls |
|
|