|
import gradio as gr |
|
import numpy as np |
|
import os |
|
from PIL import Image |
|
from math import ceil, floor |
|
from numpy import ndarray |
|
from typing import Callable, List |
|
import scipy.signal |
|
import onnxruntime as ort |
|
from tqdm import tqdm |
|
|
|
|
|
os.environ["GRADIO_TEMP_DIR"] = ".tmp" |
|
|
|
WINDOW_CACHE = dict() |
|
|
|
|
|
def _spline_window(window_size: int, power: int = 2) -> np.ndarray: |
|
"""Generates a 1-dimensional spline of order 'power' (typically 2), in the designated |
|
window. |
|
Args: |
|
window_size (int): size of the interested window |
|
power (int, optional): Order of the spline. Defaults to 2. |
|
Returns: |
|
np.ndarray: 1D spline |
|
""" |
|
intersection = int(window_size / 4) |
|
wind_outer = ( |
|
abs(2 * (scipy.signal.windows.triang(window_size))) ** power) / 2 |
|
wind_outer[intersection:-intersection] = 0 |
|
wind_inner = ( |
|
1 - (abs(2 * (scipy.signal.windows.triang(window_size) - 1)) ** power) / 2 |
|
) |
|
wind_inner[:intersection] = 0 |
|
wind_inner[-intersection:] = 0 |
|
wind = wind_inner + wind_outer |
|
wind = wind / np.average(wind) |
|
return wind |
|
|
|
|
|
def _spline_2d(window_size: int, power: int = 2) -> ndarray: |
|
"""Makes a 1D window spline function, then combines it to return a 2D window function. |
|
The 2D window is useful to smoothly interpolate between patches. |
|
Args: |
|
window_size (int): size of the window (patch) |
|
power (int, optional): Which order for the spline. Defaults to 2. |
|
Returns: |
|
np.ndarray: numpy array containing a 2D spline function |
|
""" |
|
|
|
|
|
wind = _spline_window(window_size, power) |
|
|
|
wind2 = wind[:, None] * wind[None, :] |
|
wind2 = wind2 / np.max(wind2) |
|
return wind2 |
|
|
|
|
|
def _spline_4d( |
|
window_size: int, |
|
power: int = 2, |
|
batch_size: int = 1, |
|
channels: int = 1 |
|
) -> ndarray: |
|
"""Makes a 4D window spline function |
|
Same as the 2D version, but repeated across all channels and batch""" |
|
global WINDOW_CACHE |
|
key = f"{window_size}_{power}" |
|
if key in WINDOW_CACHE: |
|
wind4 = WINDOW_CACHE[key] |
|
else: |
|
wind2 = _spline_2d(window_size, power) |
|
wind4 = wind2[None, None, :, :] * np.ones((batch_size, channels, 1, 1)) |
|
WINDOW_CACHE[key] = wind2 |
|
return wind4 |
|
|
|
|
|
def pad_image(image: np.array, tile_size: int, subdivisions: int) -> np.array: |
|
"""Add borders to the given image for a "valid" border pattern according to "window_size" and "subdivisions". |
|
Image is expected as a numpy array with shape (width, height, channels). |
|
Args: |
|
image (torch.Tensor): input image, 3D channels-last tensor |
|
tile_size (int): size of a single patch, useful to compute padding |
|
subdivisions (int): amount of overlap, useful for padding |
|
Returns: |
|
torch.Tensor: same image, padded specularly by a certain amount in every direction |
|
""" |
|
step = tile_size // subdivisions |
|
_, in_h, in_w = image.shape |
|
pad_h = step - (in_h % step) |
|
pad_w = step - (in_w % step) |
|
pad_h_l = pad_h // 2 |
|
pad_h_r = (pad_h // 2) + (pad_h % 2) |
|
pad_w_l = pad_w // 2 |
|
pad_w_r = (pad_w // 2) + (pad_w % 2) |
|
pad = int(round(tile_size * (1 - 1.0 / subdivisions))) |
|
image = np.pad( |
|
image, |
|
((0, 0), (pad + pad_h_l, pad + pad_h_r), (pad + pad_w_l, pad + pad_w_r)), |
|
mode="reflect", |
|
) |
|
return image, [pad + pad_h_l, pad + pad_h_r, pad + pad_w_l, pad + pad_w_r] |
|
|
|
|
|
def unpad_image(padded_image: ndarray, pads) -> ndarray: |
|
"""Reverts changes made by 'pad_image'. The same padding is removed, so tile_size and subdivisions |
|
must be coherent. |
|
|
|
Args: |
|
padded_image (torch.Tensor): image with padding still applied |
|
tile_size (int): size of a single patch |
|
subdivisions (int): subdivisions to compute overlap |
|
|
|
Returns: |
|
torch.Tensor: image without padding, 2D channels-last tensor |
|
""" |
|
pad_left, pad_right, pad_top, pad_bottom = pads |
|
|
|
|
|
n_dims = len(padded_image.shape) |
|
|
|
if n_dims == 2: |
|
result = padded_image[pad_left:-pad_right, pad_top:-pad_bottom] |
|
|
|
elif n_dims == 3: |
|
result = padded_image[:, pad_left:-pad_right, pad_top:-pad_bottom] |
|
else: |
|
raise ValueError( |
|
f"padded_image has {n_dims} dimensions, expected 2 or 3.") |
|
return result |
|
|
|
|
|
def windowed_generator( |
|
padded_image: ndarray, window_size: int, subdivisions: int, batch_size: int = None |
|
): |
|
"""Generator that yield tiles grouped by batch size. |
|
Args: |
|
padded_image (np.ndarray): input image to be processed (already padded), supposed channels-first |
|
window_size (int): size of a single patch |
|
subdivisions (int): subdivision count on each patch to compute the step |
|
batch_size (int, optional): amount of patches in each batch. Defaults to None. |
|
|
|
Yields: |
|
Tuple[List[tuple], np.ndarray]: list of coordinates and respective patches as single batch array |
|
""" |
|
step = window_size // subdivisions |
|
channel, width, height = padded_image.shape |
|
batch_size = batch_size or 1 |
|
batch = [] |
|
coords = [] |
|
for x in range(0, width - window_size + 1, step): |
|
for y in range(0, height - window_size + 1, step): |
|
coords.append((x, y)) |
|
|
|
tile = padded_image[:, x: x + window_size, y: y + window_size] |
|
batch.append(tile) |
|
|
|
if len(batch) == batch_size: |
|
yield coords, np.stack(batch) |
|
coords.clear() |
|
batch.clear() |
|
|
|
if len(batch) > 0: |
|
yield coords, np.stack(batch) |
|
|
|
|
|
def reconstruct( |
|
canvas: ndarray, tile_size: int, coords: List[tuple], predictions: ndarray |
|
) -> ndarray: |
|
"""Helper function that iterates the result batch onto the given canvas to reconstruct |
|
the final result batch after batch. |
|
Args: |
|
canvas (torch.Tensor): container for the final image. |
|
tile_size (int): size of a single patch. |
|
coords (List[tuple]): list of pixel coordinates corresponding to the batch items |
|
predictions (torch.Tensor): array containing patch predictions, shape (batch, tile_size, tile_size, num_classes) |
|
|
|
Returns: |
|
torch.Tensor: the updated canvas, shape (padded_w, padded_h, num_classes) |
|
""" |
|
for (x, y), patch in zip(coords, predictions): |
|
|
|
n_dims = len(canvas.shape) |
|
|
|
if n_dims == 2: |
|
canvas[x: x + tile_size, y: y + tile_size] += patch |
|
|
|
elif n_dims == 3: |
|
canvas[:, x: x + tile_size, y: y + tile_size] += patch |
|
else: |
|
raise ValueError( |
|
f"Canvas has {n_dims} dimensions, expected 2 or 3.") |
|
return canvas |
|
|
|
|
|
def predict_smooth_windowing( |
|
image: ndarray, |
|
tile_size: int, |
|
subdivisions: int, |
|
prediction_fn: Callable, |
|
batch_size: int = 1, |
|
out_dim: int = 1, |
|
) -> np.ndarray: |
|
"""Allows to predict a large image in one go, dividing it in squared, fixed-size tiles and smoothly |
|
interpolating over them to produce a single, coherent output with the same dimensions. |
|
Args: |
|
image (np.ndarray): input image, expected a 3D vector |
|
tile_size (int): size of each squared tile |
|
subdivisions (int): number of subdivisions over the single tile for overlaps |
|
prediction_fn (Callable): callback that takes the input batch and returns an output tensor |
|
batch_size (int, optional): size of each batch. Defaults to None. |
|
channels_first (int, optional): whether the input image is channels-first or not |
|
mirrored (bool, optional): whether to use dihedral predictions (every simmetry). Defaults to False. |
|
|
|
Returns: |
|
np.ndarray: numpy array with dimensions (w, h), containing smooth predictions |
|
""" |
|
img, pads = pad_image(image=image, tile_size=tile_size, |
|
subdivisions=subdivisions) |
|
spline = _spline_4d(window_size=tile_size, power=2) |
|
|
|
canvas = np.zeros((out_dim, img.shape[1], img.shape[2])) |
|
loop = tqdm(windowed_generator( |
|
padded_image=img, |
|
window_size=tile_size, |
|
subdivisions=subdivisions, |
|
batch_size=batch_size, |
|
)) |
|
for coords, batch in loop: |
|
pred_batch = prediction_fn(batch) |
|
|
|
pred_batch = pred_batch * spline |
|
canvas = reconstruct( |
|
canvas, tile_size=tile_size, coords=coords, predictions=pred_batch |
|
) |
|
prediction = unpad_image(canvas, pads=pads) |
|
return prediction |
|
|
|
|
|
def center_pad(x, padding, div_factor=32, mode="reflect"): |
|
|
|
|
|
|
|
|
|
|
|
size_x = x.shape[3] |
|
size_y = x.shape[2] |
|
|
|
min_padding_x = size_x + 2 * padding |
|
min_padding_y = size_y + 2 * padding |
|
|
|
new_size_x = int(ceil(min_padding_x / div_factor) * div_factor) |
|
new_size_y = int(ceil(min_padding_y / div_factor) * div_factor) |
|
|
|
pad_x = new_size_x - size_x |
|
pad_y = new_size_y - size_y |
|
pad_left = int(floor(pad_x / 2)) |
|
pad_right = int(ceil(pad_x / 2)) |
|
pad_top = int(floor(pad_y / 2)) |
|
pad_bottom = int(ceil(pad_y / 2)) |
|
if pad_x > size_x or pad_y > size_y: |
|
padded = np.pad( |
|
x, |
|
( |
|
(0, 0), |
|
(0, 0), |
|
(int(floor(size_x / 2)), int(ceil(size_x / 2))), |
|
(int(floor(size_y / 2)), int(ceil(size_y / 2))), |
|
), |
|
mode=mode, |
|
) |
|
|
|
padded = np.pad( |
|
x, |
|
( |
|
(0, 0), |
|
(0, 0), |
|
(int(floor(new_size_x / 2)), int(ceil(new_size_x / 2))), |
|
(int(floor(new_size_y / 2)), int(ceil(new_size_y / 2))), |
|
), |
|
mode=mode, |
|
) |
|
else: |
|
padded = np.pad( |
|
x, |
|
( |
|
(0, 0), |
|
(0, 0), |
|
(pad_top, pad_bottom), |
|
(pad_left, pad_right), |
|
), |
|
mode=mode, |
|
) |
|
paddings = (pad_top, pad_bottom, pad_left, pad_right) |
|
return padded, paddings |
|
|
|
|
|
class ChangeDetectionModel: |
|
def __init__(self): |
|
path = "assets/models/change_detection.onnx" |
|
self.model = ort.InferenceSession(path) |
|
self.size = 256 |
|
self.subdivisions = 2 |
|
self.batch_size = 2 |
|
self.out_dim = 1 |
|
|
|
def forward(self, x): |
|
assert x.ndim == 3, "Expected 3D tensor" |
|
|
|
x = x/255 |
|
|
|
x = x.astype(np.float32) |
|
pred = predict_smooth_windowing( |
|
image=x, |
|
tile_size=self.size, |
|
subdivisions=self.subdivisions, |
|
prediction_fn=self.callback, |
|
batch_size=self.batch_size, |
|
out_dim=self.out_dim |
|
) |
|
|
|
pred = 1 / (1 + np.exp(-pred)) |
|
|
|
|
|
pred = pred * 3 |
|
pred = np.round(pred) |
|
return pred[0] |
|
|
|
def callback(self, x: ndarray) -> ndarray: |
|
|
|
out = self.model.run(None, {"input": x})[0] |
|
return out |
|
|
|
|
|
class LocalizationModel: |
|
def __init__(self): |
|
path = "assets/models/localization.onnx" |
|
self.model = ort.InferenceSession(path) |
|
self.size = 384 |
|
self.subdivisions = 2 |
|
self.batch_size = 2 |
|
self.out_dim = 3 |
|
|
|
def forward(self, x): |
|
assert x.ndim == 3, "Expected 3D tensor" |
|
|
|
x = x/255 |
|
|
|
x = x.astype(np.float32) |
|
pred = predict_smooth_windowing( |
|
image=x, |
|
tile_size=self.size, |
|
subdivisions=self.subdivisions, |
|
prediction_fn=self.callback, |
|
batch_size=self.batch_size, |
|
out_dim=self.out_dim |
|
) |
|
|
|
pred = np.argmax(pred, axis=0) |
|
return pred |
|
|
|
def callback(self, x: ndarray) -> ndarray: |
|
|
|
out = self.model.run(None, {"input": x})[0] |
|
return out |
|
|
|
|
|
def infer(image1, image2): |
|
assert isinstance(image1, Image.Image), "image1 is not a PIL Image" |
|
assert isinstance(image2, Image.Image), "image2 is not a PIL Image" |
|
localization_model = LocalizationModel() |
|
change_detection_model = ChangeDetectionModel() |
|
|
|
image1 = image1.resize(image2.size) |
|
|
|
image1 = image1.resize((image1.width // 2, image1.height // 2)) |
|
image2 = image2.resize((image2.width // 2, image2.height // 2)) |
|
|
|
image1 = np.array(image1) |
|
image2 = np.array(image2) |
|
|
|
image1_array = np.transpose(image1, (2, 0, 1)) |
|
image2_array = np.transpose(image2, (2, 0, 1)) |
|
output_image1 = localization_model.forward(image1_array) |
|
|
|
cat_image_array = np.concatenate([image1_array, image2_array], axis=0) |
|
output_image2 = change_detection_model.forward(cat_image_array) |
|
output_image1_color = np.zeros( |
|
(output_image1.shape[0], output_image1.shape[1], 3)) |
|
|
|
output_image1_color[output_image1 == 0] = [0, 0, 0] |
|
output_image1_color[output_image1 == 1] = [150, 150, 150] |
|
output_image1_color[output_image1 == 2] = [200, 0, 0] |
|
|
|
output_image1_color = (output_image1_color*0.5 + image1*0.5) |
|
output_image1 = Image.fromarray(output_image1_color.astype(np.uint8)) |
|
output_image2_color = np.zeros( |
|
(output_image2.shape[0], output_image2.shape[1], 3)) |
|
output_image2_color[output_image2 == 0] = [0, 0, 0] |
|
output_image2_color[output_image2 == 1] = [0, 255, 0] |
|
output_image2_color[output_image2 == 2] = [255, 255, 0] |
|
output_image2_color[output_image2 == 3] = [255, 0, 0] |
|
output_image2_color = output_image2_color*0.5 + image2*0.5 |
|
output_image2 = Image.fromarray(output_image2_color.astype(np.uint8)) |
|
return output_image1, output_image2 |
|
|
|
|
|
|
|
sample_images = [ |
|
["assets/data/bata_1_pre.png", "assets/data/bata_1_post.png"], |
|
["assets/data/bata_2_pre.png", "assets/data/bata_2_post.png"], |
|
["assets/data/beirut_1_pre.png", "assets/data/beirut_1_post.png"] |
|
] |
|
|
|
|
|
for pair in sample_images: |
|
for file in pair: |
|
assert os.path.exists(file), f"File not found: {file}" |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## Infrastructure Damage Assessment") |
|
|
|
gr.Markdown( |
|
"This is a demo for infrastructure damage assessment using satellite images.\ |
|
It contains two models: one for localization and the other for change detection. \ |
|
The localization model is used to segment the image into three classes: background (in black), road (in grey), and houses (in red). \ |
|
The change detection model is used to detect changes between two images.\ |
|
The output of the change detection model is colored as follows: no change (in black), minor change (in green), major change (in yellow), and destroyed (in red).\ |
|
The output of the localization model (on the left) is blended with the pre-disaster image to highlight the areas of interest.\ |
|
The output of the change detection model (on the right) is blended with the post-disaster image to highlight the changes.\ |
|
You can upload your own images or use the sample images provided below." |
|
) |
|
gr.Markdown( |
|
"Note: the models run at half resolution for faster inference, \ |
|
so the output images will be less accurate than the full-resolution models.\ |
|
It still takes a few minutes to run the inference, so please be patient." |
|
) |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
input_image1 = gr.Image(label="Pre-disaster Image", type="pil") |
|
with gr.Column(scale=1): |
|
input_image2 = gr.Image(label="Post-disaster Image", type="pil") |
|
with gr.Row(): |
|
output_image1 = gr.Image(label="Roads and buildings localization", type="pil") |
|
output_image2 = gr.Image(label="Change detection", type="pil") |
|
submit_button = gr.Button("Run Inference") |
|
examples = gr.Examples( |
|
examples=sample_images, |
|
inputs=[input_image1, input_image2] |
|
) |
|
submit_button.click( |
|
infer, |
|
inputs=[input_image1, input_image2], |
|
outputs=[output_image1, output_image2], |
|
) |
|
|
|
demo.launch() |
|
|