Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,906 Bytes
17cd746 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from .smoothL1Loss import SmoothL1Loss
from .wingLoss import WingLoss
def get_channel_sum(input):
temp = torch.sum(input, dim=3)
output = torch.sum(temp, dim=2)
return output
def expand_two_dimensions_at_end(input, dim1, dim2):
input = input.unsqueeze(-1).unsqueeze(-1)
input = input.expand(-1, -1, dim1, dim2)
return input
class STARLoss(nn.Module):
def __init__(self, w=1, dist='smoothl1', num_dim_image=2, EPSILON=1e-5):
super(STARLoss, self).__init__()
self.w = w
self.num_dim_image = num_dim_image
self.EPSILON = EPSILON
self.dist = dist
if self.dist == 'smoothl1':
self.dist_func = SmoothL1Loss()
elif self.dist == 'l1':
self.dist_func = F.l1_loss
elif self.dist == 'l2':
self.dist_func = F.mse_loss
elif self.dist == 'wing':
self.dist_func = WingLoss()
else:
raise NotImplementedError
def __repr__(self):
return "STARLoss()"
def _make_grid(self, h, w):
yy, xx = torch.meshgrid(
torch.arange(h).float() / (h - 1) * 2 - 1,
torch.arange(w).float() / (w - 1) * 2 - 1)
return yy, xx
def weighted_mean(self, heatmap):
batch, npoints, h, w = heatmap.shape
yy, xx = self._make_grid(h, w)
yy = yy.view(1, 1, h, w).to(heatmap)
xx = xx.view(1, 1, h, w).to(heatmap)
yy_coord = (yy * heatmap).sum([2, 3]) # batch x npoints
xx_coord = (xx * heatmap).sum([2, 3]) # batch x npoints
coords = torch.stack([xx_coord, yy_coord], dim=-1)
return coords
def unbiased_weighted_covariance(self, htp, means, num_dim_image=2, EPSILON=1e-5):
batch_size, num_points, height, width = htp.shape
yv, xv = self._make_grid(height, width)
xv = Variable(xv)
yv = Variable(yv)
if htp.is_cuda:
xv = xv.cuda()
yv = yv.cuda()
xmean = means[:, :, 0]
xv_minus_mean = xv.expand(batch_size, num_points, -1, -1) - expand_two_dimensions_at_end(xmean, height,
width) # [batch_size, 68, 64, 64]
ymean = means[:, :, 1]
yv_minus_mean = yv.expand(batch_size, num_points, -1, -1) - expand_two_dimensions_at_end(ymean, height,
width) # [batch_size, 68, 64, 64]
wt_xv_minus_mean = xv_minus_mean
wt_yv_minus_mean = yv_minus_mean
wt_xv_minus_mean = wt_xv_minus_mean.view(batch_size * num_points, height * width) # [batch_size*68, 4096]
wt_xv_minus_mean = wt_xv_minus_mean.view(batch_size * num_points, 1, height * width) # [batch_size*68, 1, 4096]
wt_yv_minus_mean = wt_yv_minus_mean.view(batch_size * num_points, height * width) # [batch_size*68, 4096]
wt_yv_minus_mean = wt_yv_minus_mean.view(batch_size * num_points, 1, height * width) # [batch_size*68, 1, 4096]
vec_concat = torch.cat((wt_xv_minus_mean, wt_yv_minus_mean), 1) # [batch_size*68, 2, 4096]
htp_vec = htp.view(batch_size * num_points, 1, height * width)
htp_vec = htp_vec.expand(-1, 2, -1)
covariance = torch.bmm(htp_vec * vec_concat, vec_concat.transpose(1, 2)) # [batch_size*68, 2, 2]
covariance = covariance.view(batch_size, num_points, num_dim_image, num_dim_image) # [batch_size, 68, 2, 2]
V_1 = htp.sum([2, 3]) + EPSILON # [batch_size, 68]
V_2 = torch.pow(htp, 2).sum([2, 3]) + EPSILON # [batch_size, 68]
denominator = V_1 - (V_2 / V_1)
covariance = covariance / expand_two_dimensions_at_end(denominator, num_dim_image, num_dim_image)
return covariance
def ambiguity_guided_decompose(self, pts, eigenvalues, eigenvectors):
batch_size, npoints = pts.shape[:2]
rotate = torch.matmul(pts.view(batch_size, npoints, 1, 2), eigenvectors.transpose(-1, -2))
scale = rotate.view(batch_size, npoints, 2) / torch.sqrt(eigenvalues + self.EPSILON)
return scale
def eigenvalue_restriction(self, evalues, batch, npoints):
eigen_loss = torch.abs(evalues.view(batch * npoints, 2)).sum(-1)
return eigen_loss.mean()
def forward(self, heatmap, groundtruth):
"""
heatmap: b x n x 64 x 64
groundtruth: b x n x 2
output: b x n x 1 => 1
"""
# normalize
bs, npoints, h, w = heatmap.shape
heatmap_sum = torch.clamp(heatmap.sum([2, 3]), min=1e-6)
heatmap = heatmap / heatmap_sum.view(bs, npoints, 1, 1)
means = self.weighted_mean(heatmap) # [bs, 68, 2]
covars = self.unbiased_weighted_covariance(heatmap, means) # covars [bs, 68, 2, 2]
# TODO: GPU-based eigen-decomposition
# https://github.com/pytorch/pytorch/issues/60537
_covars = covars.view(bs * npoints, 2, 2).cpu()
evalues, evectors = _covars.symeig(eigenvectors=True) # evalues [bs * 68, 2], evectors [bs * 68, 2, 2]
evalues = evalues.view(bs, npoints, 2).to(heatmap)
evectors = evectors.view(bs, npoints, 2, 2).to(heatmap)
# STAR Loss
# Ambiguity-guided Decomposition
error = self.ambiguity_guided_decompose(groundtruth - means, evalues, evectors)
loss_trans = self.dist_func(torch.zeros_like(error).to(error), error)
# Eigenvalue Restriction
loss_eigen = self.eigenvalue_restriction(evalues, bs, npoints)
star_loss = loss_trans + self.w * loss_eigen
return star_loss
|