Spaces:
Running
on
L4
Running
on
L4
File size: 4,344 Bytes
5ac1897 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
from lib.kits.basic import *
from lib.data.datasets.hsmr_v1.mocap_dataset import MoCapDataset
from lib.data.datasets.hsmr_v1.wds_loader import load_tars_as_wds
import webdataset as wds
class MixedWebDataset(wds.WebDataset):
def __init__(self) -> None:
super(wds.WebDataset, self).__init__()
class DataModule(pl.LightningDataModule):
def __init__(self, name, cfg):
super().__init__()
self.name = name
self.cfg = cfg
self.cfg_eval = self.cfg.pop('eval', None)
self.cfg_train = self.cfg.pop('train', None)
self.cfg_mocap = self.cfg.pop('mocap', None)
def setup(self, stage=None):
if stage in ['test', None, '_debug_eval'] and self.cfg_eval is not None:
# get_logger().info('Evaluation dataset will be enabled.')
self._setup_eval()
if stage in ['fit', None, '_debug_train'] and self.cfg_train is not None:
# get_logger().info('Training dataset will be enabled.')
self._setup_train()
if stage in ['fit', None, '_debug_mocap'] and self.cfg_mocap is not None:
# get_logger().info('Mocap dataset will be enabled.')
self._setup_mocap()
def train_dataloader(self):
img_dataset = torch.utils.data.DataLoader(
dataset = self.train_dataset,
**self.cfg_train.dataloader,
)
ret = {'img_ds' : img_dataset }
if self.cfg_mocap is not None:
mocap_dataset = torch.utils.data.DataLoader(
dataset = self.mocap_dataset,
**self.cfg_mocap.dataloader,
)
ret['mocap_ds'] = mocap_dataset
return ret
# ========== Internal Modules to Setup Datasets ==========
def _setup_train(self):
names, datasets, weights = [], [], []
ld_cfg = self.cfg_train.cfg # cfg for initializing wds loading pipeline
for ds_cfg in self.cfg_train.datasets:
dataset = load_tars_as_wds(
ld_cfg,
ds_cfg.item.urls,
ds_cfg.item.epoch_size
)
names.append(ds_cfg.name)
datasets.append(dataset)
weights.append(ds_cfg.weight)
# get_logger().info(f"Dataset '{ds_cfg.name}' loaded.")
# Normalize the weights and mix the datasets.
weights = to_numpy(weights)
weights = weights / weights.sum()
self.train_datasets = datasets
self.train_dataset = MixedWebDataset()
self.train_dataset.append(wds.RandomMix(datasets, weights))
self.train_dataset = self.train_dataset.with_epoch(50_000).shuffle(1000, initial=1000)
def _setup_mocap(self):
self.mocap_dataset = MoCapDataset(**self.cfg_mocap.cfg)
def _setup_eval(self, selected_ds_names:Optional[List[str]]=None):
from lib.data.datasets.skel_hmr2_fashion.image_dataset import ImageDataset
hack_cfg = {
'IMAGE_SIZE' : 256,
'IMAGE_MEAN' : [0.485, 0.456, 0.406],
'IMAGE_STD' : [0.229, 0.224, 0.225],
'BBOX_SHAPE' : [192, 256],
'augm' : self.cfg.image_augmentation,
'SUPPRESS_KP_CONF_THRESH' : 0.3,
'FILTER_NUM_KP' : 4,
'FILTER_NUM_KP_THRESH' : 0.0,
'FILTER_REPROJ_THRESH' : 31000,
'SUPPRESS_BETAS_THRESH' : 3.0,
'SUPPRESS_BAD_POSES' : False,
'POSES_BETAS_SIMULTANEOUS': True,
'FILTER_NO_POSES' : False,
'BETAS_REG' : True,
}
self.eval_datasets = {}
for dataset_cfg in self.cfg_eval.datasets:
if selected_ds_names is not None and dataset_cfg.name not in selected_ds_names:
continue
dataset = ImageDataset(
cfg = hack_cfg,
dataset_file = dataset_cfg.item.dataset_file,
img_dir = dataset_cfg.item.img_root,
train = False,
)
dataset._kp_list_ = dataset_cfg.item.kp_list
self.eval_datasets[dataset_cfg.name] = dataset
|