Spaces:
Sleeping
Sleeping
from lib.kits.basic import * | |
import webdataset as wds | |
from lib.data.datasets.skel_hmr2_fashion.image_dataset import ImageDataset | |
class MixedWebDataset(wds.WebDataset): | |
def __init__(self) -> None: | |
super(wds.WebDataset, self).__init__() | |
class DataModule(pl.LightningDataModule): | |
def __init__(self, name:str, cfg:DictConfig): | |
super().__init__() | |
self.name = name | |
self.cfg = cfg | |
self.cfg_eval = self.cfg.pop('eval', None) | |
self.cfg_train = self.cfg.pop('train', None) | |
def setup(self, stage=None): | |
if stage in ['test', None, '_debug_eval']: | |
self._setup_eval() | |
if stage in ['fit', None, '_debug_train']: | |
self._setup_train() | |
def train_dataloader(self): | |
return torch.utils.data.DataLoader( | |
dataset = self.train_dataset, | |
**self.cfg_train.dataloader, | |
) | |
def val_dataloader(self): | |
# Since we don't need validation here. | |
return self.test_dataloader() | |
def test_dataloader(self): | |
# return torch.utils.data.DataLoader( | |
# dataset = self.eval_datasets['LSP-EXTENDED'], # TODO: Support multiple datasets through ConcatDataset (but to figure out how to mix with weights) | |
# **self.cfg_eval.dataloader, | |
# ) | |
return torch.utils.data.DataLoader( | |
dataset = self.eval_datasets, # TODO: Support multiple datasets through ConcatDataset (but to figure out how to mix with weights) | |
**self.cfg_eval.dataloader, | |
) | |
# ========== Internal Modules to Setup Datasets ========== | |
def _setup_train(self): | |
hack_cfg = { | |
'IMAGE_SIZE': self.cfg.policy.img_patch_size, | |
'IMAGE_MEAN': self.cfg.policy.img_mean, | |
'IMAGE_STD' : self.cfg.policy.img_std, | |
'BBOX_SHAPE': None, | |
'augm': self.cfg.augm, | |
} | |
self.train_datasets = [] # [(dataset:Dataset, weight:float), ...] | |
datasets, weights = [], [] | |
opt = self.cfg_train.get('shared_ds_opt', {}) | |
for dataset_cfg in self.cfg_train.datasets: | |
cur_cfg = {**hack_cfg, **opt} | |
dataset = ImageDataset.load_tars_as_webdataset( | |
cfg = cur_cfg, | |
urls = dataset_cfg.item.urls, | |
train = True, | |
epoch_size = dataset_cfg.item.epoch_size, | |
) | |
weights.append(dataset_cfg.weight) | |
datasets.append(dataset) | |
weights = to_numpy(weights) | |
weights = weights / weights.sum() | |
self.train_dataset = MixedWebDataset() | |
self.train_dataset.append(wds.RandomMix(datasets, weights, longest=False)) | |
self.train_dataset = self.train_dataset.with_epoch(100_000).shuffle(4000) | |
def _setup_eval(self): | |
hack_cfg = { | |
'IMAGE_SIZE' : self.cfg.policy.img_patch_size, | |
'IMAGE_MEAN' : self.cfg.policy.img_mean, | |
'IMAGE_STD' : self.cfg.policy.img_std, | |
'BBOX_SHAPE' : [192, 256], | |
'augm' : self.cfg.augm, | |
} | |
self.eval_datasets = {} | |
opt = self.cfg_train.get('shared_ds_opt', {}) | |
for dataset_cfg in self.cfg_eval.datasets: | |
cur_cfg = {**hack_cfg, **opt} | |
dataset = ImageDataset( | |
cfg = hack_cfg, | |
dataset_file = dataset_cfg.item.dataset_file, | |
img_dir = dataset_cfg.item.img_root, | |
train = False, | |
) | |
dataset._kp_list_ = dataset_cfg.item.kp_list | |
self.eval_datasets[dataset_cfg.name] = dataset | |