Spaces:
Build error
Build error
r""" 4D and 6D convolutional Hough matching layers """ | |
from torch.nn.modules.conv import _ConvNd | |
import torch.nn.functional as F | |
import torch.nn as nn | |
import torch | |
from common.logger import Logger | |
from . import chm_kernel | |
def fast4d(corr, kernel, bias=None): | |
r""" Optimized implementation of 4D convolution """ | |
bsz, ch, srch, srcw, trgh, trgw = corr.size() | |
out_channels, _, kernel_size, kernel_size, kernel_size, kernel_size = kernel.size() | |
psz = kernel_size // 2 | |
out_corr = torch.zeros((bsz, out_channels, srch, srcw, trgh, trgw)) | |
corr = corr.transpose(1, 2).contiguous().view(bsz * srch, ch, srcw, trgh, trgw) | |
for pidx, k3d in enumerate(kernel.permute(2, 0, 1, 3, 4, 5)): | |
inter_corr = F.conv3d(corr, k3d, bias=None, stride=1, padding=psz) | |
inter_corr = inter_corr.view(bsz, srch, out_channels, srcw, trgh, trgw).transpose(1, 2).contiguous() | |
add_sid = max(psz - pidx, 0) | |
add_fid = min(srch, srch + psz - pidx) | |
slc_sid = max(pidx - psz, 0) | |
slc_fid = min(srch, srch - psz + pidx) | |
out_corr[:, :, add_sid:add_fid, :, :, :] += inter_corr[:, :, slc_sid:slc_fid, :, :, :] | |
if bias is not None: | |
out_corr += bias.view(1, out_channels, 1, 1, 1, 1) | |
return out_corr | |
def fast6d(corr, kernel, bias, diagonal_idx): | |
r""" Optimized implementation of 6D convolutional Hough matching | |
NOTE: this function only supports kernel size of (3, 3, 5, 5, 5, 5). | |
r""" | |
bsz, _, s6d, s6d, s4d, s4d, s4d, s4d = corr.size() | |
_, _, ks6d, ks6d, ks4d, ks4d, ks4d, ks4d = kernel.size() | |
corr = corr.permute(0, 2, 3, 1, 4, 5, 6, 7).contiguous().view(-1, 1, s4d, s4d, s4d, s4d) | |
kernel = kernel.view(-1, ks6d ** 2, ks4d, ks4d, ks4d, ks4d).transpose(0, 1) | |
corr = fast4d(corr, kernel).view(bsz, s6d * s6d, ks6d * ks6d, s4d, s4d, s4d, s4d) | |
corr = corr.view(bsz, s6d, s6d, ks6d, ks6d, s4d, s4d, s4d, s4d).transpose(2, 3).\ | |
contiguous().view(-1, s6d * ks6d, s4d, s4d, s4d, s4d) | |
ndiag = s6d + (ks6d // 2) * 2 | |
first_sum = [] | |
for didx in diagonal_idx: | |
first_sum.append(corr[:, didx, :, :, :, :].sum(dim=1)) | |
first_sum = torch.stack(first_sum).transpose(0, 1).view(bsz, s6d * ks6d, ndiag, s4d, s4d, s4d, s4d) | |
corr = [] | |
for didx in diagonal_idx: | |
corr.append(first_sum[:, didx, :, :, :, :, :].sum(dim=1)) | |
sidx = ks6d // 2 | |
eidx = ndiag - sidx | |
corr = torch.stack(corr).transpose(0, 1)[:, sidx:eidx, sidx:eidx, :, :, :, :].unsqueeze(1).contiguous() | |
corr += bias.view(1, -1, 1, 1, 1, 1, 1, 1) | |
reverse_idx = torch.linspace(s6d * s6d - 1, 0, s6d * s6d).long() | |
corr = corr.view(bsz, 1, s6d * s6d, s4d, s4d, s4d, s4d)[:, :, reverse_idx, :, :, :, :].\ | |
view(bsz, 1, s6d, s6d, s4d, s4d, s4d, s4d) | |
return corr | |
def init_param_idx4d(param_dict): | |
param_idx = [] | |
for key in param_dict: | |
curr_offset = int(key.split('_')[-1]) | |
param_idx.append(torch.tensor(param_dict[key])) | |
return param_idx | |
class CHM4d(_ConvNd): | |
r""" 4D convolutional Hough matching layer | |
NOTE: this function only supports in_channels=1 and out_channels=1. | |
r""" | |
def __init__(self, in_channels, out_channels, ksz4d, ktype, bias=True): | |
super(CHM4d, self).__init__(in_channels, out_channels, (ksz4d,) * 4, | |
(1,) * 4, (0,) * 4, (1,) * 4, False, (0,) * 4, | |
1, bias, padding_mode='zeros') | |
# Zero kernel initialization | |
self.zero_kernel4d = torch.zeros((in_channels, out_channels, ksz4d, ksz4d, ksz4d, ksz4d)) | |
self.nkernels = in_channels * out_channels | |
# Initialize kernel indices | |
param_dict4d = chm_kernel.KernelGenerator(ksz4d, ktype).generate() | |
param_shared = param_dict4d is not None | |
if param_shared: | |
# Initialize the shared parameters (multiplied by the number of times being shared) | |
self.param_idx = init_param_idx4d(param_dict4d) | |
weights = torch.abs(torch.randn(len(self.param_idx) * self.nkernels)) * 1e-3 | |
for weight, param_idx in zip(weights.sort()[0], self.param_idx): | |
weight *= len(param_idx) | |
self.weight = nn.Parameter(weights) | |
else: # full kernel initialziation | |
self.param_idx = None | |
self.weight = nn.Parameter(torch.abs(self.weight)) | |
if bias: self.bias = nn.Parameter(torch.tensor(0.0)) | |
Logger.info('(%s) # params in CHM 4D: %d' % (ktype, len(self.weight.view(-1)))) | |
def forward(self, x): | |
kernel = self.init_kernel() | |
x = fast4d(x, kernel, self.bias) | |
return x | |
def init_kernel(self): | |
# Initialize CHM kernel (divided by the number of times being shared) | |
ksz = self.kernel_size[-1] | |
if self.param_idx is None: | |
kernel = self.weight | |
else: | |
kernel = torch.zeros_like(self.zero_kernel4d) | |
for idx, pdx in enumerate(self.param_idx): | |
kernel = kernel.view(-1, ksz, ksz, ksz, ksz) | |
for jdx, kernel_single in enumerate(kernel): | |
weight = self.weight[idx + jdx * len(self.param_idx)].repeat(len(pdx)) / len(pdx) | |
kernel_single.view(-1)[pdx] += weight | |
kernel = kernel.view(self.in_channels, self.out_channels, ksz, ksz, ksz, ksz) | |
return kernel | |
class CHM6d(_ConvNd): | |
r""" 6D convolutional Hough matching layer with kernel (3, 3, 5, 5, 5, 5) | |
NOTE: this function only supports in_channels=1 and out_channels=1. | |
r""" | |
def __init__(self, in_channels, out_channels, ksz6d, ksz4d, ktype): | |
kernel_size = (ksz6d, ksz6d, ksz4d, ksz4d, ksz4d, ksz4d) | |
super(CHM6d, self).__init__(in_channels, out_channels, kernel_size, (1,) * 6, | |
(0,) * 6, (1,) * 6, False, (0,) * 6, | |
1, bias=True, padding_mode='zeros') | |
# Zero kernel initialization | |
self.zero_kernel4d = torch.zeros((ksz4d, ksz4d, ksz4d, ksz4d)) | |
self.zero_kernel6d = torch.zeros((ksz6d, ksz6d, ksz4d, ksz4d, ksz4d, ksz4d)) | |
self.nkernels = in_channels * out_channels | |
# Initialize kernel indices | |
# Indices in scale-space where 4D convolutions are performed (3 by 3 scale-space) | |
self.diagonal_idx = [torch.tensor(x) for x in [[6], [3, 7], [0, 4, 8], [1, 5], [2]]] | |
param_dict4d = chm_kernel.KernelGenerator(ksz4d, ktype).generate() | |
param_shared = param_dict4d is not None | |
if param_shared: # psi & iso kernel initialization | |
if ktype == 'psi': | |
self.param_dict6d = [[4], [0, 8], [2, 6], [1, 3, 5, 7]] | |
elif ktype == 'iso': | |
self.param_dict6d = [[0, 4, 8], [2, 6], [1, 3, 5, 7]] | |
self.param_dict6d = [torch.tensor(i) for i in self.param_dict6d] | |
# Initialize the shared parameters (multiplied by the number of times being shared) | |
self.param_idx = init_param_idx4d(param_dict4d) | |
self.param = [] | |
for param_dict6d in self.param_dict6d: | |
weights = torch.abs(torch.randn(len(self.param_idx))) * 1e-3 | |
for weight, param_idx in zip(weights, self.param_idx): | |
weight *= (len(param_idx) * len(param_dict6d)) | |
self.param.append(nn.Parameter(weights)) | |
self.param = nn.ParameterList(self.param) | |
else: # full kernel initialziation | |
self.param_idx = None | |
self.param = nn.Parameter(torch.abs(self.weight) * 1e-3) | |
Logger.info('(%s) # params in CHM 6D: %d' % (ktype, sum([len(x.view(-1)) for x in self.param]))) | |
self.weight = None | |
def forward(self, corr): | |
kernel = self.init_kernel() | |
corr = fast6d(corr, kernel, self.bias, self.diagonal_idx) | |
return corr | |
def init_kernel(self): | |
# Initialize CHM kernel (divided by the number of times being shared) | |
if self.param_idx is None: | |
return self.param | |
kernel6d = torch.zeros_like(self.zero_kernel6d) | |
for idx, (param, param_dict6d) in enumerate(zip(self.param, self.param_dict6d)): | |
ksz4d = self.kernel_size[-1] | |
kernel4d = torch.zeros_like(self.zero_kernel4d) | |
for jdx, pdx in enumerate(self.param_idx): | |
kernel4d.view(-1)[pdx] += ((param[jdx] / len(pdx)) / len(param_dict6d)) | |
kernel6d.view(-1, ksz4d, ksz4d, ksz4d, ksz4d)[param_dict6d] += kernel4d.view(ksz4d, ksz4d, ksz4d, ksz4d) | |
kernel6d = kernel6d.unsqueeze(0).unsqueeze(0) | |
return kernel6d | |