File size: 5,478 Bytes
1bb1365
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).

import os

from datasets.transforms import get_pair_transforms
from PIL import Image
from torch.utils.data import Dataset


def load_image(impath):
    return Image.open(impath)


def load_pairs_from_cache_file(fname, root=""):
    assert os.path.isfile(
        fname
    ), "cannot parse pairs from {:s}, file does not exist".format(fname)
    with open(fname, "r") as fid:
        lines = fid.read().strip().splitlines()
    pairs = [
        (os.path.join(root, l.split()[0]), os.path.join(root, l.split()[1]))
        for l in lines
    ]
    return pairs


def load_pairs_from_list_file(fname, root=""):
    assert os.path.isfile(
        fname
    ), "cannot parse pairs from {:s}, file does not exist".format(fname)
    with open(fname, "r") as fid:
        lines = fid.read().strip().splitlines()
    pairs = [
        (os.path.join(root, l + "_1.jpg"), os.path.join(root, l + "_2.jpg"))
        for l in lines
        if not l.startswith("#")
    ]
    return pairs


def write_cache_file(fname, pairs, root=""):
    if len(root) > 0:
        if not root.endswith("/"):
            root += "/"
        assert os.path.isdir(root)
    s = ""
    for im1, im2 in pairs:
        if len(root) > 0:
            assert im1.startswith(root), im1
            assert im2.startswith(root), im2
        s += "{:s} {:s}\n".format(im1[len(root) :], im2[len(root) :])
    with open(fname, "w") as fid:
        fid.write(s[:-1])


def parse_and_cache_all_pairs(dname, data_dir="./data/"):
    if dname == "habitat_release":
        dirname = os.path.join(data_dir, "habitat_release")
        assert os.path.isdir(dirname), (
            "cannot find folder for habitat_release pairs: " + dirname
        )
        cache_file = os.path.join(dirname, "pairs.txt")
        assert not os.path.isfile(cache_file), (
            "cache file already exists: " + cache_file
        )

        print("Parsing pairs for dataset: " + dname)
        pairs = []
        for root, dirs, files in os.walk(dirname):
            if "val" in root:
                continue
            dirs.sort()
            pairs += [
                (
                    os.path.join(root, f),
                    os.path.join(root, f[: -len("_1.jpeg")] + "_2.jpeg"),
                )
                for f in sorted(files)
                if f.endswith("_1.jpeg")
            ]
        print("Found {:,} pairs".format(len(pairs)))
        print("Writing cache to: " + cache_file)
        write_cache_file(cache_file, pairs, root=dirname)

    else:
        raise NotImplementedError("Unknown dataset: " + dname)


def dnames_to_image_pairs(dnames, data_dir="./data/"):
    """
    dnames: list of datasets with image pairs, separated by +
    """
    all_pairs = []
    for dname in dnames.split("+"):
        if dname == "habitat_release":
            dirname = os.path.join(data_dir, "habitat_release")
            assert os.path.isdir(dirname), (
                "cannot find folder for habitat_release pairs: " + dirname
            )
            cache_file = os.path.join(dirname, "pairs.txt")
            assert os.path.isfile(cache_file), (
                "cannot find cache file for habitat_release pairs, please first create the cache file, see instructions. "
                + cache_file
            )
            pairs = load_pairs_from_cache_file(cache_file, root=dirname)
        elif dname in ["ARKitScenes", "MegaDepth", "3DStreetView", "IndoorVL"]:
            dirname = os.path.join(data_dir, dname + "_crops")
            assert os.path.isdir(
                dirname
            ), "cannot find folder for {:s} pairs: {:s}".format(dname, dirname)
            list_file = os.path.join(dirname, "listing.txt")
            assert os.path.isfile(
                list_file
            ), "cannot find list file for {:s} pairs, see instructions. {:s}".format(
                dname, list_file
            )
            pairs = load_pairs_from_list_file(list_file, root=dirname)
        print("  {:s}: {:,} pairs".format(dname, len(pairs)))
        all_pairs += pairs
    if "+" in dnames:
        print(" Total: {:,} pairs".format(len(all_pairs)))
    return all_pairs


class PairsDataset(Dataset):
    def __init__(
        self, dnames, trfs="", totensor=True, normalize=True, data_dir="./data/"
    ):
        super().__init__()
        self.image_pairs = dnames_to_image_pairs(dnames, data_dir=data_dir)
        self.transforms = get_pair_transforms(
            transform_str=trfs, totensor=totensor, normalize=normalize
        )

    def __len__(self):
        return len(self.image_pairs)

    def __getitem__(self, index):
        im1path, im2path = self.image_pairs[index]
        im1 = load_image(im1path)
        im2 = load_image(im2path)
        if self.transforms is not None:
            im1, im2 = self.transforms(im1, im2)
        return im1, im2


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(
        prog="Computing and caching list of pairs for a given dataset"
    )
    parser.add_argument(
        "--data_dir", default="./data/", type=str, help="path where data are stored"
    )
    parser.add_argument(
        "--dataset", default="habitat_release", type=str, help="name of the dataset"
    )
    args = parser.parse_args()
    parse_and_cache_all_pairs(dname=args.dataset, data_dir=args.data_dir)