File size: 3,154 Bytes
f670afc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md

from imaginaire.datasets.paired_videos import Dataset as VideoDataset


class Dataset(VideoDataset):
    r"""Paired image dataset for use in pix2pixHD, SPADE.

    Args:
        cfg (Config): Loaded config object.
        is_inference (bool): In train or inference mode?
    """

    def __init__(self, cfg, is_inference=False, is_test=False):
        self.paired = True
        super(Dataset, self).__init__(cfg, is_inference,
                                      sequence_length=1,
                                      is_test=is_test)
        self.is_video_dataset = False

    def _create_mapping(self):
        r"""Creates mapping from idx to key in LMDB.

        Returns:
            (tuple):
              - self.mapping (list): List mapping idx to key.
              - self.epoch_length (int): Number of samples in an epoch.
        """
        idx_to_key = []
        for lmdb_idx, sequence_list in enumerate(self.sequence_lists):
            for sequence_name, filenames in sequence_list.items():
                for filename in filenames:
                    idx_to_key.append({
                        'lmdb_root': self.lmdb_roots[lmdb_idx],
                        'lmdb_idx': lmdb_idx,
                        'sequence_name': sequence_name,
                        'filenames': [filename],
                    })
        self.mapping = idx_to_key
        self.epoch_length = len(self.mapping)
        return self.mapping, self.epoch_length

    def _sample_keys(self, index):
        r"""Gets files to load for this sample.

        Args:
            index (int): Index in [0, len(dataset)].
        Returns:
            key (dict):
              - lmdb_idx (int): Chosen LMDB dataset root.
              - sequence_name (str): Chosen sequence in chosen dataset.
              - filenames (list of str): Chosen filenames in chosen sequence.
        """
        assert self.sequence_length == 1, \
            'Image dataset can only have sequence length = 1, not %d' % (
                self.sequence_length)
        return self.mapping[index]

    def set_sequence_length(self, sequence_length):
        r"""Set the length of sequence you want as output from dataloader.
        Ignore this as this is an image loader.

        Args:
            sequence_length (int): Length of output sequences.
        """
        pass

    def set_inference_sequence_idx(self, index):
        r"""Get frames from this sequence during inference.
        Overriden from super as this is not applicable for images.

        Args:
            index (int): Index of inference sequence.
        """
        raise RuntimeError('Image dataset does not have sequences.')

    def num_inference_sequences(self):
        r"""Number of sequences available for inference.
        Overriden from super as this is not applicable for images.

        Returns:
            (int)
        """
        raise RuntimeError('Image dataset does not have sequences.')