# Copyright (C) 2022-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). import os from datasets.transforms import get_pair_transforms from PIL import Image from torch.utils.data import Dataset def load_image(impath): return Image.open(impath) def load_pairs_from_cache_file(fname, root=""): assert os.path.isfile( fname ), "cannot parse pairs from {:s}, file does not exist".format(fname) with open(fname, "r") as fid: lines = fid.read().strip().splitlines() pairs = [ (os.path.join(root, l.split()[0]), os.path.join(root, l.split()[1])) for l in lines ] return pairs def load_pairs_from_list_file(fname, root=""): assert os.path.isfile( fname ), "cannot parse pairs from {:s}, file does not exist".format(fname) with open(fname, "r") as fid: lines = fid.read().strip().splitlines() pairs = [ (os.path.join(root, l + "_1.jpg"), os.path.join(root, l + "_2.jpg")) for l in lines if not l.startswith("#") ] return pairs def write_cache_file(fname, pairs, root=""): if len(root) > 0: if not root.endswith("/"): root += "/" assert os.path.isdir(root) s = "" for im1, im2 in pairs: if len(root) > 0: assert im1.startswith(root), im1 assert im2.startswith(root), im2 s += "{:s} {:s}\n".format(im1[len(root) :], im2[len(root) :]) with open(fname, "w") as fid: fid.write(s[:-1]) def parse_and_cache_all_pairs(dname, data_dir="./data/"): if dname == "habitat_release": dirname = os.path.join(data_dir, "habitat_release") assert os.path.isdir(dirname), ( "cannot find folder for habitat_release pairs: " + dirname ) cache_file = os.path.join(dirname, "pairs.txt") assert not os.path.isfile(cache_file), ( "cache file already exists: " + cache_file ) print("Parsing pairs for dataset: " + dname) pairs = [] for root, dirs, files in os.walk(dirname): if "val" in root: continue dirs.sort() pairs += [ ( os.path.join(root, f), os.path.join(root, f[: -len("_1.jpeg")] + "_2.jpeg"), ) for f in sorted(files) if f.endswith("_1.jpeg") ] print("Found {:,} pairs".format(len(pairs))) print("Writing cache to: " + cache_file) write_cache_file(cache_file, pairs, root=dirname) else: raise NotImplementedError("Unknown dataset: " + dname) def dnames_to_image_pairs(dnames, data_dir="./data/"): """ dnames: list of datasets with image pairs, separated by + """ all_pairs = [] for dname in dnames.split("+"): if dname == "habitat_release": dirname = os.path.join(data_dir, "habitat_release") assert os.path.isdir(dirname), ( "cannot find folder for habitat_release pairs: " + dirname ) cache_file = os.path.join(dirname, "pairs.txt") assert os.path.isfile(cache_file), ( "cannot find cache file for habitat_release pairs, please first create the cache file, see instructions. " + cache_file ) pairs = load_pairs_from_cache_file(cache_file, root=dirname) elif dname in ["ARKitScenes", "MegaDepth", "3DStreetView", "IndoorVL"]: dirname = os.path.join(data_dir, dname + "_crops") assert os.path.isdir( dirname ), "cannot find folder for {:s} pairs: {:s}".format(dname, dirname) list_file = os.path.join(dirname, "listing.txt") assert os.path.isfile( list_file ), "cannot find list file for {:s} pairs, see instructions. {:s}".format( dname, list_file ) pairs = load_pairs_from_list_file(list_file, root=dirname) print(" {:s}: {:,} pairs".format(dname, len(pairs))) all_pairs += pairs if "+" in dnames: print(" Total: {:,} pairs".format(len(all_pairs))) return all_pairs class PairsDataset(Dataset): def __init__( self, dnames, trfs="", totensor=True, normalize=True, data_dir="./data/" ): super().__init__() self.image_pairs = dnames_to_image_pairs(dnames, data_dir=data_dir) self.transforms = get_pair_transforms( transform_str=trfs, totensor=totensor, normalize=normalize ) def __len__(self): return len(self.image_pairs) def __getitem__(self, index): im1path, im2path = self.image_pairs[index] im1 = load_image(im1path) im2 = load_image(im2path) if self.transforms is not None: im1, im2 = self.transforms(im1, im2) return im1, im2 if __name__ == "__main__": import argparse parser = argparse.ArgumentParser( prog="Computing and caching list of pairs for a given dataset" ) parser.add_argument( "--data_dir", default="./data/", type=str, help="path where data are stored" ) parser.add_argument( "--dataset", default="habitat_release", type=str, help="name of the dataset" ) args = parser.parse_args() parse_and_cache_all_pairs(dname=args.dataset, data_dir=args.data_dir)