File size: 5,906 Bytes
17cd746
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

from .smoothL1Loss import SmoothL1Loss
from .wingLoss import WingLoss


def get_channel_sum(input):
    temp = torch.sum(input, dim=3)
    output = torch.sum(temp, dim=2)
    return output


def expand_two_dimensions_at_end(input, dim1, dim2):
    input = input.unsqueeze(-1).unsqueeze(-1)
    input = input.expand(-1, -1, dim1, dim2)
    return input


class STARLoss(nn.Module):
    def __init__(self, w=1, dist='smoothl1', num_dim_image=2, EPSILON=1e-5):
        super(STARLoss, self).__init__()
        self.w = w
        self.num_dim_image = num_dim_image
        self.EPSILON = EPSILON
        self.dist = dist
        if self.dist == 'smoothl1':
            self.dist_func = SmoothL1Loss()
        elif self.dist == 'l1':
            self.dist_func = F.l1_loss
        elif self.dist == 'l2':
            self.dist_func = F.mse_loss
        elif self.dist == 'wing':
            self.dist_func = WingLoss()
        else:
            raise NotImplementedError

    def __repr__(self):
        return "STARLoss()"

    def _make_grid(self, h, w):
        yy, xx = torch.meshgrid(
            torch.arange(h).float() / (h - 1) * 2 - 1,
            torch.arange(w).float() / (w - 1) * 2 - 1)
        return yy, xx

    def weighted_mean(self, heatmap):
        batch, npoints, h, w = heatmap.shape

        yy, xx = self._make_grid(h, w)
        yy = yy.view(1, 1, h, w).to(heatmap)
        xx = xx.view(1, 1, h, w).to(heatmap)

        yy_coord = (yy * heatmap).sum([2, 3])  # batch x npoints
        xx_coord = (xx * heatmap).sum([2, 3])  # batch x npoints
        coords = torch.stack([xx_coord, yy_coord], dim=-1)
        return coords

    def unbiased_weighted_covariance(self, htp, means, num_dim_image=2, EPSILON=1e-5):
        batch_size, num_points, height, width = htp.shape

        yv, xv = self._make_grid(height, width)
        xv = Variable(xv)
        yv = Variable(yv)

        if htp.is_cuda:
            xv = xv.cuda()
            yv = yv.cuda()

        xmean = means[:, :, 0]
        xv_minus_mean = xv.expand(batch_size, num_points, -1, -1) - expand_two_dimensions_at_end(xmean, height,
                                                                                                 width)  # [batch_size, 68, 64, 64]
        ymean = means[:, :, 1]
        yv_minus_mean = yv.expand(batch_size, num_points, -1, -1) - expand_two_dimensions_at_end(ymean, height,
                                                                                                 width)  # [batch_size, 68, 64, 64]
        wt_xv_minus_mean = xv_minus_mean
        wt_yv_minus_mean = yv_minus_mean

        wt_xv_minus_mean = wt_xv_minus_mean.view(batch_size * num_points, height * width)  # [batch_size*68, 4096]
        wt_xv_minus_mean = wt_xv_minus_mean.view(batch_size * num_points, 1, height * width)  # [batch_size*68, 1, 4096]
        wt_yv_minus_mean = wt_yv_minus_mean.view(batch_size * num_points, height * width)  # [batch_size*68, 4096]
        wt_yv_minus_mean = wt_yv_minus_mean.view(batch_size * num_points, 1, height * width)  # [batch_size*68, 1, 4096]
        vec_concat = torch.cat((wt_xv_minus_mean, wt_yv_minus_mean), 1)  # [batch_size*68, 2, 4096]

        htp_vec = htp.view(batch_size * num_points, 1, height * width)
        htp_vec = htp_vec.expand(-1, 2, -1)

        covariance = torch.bmm(htp_vec * vec_concat, vec_concat.transpose(1, 2))  # [batch_size*68, 2, 2]
        covariance = covariance.view(batch_size, num_points, num_dim_image, num_dim_image)  # [batch_size, 68, 2, 2]

        V_1 = htp.sum([2, 3]) + EPSILON  # [batch_size, 68]
        V_2 = torch.pow(htp, 2).sum([2, 3]) + EPSILON  # [batch_size, 68]

        denominator = V_1 - (V_2 / V_1)
        covariance = covariance / expand_two_dimensions_at_end(denominator, num_dim_image, num_dim_image)

        return covariance

    def ambiguity_guided_decompose(self, pts, eigenvalues, eigenvectors):
        batch_size, npoints = pts.shape[:2]
        rotate = torch.matmul(pts.view(batch_size, npoints, 1, 2), eigenvectors.transpose(-1, -2))
        scale = rotate.view(batch_size, npoints, 2) / torch.sqrt(eigenvalues + self.EPSILON)
        return scale

    def eigenvalue_restriction(self, evalues, batch, npoints):
        eigen_loss = torch.abs(evalues.view(batch * npoints, 2)).sum(-1)
        return eigen_loss.mean()

    def forward(self, heatmap, groundtruth):
        """

            heatmap:     b x n x 64 x 64

            groundtruth: b x n x 2

            output:      b x n x 1 => 1

        """
        # normalize
        bs, npoints, h, w = heatmap.shape
        heatmap_sum = torch.clamp(heatmap.sum([2, 3]), min=1e-6)
        heatmap = heatmap / heatmap_sum.view(bs, npoints, 1, 1)

        means = self.weighted_mean(heatmap)  # [bs, 68, 2]
        covars = self.unbiased_weighted_covariance(heatmap, means)  # covars [bs, 68, 2, 2]

        # TODO: GPU-based eigen-decomposition
        # https://github.com/pytorch/pytorch/issues/60537
        _covars = covars.view(bs * npoints, 2, 2).cpu()
        evalues, evectors = _covars.symeig(eigenvectors=True)  # evalues [bs * 68, 2], evectors [bs * 68, 2, 2]
        evalues = evalues.view(bs, npoints, 2).to(heatmap)
        evectors = evectors.view(bs, npoints, 2, 2).to(heatmap)

        # STAR Loss
        # Ambiguity-guided Decomposition
        error = self.ambiguity_guided_decompose(groundtruth - means, evalues, evectors)
        loss_trans = self.dist_func(torch.zeros_like(error).to(error), error)
        # Eigenvalue Restriction
        loss_eigen = self.eigenvalue_restriction(evalues, bs, npoints)
        star_loss = loss_trans + self.w * loss_eigen

        return star_loss