|
""" |
|
ArkitScenes Dataset |
|
|
|
Author: Xiaoyang Wu ([email protected]) |
|
Please cite our work if the code is helpful to you. |
|
""" |
|
|
|
import os |
|
import glob |
|
import numpy as np |
|
import torch |
|
from copy import deepcopy |
|
from torch.utils.data import Dataset |
|
|
|
from pointcept.utils.logger import get_root_logger |
|
from .builder import DATASETS |
|
from .transform import Compose, TRANSFORMS |
|
from .preprocessing.scannet.meta_data.scannet200_constants import VALID_CLASS_IDS_200 |
|
|
|
|
|
@DATASETS.register_module() |
|
class ArkitScenesDataset(Dataset): |
|
def __init__( |
|
self, |
|
split="Training", |
|
data_root="data/ARKitScenesMesh", |
|
transform=None, |
|
test_mode=False, |
|
test_cfg=None, |
|
loop=1, |
|
): |
|
super(ArkitScenesDataset, self).__init__() |
|
self.data_root = data_root |
|
self.split = split |
|
self.transform = Compose(transform) |
|
self.loop = ( |
|
loop if not test_mode else 1 |
|
) |
|
self.test_mode = test_mode |
|
self.test_cfg = test_cfg if test_mode else None |
|
self.class2id = np.array(VALID_CLASS_IDS_200) |
|
|
|
if test_mode: |
|
self.test_voxelize = TRANSFORMS.build(self.test_cfg.voxelize) |
|
self.test_crop = TRANSFORMS.build(self.test_cfg.crop) |
|
self.post_transform = Compose(self.test_cfg.post_transform) |
|
self.aug_transform = [Compose(aug) for aug in self.test_cfg.aug_transform] |
|
|
|
self.data_list = self.get_data_list() |
|
logger = get_root_logger() |
|
logger.info( |
|
"Totally {} x {} samples in {} set.".format( |
|
len(self.data_list), self.loop, split |
|
) |
|
) |
|
|
|
def get_data_list(self): |
|
if isinstance(self.split, str): |
|
data_list = glob.glob(os.path.join(self.data_root, self.split, "*.pth")) |
|
elif isinstance(self.split, list): |
|
data_list = [] |
|
for split in self.split: |
|
data_list += glob.glob(os.path.join(self.data_root, split, "*.pth")) |
|
else: |
|
raise NotImplementedError |
|
return data_list |
|
|
|
def get_data(self, idx): |
|
data = torch.load(self.data_list[idx % len(self.data_list)]) |
|
coord = data["coord"] |
|
color = data["color"] |
|
normal = data["normal"] |
|
segment = np.zeros(coord.shape[0]) |
|
data_dict = dict(coord=coord, normal=normal, color=color, segment=segment) |
|
return data_dict |
|
|
|
def get_data_name(self, idx): |
|
data_idx = self.data_idx[idx % len(self.data_idx)] |
|
return os.path.basename(self.data_list[data_idx]).split(".")[0] |
|
|
|
def prepare_train_data(self, idx): |
|
|
|
data_dict = self.get_data(idx) |
|
data_dict = self.transform(data_dict) |
|
return data_dict |
|
|
|
def prepare_test_data(self, idx): |
|
|
|
data_dict = self.get_data(idx) |
|
segment = data_dict.pop("segment") |
|
data_dict = self.transform(data_dict) |
|
data_dict_list = [] |
|
for aug in self.aug_transform: |
|
data_dict_list.append(aug(deepcopy(data_dict))) |
|
|
|
input_dict_list = [] |
|
for data in data_dict_list: |
|
data_part_list = self.test_voxelize(data) |
|
for data_part in data_part_list: |
|
data_part_list = self.test_crop(data_part) |
|
input_dict_list += data_part_list |
|
|
|
for i in range(len(input_dict_list)): |
|
input_dict_list[i] = self.post_transform(input_dict_list[i]) |
|
return input_dict_list, segment |
|
|
|
def __getitem__(self, idx): |
|
if self.test_mode: |
|
return self.prepare_test_data(idx) |
|
else: |
|
return self.prepare_train_data(idx) |
|
|
|
def __len__(self): |
|
return len(self.data_list) * self.loop |
|
|