Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
class SmoothL1Loss(nn.Module): | |
def __init__(self, scale=0.01): | |
super(SmoothL1Loss, self).__init__() | |
self.scale = scale | |
self.EPSILON = 1e-10 | |
def __repr__(self): | |
return "SmoothL1Loss()" | |
def forward(self, output: torch.Tensor, groundtruth: torch.Tensor, reduction='mean'): | |
""" | |
input: b x n x 2 | |
output: b x n x 1 => 1 | |
""" | |
if output.dim() == 4: | |
shape = output.shape | |
groundtruth = groundtruth.reshape(shape[0], shape[1], 1, shape[3]) | |
delta_2 = (output - groundtruth).pow(2).sum(dim=-1, keepdim=False) | |
delta = delta_2.clamp(min=1e-6).sqrt() | |
# delta = torch.sqrt(delta_2 + self.EPSILON) | |
loss = torch.where( \ | |
delta_2 < self.scale * self.scale, \ | |
0.5 / self.scale * delta_2, \ | |
delta - 0.5 * self.scale) | |
if reduction == 'mean': | |
loss = loss.mean() | |
elif reduction == 'sum': | |
loss = loss.sum() | |
return loss | |