""" ScanNet20 / ScanNet200 / ScanNet Data Efficient Dataset Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) Please cite our work if the code is helpful to you. """ import os import glob import numpy as np import torch from copy import deepcopy from torch.utils.data import Dataset from collections.abc import Sequence from pointcept.utils.logger import get_root_logger from pointcept.utils.cache import shared_dict from .builder import DATASETS from .defaults import DefaultDataset from .transform import Compose, TRANSFORMS from .preprocessing.scannet.meta_data.scannet200_constants import ( VALID_CLASS_IDS_20, VALID_CLASS_IDS_200, ) @DATASETS.register_module() class ScanNetDataset(DefaultDataset): VALID_ASSETS = [ "coord", "color", "normal", "segment20", "instance", ] class2id = np.array(VALID_CLASS_IDS_20) def __init__( self, lr_file=None, la_file=None, **kwargs, ): self.lr = np.loadtxt(lr_file, dtype=str) if lr_file is not None else None self.la = torch.load(la_file) if la_file is not None else None super().__init__(**kwargs) def get_data_list(self): if self.lr is None: data_list = super().get_data_list() else: data_list = [ os.path.join(self.data_root, "train", name) for name in self.lr ] return data_list def get_data(self, idx): data_path = self.data_list[idx % len(self.data_list)] name = self.get_data_name(idx) if self.cache: cache_name = f"pointcept-{name}" return shared_dict(cache_name) data_dict = {} assets = os.listdir(data_path) for asset in assets: if not asset.endswith(".npy"): continue if asset[:-4] not in self.VALID_ASSETS: continue data_dict[asset[:-4]] = np.load(os.path.join(data_path, asset)) data_dict["name"] = name data_dict["coord"] = data_dict["coord"].astype(np.float32) data_dict["color"] = data_dict["color"].astype(np.float32) data_dict["normal"] = data_dict["normal"].astype(np.float32) if "segment20" in data_dict.keys(): data_dict["segment"] = ( data_dict.pop("segment20").reshape([-1]).astype(np.int32) ) elif "segment200" in data_dict.keys(): data_dict["segment"] = ( data_dict.pop("segment200").reshape([-1]).astype(np.int32) ) else: data_dict["segment"] = ( np.ones(data_dict["coord"].shape[0], dtype=np.int32) * -1 ) if "instance" in data_dict.keys(): data_dict["instance"] = ( data_dict.pop("instance").reshape([-1]).astype(np.int32) ) else: data_dict["instance"] = ( np.ones(data_dict["coord"].shape[0], dtype=np.int32) * -1 ) if self.la: sampled_index = self.la[self.get_data_name(idx)] mask = np.ones_like(data_dict["segment"], dtype=bool) mask[sampled_index] = False data_dict["segment"][mask] = self.ignore_index data_dict["sampled_index"] = sampled_index return data_dict @DATASETS.register_module() class ScanNet200Dataset(ScanNetDataset): VALID_ASSETS = [ "coord", "color", "normal", "segment200", "instance", ] class2id = np.array(VALID_CLASS_IDS_200)