|
import sys |
|
import torch.nn as nn |
|
import spconv.pytorch as spconv |
|
from collections import OrderedDict |
|
from pointcept.models.utils.structure import Point |
|
|
|
|
|
class PointModule(nn.Module): |
|
r"""PointModule |
|
placeholder, all module subclass from this will take Point in PointSequential. |
|
""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
class PointSequential(PointModule): |
|
r"""A sequential container. |
|
Modules will be added to it in the order they are passed in the constructor. |
|
Alternatively, an ordered dict of modules can also be passed in. |
|
""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__() |
|
if len(args) == 1 and isinstance(args[0], OrderedDict): |
|
for key, module in args[0].items(): |
|
self.add_module(key, module) |
|
else: |
|
for idx, module in enumerate(args): |
|
self.add_module(str(idx), module) |
|
for name, module in kwargs.items(): |
|
if sys.version_info < (3, 6): |
|
raise ValueError("kwargs only supported in py36+") |
|
if name in self._modules: |
|
raise ValueError("name exists.") |
|
self.add_module(name, module) |
|
|
|
def __getitem__(self, idx): |
|
if not (-len(self) <= idx < len(self)): |
|
raise IndexError("index {} is out of range".format(idx)) |
|
if idx < 0: |
|
idx += len(self) |
|
it = iter(self._modules.values()) |
|
for i in range(idx): |
|
next(it) |
|
return next(it) |
|
|
|
def __len__(self): |
|
return len(self._modules) |
|
|
|
def add(self, module, name=None): |
|
if name is None: |
|
name = str(len(self._modules)) |
|
if name in self._modules: |
|
raise KeyError("name exists") |
|
self.add_module(name, module) |
|
|
|
def forward(self, input): |
|
for k, module in self._modules.items(): |
|
|
|
if isinstance(module, PointModule): |
|
input = module(input) |
|
|
|
elif spconv.modules.is_spconv_module(module): |
|
if isinstance(input, Point): |
|
input.sparse_conv_feat = module(input.sparse_conv_feat) |
|
input.feat = input.sparse_conv_feat.features |
|
else: |
|
input = module(input) |
|
|
|
else: |
|
if isinstance(input, Point): |
|
input.feat = module(input.feat) |
|
if "sparse_conv_feat" in input.keys(): |
|
input.sparse_conv_feat = input.sparse_conv_feat.replace_feature( |
|
input.feat |
|
) |
|
elif isinstance(input, spconv.SparseConvTensor): |
|
if input.indices.shape[0] != 0: |
|
input = input.replace_feature(module(input.features)) |
|
else: |
|
input = module(input) |
|
return input |
|
|