File size: 2,781 Bytes
5ac1897
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from lib.kits.basic import *

import webdataset as wds

from .utils import *
from .stream_pipelines import *

# This line is to fix the problem of "OSError: image file is truncated" when loading images.
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

def load_tars_as_wds(
    cfg        : DictConfig,
    urls       : Union[str, List[str]],
    resampled  : bool = False,
    epoch_size : int  = None,
    cache_dir  : str  = None,
    train      : bool = True,
):
    urls = expand_urls(urls)  # to list of URL strings

    dataset : wds.WebDataset = wds.WebDataset(
            urls,
            nodesplitter = wds.split_by_node,
            shardshuffle = True,
            resampled    = resampled,
            cache_dir    = cache_dir,
        )
    if train:
        dataset = dataset.shuffle(100)

    # A lot of processes to initialize the dataset. Check the pipeline generator function for more details.
    # The order of the pipeline is important, since some of the process are dependent on some previous ones.
    dataset = apply_corrupt_filter(dataset)
    dataset = dataset.decode('rgb8').rename(jpg='jpg;jpeg;png')
    dataset = apply_multi_ppl_splitter(dataset)
    dataset = apply_keys_adapter(dataset)  #* This adapter is only in HSMR's design, not in the baseline.
    dataset = apply_bad_pgt_params_nan_suppressor(dataset)
    dataset = apply_bad_pgt_params_kp2d_err_suppressor(dataset, cfg.get('suppress_pgt_params_kp2d_err_thresh', 0.0))
    dataset = apply_bad_pgt_params_pve_max_suppressor(dataset, cfg.get('suppress_pgt_params_pve_max_thresh', 0.0))
    dataset = apply_bad_kp_suppressor(dataset, cfg.get('suppress_kp_conf_thresh', 0.0))
    dataset = apply_bad_betas_suppressor(dataset, cfg.get('suppress_betas_thresh', 0.0))
    # dataset = apply_bad_pose_suppressor(dataset, cfg.get('suppress_pose_thresh', 0.0))  # Not used in baseline, so not implemented.
    dataset = apply_params_synchronizer(dataset, cfg.get('poses_betas_simultaneous', False))
    # dataset = apply_no_pose_filter(dataset, cfg.get('no_pose_filter', False))  # Not used in baseline, so not implemented.
    dataset = apply_insuff_kp_filter(dataset, cfg.get('filter_insufficient_kp_cnt', 4), cfg.get('suppress_insufficient_kp_thresh', 0.0))
    dataset = apply_bbox_size_filter(dataset, cfg.get('filter_bbox_size_thresh', None))
    dataset = apply_reproj_err_filter(dataset, cfg.get('filter_reproj_err_thresh', 0.0))
    dataset = apply_invalid_betas_regularizer(dataset, cfg.get('regularize_invalid_betas', False))

    # Final preprocess / format of the data. (Consider to extract the augmentation process.)
    dataset = apply_example_formatter(dataset, cfg)

    if epoch_size is not None:
        dataset = dataset.with_epoch(epoch_size)
    return dataset