Luuu / torch_lydorn /torchvision /datasets /luxcarta_buildings.py
็™ฝ้นญๅ…ˆ็”Ÿ
init
abd2a81
raw
history blame
17.2 kB
import skimage.io
from functools import partial
from multiprocess import Pool
import itertools
import numpy as np
import rasterio
import rasterio.rio.insp
import fiona
from pyproj import Proj, transform, CRS
import torch_lydorn.torch.utils.data
import os
from tqdm import tqdm
import torch
from frame_field_learning import data_transforms
from lydorn_utils import run_utils
from lydorn_utils import print_utils
from lydorn_utils import python_utils
from lydorn_utils import polygon_utils
from lydorn_utils import ogr2ogr
from lydorn_utils import image_utils
class LuxcartaBuildings(torch_lydorn.torch.utils.data.Dataset):
def __init__(self, root, transform=None, pre_transform=None, fold="train", patch_size=None, patch_stride=None,
pool_size=1):
assert fold in {"train", "test"}, "fold should be either train of test"
self.fold = fold
self.patch_size = patch_size
self.patch_stride = patch_stride
self.pool_size = pool_size
self._processed_filepaths = None
# TODO: implement pool_size option
super(LuxcartaBuildings, self).__init__(root, transform, pre_transform)
@property
def raw_dir(self):
return os.path.join(self.root, 'raw', self.fold)
@property
def processed_dir(self):
return os.path.join(self.root, 'processed', self.fold)
@property
def raw_file_names(self):
return []
@property
def raw_sample_metadata_list(self):
returned_list = []
for dirname in os.listdir(self.raw_dir):
dirpath = os.path.join(self.raw_dir, dirname)
if os.path.isdir(dirpath):
image_filepath_list = python_utils.get_filepaths(dirpath, endswith_str="_crop.tif",
startswith_str="Ortho",
not_endswith_str=".tif_crop.tif")
gt_polygons_filepath_list = python_utils.get_filepaths(dirpath, endswith_str=".shp",
startswith_str="Building")
mask_filepath_list = python_utils.get_filepaths(dirpath, endswith_str=".kml", startswith_str="zone")
if len(image_filepath_list) and len(gt_polygons_filepath_list):
metadata = {
"dirname": dirname,
"image_filepath": image_filepath_list[0],
"gt_polygons_filepath": gt_polygons_filepath_list[0],
}
if len(mask_filepath_list):
metadata["mask_filepath"] = mask_filepath_list[0]
returned_list.append(metadata)
return returned_list
@property
def processed_filepaths(self):
if self._processed_filepaths is None:
self._processed_filepaths = []
for dirname in os.listdir(self.processed_dir):
dirpath = os.path.join(self.processed_dir, dirname)
if os.path.isdir(dirpath):
filepath_list = python_utils.get_filepaths(dirpath, endswith_str=".pt", startswith_str="data.")
self._processed_filepaths.extend(filepath_list)
return self._processed_filepaths
def __len__(self):
return len(self.processed_filepaths)
def _download(self):
pass
def download(self):
pass
def _get_mask_multi_polygon(self, mask_filepath, shp_srs):
# Read mask and convert to shapefile's crs:
mask_shp_filepath = os.path.join(os.path.dirname(mask_filepath), "mask.shp")
# TODO: see what's up with the warnings:
ogr2ogr.main(["", "-f", "ESRI Shapefile", mask_shp_filepath,
mask_filepath]) # Convert .kml into .shp
mask = fiona.open(mask_shp_filepath, "r")
mask_srs = Proj(mask.crs["init"])
mask_multi_polygon = []
for feat in mask:
mask_polygon = []
for point in feat['geometry']['coordinates'][0]:
long, lat = point[:2] # one 2D point of the LinearRing
x, y = transform(mask_srs, shp_srs, long, lat, always_xy=True) # transform the point
mask_polygon.append((x, y))
mask_multi_polygon.append(mask_polygon)
# mask_polygon is now in UTM proj, ready to compare to shapefile proj
mask.close()
return mask_multi_polygon
def _read_annotations(self, annotations_filepath, raster, mask_filepath=None):
shapefile = fiona.open(annotations_filepath, "r")
shp_srs = Proj(shapefile.crs["init"])
raster_srs = Proj(raster.crs)
# --- Read and crop shapefile with mask if specified --- #
if mask_filepath is not None:
mask_multi_polygon = self._get_mask_multi_polygon(mask_filepath, shp_srs)
else:
mask_multi_polygon = None
process_feat_partial = partial(process_feat, shp_srs=shp_srs, raster_srs=raster_srs,
raster_transform=raster.transform, mask_multi_polygon=mask_multi_polygon)
with Pool() as pool:
out_polygons = list(
tqdm(pool.imap(process_feat_partial, shapefile), desc="Process shp feature", total=len(shapefile),
leave=True))
out_polygons = list(itertools.chain.from_iterable(out_polygons))
shapefile.close()
return out_polygons
def process(self, metadata_list):
progress_bar = tqdm(metadata_list, desc="Pre-process")
for metadata in progress_bar:
progress_bar.set_postfix(image=metadata["dirname"], status="Loading image")
# Load image
# image = skimage.io.imread(metadata["image_filepath"])
raster = rasterio.open(metadata["image_filepath"])
# print(raster)
# print(dir(raster))
# print(raster.meta)
# exit()
progress_bar.set_postfix(image=metadata["dirname"], status="Process shapefile")
mask_filepath = metadata["mask_filepath"] if "mask_filepath" in metadata else None
gt_polygons = self._read_annotations(metadata["gt_polygons_filepath"], raster, mask_filepath=mask_filepath)
# Compute image mean and std
b1, b2, b3 = raster.read()
image = np.stack([b1, b2, b3], axis=-1)
progress_bar.set_postfix(image=metadata["dirname"], status="Compute mean and std")
image_float = image / 255
mean = np.mean(image_float.reshape(-1, image_float.shape[-1]), axis=0)
std = np.std(image_float.reshape(-1, image_float.shape[-1]), axis=0)
if self.patch_size is not None:
# Patch the tile
progress_bar.set_postfix(image=metadata["dirname"], status="Patching")
patch_stride = self.patch_stride if self.patch_stride is not None else self.patch_size
patch_boundingboxes = image_utils.compute_patch_boundingboxes(image.shape[0:2],
stride=patch_stride,
patch_res=self.patch_size)
for i, patch_boundingbox in enumerate(tqdm(patch_boundingboxes, desc="Process patches", leave=False)):
patch_gt_polygons = polygon_utils.crop_polygons_to_patch_if_touch(gt_polygons, patch_boundingbox)
if len(patch_gt_polygons) == 0:
# Do not save patches empty of polygons # TODO: keep empty patches?
break
patch_image = image[patch_boundingbox[0]:patch_boundingbox[2],
patch_boundingbox[1]:patch_boundingbox[3], :]
sample = {
"dirname": metadata["dirname"],
"image": patch_image,
"image_mean": mean,
"image_std": std,
"gt_polygons": patch_gt_polygons,
"image_filepath": metadata["image_filepath"],
}
if self.pre_transform:
sample = self.pre_transform(sample)
filepath = os.path.join(self.processed_dir, sample["dirname"], "data.{:06d}.pt".format(i))
os.makedirs(os.path.dirname(filepath), exist_ok=True)
torch.save(sample, filepath)
else:
# Tile is saved as is
sample = {
"dirname": metadata["dirname"],
"image": image,
"image_mean": mean,
"image_std": std,
"gt_polygons": gt_polygons,
"image_filepath": metadata["image_filepath"],
}
if self.pre_transform:
sample = self.pre_transform(sample)
filepath = os.path.join(self.processed_dir, sample["dirname"], "data.{:06d}.pt".format(0))
os.makedirs(os.path.dirname(filepath), exist_ok=True)
torch.save(sample, filepath)
flag_filepath = os.path.join(self.processed_dir, metadata["dirname"], "flag.json")
flag = python_utils.load_json(flag_filepath)
flag["done"] = True
python_utils.save_json(flag_filepath, flag)
raster.close()
def _process(self):
to_process_metadata_list = []
for metadata in self.raw_sample_metadata_list:
flag_filepath = os.path.join(self.processed_dir, metadata["dirname"], "flag.json")
if os.path.exists(flag_filepath):
flag = python_utils.load_json(flag_filepath)
if not flag["done"]:
to_process_metadata_list.append(metadata)
else:
flag = {
"done": False
}
python_utils.save_json(flag_filepath, flag)
to_process_metadata_list.append(metadata)
if len(to_process_metadata_list) == 0:
return
print('Processing...')
torch_lydorn.torch.utils.data.makedirs(self.processed_dir)
self.process(to_process_metadata_list)
path = os.path.join(self.processed_dir, 'pre_transform.pt')
torch.save(torch_lydorn.torch.utils.data.__repr__(self.pre_transform), path)
path = os.path.join(self.processed_dir, 'pre_filter.pt')
torch.save(torch_lydorn.torch.utils.data.__repr__(self.pre_filter), path)
print('Done!')
def get(self, idx):
filepath = self.processed_filepaths[idx]
data = torch.load(filepath)
data["patch_bbox"] = torch.tensor([0, 0, 0, 0]) # TODO: implement in pre-processing
tile_name = os.path.basename(os.path.dirname(filepath))
patch_name = os.path.basename(filepath)
patch_name = patch_name[len("data."):-len(".pt")]
data["name"] = tile_name + "." + patch_name
return data
def process_feat(feat, shp_srs, raster_srs, raster_transform, mask_multi_polygon=None):
out_polygons = []
if feat['geometry']["type"] == "Polygon":
poly = feat['geometry']['coordinates']
polygons = process_polygon_feat(poly, shp_srs, raster_srs, raster_transform, mask_multi_polygon)
out_polygons.extend(polygons)
elif feat['geometry']["type"] == "MultiPolygon":
for poly in feat['geometry']["coordinates"]:
polygons = process_polygon_feat(poly, shp_srs, raster_srs, raster_transform, mask_multi_polygon)
out_polygons.extend(polygons)
return out_polygons
def process_polygon_feat(in_polygon, shp_srs, raster_srs, raster_transform, mask_multi_polygon=None):
out_polygons = []
points = in_polygon[0] # TODO: handle holes
# Intersect with mask_polygon if specified
if mask_multi_polygon is not None:
multi_polygon_simple = polygon_utils.intersect_polygons(points, mask_multi_polygon)
if multi_polygon_simple is None:
return out_polygons
else:
multi_polygon_simple = [points]
for polygon_simple in multi_polygon_simple:
new_poly = []
for point in polygon_simple:
x, y = point[:2] # 740524.429227941 7175355.263524155
x, y = transform(shp_srs, raster_srs, x, y) # transform the point # 740520.728530676 7175320.732711278
j, i = ~raster_transform * (x, y)
new_poly.append((i, j)) # 2962.534577447921 2457.457061359659
out_polygons.append(np.array(new_poly))
return out_polygons
def get_seg_display(seg):
seg_display = np.zeros([seg.shape[0], seg.shape[1], 4], dtype=np.float)
if len(seg.shape) == 2:
seg_display[..., 0] = seg
seg_display[..., 3] = seg
else:
for i in range(seg.shape[-1]):
seg_display[..., i] = seg[..., i]
seg_display[..., 3] = np.clip(np.sum(seg, axis=-1), 0, 1)
return seg_display
def main():
# --- Params --- #
config_name = "config.luxcarta_dataset"
# --- --- #
# Load config
config = run_utils.load_config(config_name)
if config is None:
print_utils.print_error(
"ERROR: cannot continue without a config file. Exiting now...")
exit()
# Find data_dir
data_dir = python_utils.choose_first_existing_path(config["data_dir_candidates"])
if data_dir is None:
print_utils.print_error("ERROR: Data directory not found!")
exit()
else:
print_utils.print_info("Using data from {}".format(data_dir))
root_dir = os.path.join(data_dir, config["dataset_params"]["root_dirname"])
# --- Transforms: --- #
# --- pre-processing transform (done once then saved on disk):
train_pre_transform = data_transforms.get_offline_transform(config,
augmentations=config["data_aug_params"]["enable"])
eval_pre_transform = data_transforms.get_offline_transform(config, augmentations=False)
# --- Online transform done on the host (CPU):
train_online_cpu_transform = data_transforms.get_online_cpu_transform(config,
augmentations=config["data_aug_params"][
"enable"])
eval_online_cpu_transform = data_transforms.get_online_cpu_transform(config, augmentations=False)
# --- Online transform performed on the device (GPU):
train_online_cuda_transform = data_transforms.get_online_cuda_transform(config,
augmentations=config["data_aug_params"][
"enable"])
eval_online_cuda_transform = data_transforms.get_online_cuda_transform(config, augmentations=False)
# --- --- #
data_patch_size = config["dataset_params"]["data_patch_size"] if config["data_aug_params"]["enable"] else config["dataset_params"]["input_patch_size"]
fold = "test"
if fold == "train":
dataset = LuxcartaBuildings(root_dir,
transform=train_online_cpu_transform,
patch_size=data_patch_size,
patch_stride=config["dataset_params"]["input_patch_size"],
pre_transform=data_transforms.get_offline_transform_patch(),
fold="train",
pool_size=config["num_workers"])
elif fold == "test":
dataset = LuxcartaBuildings(root_dir,
transform=train_online_cpu_transform,
pre_transform=data_transforms.get_offline_transform_patch(),
fold="test",
pool_size=config["num_workers"])
print("# --- Sample 0 --- #")
sample = dataset[0]
print(sample["image"].shape)
print(sample["gt_polygons_image"].shape)
print("# --- Samples --- #")
# for data in tqdm(dataset):
# pass
data_loader = torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=True, num_workers=0)
print("# --- Batches --- #")
for batch in tqdm(data_loader):
print(batch["image"].shape)
print(batch["gt_polygons_image"].shape)
# Save output to visualize
seg = np.array(batch["gt_polygons_image"][0]) / 255 # First batch
seg = np.moveaxis(seg, 0, -1)
seg_display = get_seg_display(seg)
seg_display = (seg_display * 255).astype(np.uint8)
skimage.io.imsave("gt_seg.png", seg_display)
im = np.array(batch["image"][0])
im = np.moveaxis(im, 0, -1)
skimage.io.imsave('im.png', im)
input("Enter to continue...")
if __name__ == '__main__':
main()