File size: 7,681 Bytes
abd2a81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
import copy
import collections
import os.path as osp
import warnings
import re
import torch.utils.data
from .makedirs import makedirs
def to_list(x):
if not isinstance(x, collections.Iterable) or isinstance(x, str):
x = [x]
return x
def files_exist(files):
return all([osp.exists(f) for f in files])
def __repr__(obj):
if obj is None:
return 'None'
return re.sub('(<.*?)\\s.*(>)', r'\1\2', obj.__repr__())
class Dataset(torch.utils.data.Dataset):
r"""Dataset base class for creating datasets with pre-processing.
Based on Dataset from pytorch-geometric: see `here <https://pytorch-geometric.readthedocs.io/en/latest/notes/
create_dataset.html>`__ for the accompanying tutorial.
Args:
root (string, optional): Root directory where the dataset should be
saved. (optional: :obj:`None`)
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
(default: :obj:`None`)
pre_transform (callable, optional): A function/transform that takes in
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
pre_filter (callable, optional): A function that takes in an
:obj:`torch_geometric.data.Data` object and returns a boolean
value, indicating whether the data object should be included in the
final dataset. (default: :obj:`None`)
"""
@property
def raw_file_names(self):
r"""The name of the files to find in the :obj:`self.raw_dir` folder in
order to skip the download."""
raise NotImplementedError
@property
def processed_file_names(self):
r"""The name of the files to find in the :obj:`self.processed_dir`
folder in order to skip the processing."""
raise NotImplementedError
def download(self):
r"""Downloads the dataset to the :obj:`self.raw_dir` folder."""
raise NotImplementedError
def process(self):
r"""Processes the dataset to the :obj:`self.processed_dir` folder."""
raise NotImplementedError
def len(self):
raise NotImplementedError
def get(self, idx):
r"""Gets the data object at index :obj:`idx`."""
raise NotImplementedError
def __init__(self, root=None, transform=None, pre_transform=None,
pre_filter=None):
super(Dataset, self).__init__()
if isinstance(root, str):
root = osp.expanduser(osp.normpath(root))
self.root = root
self.transform = transform
self.pre_transform = pre_transform
self.pre_filter = pre_filter
self.__indices__ = None
if 'download' in self.__class__.__dict__.keys():
self._download()
if 'process' in self.__class__.__dict__.keys():
self._process()
def indices(self):
if self.__indices__ is not None:
return self.__indices__
else:
return range(len(self))
@property
def raw_dir(self):
return osp.join(self.root, 'raw')
@property
def processed_dir(self):
return osp.join(self.root, 'processed')
@property
def raw_paths(self):
r"""The filepaths to find in order to skip the download."""
files = to_list(self.raw_file_names)
return [osp.join(self.raw_dir, f) for f in files]
@property
def processed_paths(self):
r"""The filepaths to find in the :obj:`self.processed_dir`
folder in order to skip the processing."""
files = to_list(self.processed_file_names)
return [osp.join(self.processed_dir, f) for f in files]
def _download(self):
if files_exist(self.raw_paths): # pragma: no cover
return
makedirs(self.raw_dir)
self.download()
def _process(self):
f = osp.join(self.processed_dir, 'pre_transform.pt')
if osp.exists(f) and torch.load(f) != __repr__(self.pre_transform):
warnings.warn(
'The `pre_transform` argument differs from the one used in '
'the pre-processed version of this dataset. If you really '
'want to make use of another pre-processing technique, make '
'sure to delete `{}` first.'.format(self.processed_dir))
f = osp.join(self.processed_dir, 'pre_filter.pt')
if osp.exists(f) and torch.load(f) != __repr__(self.pre_filter):
warnings.warn(
'The `pre_filter` argument differs from the one used in the '
'pre-processed version of this dataset. If you really want to '
'make use of another pre-fitering technique, make sure to '
'delete `{}` first.'.format(self.processed_dir))
if files_exist(self.processed_paths): # pragma: no cover
return
print('Processing...')
makedirs(self.processed_dir)
self.process()
path = osp.join(self.processed_dir, 'pre_transform.pt')
torch.save(__repr__(self.pre_transform), path)
path = osp.join(self.processed_dir, 'pre_filter.pt')
torch.save(__repr__(self.pre_filter), path)
print('Done!')
def __len__(self):
r"""The number of examples in the dataset."""
if self.__indices__ is not None:
return len(self.__indices__)
return self.len()
def __getitem__(self, idx):
r"""Gets the data object at index :obj:`idx` and transforms it (in case
a :obj:`self.transform` is given).
In case :obj:`idx` is a slicing object, *e.g.*, :obj:`[2:5]`, a list, a
tuple, a LongTensor or a BoolTensor, will return a subset of the
dataset at the specified indices."""
if isinstance(idx, int):
data = self.get(self.indices()[idx])
data = data if self.transform is None else self.transform(data)
return data
else:
return self.index_select(idx)
def index_select(self, idx):
indices = self.indices()
if isinstance(idx, slice):
indices = indices[idx]
elif torch.is_tensor(idx):
if idx.dtype == torch.long:
return self.index_select(idx.tolist())
elif idx.dtype == torch.bool or idx.dtype == torch.uint8:
return self.index_select(idx.nonzero().flatten().tolist())
elif isinstance(idx, list) or isinstance(idx, tuple):
indices = [indices[i] for i in idx]
else:
raise IndexError(
'Only integers, slices (`:`), list, tuples, and long or bool '
'tensors are valid indices (got {}).'.format(
type(idx).__name__))
dataset = copy.copy(self)
dataset.__indices__ = indices
return dataset
def shuffle(self, return_perm=False):
r"""Randomly shuffles the examples in the dataset.
Args:
return_perm (bool, optional): If set to :obj:`True`, will
additionally return the random permutation used to shuffle the
dataset. (default: :obj:`False`)
"""
perm = torch.randperm(len(self))
dataset = self.index_select(perm)
return (dataset, perm) if return_perm is True else dataset
def __repr__(self): # pragma: no cover
return f'{self.__class__.__name__}({len(self)})'
|