|
import torch
|
|
|
|
|
|
class decoder_default:
|
|
def __init__(self, weight=1, use_weight_map=False):
|
|
self.weight = weight
|
|
self.use_weight_map = use_weight_map
|
|
|
|
def _make_grid(self, h, w):
|
|
yy, xx = torch.meshgrid(
|
|
torch.arange(h).float() / (h - 1) * 2 - 1,
|
|
torch.arange(w).float() / (w - 1) * 2 - 1)
|
|
return yy, xx
|
|
|
|
def get_coords_from_heatmap(self, heatmap):
|
|
"""
|
|
inputs:
|
|
- heatmap: batch x npoints x h x w
|
|
|
|
outputs:
|
|
- coords: batch x npoints x 2 (x,y), [-1, +1]
|
|
- radius_sq: batch x npoints
|
|
"""
|
|
batch, npoints, h, w = heatmap.shape
|
|
if self.use_weight_map:
|
|
heatmap = heatmap * self.weight
|
|
|
|
yy, xx = self._make_grid(h, w)
|
|
yy = yy.view(1, 1, h, w).to(heatmap)
|
|
xx = xx.view(1, 1, h, w).to(heatmap)
|
|
|
|
heatmap_sum = torch.clamp(heatmap.sum([2, 3]), min=1e-6)
|
|
|
|
yy_coord = (yy * heatmap).sum([2, 3]) / heatmap_sum
|
|
xx_coord = (xx * heatmap).sum([2, 3]) / heatmap_sum
|
|
coords = torch.stack([xx_coord, yy_coord], dim=-1)
|
|
|
|
return coords
|
|
|