File size: 1,229 Bytes
57746f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import torch
from torch.autograd import Function
from pointops._C import subtraction_forward_cuda, subtraction_backward_cuda
class Subtraction(Function):
@staticmethod
def forward(ctx, input1, input2, idx):
"""
input: input1: (n, c), input2: (n, c), idx: (n, nsample)
output: (n, nsample, c)
"""
assert input1.is_contiguous() and input2.is_contiguous()
n, c = input1.shape
nsample = idx.shape[-1]
output = torch.cuda.FloatTensor(n, nsample, c).zero_()
subtraction_forward_cuda(n, nsample, c, input1, input2, idx, output)
ctx.save_for_backward(idx)
return output
@staticmethod
def backward(ctx, grad_output):
"""
input: grad_out: (n, nsample, c)
output: grad_input1: (n, c), grad_input2: (n, c)
"""
(idx,) = ctx.saved_tensors
n, nsample, c = grad_output.shape
grad_input1 = torch.cuda.FloatTensor(n, c).zero_()
grad_input2 = torch.cuda.FloatTensor(n, c).zero_()
subtraction_backward_cuda(
n, nsample, c, idx, grad_output, grad_input1, grad_input2
)
return grad_input1, grad_input2, None
subtraction = Subtraction.apply
|