File size: 4,655 Bytes
db6a3b7
 
 
 
690b53e
db6a3b7
a6bbecf
db6a3b7
a6bbecf
 
 
 
 
 
 
 
 
 
 
db6a3b7
a6bbecf
db6a3b7
690b53e
a6bbecf
690b53e
a6bbecf
690b53e
db6a3b7
a6bbecf
 
 
 
 
 
 
 
 
db6a3b7
a6bbecf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db6a3b7
 
 
 
 
 
 
 
 
 
 
a6bbecf
 
 
db6a3b7
 
 
 
 
 
a6bbecf
 
 
db6a3b7
 
 
 
 
a6bbecf
 
 
 
 
db6a3b7
 
 
 
a6bbecf
 
 
 
 
 
 
 
 
 
db6a3b7
a6bbecf
db6a3b7
a6bbecf
 
 
 
 
 
 
 
db6a3b7
 
 
 
 
a6bbecf
 
db6a3b7
 
a6bbecf
 
 
db6a3b7
 
 
 
 
 
 
a6bbecf
 
 
db6a3b7
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch
import torch.nn as nn
from .. import SparseTensor
from .. import DEBUG
from . import SPCONV_ALGO


class SparseConv3d(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        dilation=1,
        padding=None,
        bias=True,
        indice_key=None,
    ):
        super(SparseConv3d, self).__init__()
        if "spconv" not in globals():
            import spconv.pytorch as spconv
        algo = None
        if SPCONV_ALGO == "native":
            algo = spconv.ConvAlgo.Native
        elif SPCONV_ALGO == "implicit_gemm":
            algo = spconv.ConvAlgo.MaskImplicitGemm
        if stride == 1 and (padding is None):
            self.conv = spconv.SubMConv3d(
                in_channels,
                out_channels,
                kernel_size,
                dilation=dilation,
                bias=bias,
                indice_key=indice_key,
                algo=algo,
            )
        else:
            self.conv = spconv.SparseConv3d(
                in_channels,
                out_channels,
                kernel_size,
                stride=stride,
                dilation=dilation,
                padding=padding,
                bias=bias,
                indice_key=indice_key,
                algo=algo,
            )
        self.stride = (
            tuple(stride)
            if isinstance(stride, (list, tuple))
            else (stride, stride, stride)
        )
        self.padding = padding

    def forward(self, x: SparseTensor) -> SparseTensor:
        spatial_changed = any(s != 1 for s in self.stride) or (self.padding is not None)
        new_data = self.conv(x.data)
        new_shape = [x.shape[0], self.conv.out_channels]
        new_layout = None if spatial_changed else x.layout

        if spatial_changed and (x.shape[0] != 1):
            # spconv was non-1 stride will break the contiguous of the output tensor, sort by the coords
            fwd = new_data.indices[:, 0].argsort()
            bwd = torch.zeros_like(fwd).scatter_(
                0, fwd, torch.arange(fwd.shape[0], device=fwd.device)
            )
            sorted_feats = new_data.features[fwd]
            sorted_coords = new_data.indices[fwd]
            unsorted_data = new_data
            new_data = spconv.SparseConvTensor(sorted_feats, sorted_coords, unsorted_data.spatial_shape, unsorted_data.batch_size)  # type: ignore

        out = SparseTensor(
            new_data,
            shape=torch.Size(new_shape),
            layout=new_layout,
            scale=tuple([s * stride for s, stride in zip(x._scale, self.stride)]),
            spatial_cache=x._spatial_cache,
        )

        if spatial_changed and (x.shape[0] != 1):
            out.register_spatial_cache(
                f"conv_{self.stride}_unsorted_data", unsorted_data
            )
            out.register_spatial_cache(f"conv_{self.stride}_sort_bwd", bwd)

        return out


class SparseInverseConv3d(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        dilation=1,
        bias=True,
        indice_key=None,
    ):
        super(SparseInverseConv3d, self).__init__()
        if "spconv" not in globals():
            import spconv.pytorch as spconv
        self.conv = spconv.SparseInverseConv3d(
            in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key
        )
        self.stride = (
            tuple(stride)
            if isinstance(stride, (list, tuple))
            else (stride, stride, stride)
        )

    def forward(self, x: SparseTensor) -> SparseTensor:
        spatial_changed = any(s != 1 for s in self.stride)
        if spatial_changed:
            # recover the original spconv order
            data = x.get_spatial_cache(f"conv_{self.stride}_unsorted_data")
            bwd = x.get_spatial_cache(f"conv_{self.stride}_sort_bwd")
            data = data.replace_feature(x.feats[bwd])
            if DEBUG:
                assert torch.equal(
                    data.indices, x.coords[bwd]
                ), "Recover the original order failed"
        else:
            data = x.data

        new_data = self.conv(data)
        new_shape = [x.shape[0], self.conv.out_channels]
        new_layout = None if spatial_changed else x.layout
        out = SparseTensor(
            new_data,
            shape=torch.Size(new_shape),
            layout=new_layout,
            scale=tuple([s // stride for s, stride in zip(x._scale, self.stride)]),
            spatial_cache=x._spatial_cache,
        )
        return out