kairunwen's picture
Update Code
57746f1
raw
history blame contribute delete
1.23 kB
import torch
from torch.autograd import Function
from pointops._C import subtraction_forward_cuda, subtraction_backward_cuda
class Subtraction(Function):
@staticmethod
def forward(ctx, input1, input2, idx):
"""
input: input1: (n, c), input2: (n, c), idx: (n, nsample)
output: (n, nsample, c)
"""
assert input1.is_contiguous() and input2.is_contiguous()
n, c = input1.shape
nsample = idx.shape[-1]
output = torch.cuda.FloatTensor(n, nsample, c).zero_()
subtraction_forward_cuda(n, nsample, c, input1, input2, idx, output)
ctx.save_for_backward(idx)
return output
@staticmethod
def backward(ctx, grad_output):
"""
input: grad_out: (n, nsample, c)
output: grad_input1: (n, c), grad_input2: (n, c)
"""
(idx,) = ctx.saved_tensors
n, nsample, c = grad_output.shape
grad_input1 = torch.cuda.FloatTensor(n, c).zero_()
grad_input2 = torch.cuda.FloatTensor(n, c).zero_()
subtraction_backward_cuda(
n, nsample, c, idx, grad_output, grad_input1, grad_input2
)
return grad_input1, grad_input2, None
subtraction = Subtraction.apply