import sys import torch.nn as nn import spconv.pytorch as spconv from collections import OrderedDict from pointcept.models.utils.structure import Point class PointModule(nn.Module): r"""PointModule placeholder, all module subclass from this will take Point in PointSequential. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) class PointSequential(PointModule): r"""A sequential container. Modules will be added to it in the order they are passed in the constructor. Alternatively, an ordered dict of modules can also be passed in. """ def __init__(self, *args, **kwargs): super().__init__() if len(args) == 1 and isinstance(args[0], OrderedDict): for key, module in args[0].items(): self.add_module(key, module) else: for idx, module in enumerate(args): self.add_module(str(idx), module) for name, module in kwargs.items(): if sys.version_info < (3, 6): raise ValueError("kwargs only supported in py36+") if name in self._modules: raise ValueError("name exists.") self.add_module(name, module) def __getitem__(self, idx): if not (-len(self) <= idx < len(self)): raise IndexError("index {} is out of range".format(idx)) if idx < 0: idx += len(self) it = iter(self._modules.values()) for i in range(idx): next(it) return next(it) def __len__(self): return len(self._modules) def add(self, module, name=None): if name is None: name = str(len(self._modules)) if name in self._modules: raise KeyError("name exists") self.add_module(name, module) def forward(self, input): for k, module in self._modules.items(): # Point module if isinstance(module, PointModule): input = module(input) # Spconv module elif spconv.modules.is_spconv_module(module): if isinstance(input, Point): input.sparse_conv_feat = module(input.sparse_conv_feat) input.feat = input.sparse_conv_feat.features else: input = module(input) # PyTorch module else: if isinstance(input, Point): input.feat = module(input.feat) if "sparse_conv_feat" in input.keys(): input.sparse_conv_feat = input.sparse_conv_feat.replace_feature( input.feat ) elif isinstance(input, spconv.SparseConvTensor): if input.indices.shape[0] != 0: input = input.replace_feature(module(input.features)) else: input = module(input) return input