IsshikiHugh's picture
feat: CPU demo
5ac1897
from lib.kits.basic import *
import webdataset as wds
from lib.data.datasets.skel_hmr2_fashion.image_dataset import ImageDataset
class MixedWebDataset(wds.WebDataset):
def __init__(self) -> None:
super(wds.WebDataset, self).__init__()
class DataModule(pl.LightningDataModule):
def __init__(self, name:str, cfg:DictConfig):
super().__init__()
self.name = name
self.cfg = cfg
self.cfg_eval = self.cfg.pop('eval', None)
self.cfg_train = self.cfg.pop('train', None)
def setup(self, stage=None):
if stage in ['test', None, '_debug_eval']:
self._setup_eval()
if stage in ['fit', None, '_debug_train']:
self._setup_train()
def train_dataloader(self):
return torch.utils.data.DataLoader(
dataset = self.train_dataset,
**self.cfg_train.dataloader,
)
def val_dataloader(self):
# Since we don't need validation here.
return self.test_dataloader()
def test_dataloader(self):
# return torch.utils.data.DataLoader(
# dataset = self.eval_datasets['LSP-EXTENDED'], # TODO: Support multiple datasets through ConcatDataset (but to figure out how to mix with weights)
# **self.cfg_eval.dataloader,
# )
return torch.utils.data.DataLoader(
dataset = self.eval_datasets, # TODO: Support multiple datasets through ConcatDataset (but to figure out how to mix with weights)
**self.cfg_eval.dataloader,
)
# ========== Internal Modules to Setup Datasets ==========
def _setup_train(self):
hack_cfg = {
'IMAGE_SIZE': self.cfg.policy.img_patch_size,
'IMAGE_MEAN': self.cfg.policy.img_mean,
'IMAGE_STD' : self.cfg.policy.img_std,
'BBOX_SHAPE': None,
'augm': self.cfg.augm,
}
self.train_datasets = [] # [(dataset:Dataset, weight:float), ...]
datasets, weights = [], []
opt = self.cfg_train.get('shared_ds_opt', {})
for dataset_cfg in self.cfg_train.datasets:
cur_cfg = {**hack_cfg, **opt}
dataset = ImageDataset.load_tars_as_webdataset(
cfg = cur_cfg,
urls = dataset_cfg.item.urls,
train = True,
epoch_size = dataset_cfg.item.epoch_size,
)
weights.append(dataset_cfg.weight)
datasets.append(dataset)
weights = to_numpy(weights)
weights = weights / weights.sum()
self.train_dataset = MixedWebDataset()
self.train_dataset.append(wds.RandomMix(datasets, weights, longest=False))
self.train_dataset = self.train_dataset.with_epoch(100_000).shuffle(4000)
def _setup_eval(self):
hack_cfg = {
'IMAGE_SIZE' : self.cfg.policy.img_patch_size,
'IMAGE_MEAN' : self.cfg.policy.img_mean,
'IMAGE_STD' : self.cfg.policy.img_std,
'BBOX_SHAPE' : [192, 256],
'augm' : self.cfg.augm,
}
self.eval_datasets = {}
opt = self.cfg_train.get('shared_ds_opt', {})
for dataset_cfg in self.cfg_eval.datasets:
cur_cfg = {**hack_cfg, **opt}
dataset = ImageDataset(
cfg = hack_cfg,
dataset_file = dataset_cfg.item.dataset_file,
img_dir = dataset_cfg.item.img_root,
train = False,
)
dataset._kp_list_ = dataset_cfg.item.kp_list
self.eval_datasets[dataset_cfg.name] = dataset