yuandong513
feat: init
17cd746
import torch
import torch.nn as nn
import torch.nn.functional as F
class AWingLoss(nn.Module):
def __init__(self, omega=14, theta=0.5, epsilon=1, alpha=2.1, use_weight_map=True):
super(AWingLoss, self).__init__()
self.omega = omega
self.theta = theta
self.epsilon = epsilon
self.alpha = alpha
self.use_weight_map = use_weight_map
def __repr__(self):
return "AWingLoss()"
def generate_weight_map(self, heatmap, k_size=3, w=10):
dilate = F.max_pool2d(heatmap, kernel_size=k_size, stride=1, padding=1)
weight_map = torch.where(dilate < 0.2, torch.zeros_like(heatmap), torch.ones_like(heatmap))
return w * weight_map + 1
def forward(self, output, groundtruth):
"""
input: b x n x h x w
output: b x n x h x w => 1
"""
delta = (output - groundtruth).abs()
A = self.omega * (1 / (1 + torch.pow(self.theta / self.epsilon, self.alpha - groundtruth))) * (self.alpha - groundtruth) * \
(torch.pow(self.theta / self.epsilon, self.alpha - groundtruth - 1)) * (1 / self.epsilon)
C = self.theta * A - self.omega * \
torch.log(1 + torch.pow(self.theta / self.epsilon, self.alpha - groundtruth))
loss = torch.where(delta < self.theta,
self.omega * torch.log(1 + torch.pow(delta / self.epsilon, self.alpha - groundtruth)),
(A * delta - C))
if self.use_weight_map:
weight = self.generate_weight_map(groundtruth)
loss = loss * weight
return loss.mean()