kairunwen's picture
Update Code
57746f1
"""
ScanNet20 / ScanNet200 / ScanNet Data Efficient Dataset
Author: Xiaoyang Wu ([email protected])
Please cite our work if the code is helpful to you.
"""
import os
import glob
import numpy as np
import torch
from copy import deepcopy
from torch.utils.data import Dataset
from collections.abc import Sequence
from pointcept.utils.logger import get_root_logger
from pointcept.utils.cache import shared_dict
from .builder import DATASETS
from .defaults import DefaultDataset
from .transform import Compose, TRANSFORMS
from .preprocessing.scannet.meta_data.scannet200_constants import (
VALID_CLASS_IDS_20,
VALID_CLASS_IDS_200,
)
@DATASETS.register_module()
class ScanNetDataset(DefaultDataset):
VALID_ASSETS = [
"coord",
"color",
"normal",
"segment20",
"instance",
]
class2id = np.array(VALID_CLASS_IDS_20)
def __init__(
self,
lr_file=None,
la_file=None,
**kwargs,
):
self.lr = np.loadtxt(lr_file, dtype=str) if lr_file is not None else None
self.la = torch.load(la_file) if la_file is not None else None
super().__init__(**kwargs)
def get_data_list(self):
if self.lr is None:
data_list = super().get_data_list()
else:
data_list = [
os.path.join(self.data_root, "train", name) for name in self.lr
]
return data_list
def get_data(self, idx):
data_path = self.data_list[idx % len(self.data_list)]
name = self.get_data_name(idx)
if self.cache:
cache_name = f"pointcept-{name}"
return shared_dict(cache_name)
data_dict = {}
assets = os.listdir(data_path)
for asset in assets:
if not asset.endswith(".npy"):
continue
if asset[:-4] not in self.VALID_ASSETS:
continue
data_dict[asset[:-4]] = np.load(os.path.join(data_path, asset))
data_dict["name"] = name
data_dict["coord"] = data_dict["coord"].astype(np.float32)
data_dict["color"] = data_dict["color"].astype(np.float32)
data_dict["normal"] = data_dict["normal"].astype(np.float32)
if "segment20" in data_dict.keys():
data_dict["segment"] = (
data_dict.pop("segment20").reshape([-1]).astype(np.int32)
)
elif "segment200" in data_dict.keys():
data_dict["segment"] = (
data_dict.pop("segment200").reshape([-1]).astype(np.int32)
)
else:
data_dict["segment"] = (
np.ones(data_dict["coord"].shape[0], dtype=np.int32) * -1
)
if "instance" in data_dict.keys():
data_dict["instance"] = (
data_dict.pop("instance").reshape([-1]).astype(np.int32)
)
else:
data_dict["instance"] = (
np.ones(data_dict["coord"].shape[0], dtype=np.int32) * -1
)
if self.la:
sampled_index = self.la[self.get_data_name(idx)]
mask = np.ones_like(data_dict["segment"], dtype=bool)
mask[sampled_index] = False
data_dict["segment"][mask] = self.ignore_index
data_dict["sampled_index"] = sampled_index
return data_dict
@DATASETS.register_module()
class ScanNet200Dataset(ScanNetDataset):
VALID_ASSETS = [
"coord",
"color",
"normal",
"segment200",
"instance",
]
class2id = np.array(VALID_CLASS_IDS_200)