Spaces:
Build error
Build error
r""" Provides functions that manipulate boxes and points """ | |
import math | |
import torch.nn.functional as F | |
import torch | |
class Geometry(object): | |
def initialize(cls, img_size): | |
cls.img_size = img_size | |
cls.spatial_side = int(img_size / 8) | |
norm_grid1d = torch.linspace(-1, 1, cls.spatial_side) | |
cls.norm_grid_x = norm_grid1d.view(1, -1).repeat(cls.spatial_side, 1).view(1, 1, -1) | |
cls.norm_grid_y = norm_grid1d.view(-1, 1).repeat(1, cls.spatial_side).view(1, 1, -1) | |
cls.grid = torch.stack(list(reversed(torch.meshgrid(norm_grid1d, norm_grid1d)))).permute(1, 2, 0) | |
cls.feat_idx = torch.arange(0, cls.spatial_side).float() | |
def normalize_kps(cls, kps): | |
kps = kps.clone().detach() | |
kps[kps != -2] -= (cls.img_size // 2) | |
kps[kps != -2] /= (cls.img_size // 2) | |
return kps | |
def unnormalize_kps(cls, kps): | |
kps = kps.clone().detach() | |
kps[kps != -2] *= (cls.img_size // 2) | |
kps[kps != -2] += (cls.img_size // 2) | |
return kps | |
def attentive_indexing(cls, kps, thres=0.1): | |
r"""kps: normalized keypoints x, y (N, 2) | |
returns attentive index map(N, spatial_side, spatial_side) | |
""" | |
nkps = kps.size(0) | |
kps = kps.view(nkps, 1, 1, 2) | |
eps = 1e-5 | |
attmap = (cls.grid.unsqueeze(0).repeat(nkps, 1, 1, 1) - kps).pow(2).sum(dim=3) | |
attmap = (attmap + eps).pow(0.5) | |
attmap = (thres - attmap).clamp(min=0).view(nkps, -1) | |
attmap = attmap / attmap.sum(dim=1, keepdim=True) | |
attmap = attmap.view(nkps, cls.spatial_side, cls.spatial_side) | |
return attmap | |
def apply_gaussian_kernel(cls, corr, sigma=17): | |
bsz, side, side = corr.size() | |
center = corr.max(dim=2)[1] | |
center_y = center // cls.spatial_side | |
center_x = center % cls.spatial_side | |
y = cls.feat_idx.view(1, 1, cls.spatial_side).repeat(bsz, center_y.size(1), 1) - center_y.unsqueeze(2) | |
x = cls.feat_idx.view(1, 1, cls.spatial_side).repeat(bsz, center_x.size(1), 1) - center_x.unsqueeze(2) | |
y = y.unsqueeze(3).repeat(1, 1, 1, cls.spatial_side) | |
x = x.unsqueeze(2).repeat(1, 1, cls.spatial_side, 1) | |
gauss_kernel = torch.exp(-(x.pow(2) + y.pow(2)) / (2 * sigma ** 2)) | |
filtered_corr = gauss_kernel * corr.view(bsz, -1, cls.spatial_side, cls.spatial_side) | |
filtered_corr = filtered_corr.view(bsz, side, side) | |
return filtered_corr | |
def transfer_kps(cls, confidence_ts, src_kps, n_pts, normalized): | |
r""" Transfer keypoints by weighted average """ | |
if not normalized: | |
src_kps = Geometry.normalize_kps(src_kps) | |
confidence_ts = cls.apply_gaussian_kernel(confidence_ts) | |
pdf = F.softmax(confidence_ts, dim=2) | |
prd_x = (pdf * cls.norm_grid_x).sum(dim=2) | |
prd_y = (pdf * cls.norm_grid_y).sum(dim=2) | |
prd_kps = [] | |
for idx, (x, y, src_kp, np) in enumerate(zip(prd_x, prd_y, src_kps, n_pts)): | |
max_pts = src_kp.size()[1] | |
prd_xy = torch.stack([x, y]).t() | |
src_kp = src_kp[:, :np].t() | |
attmap = cls.attentive_indexing(src_kp).view(np, -1) | |
prd_kp = (prd_xy.unsqueeze(0) * attmap.unsqueeze(-1)).sum(dim=1).t() | |
pads = (torch.zeros((2, max_pts - np)) - 2) | |
prd_kp = torch.cat([prd_kp, pads], dim=1) | |
prd_kps.append(prd_kp) | |
return torch.stack(prd_kps) | |
def get_coord1d(coord4d, ksz): | |
i, j, k, l = coord4d | |
coord1d = i * (ksz ** 3) + j * (ksz ** 2) + k * (ksz) + l | |
return coord1d | |
def get_distance(coord1, coord2): | |
delta_y = int(math.pow(coord1[0] - coord2[0], 2)) | |
delta_x = int(math.pow(coord1[1] - coord2[1], 2)) | |
dist = delta_y + delta_x | |
return dist | |
def interpolate4d(tensor4d, size): | |
bsz, h1, w1, h2, w2 = tensor4d.size() | |
tensor4d = tensor4d.view(bsz, h1, w1, -1).permute(0, 3, 1, 2) | |
tensor4d = F.interpolate(tensor4d, size, mode='bilinear', align_corners=True) | |
tensor4d = tensor4d.view(bsz, h2, w2, -1).permute(0, 3, 1, 2) | |
tensor4d = F.interpolate(tensor4d, size, mode='bilinear', align_corners=True) | |
tensor4d = tensor4d.view(bsz, size[0], size[0], size[0], size[0]) | |
return tensor4d | |
def init_idx4d(ksz): | |
i0 = torch.arange(0, ksz).repeat(ksz ** 3) | |
i1 = torch.arange(0, ksz).unsqueeze(1).repeat(1, ksz).view(-1).repeat(ksz ** 2) | |
i2 = torch.arange(0, ksz).unsqueeze(1).repeat(1, ksz ** 2).view(-1).repeat(ksz) | |
i3 = torch.arange(0, ksz).unsqueeze(1).repeat(1, ksz ** 3).view(-1) | |
idx4d = torch.stack([i3, i2, i1, i0]).t().numpy() | |
return idx4d | |