|
|
|
import functools |
|
from typing import Callable, Optional |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import Tensor |
|
|
|
|
|
def reduce_loss(loss: Tensor, reduction: str) -> Tensor: |
|
"""Reduce loss as specified. |
|
|
|
Args: |
|
loss (Tensor): Elementwise loss tensor. |
|
reduction (str): Options are "none", "mean" and "sum". |
|
|
|
Return: |
|
Tensor: Reduced loss tensor. |
|
""" |
|
reduction_enum = F._Reduction.get_enum(reduction) |
|
|
|
if reduction_enum == 0: |
|
return loss |
|
elif reduction_enum == 1: |
|
return loss.mean() |
|
elif reduction_enum == 2: |
|
return loss.sum() |
|
|
|
|
|
def weight_reduce_loss(loss: Tensor, |
|
weight: Optional[Tensor] = None, |
|
reduction: str = 'mean', |
|
avg_factor: Optional[float] = None) -> Tensor: |
|
"""Apply element-wise weight and reduce loss. |
|
|
|
Args: |
|
loss (Tensor): Element-wise loss. |
|
weight (Optional[Tensor], optional): Element-wise weights. |
|
Defaults to None. |
|
reduction (str, optional): Same as built-in losses of PyTorch. |
|
Defaults to 'mean'. |
|
avg_factor (Optional[float], optional): Average factor when |
|
computing the mean of losses. Defaults to None. |
|
|
|
Returns: |
|
Tensor: Processed loss values. |
|
""" |
|
|
|
if weight is not None: |
|
loss = loss * weight |
|
|
|
|
|
if avg_factor is None: |
|
loss = reduce_loss(loss, reduction) |
|
else: |
|
|
|
if reduction == 'mean': |
|
|
|
|
|
eps = torch.finfo(torch.float32).eps |
|
loss = loss.sum() / (avg_factor + eps) |
|
|
|
elif reduction != 'none': |
|
raise ValueError('avg_factor can not be used with reduction="sum"') |
|
return loss |
|
|
|
|
|
def weighted_loss(loss_func: Callable) -> Callable: |
|
"""Create a weighted version of a given loss function. |
|
|
|
To use this decorator, the loss function must have the signature like |
|
`loss_func(pred, target, **kwargs)`. The function only needs to compute |
|
element-wise loss without any reduction. This decorator will add weight |
|
and reduction arguments to the function. The decorated function will have |
|
the signature like `loss_func(pred, target, weight=None, reduction='mean', |
|
avg_factor=None, **kwargs)`. |
|
|
|
:Example: |
|
|
|
>>> import torch |
|
>>> @weighted_loss |
|
>>> def l1_loss(pred, target): |
|
>>> return (pred - target).abs() |
|
|
|
>>> pred = torch.Tensor([0, 2, 3]) |
|
>>> target = torch.Tensor([1, 1, 1]) |
|
>>> weight = torch.Tensor([1, 0, 1]) |
|
|
|
>>> l1_loss(pred, target) |
|
tensor(1.3333) |
|
>>> l1_loss(pred, target, weight) |
|
tensor(1.) |
|
>>> l1_loss(pred, target, reduction='none') |
|
tensor([1., 1., 2.]) |
|
>>> l1_loss(pred, target, weight, avg_factor=2) |
|
tensor(1.5000) |
|
""" |
|
|
|
@functools.wraps(loss_func) |
|
def wrapper(pred: Tensor, |
|
target: Tensor, |
|
weight: Optional[Tensor] = None, |
|
reduction: str = 'mean', |
|
avg_factor: Optional[int] = None, |
|
**kwargs) -> Tensor: |
|
""" |
|
Args: |
|
pred (Tensor): The prediction. |
|
target (Tensor): Target bboxes. |
|
weight (Optional[Tensor], optional): The weight of loss for each |
|
prediction. Defaults to None. |
|
reduction (str, optional): Options are "none", "mean" and "sum". |
|
Defaults to 'mean'. |
|
avg_factor (Optional[int], optional): Average factor that is used |
|
to average the loss. Defaults to None. |
|
|
|
Returns: |
|
Tensor: Loss tensor. |
|
""" |
|
|
|
loss = loss_func(pred, target, **kwargs) |
|
loss = weight_reduce_loss(loss, weight, reduction, avg_factor) |
|
return loss |
|
|
|
return wrapper |
|
|