|
import skimage.io |
|
from functools import partial |
|
from multiprocess import Pool |
|
import itertools |
|
|
|
import numpy as np |
|
import rasterio |
|
import rasterio.rio.insp |
|
import fiona |
|
from pyproj import Proj, transform, CRS |
|
|
|
import torch_lydorn.torch.utils.data |
|
|
|
import os |
|
|
|
from tqdm import tqdm |
|
|
|
import torch |
|
|
|
from frame_field_learning import data_transforms |
|
|
|
from lydorn_utils import run_utils |
|
from lydorn_utils import print_utils |
|
from lydorn_utils import python_utils |
|
from lydorn_utils import polygon_utils |
|
from lydorn_utils import ogr2ogr |
|
from lydorn_utils import image_utils |
|
|
|
|
|
class LuxcartaBuildings(torch_lydorn.torch.utils.data.Dataset): |
|
def __init__(self, root, transform=None, pre_transform=None, fold="train", patch_size=None, patch_stride=None, |
|
pool_size=1): |
|
assert fold in {"train", "test"}, "fold should be either train of test" |
|
self.fold = fold |
|
self.patch_size = patch_size |
|
self.patch_stride = patch_stride |
|
self.pool_size = pool_size |
|
self._processed_filepaths = None |
|
|
|
super(LuxcartaBuildings, self).__init__(root, transform, pre_transform) |
|
|
|
@property |
|
def raw_dir(self): |
|
return os.path.join(self.root, 'raw', self.fold) |
|
|
|
@property |
|
def processed_dir(self): |
|
return os.path.join(self.root, 'processed', self.fold) |
|
|
|
@property |
|
def raw_file_names(self): |
|
return [] |
|
|
|
@property |
|
def raw_sample_metadata_list(self): |
|
returned_list = [] |
|
for dirname in os.listdir(self.raw_dir): |
|
dirpath = os.path.join(self.raw_dir, dirname) |
|
if os.path.isdir(dirpath): |
|
image_filepath_list = python_utils.get_filepaths(dirpath, endswith_str="_crop.tif", |
|
startswith_str="Ortho", |
|
not_endswith_str=".tif_crop.tif") |
|
gt_polygons_filepath_list = python_utils.get_filepaths(dirpath, endswith_str=".shp", |
|
startswith_str="Building") |
|
mask_filepath_list = python_utils.get_filepaths(dirpath, endswith_str=".kml", startswith_str="zone") |
|
if len(image_filepath_list) and len(gt_polygons_filepath_list): |
|
metadata = { |
|
"dirname": dirname, |
|
"image_filepath": image_filepath_list[0], |
|
"gt_polygons_filepath": gt_polygons_filepath_list[0], |
|
} |
|
if len(mask_filepath_list): |
|
metadata["mask_filepath"] = mask_filepath_list[0] |
|
returned_list.append(metadata) |
|
return returned_list |
|
|
|
@property |
|
def processed_filepaths(self): |
|
if self._processed_filepaths is None: |
|
self._processed_filepaths = [] |
|
for dirname in os.listdir(self.processed_dir): |
|
dirpath = os.path.join(self.processed_dir, dirname) |
|
if os.path.isdir(dirpath): |
|
filepath_list = python_utils.get_filepaths(dirpath, endswith_str=".pt", startswith_str="data.") |
|
self._processed_filepaths.extend(filepath_list) |
|
return self._processed_filepaths |
|
|
|
def __len__(self): |
|
return len(self.processed_filepaths) |
|
|
|
def _download(self): |
|
pass |
|
|
|
def download(self): |
|
pass |
|
|
|
def _get_mask_multi_polygon(self, mask_filepath, shp_srs): |
|
|
|
mask_shp_filepath = os.path.join(os.path.dirname(mask_filepath), "mask.shp") |
|
|
|
|
|
ogr2ogr.main(["", "-f", "ESRI Shapefile", mask_shp_filepath, |
|
mask_filepath]) |
|
mask = fiona.open(mask_shp_filepath, "r") |
|
mask_srs = Proj(mask.crs["init"]) |
|
|
|
mask_multi_polygon = [] |
|
for feat in mask: |
|
mask_polygon = [] |
|
for point in feat['geometry']['coordinates'][0]: |
|
long, lat = point[:2] |
|
x, y = transform(mask_srs, shp_srs, long, lat, always_xy=True) |
|
mask_polygon.append((x, y)) |
|
mask_multi_polygon.append(mask_polygon) |
|
|
|
|
|
mask.close() |
|
|
|
return mask_multi_polygon |
|
|
|
def _read_annotations(self, annotations_filepath, raster, mask_filepath=None): |
|
shapefile = fiona.open(annotations_filepath, "r") |
|
shp_srs = Proj(shapefile.crs["init"]) |
|
raster_srs = Proj(raster.crs) |
|
|
|
|
|
if mask_filepath is not None: |
|
mask_multi_polygon = self._get_mask_multi_polygon(mask_filepath, shp_srs) |
|
else: |
|
mask_multi_polygon = None |
|
|
|
process_feat_partial = partial(process_feat, shp_srs=shp_srs, raster_srs=raster_srs, |
|
raster_transform=raster.transform, mask_multi_polygon=mask_multi_polygon) |
|
with Pool() as pool: |
|
out_polygons = list( |
|
tqdm(pool.imap(process_feat_partial, shapefile), desc="Process shp feature", total=len(shapefile), |
|
leave=True)) |
|
out_polygons = list(itertools.chain.from_iterable(out_polygons)) |
|
|
|
shapefile.close() |
|
|
|
return out_polygons |
|
|
|
def process(self, metadata_list): |
|
progress_bar = tqdm(metadata_list, desc="Pre-process") |
|
for metadata in progress_bar: |
|
progress_bar.set_postfix(image=metadata["dirname"], status="Loading image") |
|
|
|
|
|
raster = rasterio.open(metadata["image_filepath"]) |
|
|
|
|
|
|
|
|
|
|
|
progress_bar.set_postfix(image=metadata["dirname"], status="Process shapefile") |
|
mask_filepath = metadata["mask_filepath"] if "mask_filepath" in metadata else None |
|
gt_polygons = self._read_annotations(metadata["gt_polygons_filepath"], raster, mask_filepath=mask_filepath) |
|
|
|
|
|
b1, b2, b3 = raster.read() |
|
image = np.stack([b1, b2, b3], axis=-1) |
|
progress_bar.set_postfix(image=metadata["dirname"], status="Compute mean and std") |
|
image_float = image / 255 |
|
mean = np.mean(image_float.reshape(-1, image_float.shape[-1]), axis=0) |
|
std = np.std(image_float.reshape(-1, image_float.shape[-1]), axis=0) |
|
|
|
if self.patch_size is not None: |
|
|
|
progress_bar.set_postfix(image=metadata["dirname"], status="Patching") |
|
patch_stride = self.patch_stride if self.patch_stride is not None else self.patch_size |
|
patch_boundingboxes = image_utils.compute_patch_boundingboxes(image.shape[0:2], |
|
stride=patch_stride, |
|
patch_res=self.patch_size) |
|
for i, patch_boundingbox in enumerate(tqdm(patch_boundingboxes, desc="Process patches", leave=False)): |
|
patch_gt_polygons = polygon_utils.crop_polygons_to_patch_if_touch(gt_polygons, patch_boundingbox) |
|
if len(patch_gt_polygons) == 0: |
|
|
|
break |
|
patch_image = image[patch_boundingbox[0]:patch_boundingbox[2], |
|
patch_boundingbox[1]:patch_boundingbox[3], :] |
|
sample = { |
|
"dirname": metadata["dirname"], |
|
"image": patch_image, |
|
"image_mean": mean, |
|
"image_std": std, |
|
"gt_polygons": patch_gt_polygons, |
|
"image_filepath": metadata["image_filepath"], |
|
} |
|
|
|
if self.pre_transform: |
|
sample = self.pre_transform(sample) |
|
|
|
filepath = os.path.join(self.processed_dir, sample["dirname"], "data.{:06d}.pt".format(i)) |
|
os.makedirs(os.path.dirname(filepath), exist_ok=True) |
|
torch.save(sample, filepath) |
|
else: |
|
|
|
sample = { |
|
"dirname": metadata["dirname"], |
|
"image": image, |
|
"image_mean": mean, |
|
"image_std": std, |
|
"gt_polygons": gt_polygons, |
|
"image_filepath": metadata["image_filepath"], |
|
} |
|
|
|
if self.pre_transform: |
|
sample = self.pre_transform(sample) |
|
|
|
filepath = os.path.join(self.processed_dir, sample["dirname"], "data.{:06d}.pt".format(0)) |
|
os.makedirs(os.path.dirname(filepath), exist_ok=True) |
|
torch.save(sample, filepath) |
|
|
|
flag_filepath = os.path.join(self.processed_dir, metadata["dirname"], "flag.json") |
|
flag = python_utils.load_json(flag_filepath) |
|
flag["done"] = True |
|
python_utils.save_json(flag_filepath, flag) |
|
|
|
raster.close() |
|
|
|
def _process(self): |
|
to_process_metadata_list = [] |
|
for metadata in self.raw_sample_metadata_list: |
|
flag_filepath = os.path.join(self.processed_dir, metadata["dirname"], "flag.json") |
|
if os.path.exists(flag_filepath): |
|
flag = python_utils.load_json(flag_filepath) |
|
if not flag["done"]: |
|
to_process_metadata_list.append(metadata) |
|
else: |
|
flag = { |
|
"done": False |
|
} |
|
python_utils.save_json(flag_filepath, flag) |
|
to_process_metadata_list.append(metadata) |
|
|
|
if len(to_process_metadata_list) == 0: |
|
return |
|
|
|
print('Processing...') |
|
|
|
torch_lydorn.torch.utils.data.makedirs(self.processed_dir) |
|
self.process(to_process_metadata_list) |
|
|
|
path = os.path.join(self.processed_dir, 'pre_transform.pt') |
|
torch.save(torch_lydorn.torch.utils.data.__repr__(self.pre_transform), path) |
|
path = os.path.join(self.processed_dir, 'pre_filter.pt') |
|
torch.save(torch_lydorn.torch.utils.data.__repr__(self.pre_filter), path) |
|
|
|
print('Done!') |
|
|
|
def get(self, idx): |
|
filepath = self.processed_filepaths[idx] |
|
data = torch.load(filepath) |
|
data["patch_bbox"] = torch.tensor([0, 0, 0, 0]) |
|
tile_name = os.path.basename(os.path.dirname(filepath)) |
|
patch_name = os.path.basename(filepath) |
|
patch_name = patch_name[len("data."):-len(".pt")] |
|
data["name"] = tile_name + "." + patch_name |
|
return data |
|
|
|
|
|
def process_feat(feat, shp_srs, raster_srs, raster_transform, mask_multi_polygon=None): |
|
out_polygons = [] |
|
|
|
if feat['geometry']["type"] == "Polygon": |
|
poly = feat['geometry']['coordinates'] |
|
polygons = process_polygon_feat(poly, shp_srs, raster_srs, raster_transform, mask_multi_polygon) |
|
out_polygons.extend(polygons) |
|
|
|
elif feat['geometry']["type"] == "MultiPolygon": |
|
for poly in feat['geometry']["coordinates"]: |
|
polygons = process_polygon_feat(poly, shp_srs, raster_srs, raster_transform, mask_multi_polygon) |
|
out_polygons.extend(polygons) |
|
|
|
return out_polygons |
|
|
|
|
|
def process_polygon_feat(in_polygon, shp_srs, raster_srs, raster_transform, mask_multi_polygon=None): |
|
out_polygons = [] |
|
points = in_polygon[0] |
|
|
|
if mask_multi_polygon is not None: |
|
multi_polygon_simple = polygon_utils.intersect_polygons(points, mask_multi_polygon) |
|
if multi_polygon_simple is None: |
|
return out_polygons |
|
else: |
|
multi_polygon_simple = [points] |
|
|
|
for polygon_simple in multi_polygon_simple: |
|
new_poly = [] |
|
for point in polygon_simple: |
|
x, y = point[:2] |
|
x, y = transform(shp_srs, raster_srs, x, y) |
|
j, i = ~raster_transform * (x, y) |
|
new_poly.append((i, j)) |
|
out_polygons.append(np.array(new_poly)) |
|
|
|
return out_polygons |
|
|
|
|
|
def get_seg_display(seg): |
|
seg_display = np.zeros([seg.shape[0], seg.shape[1], 4], dtype=np.float) |
|
if len(seg.shape) == 2: |
|
seg_display[..., 0] = seg |
|
seg_display[..., 3] = seg |
|
else: |
|
for i in range(seg.shape[-1]): |
|
seg_display[..., i] = seg[..., i] |
|
seg_display[..., 3] = np.clip(np.sum(seg, axis=-1), 0, 1) |
|
return seg_display |
|
|
|
|
|
def main(): |
|
|
|
config_name = "config.luxcarta_dataset" |
|
|
|
|
|
|
|
config = run_utils.load_config(config_name) |
|
if config is None: |
|
print_utils.print_error( |
|
"ERROR: cannot continue without a config file. Exiting now...") |
|
exit() |
|
|
|
|
|
data_dir = python_utils.choose_first_existing_path(config["data_dir_candidates"]) |
|
if data_dir is None: |
|
print_utils.print_error("ERROR: Data directory not found!") |
|
exit() |
|
else: |
|
print_utils.print_info("Using data from {}".format(data_dir)) |
|
root_dir = os.path.join(data_dir, config["dataset_params"]["root_dirname"]) |
|
|
|
|
|
|
|
train_pre_transform = data_transforms.get_offline_transform(config, |
|
augmentations=config["data_aug_params"]["enable"]) |
|
eval_pre_transform = data_transforms.get_offline_transform(config, augmentations=False) |
|
|
|
train_online_cpu_transform = data_transforms.get_online_cpu_transform(config, |
|
augmentations=config["data_aug_params"][ |
|
"enable"]) |
|
eval_online_cpu_transform = data_transforms.get_online_cpu_transform(config, augmentations=False) |
|
|
|
train_online_cuda_transform = data_transforms.get_online_cuda_transform(config, |
|
augmentations=config["data_aug_params"][ |
|
"enable"]) |
|
eval_online_cuda_transform = data_transforms.get_online_cuda_transform(config, augmentations=False) |
|
|
|
|
|
data_patch_size = config["dataset_params"]["data_patch_size"] if config["data_aug_params"]["enable"] else config["dataset_params"]["input_patch_size"] |
|
fold = "test" |
|
if fold == "train": |
|
dataset = LuxcartaBuildings(root_dir, |
|
transform=train_online_cpu_transform, |
|
patch_size=data_patch_size, |
|
patch_stride=config["dataset_params"]["input_patch_size"], |
|
pre_transform=data_transforms.get_offline_transform_patch(), |
|
fold="train", |
|
pool_size=config["num_workers"]) |
|
elif fold == "test": |
|
dataset = LuxcartaBuildings(root_dir, |
|
transform=train_online_cpu_transform, |
|
pre_transform=data_transforms.get_offline_transform_patch(), |
|
fold="test", |
|
pool_size=config["num_workers"]) |
|
|
|
print("# --- Sample 0 --- #") |
|
sample = dataset[0] |
|
print(sample["image"].shape) |
|
print(sample["gt_polygons_image"].shape) |
|
print("# --- Samples --- #") |
|
|
|
|
|
|
|
data_loader = torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=True, num_workers=0) |
|
print("# --- Batches --- #") |
|
for batch in tqdm(data_loader): |
|
print(batch["image"].shape) |
|
print(batch["gt_polygons_image"].shape) |
|
|
|
|
|
seg = np.array(batch["gt_polygons_image"][0]) / 255 |
|
seg = np.moveaxis(seg, 0, -1) |
|
seg_display = get_seg_display(seg) |
|
seg_display = (seg_display * 255).astype(np.uint8) |
|
skimage.io.imsave("gt_seg.png", seg_display) |
|
|
|
im = np.array(batch["image"][0]) |
|
im = np.moveaxis(im, 0, -1) |
|
skimage.io.imsave('im.png', im) |
|
|
|
input("Enter to continue...") |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|