File size: 4,655 Bytes
db6a3b7 690b53e db6a3b7 a6bbecf db6a3b7 a6bbecf db6a3b7 a6bbecf db6a3b7 690b53e a6bbecf 690b53e a6bbecf 690b53e db6a3b7 a6bbecf db6a3b7 a6bbecf db6a3b7 a6bbecf db6a3b7 a6bbecf db6a3b7 a6bbecf db6a3b7 a6bbecf db6a3b7 a6bbecf db6a3b7 a6bbecf db6a3b7 a6bbecf db6a3b7 a6bbecf db6a3b7 a6bbecf db6a3b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import torch
import torch.nn as nn
from .. import SparseTensor
from .. import DEBUG
from . import SPCONV_ALGO
class SparseConv3d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
dilation=1,
padding=None,
bias=True,
indice_key=None,
):
super(SparseConv3d, self).__init__()
if "spconv" not in globals():
import spconv.pytorch as spconv
algo = None
if SPCONV_ALGO == "native":
algo = spconv.ConvAlgo.Native
elif SPCONV_ALGO == "implicit_gemm":
algo = spconv.ConvAlgo.MaskImplicitGemm
if stride == 1 and (padding is None):
self.conv = spconv.SubMConv3d(
in_channels,
out_channels,
kernel_size,
dilation=dilation,
bias=bias,
indice_key=indice_key,
algo=algo,
)
else:
self.conv = spconv.SparseConv3d(
in_channels,
out_channels,
kernel_size,
stride=stride,
dilation=dilation,
padding=padding,
bias=bias,
indice_key=indice_key,
algo=algo,
)
self.stride = (
tuple(stride)
if isinstance(stride, (list, tuple))
else (stride, stride, stride)
)
self.padding = padding
def forward(self, x: SparseTensor) -> SparseTensor:
spatial_changed = any(s != 1 for s in self.stride) or (self.padding is not None)
new_data = self.conv(x.data)
new_shape = [x.shape[0], self.conv.out_channels]
new_layout = None if spatial_changed else x.layout
if spatial_changed and (x.shape[0] != 1):
# spconv was non-1 stride will break the contiguous of the output tensor, sort by the coords
fwd = new_data.indices[:, 0].argsort()
bwd = torch.zeros_like(fwd).scatter_(
0, fwd, torch.arange(fwd.shape[0], device=fwd.device)
)
sorted_feats = new_data.features[fwd]
sorted_coords = new_data.indices[fwd]
unsorted_data = new_data
new_data = spconv.SparseConvTensor(sorted_feats, sorted_coords, unsorted_data.spatial_shape, unsorted_data.batch_size) # type: ignore
out = SparseTensor(
new_data,
shape=torch.Size(new_shape),
layout=new_layout,
scale=tuple([s * stride for s, stride in zip(x._scale, self.stride)]),
spatial_cache=x._spatial_cache,
)
if spatial_changed and (x.shape[0] != 1):
out.register_spatial_cache(
f"conv_{self.stride}_unsorted_data", unsorted_data
)
out.register_spatial_cache(f"conv_{self.stride}_sort_bwd", bwd)
return out
class SparseInverseConv3d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
dilation=1,
bias=True,
indice_key=None,
):
super(SparseInverseConv3d, self).__init__()
if "spconv" not in globals():
import spconv.pytorch as spconv
self.conv = spconv.SparseInverseConv3d(
in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key
)
self.stride = (
tuple(stride)
if isinstance(stride, (list, tuple))
else (stride, stride, stride)
)
def forward(self, x: SparseTensor) -> SparseTensor:
spatial_changed = any(s != 1 for s in self.stride)
if spatial_changed:
# recover the original spconv order
data = x.get_spatial_cache(f"conv_{self.stride}_unsorted_data")
bwd = x.get_spatial_cache(f"conv_{self.stride}_sort_bwd")
data = data.replace_feature(x.feats[bwd])
if DEBUG:
assert torch.equal(
data.indices, x.coords[bwd]
), "Recover the original order failed"
else:
data = x.data
new_data = self.conv(data)
new_shape = [x.shape[0], self.conv.out_channels]
new_layout = None if spatial_changed else x.layout
out = SparseTensor(
new_data,
shape=torch.Size(new_shape),
layout=new_layout,
scale=tuple([s // stride for s, stride in zip(x._scale, self.stride)]),
spatial_cache=x._spatial_cache,
)
return out
|