File size: 4,228 Bytes
f670afc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import importlib

import torch
import torch.distributed as dist

from imaginaire.utils.distributed import master_only_print as print


def _get_train_and_val_dataset_objects(cfg):
    r"""Return dataset objects for the training and validation sets.

    Args:
        cfg (obj): Global configuration file.

    Returns:
        (dict):
          - train_dataset (obj): PyTorch training dataset object.
          - val_dataset (obj): PyTorch validation dataset object.
    """
    dataset_module = importlib.import_module(cfg.data.type)
    train_dataset = dataset_module.Dataset(cfg, is_inference=False)
    if hasattr(cfg.data.val, 'type'):
        for key in ['type', 'input_types', 'input_image']:
            setattr(cfg.data, key, getattr(cfg.data.val, key))
        dataset_module = importlib.import_module(cfg.data.type)
    val_dataset = dataset_module.Dataset(cfg, is_inference=True)
    print('Train dataset length:', len(train_dataset))
    print('Val dataset length:', len(val_dataset))
    return train_dataset, val_dataset


def _get_data_loader(cfg, dataset, batch_size, not_distributed=False,
                     shuffle=True, drop_last=True, seed=0):
    r"""Return data loader .

    Args:
        cfg (obj): Global configuration file.
        dataset (obj): PyTorch dataset object.
        batch_size (int): Batch size.
        not_distributed (bool): Do not use distributed samplers.

    Return:
        (obj): Data loader.
    """
    not_distributed = not_distributed or not dist.is_initialized()
    if not_distributed:
        sampler = None
    else:
        sampler = torch.utils.data.distributed.DistributedSampler(dataset, seed=seed)
    num_workers = getattr(cfg.data, 'num_workers', 8)
    persistent_workers = getattr(cfg.data, 'persistent_workers', False)
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle and (sampler is None),
        sampler=sampler,
        pin_memory=True,
        num_workers=num_workers,
        drop_last=drop_last,
        persistent_workers=persistent_workers if num_workers > 0 else False
    )
    return data_loader


def get_train_and_val_dataloader(cfg, seed=0):
    r"""Return dataset objects for the training and validation sets.

    Args:
        cfg (obj): Global configuration file.

    Returns:
        (dict):
          - train_data_loader (obj): Train data loader.
          - val_data_loader (obj): Val data loader.
    """
    train_dataset, val_dataset = _get_train_and_val_dataset_objects(cfg)
    train_data_loader = _get_data_loader(cfg, train_dataset, cfg.data.train.batch_size, drop_last=True, seed=seed)
    not_distributed = getattr(cfg.data, 'val_data_loader_not_distributed', False)
    not_distributed = 'video' in cfg.data.type or not_distributed
    val_data_loader = _get_data_loader(
        cfg, val_dataset, cfg.data.val.batch_size, not_distributed,
        shuffle=False, drop_last=getattr(cfg.data.val, 'drop_last', False), seed=seed)
    return train_data_loader, val_data_loader


def _get_test_dataset_object(cfg):
    r"""Return dataset object for the test set

    Args:
        cfg (obj): Global configuration file.

    Returns:
        (obj): PyTorch dataset object.
    """
    dataset_module = importlib.import_module(cfg.test_data.type)
    test_dataset = dataset_module.Dataset(cfg, is_inference=True, is_test=True)
    return test_dataset


def get_test_dataloader(cfg):
    r"""Return dataset objects for testing

    Args:
        cfg (obj): Global configuration file.

    Returns:
        (obj): Val data loader. It may not contain the ground truth.
    """
    test_dataset = _get_test_dataset_object(cfg)
    not_distributed = getattr(
        cfg.test_data, 'val_data_loader_not_distributed', False)
    not_distributed = 'video' in cfg.test_data.type or not_distributed
    test_data_loader = _get_data_loader(
        cfg, test_dataset, cfg.test_data.test.batch_size, not_distributed,
        shuffle=False)
    return test_data_loader