kairunwen's picture
Update Code
57746f1
raw
history blame contribute delete
3.76 kB
import torch
from torch.autograd import Function
from pointops._C import (
attention_relation_step_forward_cuda,
attention_relation_step_backward_cuda,
attention_fusion_step_forward_cuda,
attention_fusion_step_backward_cuda,
)
class AttentionRelationStep(Function):
@staticmethod
def forward(ctx, query, key, weight, index_target, index_refer):
"""
input - query: (n, g, c), key: (n, g, c), weight: (c) 1_c for scatter attention,
index_target: (m), index_refer: (m)
output - relation: (M, g)
"""
assert (
query.is_contiguous()
and key.is_contiguous()
and index_target.is_contiguous()
and index_refer.is_contiguous()
and weight.is_contiguous()
)
assert index_target.shape[0] == index_refer.shape[0]
_, g, c = query.shape
m = index_target.shape[0]
output = torch.cuda.FloatTensor(m, g).zero_()
attention_relation_step_forward_cuda(
m, g, c, query, key, weight, index_target.int(), index_refer.int(), output
)
ctx.save_for_backward(query, key, weight, index_target, index_refer)
return output
@staticmethod
def backward(ctx, grad_output):
query, key, weight, index_target, index_refer = ctx.saved_tensors
n, g, c = query.shape
m = index_target.shape[0]
grad_query = torch.cuda.FloatTensor(n, g, c).zero_()
grad_key = torch.cuda.FloatTensor(n, g, c).zero_()
grad_weight = torch.cuda.FloatTensor(c).zero_()
attention_relation_step_backward_cuda(
m,
g,
c,
query,
grad_query,
key,
grad_key,
weight,
grad_weight,
index_target.int(),
index_refer.int(),
grad_output,
)
return grad_query, grad_key, None, None, None
class AttentionFusionStep(Function):
@staticmethod
def forward(ctx, weight, value, index_target, index_refer):
"""
input - weight: (m, g), value: (n, g, c)
index_target: (m), index_value: (m)
output - output: (n, g, c)
"""
assert (
weight.is_contiguous()
and value.is_contiguous()
and index_target.is_contiguous()
and index_refer.is_contiguous()
and weight.is_contiguous()
)
assert index_target.shape[0] == index_refer.shape[0]
n, g, c = value.shape
m = index_refer.shape[0]
output = torch.cuda.FloatTensor(n, g, c).zero_()
attention_fusion_step_forward_cuda(
m, g, c, weight, value, index_target.int(), index_refer.int(), output
)
ctx.save_for_backward(weight, value, index_target, index_refer)
return output
@staticmethod
def backward(ctx, grad_output):
"""
input: grad_output: (n, g, c)
output: grad_weight: (m, g), grad_value: (n, g, c), none, none
"""
weight, value, index_target, index_refer = ctx.saved_tensors
n, g, c = value.shape
m = index_target.shape[0]
grad_weight = torch.cuda.FloatTensor(m, g).zero_()
grad_value = torch.cuda.FloatTensor(n, g, c).zero_()
attention_fusion_step_backward_cuda(
m,
g,
c,
weight,
grad_weight,
value,
grad_value,
index_target.int(),
index_refer.int(),
grad_output,
)
return grad_weight, grad_value, None, None
attention_relation_step = AttentionRelationStep.apply
attention_fusion_step = AttentionFusionStep.apply