sengerchen's picture
Upload folder using huggingface_hub
1bb1365 verified
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
import os
from datasets.transforms import get_pair_transforms
from PIL import Image
from torch.utils.data import Dataset
def load_image(impath):
return Image.open(impath)
def load_pairs_from_cache_file(fname, root=""):
assert os.path.isfile(
fname
), "cannot parse pairs from {:s}, file does not exist".format(fname)
with open(fname, "r") as fid:
lines = fid.read().strip().splitlines()
pairs = [
(os.path.join(root, l.split()[0]), os.path.join(root, l.split()[1]))
for l in lines
]
return pairs
def load_pairs_from_list_file(fname, root=""):
assert os.path.isfile(
fname
), "cannot parse pairs from {:s}, file does not exist".format(fname)
with open(fname, "r") as fid:
lines = fid.read().strip().splitlines()
pairs = [
(os.path.join(root, l + "_1.jpg"), os.path.join(root, l + "_2.jpg"))
for l in lines
if not l.startswith("#")
]
return pairs
def write_cache_file(fname, pairs, root=""):
if len(root) > 0:
if not root.endswith("/"):
root += "/"
assert os.path.isdir(root)
s = ""
for im1, im2 in pairs:
if len(root) > 0:
assert im1.startswith(root), im1
assert im2.startswith(root), im2
s += "{:s} {:s}\n".format(im1[len(root) :], im2[len(root) :])
with open(fname, "w") as fid:
fid.write(s[:-1])
def parse_and_cache_all_pairs(dname, data_dir="./data/"):
if dname == "habitat_release":
dirname = os.path.join(data_dir, "habitat_release")
assert os.path.isdir(dirname), (
"cannot find folder for habitat_release pairs: " + dirname
)
cache_file = os.path.join(dirname, "pairs.txt")
assert not os.path.isfile(cache_file), (
"cache file already exists: " + cache_file
)
print("Parsing pairs for dataset: " + dname)
pairs = []
for root, dirs, files in os.walk(dirname):
if "val" in root:
continue
dirs.sort()
pairs += [
(
os.path.join(root, f),
os.path.join(root, f[: -len("_1.jpeg")] + "_2.jpeg"),
)
for f in sorted(files)
if f.endswith("_1.jpeg")
]
print("Found {:,} pairs".format(len(pairs)))
print("Writing cache to: " + cache_file)
write_cache_file(cache_file, pairs, root=dirname)
else:
raise NotImplementedError("Unknown dataset: " + dname)
def dnames_to_image_pairs(dnames, data_dir="./data/"):
"""
dnames: list of datasets with image pairs, separated by +
"""
all_pairs = []
for dname in dnames.split("+"):
if dname == "habitat_release":
dirname = os.path.join(data_dir, "habitat_release")
assert os.path.isdir(dirname), (
"cannot find folder for habitat_release pairs: " + dirname
)
cache_file = os.path.join(dirname, "pairs.txt")
assert os.path.isfile(cache_file), (
"cannot find cache file for habitat_release pairs, please first create the cache file, see instructions. "
+ cache_file
)
pairs = load_pairs_from_cache_file(cache_file, root=dirname)
elif dname in ["ARKitScenes", "MegaDepth", "3DStreetView", "IndoorVL"]:
dirname = os.path.join(data_dir, dname + "_crops")
assert os.path.isdir(
dirname
), "cannot find folder for {:s} pairs: {:s}".format(dname, dirname)
list_file = os.path.join(dirname, "listing.txt")
assert os.path.isfile(
list_file
), "cannot find list file for {:s} pairs, see instructions. {:s}".format(
dname, list_file
)
pairs = load_pairs_from_list_file(list_file, root=dirname)
print(" {:s}: {:,} pairs".format(dname, len(pairs)))
all_pairs += pairs
if "+" in dnames:
print(" Total: {:,} pairs".format(len(all_pairs)))
return all_pairs
class PairsDataset(Dataset):
def __init__(
self, dnames, trfs="", totensor=True, normalize=True, data_dir="./data/"
):
super().__init__()
self.image_pairs = dnames_to_image_pairs(dnames, data_dir=data_dir)
self.transforms = get_pair_transforms(
transform_str=trfs, totensor=totensor, normalize=normalize
)
def __len__(self):
return len(self.image_pairs)
def __getitem__(self, index):
im1path, im2path = self.image_pairs[index]
im1 = load_image(im1path)
im2 = load_image(im2path)
if self.transforms is not None:
im1, im2 = self.transforms(im1, im2)
return im1, im2
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
prog="Computing and caching list of pairs for a given dataset"
)
parser.add_argument(
"--data_dir", default="./data/", type=str, help="path where data are stored"
)
parser.add_argument(
"--dataset", default="habitat_release", type=str, help="name of the dataset"
)
args = parser.parse_args()
parse_and_cache_all_pairs(dname=args.dataset, data_dir=args.data_dir)