File size: 2,237 Bytes
57746f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
"""
ScanNet++ dataset
Author: Xiaoyang Wu ([email protected])
Please cite our work if the code is helpful to you.
"""
import os
import numpy as np
import glob
from pointcept.utils.cache import shared_dict
from .builder import DATASETS
from .defaults import DefaultDataset
@DATASETS.register_module()
class ScanNetPPDataset(DefaultDataset):
VALID_ASSETS = [
"coord",
"color",
"normal",
"segment",
"instance",
]
def __init__(
self,
multilabel=False,
**kwargs,
):
super().__init__(**kwargs)
self.multilabel = multilabel
def get_data(self, idx):
data_path = self.data_list[idx % len(self.data_list)]
name = self.get_data_name(idx)
if self.cache:
cache_name = f"pointcept-{name}"
return shared_dict(cache_name)
data_dict = {}
assets = os.listdir(data_path)
for asset in assets:
if not asset.endswith(".npy"):
continue
if asset[:-4] not in self.VALID_ASSETS:
continue
data_dict[asset[:-4]] = np.load(os.path.join(data_path, asset))
data_dict["name"] = name
if "coord" in data_dict.keys():
data_dict["coord"] = data_dict["coord"].astype(np.float32)
if "color" in data_dict.keys():
data_dict["color"] = data_dict["color"].astype(np.float32)
if "normal" in data_dict.keys():
data_dict["normal"] = data_dict["normal"].astype(np.float32)
if not self.multilabel:
if "segment" in data_dict.keys():
data_dict["segment"] = data_dict["segment"][:, 0].astype(np.int32)
else:
data_dict["segment"] = (
np.ones(data_dict["coord"].shape[0], dtype=np.int32) * -1
)
if "instance" in data_dict.keys():
data_dict["instance"] = data_dict["instance"][:, 0].astype(np.int32)
else:
data_dict["instance"] = (
np.ones(data_dict["coord"].shape[0], dtype=np.int32) * -1
)
else:
raise NotImplementedError
return data_dict
|