|
import fnmatch |
|
import os.path |
|
import pathlib |
|
import sys |
|
import time |
|
|
|
import shapely.geometry |
|
import multiprocess |
|
import itertools |
|
import skimage.io |
|
import numpy as np |
|
|
|
from tqdm import tqdm |
|
|
|
import torch |
|
import torch.utils.data |
|
import torchvision |
|
|
|
from lydorn_utils import run_utils, image_utils, polygon_utils, geo_utils |
|
from lydorn_utils import print_utils |
|
from lydorn_utils import python_utils |
|
|
|
from torch_lydorn.torchvision.datasets import utils |
|
|
|
CITY_METADATA_DICT = { |
|
"bloomington": { |
|
"fold": "test", |
|
"pixelsize": 0.3, |
|
"numbers": list(range(1, 37)), |
|
"mean": [0.44583929, 0.46205078, 0.35783887], |
|
"std": [0.18212699, 0.17152641, 0.16157062], |
|
}, |
|
"bellingham": { |
|
"fold": "test", |
|
"pixelsize": 0.3, |
|
"numbers": list(range(1, 37)), |
|
"mean": [0.3766195, 0.391402, 0.32659722], |
|
"std": [0.18134978, 0.16412577, 0.16369793], |
|
}, |
|
"innsbruck": { |
|
"fold": "test", |
|
"pixelsize": 0.3, |
|
"numbers": list(range(1, 37)), |
|
"mean": [0.41375683, 0.41818116, 0.38940192], |
|
"std": [0.16616156, 0.14364722, 0.13317743], |
|
}, |
|
"sfo": { |
|
"fold": "test", |
|
"pixelsize": 0.3, |
|
"numbers": list(range(1, 37)), |
|
"mean": [0.59388761, 0.61522012, 0.54348289], |
|
"std": [0.25730708, 0.23301019, 0.23707742], |
|
}, |
|
"tyrol-e": { |
|
"fold": "test", |
|
"pixelsize": 0.3, |
|
"numbers": list(range(1, 37)), |
|
"mean": [0.44171042, 0.48147037, 0.44642358], |
|
"std": [0.1808623, 0.15437789, 0.15102051], |
|
}, |
|
"austin": { |
|
"fold": "train", |
|
"pixelsize": 0.3, |
|
"numbers": list(range(1, 37)), |
|
"mean": [0.39584444, 0.40599795, 0.38298687], |
|
"std": [0.17341954, 0.16856597, 0.16360443], |
|
}, |
|
"chicago": { |
|
"fold": "train", |
|
"pixelsize": 0.3, |
|
"numbers": list(range(1, 37)), |
|
"mean": [0.4055142, 0.42844002, 0.38229637], |
|
"std": [0.2133328, 0.20827106, 0.20132315], |
|
}, |
|
"kitsap": { |
|
"fold": "train", |
|
"pixelsize": 0.3, |
|
"numbers": list(range(1, 37)), |
|
"mean": [0.34717916, 0.37854108, 0.32571001], |
|
"std": [0.17048794, 0.14537676, 0.13466496], |
|
}, |
|
"tyrol-w": { |
|
"fold": "train", |
|
"pixelsize": 0.3, |
|
"numbers": list(range(1, 37)), |
|
"mean": [0.39704218, 0.4545488, 0.4321427], |
|
"std": [0.19484766, 0.1742585, 0.15186383], |
|
}, |
|
"vienna": { |
|
"fold": "train", |
|
"pixelsize": 0.3, |
|
"numbers": list(range(1, 37)), |
|
"mean": [0.47861977, 0.46878486, 0.44043111], |
|
"std": [0.22614806, 0.19949128, 0.19524506], |
|
}, |
|
} |
|
|
|
IMAGE_DIRNAME = "images" |
|
IMAGE_NAME_FORMAT = "{city}{number}" |
|
IMAGE_FILENAME_FORMAT = IMAGE_NAME_FORMAT + ".tif" |
|
|
|
|
|
class InriaAerial(torch.utils.data.Dataset): |
|
""" |
|
Inria Aerial Image Dataset |
|
""" |
|
|
|
def __init__(self, root: str, fold: str="train", pre_process: bool=True, tile_filter=None, patch_size: int=None, patch_stride: int=None, |
|
pre_transform=None, transform=None, small: bool=False, pool_size: int=1, raw_dirname: str="raw", processed_dirname: str="processed", |
|
gt_source: str="disk", gt_type: str="npy", gt_dirname: str="gt_polygons", mask_only: bool=False): |
|
""" |
|
|
|
@param root: |
|
@param fold: |
|
@param pre_process: If True, the dataset will be pre-processed first, saving training patches on disk. If False, data will be serve on-the-fly without any patching. |
|
@param tile_filter: Function to call on tile_info, if returns True, include that tile. If returns False, exclude that tile. Does not affect pre-processing. |
|
@param patch_size: |
|
@param patch_stride: |
|
@param pre_transform: |
|
@param transform: |
|
@param small: If True, use a small subset of the dataset (for testing) |
|
@param pool_size: |
|
@param processed_dirname: |
|
@param gt_source: Can be "disk" for annotation that are on disk or "osm" to download from OSM (not implemented) |
|
@param gt_type: Type of annotation files on disk: can be "npy", "geojson" or "tif" |
|
@param gt_dirname: Name of directory with annotation files |
|
@param mask_only: If True, discard the RGB image, sample's "image" field is a single-channel binary mask of the polygons and there is no ground truth segmentation. |
|
This is to allow learning only the frame field from binary masks in order to polygonize binary masks |
|
""" |
|
assert gt_source in {"disk", "osm"}, "gt_source should be disk or osm" |
|
assert gt_type in {"npy", "geojson", "tif"}, f"gt_type should be npy, geojson or tif, not {gt_type}" |
|
self.root = root |
|
self.fold = fold |
|
self.pre_process = pre_process |
|
self.tile_filter = tile_filter |
|
self.patch_size = patch_size |
|
self.patch_stride = patch_stride |
|
self.pre_transform = pre_transform |
|
self.transform = transform |
|
self.small = small |
|
if self.small: |
|
print_utils.print_info("INFO: Using small version of the Inria dataset.") |
|
self.pool_size = pool_size |
|
self.raw_dirname = raw_dirname |
|
self.gt_source = gt_source |
|
self.gt_type = gt_type |
|
self.gt_dirname = gt_dirname |
|
self.mask_only = mask_only |
|
|
|
|
|
if self.gt_source == "disk": |
|
print_utils.print_info("INFO: annotations will be loaded from disk") |
|
elif self.gt_source == "osm": |
|
print_utils.print_info("INFO: annotations will be downloaded from OSM. " |
|
"Make sure you have an internet connection to the OSM servers!") |
|
|
|
if self.pre_process: |
|
|
|
processed_dirname_extention = f"{processed_dirname}.source_{self.gt_source}.type_{self.gt_type}" |
|
if self.gt_dirname is not None: |
|
processed_dirname_extention += f".dirname_{self.gt_dirname}" |
|
if self.mask_only: |
|
processed_dirname_extention += f".mask_only_{int(self.mask_only)}" |
|
processed_dirname_extention += f".patch_size_{int(self.patch_size)}" |
|
self.processed_dirpath = os.path.join(self.root, processed_dirname_extention, self.fold) |
|
self.stats_filepath = os.path.join(self.processed_dirpath, "stats-small.pt" if self.small else "stats.pt") |
|
self.processed_flag_filepath = os.path.join(self.processed_dirpath, |
|
"processed_flag-small" if self.small else "processed_flag") |
|
|
|
|
|
if os.path.exists(self.processed_flag_filepath): |
|
|
|
self.stats = torch.load(self.stats_filepath) |
|
else: |
|
|
|
tile_info_list = self.get_tile_info_list(tile_filter=None) |
|
self.stats = self.process(tile_info_list) |
|
|
|
torch.save(self.stats, self.stats_filepath) |
|
|
|
pathlib.Path(self.processed_flag_filepath).touch() |
|
|
|
|
|
tile_info_list = self.get_tile_info_list(tile_filter=self.tile_filter) |
|
self.processed_relative_paths = self.get_processed_relative_paths(tile_info_list) |
|
else: |
|
|
|
self.tile_info_list = self.get_tile_info_list(tile_filter=self.tile_filter) |
|
|
|
def get_tile_info_list(self, tile_filter=None): |
|
tile_info_list = [] |
|
for city, info in CITY_METADATA_DICT.items(): |
|
if not info["fold"] == self.fold: |
|
continue |
|
if self.small: |
|
numbers = [*info["numbers"][:5], info["numbers"][-1]] |
|
else: |
|
numbers = info["numbers"] |
|
for number in numbers: |
|
image_info = { |
|
"city": city, |
|
"number": number, |
|
"pixelsize": info["pixelsize"], |
|
"mean": np.array(info["mean"]), |
|
"std": np.array(info["std"]), |
|
} |
|
tile_info_list.append(image_info) |
|
if tile_filter is not None: |
|
tile_info_list = list(filter(self.tile_filter, tile_info_list)) |
|
return tile_info_list |
|
|
|
def get_processed_relative_paths(self, tile_info_list): |
|
processed_relative_paths = [] |
|
for tile_info in tile_info_list: |
|
processed_tile_relative_dirpath = os.path.join(tile_info['city'], f"{tile_info['number']:02d}") |
|
processed_tile_dirpath = os.path.join(self.processed_dirpath, processed_tile_relative_dirpath) |
|
sample_filenames = fnmatch.filter(os.listdir(processed_tile_dirpath), "data.*.pt") |
|
processed_tile_relative_paths = [os.path.join(processed_tile_relative_dirpath, sample_filename) for sample_filename |
|
in sample_filenames] |
|
processed_relative_paths.extend(processed_tile_relative_paths) |
|
return sorted(processed_relative_paths) |
|
|
|
def process(self, tile_info_list): |
|
|
|
with multiprocess.Pool(self.pool_size) as p: |
|
stats_all = list( |
|
tqdm(p.imap(self._process_one, tile_info_list), total=len(tile_info_list), desc="Process")) |
|
|
|
stats = {} |
|
if not self.mask_only: |
|
stats_all = list(filter(None.__ne__, stats_all)) |
|
stat_lists = {} |
|
for stats_one in stats_all: |
|
for key, stat in stats_one.items(): |
|
if key in stat_lists: |
|
stat_lists[key].append(stat) |
|
else: |
|
stat_lists[key] = [stat] |
|
|
|
|
|
if "class_freq" in stat_lists and "num" in stat_lists: |
|
class_freq_array = np.stack(stat_lists["class_freq"], axis=0) |
|
num_array = np.stack(stat_lists["num"], axis=0) |
|
if num_array.min() == 0: |
|
raise ZeroDivisionError("num_array has some zeros values, cannot divide!") |
|
stats["class_freq"] = np.sum(class_freq_array*num_array[:, None], axis=0) / np.sum(num_array) |
|
|
|
return stats |
|
|
|
def load_raw_data(self, tile_info): |
|
raw_data = {} |
|
|
|
|
|
raw_data["image_filepath"] = os.path.join(self.root, self.raw_dirname, self.fold, IMAGE_DIRNAME, |
|
IMAGE_FILENAME_FORMAT.format(city=tile_info["city"], number=tile_info["number"])) |
|
raw_data["image"] = skimage.io.imread(raw_data["image_filepath"]) |
|
assert len(raw_data["image"].shape) == 3 and raw_data["image"].shape[2] == 3, f"image should have shape (H, W, 3), not {raw_data['image'].shape}..." |
|
|
|
|
|
if self.gt_source == "disk": |
|
gt_base_filepath = os.path.join(self.root, self.raw_dirname, self.fold, self.gt_dirname, |
|
IMAGE_NAME_FORMAT.format(city=tile_info["city"], |
|
number=tile_info["number"])) |
|
gt_filepath = gt_base_filepath + "." + self.gt_type |
|
if not os.path.exists(gt_filepath): |
|
raw_data["gt_polygons"] = [] |
|
return raw_data |
|
if self.gt_type == "npy": |
|
np_gt_polygons = np.load(gt_filepath, allow_pickle=True) |
|
gt_polygons = [] |
|
for np_gt_polygon in np_gt_polygons: |
|
try: |
|
gt_polygons.append(shapely.geometry.Polygon(np_gt_polygon[:, ::-1])) |
|
except ValueError: |
|
|
|
continue |
|
raw_data["gt_polygons"] = gt_polygons |
|
elif self.gt_type == "geojson": |
|
geojson = python_utils.load_json(gt_filepath) |
|
raw_data["gt_polygons"] = list(shapely.geometry.shape(geojson)) |
|
elif self.gt_type == "tif": |
|
raw_data["gt_polygons_image"] = skimage.io.imread(gt_filepath)[:, :, None] |
|
assert len(raw_data["gt_polygons_image"].shape) == 3 and raw_data["gt_polygons_image"].shape[2] == 1, \ |
|
f"Mask should have shape (H, W, 1), not {raw_data['gt_polygons_image'].shape}..." |
|
elif self.gt_source == "osm": |
|
raise NotImplementedError( |
|
"Downloading from OSM is not implemented (takes too long to download, better download to disk first...).") |
|
|
|
|
|
return raw_data |
|
|
|
def _process_one(self, tile_info): |
|
process_id = int(multiprocess.current_process().name[-1]) |
|
|
|
|
|
|
|
tile_name = IMAGE_NAME_FORMAT.format(city=tile_info["city"], number=tile_info["number"]) |
|
processed_tile_relative_dirpath = os.path.join(tile_info['city'], f"{tile_info['number']:02d}") |
|
processed_tile_dirpath = os.path.join(self.processed_dirpath, processed_tile_relative_dirpath) |
|
processed_flag_filepath = os.path.join(processed_tile_dirpath, "processed_flag") |
|
stats_filepath = os.path.join(processed_tile_dirpath, "stats.pt") |
|
os.makedirs(processed_tile_dirpath, exist_ok=True) |
|
stats = {} |
|
|
|
|
|
if os.path.exists(processed_flag_filepath): |
|
if not self.mask_only: |
|
stats = torch.load(stats_filepath) |
|
return stats |
|
|
|
|
|
raw_data = self.load_raw_data(tile_info) |
|
|
|
|
|
if self.patch_size is not None: |
|
patch_stride = self.patch_stride if self.patch_stride is not None else self.patch_size |
|
patch_boundingboxes = image_utils.compute_patch_boundingboxes(raw_data["image"].shape[0:2], |
|
stride=patch_stride, |
|
patch_res=self.patch_size) |
|
class_freq_list = [] |
|
for i, bbox in enumerate(tqdm(patch_boundingboxes, desc=f"Patching {tile_name}", leave=False, position=process_id)): |
|
sample = { |
|
"image_filepath": raw_data["image_filepath"], |
|
"name": f"{tile_name}.rowmin_{bbox[0]}_colmin_{bbox[1]}_rowmax_{bbox[2]}_colmax_{bbox[3]}", |
|
"bbox": bbox, |
|
"city": tile_info["city"], |
|
"number": tile_info["number"], |
|
} |
|
|
|
if self.gt_type == "npy" or self.gt_type == "geojson": |
|
patch_gt_polygons = polygon_utils.patch_polygons(raw_data["gt_polygons"], minx=bbox[1], miny=bbox[0], |
|
maxx=bbox[3], maxy=bbox[2]) |
|
sample["gt_polygons"] = patch_gt_polygons |
|
elif self.gt_type == "tif": |
|
patch_gt_mask = raw_data["gt_polygons_image"][bbox[0]:bbox[2], bbox[1]:bbox[3], :] |
|
sample["gt_polygons_image"] = patch_gt_mask |
|
|
|
sample["image"] = raw_data["image"][bbox[0]:bbox[2], bbox[1]:bbox[3], :] |
|
|
|
sample = self.pre_transform(sample) |
|
if self.mask_only: |
|
del sample["image"] |
|
|
|
relative_filepath = os.path.join(processed_tile_relative_dirpath, "data.{:06d}.pt".format(i)) |
|
filepath = os.path.join(self.processed_dirpath, relative_filepath) |
|
torch.save(sample, filepath) |
|
|
|
|
|
if not self.mask_only: |
|
if self.gt_type == "npy" or self.gt_type == "geojson": |
|
class_freq_list.append(np.mean(sample["gt_polygons_image"], axis=(0, 1)) / 255) |
|
elif self.gt_type == "mask": |
|
raise NotImplementedError("mask class freq") |
|
else: |
|
raise NotImplementedError(f"gt_type={self.gt_type} not implemented for computing stats") |
|
|
|
|
|
if not self.mask_only: |
|
if len(class_freq_list): |
|
class_freq_array = np.stack(class_freq_list, axis=0) |
|
stats["class_freq"] = np.mean(class_freq_array, axis=0) |
|
stats["num"] = len(class_freq_list) |
|
else: |
|
print("Empty tile:", tile_info["city"], tile_info["number"], "polygons:", len(raw_data["gt_polygons"])) |
|
else: |
|
raise NotImplemented("patch_size is None") |
|
|
|
|
|
if not self.mask_only: |
|
torch.save(stats, stats_filepath) |
|
|
|
|
|
pathlib.Path(processed_flag_filepath).touch() |
|
|
|
return stats |
|
|
|
def __len__(self): |
|
if self.pre_process: |
|
return len(self.processed_relative_paths) |
|
else: |
|
return len(self.tile_info_list) |
|
|
|
def __getitem__(self, idx): |
|
if self.pre_process: |
|
filepath = os.path.join(self.processed_dirpath, self.processed_relative_paths[idx]) |
|
data = torch.load(filepath) |
|
if self.mask_only: |
|
data["image"] = np.repeat(data["gt_polygons_image"][:, :, 0:1], 3, axis=-1) |
|
data["image_mean"] = np.array([0.5, 0.5, 0.5]) |
|
data["image_std"] = np.array([1, 1, 1]) |
|
else: |
|
data["image_mean"] = np.array(CITY_METADATA_DICT[data["city"]]["mean"]) |
|
data["image_std"] = np.array(CITY_METADATA_DICT[data["city"]]["std"]) |
|
data["class_freq"] = self.stats["class_freq"] |
|
else: |
|
tile_info = self.tile_info_list[idx] |
|
|
|
data = self.load_raw_data(tile_info) |
|
data["name"] = IMAGE_NAME_FORMAT.format(city=tile_info["city"], number=tile_info["number"]) |
|
data["image_mean"] = np.array(tile_info["mean"]) |
|
data["image_std"] = np.array(tile_info["std"]) |
|
data = self.transform(data) |
|
return data |
|
|
|
|
|
def main(): |
|
|
|
from frame_field_learning import data_transforms |
|
|
|
config = { |
|
"data_dir_candidates": [ |
|
"/data/titane/user/nigirard/data", |
|
"~/data", |
|
"/data" |
|
], |
|
"dataset_params": { |
|
"root_dirname": "AerialImageDataset", |
|
"pre_process": False, |
|
"gt_source": "disk", |
|
"gt_type": "tif", |
|
"gt_dirname": "gt", |
|
"mask_only": False, |
|
"small": True, |
|
"data_patch_size": 425, |
|
"input_patch_size": 300, |
|
|
|
"train_fraction": 0.75 |
|
}, |
|
"num_workers": 8, |
|
"data_aug_params": { |
|
"enable": True, |
|
"vflip": True, |
|
"affine": True, |
|
"scaling": [0.9, 1.1], |
|
"color_jitter": True, |
|
"device": "cuda" |
|
} |
|
} |
|
|
|
|
|
data_dir = python_utils.choose_first_existing_path(config["data_dir_candidates"]) |
|
if data_dir is None: |
|
print_utils.print_error("ERROR: Data directory not found!") |
|
exit() |
|
else: |
|
print_utils.print_info("Using data from {}".format(data_dir)) |
|
root_dir = os.path.join(data_dir, config["dataset_params"]["root_dirname"]) |
|
|
|
|
|
|
|
|
|
online_cpu_transform = data_transforms.get_online_cpu_transform(config, |
|
augmentations=config["data_aug_params"]["enable"]) |
|
train_online_cuda_transform = data_transforms.get_online_cuda_transform(config, augmentations=config["data_aug_params"]["enable"]) |
|
mask_only = config["dataset_params"]["mask_only"] |
|
kwargs = { |
|
"pre_process": config["dataset_params"]["pre_process"], |
|
"transform": online_cpu_transform, |
|
"patch_size": config["dataset_params"]["data_patch_size"], |
|
"patch_stride": config["dataset_params"]["input_patch_size"], |
|
"pre_transform": data_transforms.get_offline_transform_patch(distances=not mask_only, sizes=not mask_only), |
|
"small": config["dataset_params"]["small"], |
|
"pool_size": config["num_workers"], |
|
"gt_source": config["dataset_params"]["gt_source"], |
|
"gt_type": config["dataset_params"]["gt_type"], |
|
"gt_dirname": config["dataset_params"]["gt_dirname"], |
|
"mask_only": config["dataset_params"]["mask_only"], |
|
} |
|
train_val_split_point = config["dataset_params"]["train_fraction"] * 36 |
|
def train_tile_filter(tile): return tile["number"] <= train_val_split_point |
|
def val_tile_filter(tile): return train_val_split_point < tile["number"] |
|
|
|
fold = "train" |
|
if fold == "train": |
|
dataset = InriaAerial(root_dir, fold="train", tile_filter=train_tile_filter, **kwargs) |
|
elif fold == "val": |
|
dataset = InriaAerial(root_dir, fold="train", tile_filter=val_tile_filter, **kwargs) |
|
elif fold == "test": |
|
dataset = InriaAerial(root_dir, fold="test", **kwargs) |
|
|
|
print(f"dataset has {len(dataset)} samples.") |
|
print("# --- Sample 0 --- #") |
|
sample = dataset[0] |
|
for key, item in sample.items(): |
|
print("{}: {}".format(key, type(item))) |
|
|
|
print("# --- Samples --- #") |
|
|
|
|
|
|
|
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=config["num_workers"]) |
|
print("# --- Batches --- #") |
|
for batch in tqdm(data_loader): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("----") |
|
print(batch["name"]) |
|
|
|
print("image:", batch["image"].shape, batch["image"].min().item(), batch["image"].max().item()) |
|
im = np.array(batch["image"][0]) |
|
im = np.moveaxis(im, 0, -1) |
|
skimage.io.imsave('im.png', im) |
|
|
|
if "gt_polygons_image" in batch: |
|
print("gt_polygons_image:", batch["gt_polygons_image"].shape, batch["gt_polygons_image"].min().item(), |
|
batch["gt_polygons_image"].max().item()) |
|
seg = np.array(batch["gt_polygons_image"][0]) / 255 |
|
seg = np.moveaxis(seg, 0, -1) |
|
seg_display = utils.get_seg_display(seg) |
|
seg_display = (seg_display * 255).astype(np.uint8) |
|
skimage.io.imsave("gt_seg.png", seg_display) |
|
|
|
if "gt_crossfield_angle" in batch: |
|
print("gt_crossfield_angle:", batch["gt_crossfield_angle"].shape, batch["gt_crossfield_angle"].min().item(), |
|
batch["gt_crossfield_angle"].max().item()) |
|
gt_crossfield_angle = np.array(batch["gt_crossfield_angle"][0]) |
|
gt_crossfield_angle = np.moveaxis(gt_crossfield_angle, 0, -1) |
|
skimage.io.imsave('gt_crossfield_angle.png', gt_crossfield_angle) |
|
|
|
if "distances" in batch: |
|
print("distances:", batch["distances"].shape, batch["distances"].min().item(), batch["distances"].max().item()) |
|
distances = np.array(batch["distances"][0]) |
|
distances = np.moveaxis(distances, 0, -1) |
|
skimage.io.imsave('distances.png', distances) |
|
|
|
if "sizes" in batch: |
|
print("sizes:", batch["sizes"].shape, batch["sizes"].min().item(), batch["sizes"].max().item()) |
|
sizes = np.array(batch["sizes"][0]) |
|
sizes = np.moveaxis(sizes, 0, -1) |
|
skimage.io.imsave('sizes.png', sizes) |
|
|
|
|
|
|
|
|
|
|
|
print("Apply online tranform:") |
|
batch = utils.batch_to_cuda(batch) |
|
batch = train_online_cuda_transform(batch) |
|
batch = utils.batch_to_cpu(batch) |
|
|
|
print("image:", batch["image"].shape, batch["image"].min().item(), batch["image"].max().item()) |
|
print("gt_polygons_image:", batch["gt_polygons_image"].shape, batch["gt_polygons_image"].min().item(), batch["gt_polygons_image"].max().item()) |
|
print("gt_crossfield_angle:", batch["gt_crossfield_angle"].shape, batch["gt_crossfield_angle"].min().item(), batch["gt_crossfield_angle"].max().item()) |
|
|
|
|
|
|
|
|
|
seg = np.array(batch["gt_polygons_image"][0]) |
|
seg = np.moveaxis(seg, 0, -1) |
|
seg_display = utils.get_seg_display(seg) |
|
seg_display = (seg_display * 255).astype(np.uint8) |
|
skimage.io.imsave("gt_seg.png", seg_display) |
|
|
|
im = np.array(batch["image"][0]) |
|
im = np.moveaxis(im, 0, -1) |
|
skimage.io.imsave('im.png', im) |
|
|
|
gt_crossfield_angle = np.array(batch["gt_crossfield_angle"][0]) |
|
gt_crossfield_angle = np.moveaxis(gt_crossfield_angle, 0, -1) |
|
skimage.io.imsave('gt_crossfield_angle.png', gt_crossfield_angle) |
|
|
|
distances = np.array(batch["distances"][0]) |
|
distances = np.moveaxis(distances, 0, -1) |
|
skimage.io.imsave('distances.png', distances) |
|
|
|
sizes = np.array(batch["sizes"][0]) |
|
sizes = np.moveaxis(sizes, 0, -1) |
|
skimage.io.imsave('sizes.png', sizes) |
|
|
|
|
|
|
|
|
|
|
|
input("Press enter to continue...") |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|