Spaces:
Veein
/
Runtime error

File size: 791 Bytes
17cd746
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# -*- coding: utf-8 -*-

import math
import torch
from torch import nn


# torch.log  and math.log is e based
class WingLoss(nn.Module):
    def __init__(self, omega=0.01, epsilon=2):
        super(WingLoss, self).__init__()
        self.omega = omega
        self.epsilon = epsilon

    def forward(self, pred, target):
        y = target
        y_hat = pred
        delta_2 = (y - y_hat).pow(2).sum(dim=-1, keepdim=False)
        # delta = delta_2.sqrt()
        delta = delta_2.clamp(min=1e-6).sqrt()
        C = self.omega - self.omega * math.log(1 + self.omega / self.epsilon)
        loss = torch.where(
            delta < self.omega,
            self.omega * torch.log(1 + delta / self.epsilon),
            delta - C
        )
        return loss.mean()