|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from mmdet.registry import MODELS |
|
from .utils import weight_reduce_loss |
|
|
|
|
|
def dice_loss(pred, |
|
target, |
|
weight=None, |
|
eps=1e-3, |
|
reduction='mean', |
|
naive_dice=False, |
|
avg_factor=None): |
|
"""Calculate dice loss, there are two forms of dice loss is supported: |
|
|
|
- the one proposed in `V-Net: Fully Convolutional Neural |
|
Networks for Volumetric Medical Image Segmentation |
|
<https://arxiv.org/abs/1606.04797>`_. |
|
- the dice loss in which the power of the number in the |
|
denominator is the first power instead of the second |
|
power. |
|
|
|
Args: |
|
pred (torch.Tensor): The prediction, has a shape (n, *) |
|
target (torch.Tensor): The learning label of the prediction, |
|
shape (n, *), same shape of pred. |
|
weight (torch.Tensor, optional): The weight of loss for each |
|
prediction, has a shape (n,). Defaults to None. |
|
eps (float): Avoid dividing by zero. Default: 1e-3. |
|
reduction (str, optional): The method used to reduce the loss into |
|
a scalar. Defaults to 'mean'. |
|
Options are "none", "mean" and "sum". |
|
naive_dice (bool, optional): If false, use the dice |
|
loss defined in the V-Net paper, otherwise, use the |
|
naive dice loss in which the power of the number in the |
|
denominator is the first power instead of the second |
|
power.Defaults to False. |
|
avg_factor (int, optional): Average factor that is used to average |
|
the loss. Defaults to None. |
|
""" |
|
|
|
input = pred.flatten(1) |
|
target = target.flatten(1).float() |
|
|
|
a = torch.sum(input * target, 1) |
|
if naive_dice: |
|
b = torch.sum(input, 1) |
|
c = torch.sum(target, 1) |
|
d = (2 * a + eps) / (b + c + eps) |
|
else: |
|
b = torch.sum(input * input, 1) + eps |
|
c = torch.sum(target * target, 1) + eps |
|
d = (2 * a) / (b + c) |
|
|
|
loss = 1 - d |
|
if weight is not None: |
|
assert weight.ndim == loss.ndim |
|
assert len(weight) == len(pred) |
|
loss = weight_reduce_loss(loss, weight, reduction, avg_factor) |
|
return loss |
|
|
|
|
|
@MODELS.register_module() |
|
class DiceLoss(nn.Module): |
|
|
|
def __init__(self, |
|
use_sigmoid=True, |
|
activate=True, |
|
reduction='mean', |
|
naive_dice=False, |
|
loss_weight=1.0, |
|
eps=1e-3): |
|
"""Compute dice loss. |
|
|
|
Args: |
|
use_sigmoid (bool, optional): Whether to the prediction is |
|
used for sigmoid or softmax. Defaults to True. |
|
activate (bool): Whether to activate the predictions inside, |
|
this will disable the inside sigmoid operation. |
|
Defaults to True. |
|
reduction (str, optional): The method used |
|
to reduce the loss. Options are "none", |
|
"mean" and "sum". Defaults to 'mean'. |
|
naive_dice (bool, optional): If false, use the dice |
|
loss defined in the V-Net paper, otherwise, use the |
|
naive dice loss in which the power of the number in the |
|
denominator is the first power instead of the second |
|
power. Defaults to False. |
|
loss_weight (float, optional): Weight of loss. Defaults to 1.0. |
|
eps (float): Avoid dividing by zero. Defaults to 1e-3. |
|
""" |
|
|
|
super(DiceLoss, self).__init__() |
|
self.use_sigmoid = use_sigmoid |
|
self.reduction = reduction |
|
self.naive_dice = naive_dice |
|
self.loss_weight = loss_weight |
|
self.eps = eps |
|
self.activate = activate |
|
|
|
def forward(self, |
|
pred, |
|
target, |
|
weight=None, |
|
reduction_override=None, |
|
avg_factor=None): |
|
"""Forward function. |
|
|
|
Args: |
|
pred (torch.Tensor): The prediction, has a shape (n, *). |
|
target (torch.Tensor): The label of the prediction, |
|
shape (n, *), same shape of pred. |
|
weight (torch.Tensor, optional): The weight of loss for each |
|
prediction, has a shape (n,). Defaults to None. |
|
avg_factor (int, optional): Average factor that is used to average |
|
the loss. Defaults to None. |
|
reduction_override (str, optional): The reduction method used to |
|
override the original reduction method of the loss. |
|
Options are "none", "mean" and "sum". |
|
|
|
Returns: |
|
torch.Tensor: The calculated loss |
|
""" |
|
|
|
assert reduction_override in (None, 'none', 'mean', 'sum') |
|
reduction = ( |
|
reduction_override if reduction_override else self.reduction) |
|
|
|
if self.activate: |
|
if self.use_sigmoid: |
|
pred = pred.sigmoid() |
|
else: |
|
raise NotImplementedError |
|
|
|
loss = self.loss_weight * dice_loss( |
|
pred, |
|
target, |
|
weight, |
|
eps=self.eps, |
|
reduction=reduction, |
|
naive_dice=self.naive_dice, |
|
avg_factor=avg_factor) |
|
|
|
return loss |
|
|