from lib.kits.basic import * import webdataset as wds from .utils import * from .stream_pipelines import * # This line is to fix the problem of "OSError: image file is truncated" when loading images. from PIL import ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True def load_tars_as_wds( cfg : DictConfig, urls : Union[str, List[str]], resampled : bool = False, epoch_size : int = None, cache_dir : str = None, train : bool = True, ): urls = expand_urls(urls) # to list of URL strings dataset : wds.WebDataset = wds.WebDataset( urls, nodesplitter = wds.split_by_node, shardshuffle = True, resampled = resampled, cache_dir = cache_dir, ) if train: dataset = dataset.shuffle(100) # A lot of processes to initialize the dataset. Check the pipeline generator function for more details. # The order of the pipeline is important, since some of the process are dependent on some previous ones. dataset = apply_corrupt_filter(dataset) dataset = dataset.decode('rgb8').rename(jpg='jpg;jpeg;png') dataset = apply_multi_ppl_splitter(dataset) dataset = apply_keys_adapter(dataset) #* This adapter is only in HSMR's design, not in the baseline. dataset = apply_bad_pgt_params_nan_suppressor(dataset) dataset = apply_bad_pgt_params_kp2d_err_suppressor(dataset, cfg.get('suppress_pgt_params_kp2d_err_thresh', 0.0)) dataset = apply_bad_pgt_params_pve_max_suppressor(dataset, cfg.get('suppress_pgt_params_pve_max_thresh', 0.0)) dataset = apply_bad_kp_suppressor(dataset, cfg.get('suppress_kp_conf_thresh', 0.0)) dataset = apply_bad_betas_suppressor(dataset, cfg.get('suppress_betas_thresh', 0.0)) # dataset = apply_bad_pose_suppressor(dataset, cfg.get('suppress_pose_thresh', 0.0)) # Not used in baseline, so not implemented. dataset = apply_params_synchronizer(dataset, cfg.get('poses_betas_simultaneous', False)) # dataset = apply_no_pose_filter(dataset, cfg.get('no_pose_filter', False)) # Not used in baseline, so not implemented. dataset = apply_insuff_kp_filter(dataset, cfg.get('filter_insufficient_kp_cnt', 4), cfg.get('suppress_insufficient_kp_thresh', 0.0)) dataset = apply_bbox_size_filter(dataset, cfg.get('filter_bbox_size_thresh', None)) dataset = apply_reproj_err_filter(dataset, cfg.get('filter_reproj_err_thresh', 0.0)) dataset = apply_invalid_betas_regularizer(dataset, cfg.get('regularize_invalid_betas', False)) # Final preprocess / format of the data. (Consider to extract the augmentation process.) dataset = apply_example_formatter(dataset, cfg) if epoch_size is not None: dataset = dataset.with_epoch(epoch_size) return dataset