Spaces:
Running
on
Zero
Running
on
Zero
# -*- coding: utf-8 -*- | |
import math | |
import torch | |
from torch import nn | |
# torch.log and math.log is e based | |
class WingLoss(nn.Module): | |
def __init__(self, omega=0.01, epsilon=2): | |
super(WingLoss, self).__init__() | |
self.omega = omega | |
self.epsilon = epsilon | |
def forward(self, pred, target): | |
y = target | |
y_hat = pred | |
delta_2 = (y - y_hat).pow(2).sum(dim=-1, keepdim=False) | |
# delta = delta_2.sqrt() | |
delta = delta_2.clamp(min=1e-6).sqrt() | |
C = self.omega - self.omega * math.log(1 + self.omega / self.epsilon) | |
loss = torch.where( | |
delta < self.omega, | |
self.omega * torch.log(1 + delta / self.epsilon), | |
delta - C | |
) | |
return loss.mean() | |