File size: 4,459 Bytes
f670afc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import random

from imaginaire.datasets.base import BaseDataset


class Dataset(BaseDataset):
    r"""Unpaired image dataset for use in MUNIT.

    Args:
        cfg (Config): Loaded config object.
        is_inference (bool): In train or inference mode?
    """

    def __init__(self, cfg, is_inference=False, is_test=False):
        self.paired = False
        super(Dataset, self).__init__(cfg, is_inference, is_test)

    def _create_mapping(self):
        r"""Creates mapping from idx to key in LMDB.

        Returns:
            (tuple):
              - self.mapping (dict): Dict with data type as key mapping idx to
              LMDB key.
              - self.epoch_length (int): Number of samples in an epoch.
        """
        idx_to_key = {}
        for lmdb_idx, sequence_list in enumerate(self.sequence_lists):
            for data_type, data_type_sequence_list in sequence_list.items():
                if data_type not in idx_to_key:
                    idx_to_key[data_type] = []
                for sequence_name, filenames in data_type_sequence_list.items():
                    for filename in filenames:
                        idx_to_key[data_type].append({
                            'lmdb_root': self.lmdb_roots[lmdb_idx],
                            'lmdb_idx': lmdb_idx,
                            'sequence_name': sequence_name,
                            'filename': filename,
                        })
        self.mapping = idx_to_key
        self.epoch_length = max([len(lmdb_keys)
                                 for _, lmdb_keys in self.mapping.items()])
        return self.mapping, self.epoch_length

    def _sample_keys(self, index):
        r"""Gets files to load for this sample.

        Args:
            index (int): Index in [0, len(dataset)].
        Returns:
            keys (dict): Each key of this dict is a data type.
                lmdb_key (dict):
                    lmdb_idx (int): Chosen LMDB dataset root.
                    sequence_name (str): Chosen sequence in chosen dataset.
                    filename (str): Chosen filename in chosen sequence.
        """
        keys = {}
        for data_type in self.data_types:
            lmdb_keys = self.mapping[data_type]
            if self.is_inference:
                # Modulo ensures valid indexing in case A and B have different
                # number of files.
                keys[data_type] = lmdb_keys[index % len(lmdb_keys)]
            else:
                keys[data_type] = random.choice(lmdb_keys)
        return keys

    def __getitem__(self, index):
        r"""Gets selected files.

        Args:
            index (int): Index into dataset.
            concat (bool): Concatenate all items in labels?
        Returns:
            data (dict): Dict with all chosen data_types.
        """
        # Select a sample from the available data.
        keys_per_data_type = self._sample_keys(index)

        # Get keys and lmdbs.
        keys, lmdbs = {}, {}
        for data_type in self.dataset_data_types:
            # Unpack keys.
            lmdb_idx = keys_per_data_type[data_type]['lmdb_idx']
            sequence_name = keys_per_data_type[data_type]['sequence_name']
            filename = keys_per_data_type[data_type]['filename']
            keys[data_type] = '%s/%s' % (sequence_name, filename)
            lmdbs[data_type] = self.lmdbs[data_type][lmdb_idx]

        # Load all data for this index.
        data = self.load_from_dataset(keys, lmdbs)

        # Apply ops pre augmentation.
        data = self.apply_ops(data, self.pre_aug_ops)

        # Do augmentations for images.
        data, is_flipped = self.perform_augmentation(data, paired=False, augment_ops=self.augmentor.augment_ops)

        # Apply ops post augmentation.
        data = self.apply_ops(data, self.post_aug_ops)
        data = self.apply_ops(data, self.full_data_post_aug_ops, full_data=True)

        # Convert images to tensor.
        data = self.to_tensor(data)

        # Remove any extra dimensions.
        for data_type in self.image_data_types:
            data[data_type] = data[data_type][0]

        # Package output.
        data['is_flipped'] = is_flipped
        data['key'] = keys_per_data_type

        return data