|
import torch |
|
from torch.autograd import Function |
|
|
|
from pointops._C import interpolation_forward_cuda, interpolation_backward_cuda |
|
from .query import knn_query |
|
|
|
|
|
def interpolation(xyz, new_xyz, feat, offset, new_offset, k=3): |
|
""" |
|
input: coords: (m, 3), new_xyz: (n, 3), color: (m, c), offset: (b), new_offset: (b) |
|
output: (n, c) |
|
""" |
|
assert xyz.is_contiguous() and new_xyz.is_contiguous() and feat.is_contiguous() |
|
idx, dist = knn_query(k, xyz, offset, new_xyz, new_offset) |
|
dist_recip = 1.0 / (dist + 1e-8) |
|
norm = torch.sum(dist_recip, dim=1, keepdim=True) |
|
weight = dist_recip / norm |
|
|
|
new_feat = torch.cuda.FloatTensor(new_xyz.shape[0], feat.shape[1]).zero_() |
|
for i in range(k): |
|
new_feat += feat[idx[:, i].long(), :] * weight[:, i].unsqueeze(-1) |
|
return new_feat |
|
|
|
|
|
class Interpolation(Function): |
|
@staticmethod |
|
def forward(ctx, xyz, new_xyz, input, offset, new_offset, k=3): |
|
""" |
|
input: coords: (m, 3), new_xyz: (n, 3), input: (m, c), offset: (b), new_offset: (b) |
|
output: (n, c) |
|
""" |
|
assert xyz.is_contiguous() and new_xyz.is_contiguous() and input.is_contiguous() |
|
idx, dist = knn_query(k, xyz, offset, new_xyz, new_offset) |
|
dist_recip = 1.0 / (dist + 1e-8) |
|
norm = torch.sum(dist_recip, dim=1, keepdim=True) |
|
weight = dist_recip / norm |
|
|
|
n, c, m = new_xyz.shape[0], input.shape[1], input.shape[0] |
|
output = torch.cuda.FloatTensor(n, c).zero_() |
|
interpolation_forward_cuda(n, c, k, input, idx, weight, output) |
|
ctx.m, ctx.k = m, k |
|
ctx.save_for_backward(idx, weight) |
|
return output |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
""" |
|
input: coords: (m, 3), new_xyz: (n, 3), input: (m, c), offset: (b), new_offset: (b) |
|
output: (n, c) |
|
""" |
|
m, k = ctx.m, ctx.k |
|
idx, weight = ctx.saved_tensors |
|
n, c = grad_output.shape |
|
grad_input = torch.cuda.FloatTensor(m, c).zero_() |
|
interpolation_backward_cuda(n, c, k, grad_output, idx, weight, grad_input) |
|
return None, None, grad_input, None, None, None |
|
|
|
|
|
interpolation2 = Interpolation.apply |
|
|