kairunwen's picture
Update Code
57746f1
raw
history blame contribute delete
1.76 kB
import torch
from torch.autograd import Function
from pointops._C import aggregation_forward_cuda, aggregation_backward_cuda
class Aggregation(Function):
@staticmethod
def forward(ctx, input, position, weight, idx):
"""
input: input: (n, c), position: (n, nsample, c), weight : (n, nsample, c'), idx: (n, nsample)
output: (n, c)
"""
assert (
input.is_contiguous()
and position.is_contiguous()
and weight.is_contiguous()
)
n, nsample, c = position.shape
w_c = weight.shape[-1]
output = torch.cuda.FloatTensor(n, c).zero_()
aggregation_forward_cuda(
n, nsample, c, w_c, input, position, weight, idx, output
)
ctx.save_for_backward(input, position, weight, idx)
return output
@staticmethod
def backward(ctx, grad_output):
"""
input: grad_out: (n, c)
output: grad_input: (n, c), grad_position: (n, nsample, c), grad_weight : (n, nsample, c')
"""
input, position, weight, idx = ctx.saved_tensors
n, nsample, c = position.shape
w_c = weight.shape[-1]
grad_input = torch.cuda.FloatTensor(n, c).zero_()
grad_position = torch.cuda.FloatTensor(n, nsample, c).zero_()
grad_weight = torch.cuda.FloatTensor(n, nsample, w_c).zero_()
aggregation_backward_cuda(
n,
nsample,
c,
w_c,
input,
position,
weight,
idx,
grad_output,
grad_input,
grad_position,
grad_weight,
)
return grad_input, grad_position, grad_weight, None
aggregation = Aggregation.apply