|
import torch |
|
import torch.nn as nn |
|
from .. import SparseTensor |
|
from .. import DEBUG |
|
from . import SPCONV_ALGO |
|
|
|
|
|
class SparseConv3d(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
kernel_size, |
|
stride=1, |
|
dilation=1, |
|
padding=None, |
|
bias=True, |
|
indice_key=None, |
|
): |
|
super(SparseConv3d, self).__init__() |
|
if "spconv" not in globals(): |
|
import spconv.pytorch as spconv |
|
algo = None |
|
if SPCONV_ALGO == "native": |
|
algo = spconv.ConvAlgo.Native |
|
elif SPCONV_ALGO == "implicit_gemm": |
|
algo = spconv.ConvAlgo.MaskImplicitGemm |
|
if stride == 1 and (padding is None): |
|
self.conv = spconv.SubMConv3d( |
|
in_channels, |
|
out_channels, |
|
kernel_size, |
|
dilation=dilation, |
|
bias=bias, |
|
indice_key=indice_key, |
|
algo=algo, |
|
) |
|
else: |
|
self.conv = spconv.SparseConv3d( |
|
in_channels, |
|
out_channels, |
|
kernel_size, |
|
stride=stride, |
|
dilation=dilation, |
|
padding=padding, |
|
bias=bias, |
|
indice_key=indice_key, |
|
algo=algo, |
|
) |
|
self.stride = ( |
|
tuple(stride) |
|
if isinstance(stride, (list, tuple)) |
|
else (stride, stride, stride) |
|
) |
|
self.padding = padding |
|
|
|
def forward(self, x: SparseTensor) -> SparseTensor: |
|
spatial_changed = any(s != 1 for s in self.stride) or (self.padding is not None) |
|
new_data = self.conv(x.data) |
|
new_shape = [x.shape[0], self.conv.out_channels] |
|
new_layout = None if spatial_changed else x.layout |
|
|
|
if spatial_changed and (x.shape[0] != 1): |
|
|
|
fwd = new_data.indices[:, 0].argsort() |
|
bwd = torch.zeros_like(fwd).scatter_( |
|
0, fwd, torch.arange(fwd.shape[0], device=fwd.device) |
|
) |
|
sorted_feats = new_data.features[fwd] |
|
sorted_coords = new_data.indices[fwd] |
|
unsorted_data = new_data |
|
new_data = spconv.SparseConvTensor(sorted_feats, sorted_coords, unsorted_data.spatial_shape, unsorted_data.batch_size) |
|
|
|
out = SparseTensor( |
|
new_data, |
|
shape=torch.Size(new_shape), |
|
layout=new_layout, |
|
scale=tuple([s * stride for s, stride in zip(x._scale, self.stride)]), |
|
spatial_cache=x._spatial_cache, |
|
) |
|
|
|
if spatial_changed and (x.shape[0] != 1): |
|
out.register_spatial_cache( |
|
f"conv_{self.stride}_unsorted_data", unsorted_data |
|
) |
|
out.register_spatial_cache(f"conv_{self.stride}_sort_bwd", bwd) |
|
|
|
return out |
|
|
|
|
|
class SparseInverseConv3d(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
kernel_size, |
|
stride=1, |
|
dilation=1, |
|
bias=True, |
|
indice_key=None, |
|
): |
|
super(SparseInverseConv3d, self).__init__() |
|
if "spconv" not in globals(): |
|
import spconv.pytorch as spconv |
|
self.conv = spconv.SparseInverseConv3d( |
|
in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key |
|
) |
|
self.stride = ( |
|
tuple(stride) |
|
if isinstance(stride, (list, tuple)) |
|
else (stride, stride, stride) |
|
) |
|
|
|
def forward(self, x: SparseTensor) -> SparseTensor: |
|
spatial_changed = any(s != 1 for s in self.stride) |
|
if spatial_changed: |
|
|
|
data = x.get_spatial_cache(f"conv_{self.stride}_unsorted_data") |
|
bwd = x.get_spatial_cache(f"conv_{self.stride}_sort_bwd") |
|
data = data.replace_feature(x.feats[bwd]) |
|
if DEBUG: |
|
assert torch.equal( |
|
data.indices, x.coords[bwd] |
|
), "Recover the original order failed" |
|
else: |
|
data = x.data |
|
|
|
new_data = self.conv(data) |
|
new_shape = [x.shape[0], self.conv.out_channels] |
|
new_layout = None if spatial_changed else x.layout |
|
out = SparseTensor( |
|
new_data, |
|
shape=torch.Size(new_shape), |
|
layout=new_layout, |
|
scale=tuple([s // stride for s, stride in zip(x._scale, self.stride)]), |
|
spatial_cache=x._spatial_cache, |
|
) |
|
return out |
|
|