|
import torch |
|
from torch.autograd import Function |
|
|
|
from pointops._C import subtraction_forward_cuda, subtraction_backward_cuda |
|
|
|
|
|
class Subtraction(Function): |
|
@staticmethod |
|
def forward(ctx, input1, input2, idx): |
|
""" |
|
input: input1: (n, c), input2: (n, c), idx: (n, nsample) |
|
output: (n, nsample, c) |
|
""" |
|
assert input1.is_contiguous() and input2.is_contiguous() |
|
n, c = input1.shape |
|
nsample = idx.shape[-1] |
|
output = torch.cuda.FloatTensor(n, nsample, c).zero_() |
|
subtraction_forward_cuda(n, nsample, c, input1, input2, idx, output) |
|
ctx.save_for_backward(idx) |
|
return output |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
""" |
|
input: grad_out: (n, nsample, c) |
|
output: grad_input1: (n, c), grad_input2: (n, c) |
|
""" |
|
(idx,) = ctx.saved_tensors |
|
n, nsample, c = grad_output.shape |
|
grad_input1 = torch.cuda.FloatTensor(n, c).zero_() |
|
grad_input2 = torch.cuda.FloatTensor(n, c).zero_() |
|
subtraction_backward_cuda( |
|
n, nsample, c, idx, grad_output, grad_input1, grad_input2 |
|
) |
|
return grad_input1, grad_input2, None |
|
|
|
|
|
subtraction = Subtraction.apply |
|
|