File size: 1,757 Bytes
57746f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import torch
from torch.autograd import Function
from pointops._C import aggregation_forward_cuda, aggregation_backward_cuda
class Aggregation(Function):
@staticmethod
def forward(ctx, input, position, weight, idx):
"""
input: input: (n, c), position: (n, nsample, c), weight : (n, nsample, c'), idx: (n, nsample)
output: (n, c)
"""
assert (
input.is_contiguous()
and position.is_contiguous()
and weight.is_contiguous()
)
n, nsample, c = position.shape
w_c = weight.shape[-1]
output = torch.cuda.FloatTensor(n, c).zero_()
aggregation_forward_cuda(
n, nsample, c, w_c, input, position, weight, idx, output
)
ctx.save_for_backward(input, position, weight, idx)
return output
@staticmethod
def backward(ctx, grad_output):
"""
input: grad_out: (n, c)
output: grad_input: (n, c), grad_position: (n, nsample, c), grad_weight : (n, nsample, c')
"""
input, position, weight, idx = ctx.saved_tensors
n, nsample, c = position.shape
w_c = weight.shape[-1]
grad_input = torch.cuda.FloatTensor(n, c).zero_()
grad_position = torch.cuda.FloatTensor(n, nsample, c).zero_()
grad_weight = torch.cuda.FloatTensor(n, nsample, w_c).zero_()
aggregation_backward_cuda(
n,
nsample,
c,
w_c,
input,
position,
weight,
idx,
grad_output,
grad_input,
grad_position,
grad_weight,
)
return grad_input, grad_position, grad_weight, None
aggregation = Aggregation.apply
|