็™ฝ้นญๅ…ˆ็”Ÿ
init
abd2a81
raw
history blame
26.8 kB
import fnmatch
import os.path
import pathlib
import sys
import time
import shapely.geometry
import multiprocess
import itertools
import skimage.io
import numpy as np
from tqdm import tqdm
import torch
import torch.utils.data
import torchvision
from lydorn_utils import run_utils, image_utils, polygon_utils, geo_utils
from lydorn_utils import print_utils
from lydorn_utils import python_utils
from torch_lydorn.torchvision.datasets import utils
CITY_METADATA_DICT = {
"bloomington": {
"fold": "test",
"pixelsize": 0.3,
"numbers": list(range(1, 37)),
"mean": [0.44583929, 0.46205078, 0.35783887],
"std": [0.18212699, 0.17152641, 0.16157062],
},
"bellingham": {
"fold": "test",
"pixelsize": 0.3,
"numbers": list(range(1, 37)),
"mean": [0.3766195, 0.391402, 0.32659722],
"std": [0.18134978, 0.16412577, 0.16369793],
},
"innsbruck": {
"fold": "test",
"pixelsize": 0.3,
"numbers": list(range(1, 37)),
"mean": [0.41375683, 0.41818116, 0.38940192],
"std": [0.16616156, 0.14364722, 0.13317743],
},
"sfo": {
"fold": "test",
"pixelsize": 0.3,
"numbers": list(range(1, 37)),
"mean": [0.59388761, 0.61522012, 0.54348289],
"std": [0.25730708, 0.23301019, 0.23707742],
},
"tyrol-e": {
"fold": "test",
"pixelsize": 0.3,
"numbers": list(range(1, 37)),
"mean": [0.44171042, 0.48147037, 0.44642358],
"std": [0.1808623, 0.15437789, 0.15102051],
},
"austin": {
"fold": "train",
"pixelsize": 0.3,
"numbers": list(range(1, 37)),
"mean": [0.39584444, 0.40599795, 0.38298687],
"std": [0.17341954, 0.16856597, 0.16360443],
},
"chicago": {
"fold": "train",
"pixelsize": 0.3,
"numbers": list(range(1, 37)),
"mean": [0.4055142, 0.42844002, 0.38229637],
"std": [0.2133328, 0.20827106, 0.20132315],
},
"kitsap": {
"fold": "train",
"pixelsize": 0.3,
"numbers": list(range(1, 37)),
"mean": [0.34717916, 0.37854108, 0.32571001],
"std": [0.17048794, 0.14537676, 0.13466496],
},
"tyrol-w": {
"fold": "train",
"pixelsize": 0.3,
"numbers": list(range(1, 37)),
"mean": [0.39704218, 0.4545488, 0.4321427],
"std": [0.19484766, 0.1742585, 0.15186383],
},
"vienna": {
"fold": "train",
"pixelsize": 0.3,
"numbers": list(range(1, 37)),
"mean": [0.47861977, 0.46878486, 0.44043111],
"std": [0.22614806, 0.19949128, 0.19524506],
},
}
IMAGE_DIRNAME = "images"
IMAGE_NAME_FORMAT = "{city}{number}"
IMAGE_FILENAME_FORMAT = IMAGE_NAME_FORMAT + ".tif" # City name, number
class InriaAerial(torch.utils.data.Dataset):
"""
Inria Aerial Image Dataset
"""
def __init__(self, root: str, fold: str="train", pre_process: bool=True, tile_filter=None, patch_size: int=None, patch_stride: int=None,
pre_transform=None, transform=None, small: bool=False, pool_size: int=1, raw_dirname: str="raw", processed_dirname: str="processed",
gt_source: str="disk", gt_type: str="npy", gt_dirname: str="gt_polygons", mask_only: bool=False):
"""
@param root:
@param fold:
@param pre_process: If True, the dataset will be pre-processed first, saving training patches on disk. If False, data will be serve on-the-fly without any patching.
@param tile_filter: Function to call on tile_info, if returns True, include that tile. If returns False, exclude that tile. Does not affect pre-processing.
@param patch_size:
@param patch_stride:
@param pre_transform:
@param transform:
@param small: If True, use a small subset of the dataset (for testing)
@param pool_size:
@param processed_dirname:
@param gt_source: Can be "disk" for annotation that are on disk or "osm" to download from OSM (not implemented)
@param gt_type: Type of annotation files on disk: can be "npy", "geojson" or "tif"
@param gt_dirname: Name of directory with annotation files
@param mask_only: If True, discard the RGB image, sample's "image" field is a single-channel binary mask of the polygons and there is no ground truth segmentation.
This is to allow learning only the frame field from binary masks in order to polygonize binary masks
"""
assert gt_source in {"disk", "osm"}, "gt_source should be disk or osm"
assert gt_type in {"npy", "geojson", "tif"}, f"gt_type should be npy, geojson or tif, not {gt_type}"
self.root = root
self.fold = fold
self.pre_process = pre_process
self.tile_filter = tile_filter
self.patch_size = patch_size
self.patch_stride = patch_stride
self.pre_transform = pre_transform
self.transform = transform
self.small = small
if self.small:
print_utils.print_info("INFO: Using small version of the Inria dataset.")
self.pool_size = pool_size
self.raw_dirname = raw_dirname
self.gt_source = gt_source
self.gt_type = gt_type
self.gt_dirname = gt_dirname
self.mask_only = mask_only
# Fill default values
if self.gt_source == "disk":
print_utils.print_info("INFO: annotations will be loaded from disk")
elif self.gt_source == "osm":
print_utils.print_info("INFO: annotations will be downloaded from OSM. "
"Make sure you have an internet connection to the OSM servers!")
if self.pre_process:
# Setup of pre-process
processed_dirname_extention = f"{processed_dirname}.source_{self.gt_source}.type_{self.gt_type}"
if self.gt_dirname is not None:
processed_dirname_extention += f".dirname_{self.gt_dirname}"
if self.mask_only:
processed_dirname_extention += f".mask_only_{int(self.mask_only)}"
processed_dirname_extention += f".patch_size_{int(self.patch_size)}"
self.processed_dirpath = os.path.join(self.root, processed_dirname_extention, self.fold)
self.stats_filepath = os.path.join(self.processed_dirpath, "stats-small.pt" if self.small else "stats.pt")
self.processed_flag_filepath = os.path.join(self.processed_dirpath,
"processed_flag-small" if self.small else "processed_flag")
# Check if dataset has finished pre-processing by checking flag:
if os.path.exists(self.processed_flag_filepath):
# Process done, load stats
self.stats = torch.load(self.stats_filepath)
else:
# Pre-process not finished, launch it:
tile_info_list = self.get_tile_info_list(tile_filter=None)
self.stats = self.process(tile_info_list)
# Save stats
torch.save(self.stats, self.stats_filepath)
# Mark dataset as processed with flag
pathlib.Path(self.processed_flag_filepath).touch()
# Get processed_relative_paths with filter
tile_info_list = self.get_tile_info_list(tile_filter=self.tile_filter)
self.processed_relative_paths = self.get_processed_relative_paths(tile_info_list)
else:
# Setup data sample list
self.tile_info_list = self.get_tile_info_list(tile_filter=self.tile_filter)
def get_tile_info_list(self, tile_filter=None):
tile_info_list = []
for city, info in CITY_METADATA_DICT.items():
if not info["fold"] == self.fold:
continue
if self.small:
numbers = [*info["numbers"][:5], info["numbers"][-1]]
else:
numbers = info["numbers"]
for number in numbers:
image_info = {
"city": city,
"number": number,
"pixelsize": info["pixelsize"],
"mean": np.array(info["mean"]),
"std": np.array(info["std"]),
}
tile_info_list.append(image_info)
if tile_filter is not None:
tile_info_list = list(filter(self.tile_filter, tile_info_list))
return tile_info_list
def get_processed_relative_paths(self, tile_info_list):
processed_relative_paths = []
for tile_info in tile_info_list:
processed_tile_relative_dirpath = os.path.join(tile_info['city'], f"{tile_info['number']:02d}")
processed_tile_dirpath = os.path.join(self.processed_dirpath, processed_tile_relative_dirpath)
sample_filenames = fnmatch.filter(os.listdir(processed_tile_dirpath), "data.*.pt")
processed_tile_relative_paths = [os.path.join(processed_tile_relative_dirpath, sample_filename) for sample_filename
in sample_filenames]
processed_relative_paths.extend(processed_tile_relative_paths)
return sorted(processed_relative_paths)
def process(self, tile_info_list):
# os.makedirs(os.path.join(self.root, self.processed_dirname), exist_ok=True)
with multiprocess.Pool(self.pool_size) as p:
stats_all = list(
tqdm(p.imap(self._process_one, tile_info_list), total=len(tile_info_list), desc="Process"))
stats = {}
if not self.mask_only:
stats_all = list(filter(None.__ne__, stats_all))
stat_lists = {}
for stats_one in stats_all:
for key, stat in stats_one.items():
if key in stat_lists:
stat_lists[key].append(stat)
else:
stat_lists[key] = [stat]
# Aggregate stats
if "class_freq" in stat_lists and "num" in stat_lists:
class_freq_array = np.stack(stat_lists["class_freq"], axis=0)
num_array = np.stack(stat_lists["num"], axis=0)
if num_array.min() == 0:
raise ZeroDivisionError("num_array has some zeros values, cannot divide!")
stats["class_freq"] = np.sum(class_freq_array*num_array[:, None], axis=0) / np.sum(num_array)
return stats
def load_raw_data(self, tile_info):
raw_data = {}
# Image:
raw_data["image_filepath"] = os.path.join(self.root, self.raw_dirname, self.fold, IMAGE_DIRNAME,
IMAGE_FILENAME_FORMAT.format(city=tile_info["city"], number=tile_info["number"]))
raw_data["image"] = skimage.io.imread(raw_data["image_filepath"])
assert len(raw_data["image"].shape) == 3 and raw_data["image"].shape[2] == 3, f"image should have shape (H, W, 3), not {raw_data['image'].shape}..."
# Annotations:
if self.gt_source == "disk":
gt_base_filepath = os.path.join(self.root, self.raw_dirname, self.fold, self.gt_dirname,
IMAGE_NAME_FORMAT.format(city=tile_info["city"],
number=tile_info["number"]))
gt_filepath = gt_base_filepath + "." + self.gt_type
if not os.path.exists(gt_filepath):
raw_data["gt_polygons"] = []
return raw_data
if self.gt_type == "npy":
np_gt_polygons = np.load(gt_filepath, allow_pickle=True)
gt_polygons = []
for np_gt_polygon in np_gt_polygons:
try:
gt_polygons.append(shapely.geometry.Polygon(np_gt_polygon[:, ::-1]))
except ValueError:
# Invalid polygon, continue without it
continue
raw_data["gt_polygons"] = gt_polygons
elif self.gt_type == "geojson":
geojson = python_utils.load_json(gt_filepath)
raw_data["gt_polygons"] = list(shapely.geometry.shape(geojson))
elif self.gt_type == "tif":
raw_data["gt_polygons_image"] = skimage.io.imread(gt_filepath)[:, :, None]
assert len(raw_data["gt_polygons_image"].shape) == 3 and raw_data["gt_polygons_image"].shape[2] == 1, \
f"Mask should have shape (H, W, 1), not {raw_data['gt_polygons_image'].shape}..."
elif self.gt_source == "osm":
raise NotImplementedError(
"Downloading from OSM is not implemented (takes too long to download, better download to disk first...).")
# np_gt_polygons = geo_utils.get_polygons_from_osm(image_filepath, tag="building", ij_coords=False)
return raw_data
def _process_one(self, tile_info):
process_id = int(multiprocess.current_process().name[-1])
# print(f"\n--- {process_id} ---\n")
# --- Init
tile_name = IMAGE_NAME_FORMAT.format(city=tile_info["city"], number=tile_info["number"])
processed_tile_relative_dirpath = os.path.join(tile_info['city'], f"{tile_info['number']:02d}")
processed_tile_dirpath = os.path.join(self.processed_dirpath, processed_tile_relative_dirpath)
processed_flag_filepath = os.path.join(processed_tile_dirpath, "processed_flag")
stats_filepath = os.path.join(processed_tile_dirpath, "stats.pt")
os.makedirs(processed_tile_dirpath, exist_ok=True)
stats = {}
# --- Check if tile has been processed already
if os.path.exists(processed_flag_filepath):
if not self.mask_only:
stats = torch.load(stats_filepath)
return stats
# --- Read data:
raw_data = self.load_raw_data(tile_info)
# --- Patch tiles
if self.patch_size is not None:
patch_stride = self.patch_stride if self.patch_stride is not None else self.patch_size
patch_boundingboxes = image_utils.compute_patch_boundingboxes(raw_data["image"].shape[0:2],
stride=patch_stride,
patch_res=self.patch_size)
class_freq_list = []
for i, bbox in enumerate(tqdm(patch_boundingboxes, desc=f"Patching {tile_name}", leave=False, position=process_id)):
sample = {
"image_filepath": raw_data["image_filepath"],
"name": f"{tile_name}.rowmin_{bbox[0]}_colmin_{bbox[1]}_rowmax_{bbox[2]}_colmax_{bbox[3]}",
"bbox": bbox,
"city": tile_info["city"],
"number": tile_info["number"],
}
if self.gt_type == "npy" or self.gt_type == "geojson":
patch_gt_polygons = polygon_utils.patch_polygons(raw_data["gt_polygons"], minx=bbox[1], miny=bbox[0],
maxx=bbox[3], maxy=bbox[2])
sample["gt_polygons"] = patch_gt_polygons
elif self.gt_type == "tif":
patch_gt_mask = raw_data["gt_polygons_image"][bbox[0]:bbox[2], bbox[1]:bbox[3], :]
sample["gt_polygons_image"] = patch_gt_mask
sample["image"] = raw_data["image"][bbox[0]:bbox[2], bbox[1]:bbox[3], :]
sample = self.pre_transform(sample) # Needs "image" to infer shape even if mask_only is True
if self.mask_only:
del sample["image"] # Don't need RGB image anymore
relative_filepath = os.path.join(processed_tile_relative_dirpath, "data.{:06d}.pt".format(i))
filepath = os.path.join(self.processed_dirpath, relative_filepath)
torch.save(sample, filepath)
# Compute stats
if not self.mask_only:
if self.gt_type == "npy" or self.gt_type == "geojson":
class_freq_list.append(np.mean(sample["gt_polygons_image"], axis=(0, 1)) / 255)
elif self.gt_type == "mask":
raise NotImplementedError("mask class freq")
else:
raise NotImplementedError(f"gt_type={self.gt_type} not implemented for computing stats")
# Aggregate stats
if not self.mask_only:
if len(class_freq_list):
class_freq_array = np.stack(class_freq_list, axis=0)
stats["class_freq"] = np.mean(class_freq_array, axis=0)
stats["num"] = len(class_freq_list)
else:
print("Empty tile:", tile_info["city"], tile_info["number"], "polygons:", len(raw_data["gt_polygons"]))
else:
raise NotImplemented("patch_size is None")
# Save stats
if not self.mask_only:
torch.save(stats, stats_filepath)
# Mark tile as processed with flag
pathlib.Path(processed_flag_filepath).touch()
return stats
def __len__(self):
if self.pre_process:
return len(self.processed_relative_paths)
else:
return len(self.tile_info_list)
def __getitem__(self, idx):
if self.pre_process:
filepath = os.path.join(self.processed_dirpath, self.processed_relative_paths[idx])
data = torch.load(filepath)
if self.mask_only:
data["image"] = np.repeat(data["gt_polygons_image"][:, :, 0:1], 3, axis=-1) # Fill image slot
data["image_mean"] = np.array([0.5, 0.5, 0.5])
data["image_std"] = np.array([1, 1, 1])
else:
data["image_mean"] = np.array(CITY_METADATA_DICT[data["city"]]["mean"])
data["image_std"] = np.array(CITY_METADATA_DICT[data["city"]]["std"])
data["class_freq"] = self.stats["class_freq"]
else:
tile_info = self.tile_info_list[idx]
# Load raw data
data = self.load_raw_data(tile_info)
data["name"] = IMAGE_NAME_FORMAT.format(city=tile_info["city"], number=tile_info["number"])
data["image_mean"] = np.array(tile_info["mean"])
data["image_std"] = np.array(tile_info["std"])
data = self.transform(data)
return data
def main():
# Test using transforms from the frame_field_learning project:
from frame_field_learning import data_transforms
config = {
"data_dir_candidates": [
"/data/titane/user/nigirard/data",
"~/data",
"/data"
],
"dataset_params": {
"root_dirname": "AerialImageDataset",
"pre_process": False,
"gt_source": "disk",
"gt_type": "tif",
"gt_dirname": "gt",
"mask_only": False,
"small": True,
"data_patch_size": 425,
"input_patch_size": 300,
"train_fraction": 0.75
},
"num_workers": 8,
"data_aug_params": {
"enable": True,
"vflip": True,
"affine": True,
"scaling": [0.9, 1.1],
"color_jitter": True,
"device": "cuda"
}
}
# Find data_dir
data_dir = python_utils.choose_first_existing_path(config["data_dir_candidates"])
if data_dir is None:
print_utils.print_error("ERROR: Data directory not found!")
exit()
else:
print_utils.print_info("Using data from {}".format(data_dir))
root_dir = os.path.join(data_dir, config["dataset_params"]["root_dirname"])
# --- Transforms: --- #
# --- pre-processing transform (done once then saved on disk):
# --- Online transform done on the host (CPU):
online_cpu_transform = data_transforms.get_online_cpu_transform(config,
augmentations=config["data_aug_params"]["enable"])
train_online_cuda_transform = data_transforms.get_online_cuda_transform(config, augmentations=config["data_aug_params"]["enable"])
mask_only = config["dataset_params"]["mask_only"]
kwargs = {
"pre_process": config["dataset_params"]["pre_process"],
"transform": online_cpu_transform,
"patch_size": config["dataset_params"]["data_patch_size"],
"patch_stride": config["dataset_params"]["input_patch_size"],
"pre_transform": data_transforms.get_offline_transform_patch(distances=not mask_only, sizes=not mask_only),
"small": config["dataset_params"]["small"],
"pool_size": config["num_workers"],
"gt_source": config["dataset_params"]["gt_source"],
"gt_type": config["dataset_params"]["gt_type"],
"gt_dirname": config["dataset_params"]["gt_dirname"],
"mask_only": config["dataset_params"]["mask_only"],
}
train_val_split_point = config["dataset_params"]["train_fraction"] * 36
def train_tile_filter(tile): return tile["number"] <= train_val_split_point
def val_tile_filter(tile): return train_val_split_point < tile["number"]
# --- --- #
fold = "train"
if fold == "train":
dataset = InriaAerial(root_dir, fold="train", tile_filter=train_tile_filter, **kwargs)
elif fold == "val":
dataset = InriaAerial(root_dir, fold="train", tile_filter=val_tile_filter, **kwargs)
elif fold == "test":
dataset = InriaAerial(root_dir, fold="test", **kwargs)
print(f"dataset has {len(dataset)} samples.")
print("# --- Sample 0 --- #")
sample = dataset[0]
for key, item in sample.items():
print("{}: {}".format(key, type(item)))
print("# --- Samples --- #")
# for data in tqdm(dataset):
# pass
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=config["num_workers"])
print("# --- Batches --- #")
for batch in tqdm(data_loader):
# batch["distances"] = batch["distances"].float()
# batch["sizes"] = batch["sizes"].float()
# im = np.array(batch["image"][0])
# im = np.moveaxis(im, 0, -1)
# skimage.io.imsave('im_before_transform.png', im)
#
# distances = np.array(batch["distances"][0])
# distances = np.moveaxis(distances, 0, -1)
# skimage.io.imsave('distances_before_transform.png', distances)
#
# sizes = np.array(batch["sizes"][0])
# sizes = np.moveaxis(sizes, 0, -1)
# skimage.io.imsave('sizes_before_transform.png', sizes)
print("----")
print(batch["name"])
print("image:", batch["image"].shape, batch["image"].min().item(), batch["image"].max().item())
im = np.array(batch["image"][0])
im = np.moveaxis(im, 0, -1)
skimage.io.imsave('im.png', im)
if "gt_polygons_image" in batch:
print("gt_polygons_image:", batch["gt_polygons_image"].shape, batch["gt_polygons_image"].min().item(),
batch["gt_polygons_image"].max().item())
seg = np.array(batch["gt_polygons_image"][0]) / 255
seg = np.moveaxis(seg, 0, -1)
seg_display = utils.get_seg_display(seg)
seg_display = (seg_display * 255).astype(np.uint8)
skimage.io.imsave("gt_seg.png", seg_display)
if "gt_crossfield_angle" in batch:
print("gt_crossfield_angle:", batch["gt_crossfield_angle"].shape, batch["gt_crossfield_angle"].min().item(),
batch["gt_crossfield_angle"].max().item())
gt_crossfield_angle = np.array(batch["gt_crossfield_angle"][0])
gt_crossfield_angle = np.moveaxis(gt_crossfield_angle, 0, -1)
skimage.io.imsave('gt_crossfield_angle.png', gt_crossfield_angle)
if "distances" in batch:
print("distances:", batch["distances"].shape, batch["distances"].min().item(), batch["distances"].max().item())
distances = np.array(batch["distances"][0])
distances = np.moveaxis(distances, 0, -1)
skimage.io.imsave('distances.png', distances)
if "sizes" in batch:
print("sizes:", batch["sizes"].shape, batch["sizes"].min().item(), batch["sizes"].max().item())
sizes = np.array(batch["sizes"][0])
sizes = np.moveaxis(sizes, 0, -1)
skimage.io.imsave('sizes.png', sizes)
# valid_mask = np.array(batch["valid_mask"][0])
# valid_mask = np.moveaxis(valid_mask, 0, -1)
# skimage.io.imsave('valid_mask.png', valid_mask)
print("Apply online tranform:")
batch = utils.batch_to_cuda(batch)
batch = train_online_cuda_transform(batch)
batch = utils.batch_to_cpu(batch)
print("image:", batch["image"].shape, batch["image"].min().item(), batch["image"].max().item())
print("gt_polygons_image:", batch["gt_polygons_image"].shape, batch["gt_polygons_image"].min().item(), batch["gt_polygons_image"].max().item())
print("gt_crossfield_angle:", batch["gt_crossfield_angle"].shape, batch["gt_crossfield_angle"].min().item(), batch["gt_crossfield_angle"].max().item())
# print("distances:", batch["distances"].shape, batch["distances"].min().item(), batch["distances"].max().item())
# print("sizes:", batch["sizes"].shape, batch["sizes"].min().item(), batch["sizes"].max().item())
# Save output to visualize
seg = np.array(batch["gt_polygons_image"][0])
seg = np.moveaxis(seg, 0, -1)
seg_display = utils.get_seg_display(seg)
seg_display = (seg_display * 255).astype(np.uint8)
skimage.io.imsave("gt_seg.png", seg_display)
im = np.array(batch["image"][0])
im = np.moveaxis(im, 0, -1)
skimage.io.imsave('im.png', im)
gt_crossfield_angle = np.array(batch["gt_crossfield_angle"][0])
gt_crossfield_angle = np.moveaxis(gt_crossfield_angle, 0, -1)
skimage.io.imsave('gt_crossfield_angle.png', gt_crossfield_angle)
distances = np.array(batch["distances"][0])
distances = np.moveaxis(distances, 0, -1)
skimage.io.imsave('distances.png', distances)
sizes = np.array(batch["sizes"][0])
sizes = np.moveaxis(sizes, 0, -1)
skimage.io.imsave('sizes.png', sizes)
# valid_mask = np.array(batch["valid_mask"][0])
# valid_mask = np.moveaxis(valid_mask, 0, -1)
# skimage.io.imsave('valid_mask.png', valid_mask)
input("Press enter to continue...")
if __name__ == '__main__':
main()