|
import torch |
|
from torch.autograd import Function |
|
|
|
from pointops._C import aggregation_forward_cuda, aggregation_backward_cuda |
|
|
|
|
|
class Aggregation(Function): |
|
@staticmethod |
|
def forward(ctx, input, position, weight, idx): |
|
""" |
|
input: input: (n, c), position: (n, nsample, c), weight : (n, nsample, c'), idx: (n, nsample) |
|
output: (n, c) |
|
""" |
|
assert ( |
|
input.is_contiguous() |
|
and position.is_contiguous() |
|
and weight.is_contiguous() |
|
) |
|
n, nsample, c = position.shape |
|
w_c = weight.shape[-1] |
|
output = torch.cuda.FloatTensor(n, c).zero_() |
|
aggregation_forward_cuda( |
|
n, nsample, c, w_c, input, position, weight, idx, output |
|
) |
|
ctx.save_for_backward(input, position, weight, idx) |
|
return output |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
""" |
|
input: grad_out: (n, c) |
|
output: grad_input: (n, c), grad_position: (n, nsample, c), grad_weight : (n, nsample, c') |
|
""" |
|
input, position, weight, idx = ctx.saved_tensors |
|
n, nsample, c = position.shape |
|
w_c = weight.shape[-1] |
|
grad_input = torch.cuda.FloatTensor(n, c).zero_() |
|
grad_position = torch.cuda.FloatTensor(n, nsample, c).zero_() |
|
grad_weight = torch.cuda.FloatTensor(n, nsample, w_c).zero_() |
|
aggregation_backward_cuda( |
|
n, |
|
nsample, |
|
c, |
|
w_c, |
|
input, |
|
position, |
|
weight, |
|
idx, |
|
grad_output, |
|
grad_input, |
|
grad_position, |
|
grad_weight, |
|
) |
|
return grad_input, grad_position, grad_weight, None |
|
|
|
|
|
aggregation = Aggregation.apply |
|
|