|
|
|
from typing import Dict, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch import Tensor |
|
|
|
from mmdet.registry import MODELS |
|
from .accuracy import accuracy |
|
from .cross_entropy_loss import cross_entropy |
|
from .utils import weight_reduce_loss |
|
|
|
|
|
def seesaw_ce_loss(cls_score: Tensor, |
|
labels: Tensor, |
|
label_weights: Tensor, |
|
cum_samples: Tensor, |
|
num_classes: int, |
|
p: float, |
|
q: float, |
|
eps: float, |
|
reduction: str = 'mean', |
|
avg_factor: Optional[int] = None) -> Tensor: |
|
"""Calculate the Seesaw CrossEntropy loss. |
|
|
|
Args: |
|
cls_score (Tensor): The prediction with shape (N, C), |
|
C is the number of classes. |
|
labels (Tensor): The learning label of the prediction. |
|
label_weights (Tensor): Sample-wise loss weight. |
|
cum_samples (Tensor): Cumulative samples for each category. |
|
num_classes (int): The number of classes. |
|
p (float): The ``p`` in the mitigation factor. |
|
q (float): The ``q`` in the compenstation factor. |
|
eps (float): The minimal value of divisor to smooth |
|
the computation of compensation factor |
|
reduction (str, optional): The method used to reduce the loss. |
|
avg_factor (int, optional): Average factor that is used to average |
|
the loss. Defaults to None. |
|
|
|
Returns: |
|
Tensor: The calculated loss |
|
""" |
|
assert cls_score.size(-1) == num_classes |
|
assert len(cum_samples) == num_classes |
|
|
|
onehot_labels = F.one_hot(labels, num_classes) |
|
seesaw_weights = cls_score.new_ones(onehot_labels.size()) |
|
|
|
|
|
if p > 0: |
|
sample_ratio_matrix = cum_samples[None, :].clamp( |
|
min=1) / cum_samples[:, None].clamp(min=1) |
|
index = (sample_ratio_matrix < 1.0).float() |
|
sample_weights = sample_ratio_matrix.pow(p) * index + (1 - index) |
|
mitigation_factor = sample_weights[labels.long(), :] |
|
seesaw_weights = seesaw_weights * mitigation_factor |
|
|
|
|
|
if q > 0: |
|
scores = F.softmax(cls_score.detach(), dim=1) |
|
self_scores = scores[ |
|
torch.arange(0, len(scores)).to(scores.device).long(), |
|
labels.long()] |
|
score_matrix = scores / self_scores[:, None].clamp(min=eps) |
|
index = (score_matrix > 1.0).float() |
|
compensation_factor = score_matrix.pow(q) * index + (1 - index) |
|
seesaw_weights = seesaw_weights * compensation_factor |
|
|
|
cls_score = cls_score + (seesaw_weights.log() * (1 - onehot_labels)) |
|
|
|
loss = F.cross_entropy(cls_score, labels, weight=None, reduction='none') |
|
|
|
if label_weights is not None: |
|
label_weights = label_weights.float() |
|
loss = weight_reduce_loss( |
|
loss, weight=label_weights, reduction=reduction, avg_factor=avg_factor) |
|
return loss |
|
|
|
|
|
@MODELS.register_module() |
|
class SeesawLoss(nn.Module): |
|
""" |
|
Seesaw Loss for Long-Tailed Instance Segmentation (CVPR 2021) |
|
arXiv: https://arxiv.org/abs/2008.10032 |
|
|
|
Args: |
|
use_sigmoid (bool, optional): Whether the prediction uses sigmoid |
|
of softmax. Only False is supported. |
|
p (float, optional): The ``p`` in the mitigation factor. |
|
Defaults to 0.8. |
|
q (float, optional): The ``q`` in the compenstation factor. |
|
Defaults to 2.0. |
|
num_classes (int, optional): The number of classes. |
|
Default to 1203 for LVIS v1 dataset. |
|
eps (float, optional): The minimal value of divisor to smooth |
|
the computation of compensation factor |
|
reduction (str, optional): The method that reduces the loss to a |
|
scalar. Options are "none", "mean" and "sum". |
|
loss_weight (float, optional): The weight of the loss. Defaults to 1.0 |
|
return_dict (bool, optional): Whether return the losses as a dict. |
|
Default to True. |
|
""" |
|
|
|
def __init__(self, |
|
use_sigmoid: bool = False, |
|
p: float = 0.8, |
|
q: float = 2.0, |
|
num_classes: int = 1203, |
|
eps: float = 1e-2, |
|
reduction: str = 'mean', |
|
loss_weight: float = 1.0, |
|
return_dict: bool = True) -> None: |
|
super().__init__() |
|
assert not use_sigmoid |
|
self.use_sigmoid = False |
|
self.p = p |
|
self.q = q |
|
self.num_classes = num_classes |
|
self.eps = eps |
|
self.reduction = reduction |
|
self.loss_weight = loss_weight |
|
self.return_dict = return_dict |
|
|
|
|
|
self.cls_criterion = seesaw_ce_loss |
|
|
|
|
|
self.register_buffer( |
|
'cum_samples', |
|
torch.zeros(self.num_classes + 1, dtype=torch.float)) |
|
|
|
|
|
self.custom_cls_channels = True |
|
|
|
self.custom_activation = True |
|
|
|
self.custom_accuracy = True |
|
|
|
def _split_cls_score(self, cls_score: Tensor) -> Tuple[Tensor, Tensor]: |
|
"""split cls_score. |
|
|
|
Args: |
|
cls_score (Tensor): The prediction with shape (N, C + 2). |
|
|
|
Returns: |
|
Tuple[Tensor, Tensor]: The score for classes and objectness, |
|
respectively |
|
""" |
|
|
|
assert cls_score.size(-1) == self.num_classes + 2 |
|
cls_score_classes = cls_score[..., :-2] |
|
cls_score_objectness = cls_score[..., -2:] |
|
return cls_score_classes, cls_score_objectness |
|
|
|
def get_cls_channels(self, num_classes: int) -> int: |
|
"""Get custom classification channels. |
|
|
|
Args: |
|
num_classes (int): The number of classes. |
|
|
|
Returns: |
|
int: The custom classification channels. |
|
""" |
|
assert num_classes == self.num_classes |
|
return num_classes + 2 |
|
|
|
def get_activation(self, cls_score: Tensor) -> Tensor: |
|
"""Get custom activation of cls_score. |
|
|
|
Args: |
|
cls_score (Tensor): The prediction with shape (N, C + 2). |
|
|
|
Returns: |
|
Tensor: The custom activation of cls_score with shape |
|
(N, C + 1). |
|
""" |
|
cls_score_classes, cls_score_objectness = self._split_cls_score( |
|
cls_score) |
|
score_classes = F.softmax(cls_score_classes, dim=-1) |
|
score_objectness = F.softmax(cls_score_objectness, dim=-1) |
|
score_pos = score_objectness[..., [0]] |
|
score_neg = score_objectness[..., [1]] |
|
score_classes = score_classes * score_pos |
|
scores = torch.cat([score_classes, score_neg], dim=-1) |
|
return scores |
|
|
|
def get_accuracy(self, cls_score: Tensor, |
|
labels: Tensor) -> Dict[str, Tensor]: |
|
"""Get custom accuracy w.r.t. cls_score and labels. |
|
|
|
Args: |
|
cls_score (Tensor): The prediction with shape (N, C + 2). |
|
labels (Tensor): The learning label of the prediction. |
|
|
|
Returns: |
|
Dict [str, Tensor]: The accuracy for objectness and classes, |
|
respectively. |
|
""" |
|
pos_inds = labels < self.num_classes |
|
obj_labels = (labels == self.num_classes).long() |
|
cls_score_classes, cls_score_objectness = self._split_cls_score( |
|
cls_score) |
|
acc_objectness = accuracy(cls_score_objectness, obj_labels) |
|
acc_classes = accuracy(cls_score_classes[pos_inds], labels[pos_inds]) |
|
acc = dict() |
|
acc['acc_objectness'] = acc_objectness |
|
acc['acc_classes'] = acc_classes |
|
return acc |
|
|
|
def forward( |
|
self, |
|
cls_score: Tensor, |
|
labels: Tensor, |
|
label_weights: Optional[Tensor] = None, |
|
avg_factor: Optional[int] = None, |
|
reduction_override: Optional[str] = None |
|
) -> Union[Tensor, Dict[str, Tensor]]: |
|
"""Forward function. |
|
|
|
Args: |
|
cls_score (Tensor): The prediction with shape (N, C + 2). |
|
labels (Tensor): The learning label of the prediction. |
|
label_weights (Tensor, optional): Sample-wise loss weight. |
|
avg_factor (int, optional): Average factor that is used to average |
|
the loss. Defaults to None. |
|
reduction (str, optional): The method used to reduce the loss. |
|
Options are "none", "mean" and "sum". |
|
|
|
Returns: |
|
Tensor | Dict [str, Tensor]: |
|
if return_dict == False: The calculated loss | |
|
if return_dict == True: The dict of calculated losses |
|
for objectness and classes, respectively. |
|
""" |
|
assert reduction_override in (None, 'none', 'mean', 'sum') |
|
reduction = ( |
|
reduction_override if reduction_override else self.reduction) |
|
assert cls_score.size(-1) == self.num_classes + 2 |
|
pos_inds = labels < self.num_classes |
|
|
|
obj_labels = (labels == self.num_classes).long() |
|
|
|
|
|
unique_labels = labels.unique() |
|
for u_l in unique_labels: |
|
inds_ = labels == u_l.item() |
|
self.cum_samples[u_l] += inds_.sum() |
|
|
|
if label_weights is not None: |
|
label_weights = label_weights.float() |
|
else: |
|
label_weights = labels.new_ones(labels.size(), dtype=torch.float) |
|
|
|
cls_score_classes, cls_score_objectness = self._split_cls_score( |
|
cls_score) |
|
|
|
if pos_inds.sum() > 0: |
|
loss_cls_classes = self.loss_weight * self.cls_criterion( |
|
cls_score_classes[pos_inds], labels[pos_inds], |
|
label_weights[pos_inds], self.cum_samples[:self.num_classes], |
|
self.num_classes, self.p, self.q, self.eps, reduction, |
|
avg_factor) |
|
else: |
|
loss_cls_classes = cls_score_classes[pos_inds].sum() |
|
|
|
loss_cls_objectness = self.loss_weight * cross_entropy( |
|
cls_score_objectness, obj_labels, label_weights, reduction, |
|
avg_factor) |
|
|
|
if self.return_dict: |
|
loss_cls = dict() |
|
loss_cls['loss_cls_objectness'] = loss_cls_objectness |
|
loss_cls['loss_cls_classes'] = loss_cls_classes |
|
else: |
|
loss_cls = loss_cls_classes + loss_cls_objectness |
|
return loss_cls |
|
|