|
""" |
|
Semantic KITTI dataset |
|
|
|
Author: Xiaoyang Wu ([email protected]) |
|
Please cite our work if the code is helpful to you. |
|
""" |
|
|
|
import os |
|
import numpy as np |
|
|
|
from .builder import DATASETS |
|
from .defaults import DefaultDataset |
|
|
|
|
|
@DATASETS.register_module() |
|
class SemanticKITTIDataset(DefaultDataset): |
|
def __init__(self, ignore_index=-1, **kwargs): |
|
self.ignore_index = ignore_index |
|
self.learning_map = self.get_learning_map(ignore_index) |
|
self.learning_map_inv = self.get_learning_map_inv(ignore_index) |
|
super().__init__(ignore_index=ignore_index, **kwargs) |
|
|
|
def get_data_list(self): |
|
split2seq = dict( |
|
train=[0, 1, 2, 3, 4, 5, 6, 7, 9, 10], |
|
val=[8], |
|
test=[11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], |
|
) |
|
if isinstance(self.split, str): |
|
seq_list = split2seq[self.split] |
|
elif isinstance(self.split, list): |
|
seq_list = [] |
|
for split in self.split: |
|
seq_list += split2seq[split] |
|
else: |
|
raise NotImplementedError |
|
|
|
data_list = [] |
|
for seq in seq_list: |
|
seq = str(seq).zfill(2) |
|
seq_folder = os.path.join(self.data_root, "dataset", "sequences", seq) |
|
seq_files = sorted(os.listdir(os.path.join(seq_folder, "velodyne"))) |
|
data_list += [ |
|
os.path.join(seq_folder, "velodyne", file) for file in seq_files |
|
] |
|
return data_list |
|
|
|
def get_data(self, idx): |
|
data_path = self.data_list[idx % len(self.data_list)] |
|
with open(data_path, "rb") as b: |
|
scan = np.fromfile(b, dtype=np.float32).reshape(-1, 4) |
|
coord = scan[:, :3] |
|
strength = scan[:, -1].reshape([-1, 1]) |
|
|
|
label_file = data_path.replace("velodyne", "labels").replace(".bin", ".label") |
|
if os.path.exists(label_file): |
|
with open(label_file, "rb") as a: |
|
segment = np.fromfile(a, dtype=np.int32).reshape(-1) |
|
segment = np.vectorize(self.learning_map.__getitem__)( |
|
segment & 0xFFFF |
|
).astype(np.int32) |
|
else: |
|
segment = np.zeros(scan.shape[0]).astype(np.int32) |
|
data_dict = dict( |
|
coord=coord, |
|
strength=strength, |
|
segment=segment, |
|
name=self.get_data_name(idx), |
|
) |
|
return data_dict |
|
|
|
def get_data_name(self, idx): |
|
file_path = self.data_list[idx % len(self.data_list)] |
|
dir_path, file_name = os.path.split(file_path) |
|
sequence_name = os.path.basename(os.path.dirname(dir_path)) |
|
frame_name = os.path.splitext(file_name)[0] |
|
data_name = f"{sequence_name}_{frame_name}" |
|
return data_name |
|
|
|
@staticmethod |
|
def get_learning_map(ignore_index): |
|
learning_map = { |
|
0: ignore_index, |
|
1: ignore_index, |
|
10: 0, |
|
11: 1, |
|
13: 4, |
|
15: 2, |
|
16: 4, |
|
18: 3, |
|
20: 4, |
|
30: 5, |
|
31: 6, |
|
32: 7, |
|
40: 8, |
|
44: 9, |
|
48: 10, |
|
49: 11, |
|
50: 12, |
|
51: 13, |
|
52: ignore_index, |
|
60: 8, |
|
70: 14, |
|
71: 15, |
|
72: 16, |
|
80: 17, |
|
81: 18, |
|
99: ignore_index, |
|
252: 0, |
|
253: 6, |
|
254: 5, |
|
255: 7, |
|
256: 4, |
|
257: 4, |
|
258: 3, |
|
259: 4, |
|
} |
|
return learning_map |
|
|
|
@staticmethod |
|
def get_learning_map_inv(ignore_index): |
|
learning_map_inv = { |
|
ignore_index: ignore_index, |
|
0: 10, |
|
1: 11, |
|
2: 15, |
|
3: 18, |
|
4: 20, |
|
5: 30, |
|
6: 31, |
|
7: 32, |
|
8: 40, |
|
9: 44, |
|
10: 48, |
|
11: 49, |
|
12: 50, |
|
13: 51, |
|
14: 70, |
|
15: 71, |
|
16: 72, |
|
17: 80, |
|
18: 81, |
|
} |
|
return learning_map_inv |
|
|