Luigi Piccinelli
init demo
1ea89dd
import os
import random
from copy import deepcopy
from math import ceil, exp, log, log2, log10, tanh
from typing import Dict, List, Tuple
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms.v2.functional as TF
from unik3d.utils.geometric import downsample
def euler_to_rotation_matrix(angles):
"""
Convert Euler angles to a 3x3 rotation matrix.
Args:
angles (torch.Tensor): Euler angles [roll, pitch, yaw].
Returns:
torch.Tensor: 3x3 rotation matrix.
"""
phi, theta, psi = angles
cos_phi, sin_phi = torch.cos(phi), torch.sin(phi)
cos_theta, sin_theta = torch.cos(theta), torch.sin(theta)
cos_psi, sin_psi = torch.cos(psi), torch.sin(psi)
# Rotation matrices
Rx = torch.tensor([[1, 0, 0], [0, cos_phi, -sin_phi], [0, sin_phi, cos_phi]])
Ry = torch.tensor(
[[cos_theta, 0, sin_theta], [0, 1, 0], [-sin_theta, 0, cos_theta]]
)
Rz = torch.tensor([[cos_psi, -sin_psi, 0], [sin_psi, cos_psi, 0], [0, 0, 1]])
return Rz @ Ry @ Rx
def compute_grid(H, W):
meshgrid = torch.meshgrid(torch.arange(W), torch.arange(H), indexing="xy")
id_coords = torch.stack(meshgrid, axis=0).to(torch.float32)
id_coords = id_coords.reshape(2, -1)
id_coords = torch.cat(
[id_coords, torch.ones(1, id_coords.shape[-1])], dim=0
) # 3 HW
id_coords = id_coords.unsqueeze(0)
return id_coords
def lexsort(keys):
sorted_indices = torch.arange(keys[0].size(0))
for key in reversed(keys):
_, sorted_indices = key[sorted_indices].sort()
return sorted_indices
def masked_bilinear_interpolation(input, mask, target_size):
B, C, H, W = input.shape
target_H, target_W = target_size
mask = mask.float()
# Generate a grid of coordinates in the target space
grid_y, grid_x = torch.meshgrid(
torch.linspace(0, H - 1, target_H), torch.linspace(0, W - 1, target_W)
)
grid_y = grid_y.to(input.device)
grid_x = grid_x.to(input.device)
# Calculate the floor and ceil of the grid coordinates to get the bounding box
x0 = torch.floor(grid_x).long().clamp(0, W - 1)
x1 = (x0 + 1).clamp(0, W - 1)
y0 = torch.floor(grid_y).long().clamp(0, H - 1)
y1 = (y0 + 1).clamp(0, H - 1)
# Gather depth values at the four corners
Ia = input[..., y0, x0]
Ib = input[..., y1, x0]
Ic = input[..., y0, x1]
Id = input[..., y1, x1]
# Gather corresponding mask values
ma = mask[..., y0, x0]
mb = mask[..., y1, x0]
mc = mask[..., y0, x1]
md = mask[..., y1, x1]
# Calculate the areas (weights) for bilinear interpolation
wa = (x1.float() - grid_x) * (y1.float() - grid_y)
wb = (x1.float() - grid_x) * (grid_y - y0.float())
wc = (grid_x - x0.float()) * (y1.float() - grid_y)
wd = (grid_x - x0.float()) * (grid_y - y0.float())
wa = wa.reshape(1, 1, target_H, target_W).repeat(B, C, 1, 1)
wb = wb.reshape(1, 1, target_H, target_W).repeat(B, C, 1, 1)
wc = wc.reshape(1, 1, target_H, target_W).repeat(B, C, 1, 1)
wd = wd.reshape(1, 1, target_H, target_W).repeat(B, C, 1, 1)
# Only consider valid points for interpolation
weights_sum = (wa * ma) + (wb * mb) + (wc * mc) + (wd * md)
weights_sum = torch.clamp(weights_sum, min=1e-5)
# Perform the interpolation
interpolated_depth = (
wa * Ia * ma + wb * Ib * mb + wc * Ic * mc + wd * Id * md
) / weights_sum
return interpolated_depth, (ma + mb + mc + md) > 0
def masked_nearest_interpolation(input, mask, target_size):
B, C, H, W = input.shape
target_H, target_W = target_size
mask = mask.float()
# Generate a grid of coordinates in the target space
grid_y, grid_x = torch.meshgrid(
torch.linspace(0, H - 1, target_H),
torch.linspace(0, W - 1, target_W),
indexing="ij",
)
grid_y = grid_y.to(input.device)
grid_x = grid_x.to(input.device)
# Calculate the floor and ceil of the grid coordinates to get the bounding box
x0 = torch.floor(grid_x).long().clamp(0, W - 1)
x1 = (x0 + 1).clamp(0, W - 1)
y0 = torch.floor(grid_y).long().clamp(0, H - 1)
y1 = (y0 + 1).clamp(0, H - 1)
# Gather depth values at the four corners
Ia = input[..., y0, x0]
Ib = input[..., y1, x0]
Ic = input[..., y0, x1]
Id = input[..., y1, x1]
# Gather corresponding mask values
ma = mask[..., y0, x0]
mb = mask[..., y1, x0]
mc = mask[..., y0, x1]
md = mask[..., y1, x1]
# Calculate distances to each neighbor
# The distances are calculated from the center (grid_x, grid_y) to each corner
dist_a = (grid_x - x0.float()) ** 2 + (grid_y - y0.float()) ** 2 # Top-left
dist_b = (grid_x - x0.float()) ** 2 + (grid_y - y1.float()) ** 2 # Bottom-left
dist_c = (grid_x - x1.float()) ** 2 + (grid_y - y0.float()) ** 2 # Top-right
dist_d = (grid_x - x1.float()) ** 2 + (grid_y - y1.float()) ** 2 # Bottom-right
# Stack the neighbors, their masks, and distances
stacked_values = torch.stack(
[Ia, Ib, Ic, Id], dim=-1
) # Shape: (B, C, target_H, target_W, 4)
stacked_masks = torch.stack(
[ma, mb, mc, md], dim=-1
) # Shape: (B, 1, target_H, target_W, 4)
stacked_distances = torch.stack(
[dist_a, dist_b, dist_c, dist_d], dim=-1
) # Shape: (target_H, target_W, 4)
stacked_distances = (
stacked_distances.unsqueeze(0).unsqueeze(1).repeat(B, 1, 1, 1, 1)
) # Shape: (B, 1, target_H, target_W, 4)
# Set distances to infinity for invalid neighbors (so that invalid neighbors are never chosen)
stacked_distances[stacked_masks == 0] = float("inf")
# Find the index of the nearest valid neighbor (the one with the smallest distance)
nearest_indices = stacked_distances.argmin(dim=-1, keepdim=True)[
..., :1
] # Shape: (B, 1, target_H, target_W, 1)
# Select the corresponding depth value using the nearest valid neighbor index
interpolated_depth = torch.gather(
stacked_values, dim=-1, index=nearest_indices.repeat(1, C, 1, 1, 1)
).squeeze(-1)
# Set depth to zero where no valid neighbors were found
interpolated_depth = interpolated_depth * stacked_masks.sum(dim=-1).clip(
min=0.0, max=1.0
)
return interpolated_depth
def masked_nxn_interpolation(input, mask, target_size, N=2):
B, C, H, W = input.shape
target_H, target_W = target_size
# Generate a grid of coordinates in the target space
grid_y, grid_x = torch.meshgrid(
torch.linspace(0, H - 1, target_H),
torch.linspace(0, W - 1, target_W),
indexing="ij",
)
grid_y = grid_y.to(input.device)
grid_x = grid_x.to(input.device)
# Calculate the top-left corner of the NxN grid
half_N = (N - 1) // 2
y0 = torch.floor(grid_y - half_N).long().clamp(0, H - 1)
x0 = torch.floor(grid_x - half_N).long().clamp(0, W - 1)
# Prepare to gather NxN neighborhoods
input_patches = []
mask_patches = []
weights = []
for i in range(N):
for j in range(N):
yi = (y0 + i).clamp(0, H - 1)
xi = (x0 + j).clamp(0, W - 1)
# Gather depth and mask values
input_patches.append(input[..., yi, xi])
mask_patches.append(mask[..., yi, xi])
# Compute bilinear weights
weight_y = 1 - torch.abs(grid_y - yi.float()) / N
weight_x = 1 - torch.abs(grid_x - xi.float()) / N
weight = (
(weight_y * weight_x)
.reshape(1, 1, target_H, target_W)
.repeat(B, C, 1, 1)
)
weights.append(weight)
input_patches = torch.stack(input_patches)
mask_patches = torch.stack(mask_patches)
weights = torch.stack(weights)
# Calculate weighted sum and normalize by the sum of weights
weighted_sum = (input_patches * mask_patches * weights).sum(dim=0)
weights_sum = (mask_patches * weights).sum(dim=0)
interpolated_tensor = weighted_sum / torch.clamp(weights_sum, min=1e-8)
if N != 2:
interpolated_tensor_2x2, mask_sum_2x2 = masked_bilinear_interpolation(
input, mask, target_size
)
interpolated_tensor = torch.where(
mask_sum_2x2, interpolated_tensor_2x2, interpolated_tensor
)
return interpolated_tensor
class PanoCrop:
def __init__(self, crop_v=0.15):
self.crop_v = crop_v
def _crop_data(self, results, crop_size):
offset_w, offset_h = crop_size
left, top, right, bottom = offset_w[0], offset_h[0], offset_w[1], offset_h[1]
H, W = results["image"].shape[-2:]
for key in results.get("image_fields", ["image"]):
img = results[key][..., top : H - bottom, left : W - right]
results[key] = img
results["image_shape"] = tuple(img.shape)
for key in results.get("gt_fields", []):
results[key] = results[key][..., top : H - bottom, left : W - right]
for key in results.get("mask_fields", []):
results[key] = results[key][..., top : H - bottom, left : W - right]
results["camera"] = results["camera"].crop(left, top, right, bottom)
return results
def __call__(self, results):
H, W = results["image"].shape[-2:]
crop_w = (0, 0)
crop_h = (int(H * self.crop_v), int(H * self.crop_v))
results = self._crop_data(results, (crop_w, crop_h))
return results
class PanoRoll:
def __init__(self, test_mode, roll=[-0.5, 0.5]):
self.roll = roll
self.test_mode = test_mode
def __call__(self, results):
if self.test_mode:
return results
W = results["image"].shape[-1]
roll = random.randint(int(W * self.roll[0]), int(W * self.roll[1]))
for key in results.get("image_fields", ["image"]):
img = results[key]
img = torch.roll(img, roll, dims=-1)
results[key] = img
for key in results.get("gt_fields", []):
results[key] = torch.roll(results[key], roll, dims=-1)
for key in results.get("mask_fields", []):
results[key] = torch.roll(results[key], roll, dims=-1)
return results
class RandomFlip:
def __init__(self, direction="horizontal", prob=0.5, consistent=False, **kwargs):
self.flip_ratio = prob
valid_directions = ["horizontal", "vertical", "diagonal"]
if isinstance(direction, str):
assert direction in valid_directions
elif isinstance(direction, list):
assert set(direction).issubset(set(valid_directions))
else:
raise ValueError("direction must be either str or list of str")
self.direction = direction
self.consistent = consistent
def __call__(self, results):
if "flip" not in results:
# None means non-flip
if isinstance(self.direction, list):
direction_list = self.direction + [None]
else:
direction_list = [self.direction, None]
if isinstance(self.flip_ratio, list):
non_flip_ratio = 1 - sum(self.flip_ratio)
flip_ratio_list = self.flip_ratio + [non_flip_ratio]
else:
non_flip_ratio = 1 - self.flip_ratio
# exclude non-flip
single_ratio = self.flip_ratio / (len(direction_list) - 1)
flip_ratio_list = [single_ratio] * (len(direction_list) - 1) + [
non_flip_ratio
]
cur_dir = np.random.choice(direction_list, p=flip_ratio_list)
results["flip"] = cur_dir is not None
if "flip_direction" not in results:
results["flip_direction"] = cur_dir
if results["flip"]:
# flip image
if results["flip_direction"] != "vertical":
for key in results.get("image_fields", ["image"]):
results[key] = TF.hflip(results[key])
for key in results.get("mask_fields", []):
results[key] = TF.hflip(results[key])
for key in results.get("gt_fields", []):
results[key] = TF.hflip(results[key])
if "flow" in key: # flip u direction
results[key][:, 0] = -results[key][:, 0]
H, W = results["image"].shape[-2:]
results["camera"] = results["camera"].flip(
H=H, W=W, direction="horizontal"
)
flip_transform = torch.tensor(
[[-1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]],
dtype=torch.float32,
).unsqueeze(0)
repeats = (results["cam2w"].shape[0],) + (1,) * (
results["cam2w"].ndim - 1
)
results["cam2w"] = flip_transform.repeat(*repeats) @ results["cam2w"]
if results["flip_direction"] != "horizontal":
for key in results.get("image_fields", ["image"]):
results[key] = TF.vflip(results[key])
for key in results.get("mask_fields", []):
results[key] = TF.vflip(results[key])
for key in results.get("gt_fields", []):
results[key] = TF.vflip(results[key])
results["K"][..., 1, 2] = (
results["image"].shape[-2] - results["K"][..., 1, 2]
)
results["flip"] = [results["flip"]] * len(results["image"])
return results
class Crop:
def __init__(
self,
crop_size,
crop_type="absolute",
crop_offset=(0, 0),
):
if crop_type not in [
"relative_range",
"relative",
"absolute",
"absolute_range",
]:
raise ValueError(f"Invalid crop_type {crop_type}.")
if crop_type in ["absolute", "absolute_range"]:
assert crop_size[0] > 0 and crop_size[1] > 0
assert isinstance(crop_size[0], int) and isinstance(crop_size[1], int)
else:
assert 0 < crop_size[0] <= 1 and 0 < crop_size[1] <= 1
self.crop_size = crop_size
self.crop_type = crop_type
self.offset_h, self.offset_w = (
crop_offset[: len(crop_offset) // 2],
crop_offset[len(crop_offset) // 2 :],
)
def _get_crop_size(self, image_shape):
h, w = image_shape
if self.crop_type == "absolute":
return (min(self.crop_size[0], h), min(self.crop_size[1], w))
elif self.crop_type == "absolute_range":
assert self.crop_size[0] <= self.crop_size[1]
crop_h = np.random.randint(
min(h, self.crop_size[0]), min(h, self.crop_size[1]) + 1
)
crop_w = np.random.randint(
min(w, self.crop_size[0]), min(w, self.crop_size[1]) + 1
)
return crop_h, crop_w
elif self.crop_type == "relative":
crop_h, crop_w = self.crop_size
return int(h * crop_h + 0.5), int(w * crop_w + 0.5)
elif self.crop_type == "relative_range":
crop_size = np.asarray(self.crop_size, dtype=np.float32)
crop_h, crop_w = crop_size + np.random.rand(2) * (1 - crop_size)
return int(h * crop_h + 0.5), int(w * crop_w + 0.5)
def _crop_data(self, results, crop_size):
assert crop_size[0] > 0 and crop_size[1] > 0
for key in results.get("image_fields", ["image"]):
img = results[key]
img = TF.crop(
img, self.offset_h[0], self.offset_w[0], crop_size[0], crop_size[1]
)
results[key] = img
results["image_shape"] = tuple(img.shape)
for key in results.get("gt_fields", []):
gt = results[key]
results[key] = TF.crop(
gt, self.offset_h[0], self.offset_w[0], crop_size[0], crop_size[1]
)
# crop semantic seg
for key in results.get("mask_fields", []):
mask = results[key]
results[key] = TF.crop(
mask, self.offset_h[0], self.offset_w[0], crop_size[0], crop_size[1]
)
results["K"][..., 0, 2] = results["K"][..., 0, 2] - self.offset_w[0]
results["K"][..., 1, 2] = results["K"][..., 1, 2] - self.offset_h[0]
return results
def __call__(self, results):
image_shape = results["image"].shape[-2:]
crop_size = self._get_crop_size(image_shape)
results = self._crop_data(results, crop_size)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f"(crop_size={self.crop_size}, "
repr_str += f"crop_type={self.crop_type}, "
return repr_str
class KittiCrop:
def __init__(self, crop_size):
self.crop_size = crop_size
def _crop_data(self, results, crop_size):
"""Function to randomly crop images, bounding boxes, masks, semantic
segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
crop_size (tuple): Expected absolute size after cropping, (h, w).
allow_negative_crop (bool): Whether to allow a crop that does not
contain any bbox area. Default to False.
Returns:
dict: Randomly cropped results, 'image_shape' key in result dict is
updated according to crop size.
"""
assert crop_size[0] > 0 and crop_size[1] > 0
for key in results.get("image_fields", ["image"]):
img = results[key]
h, w = img.shape[-2:]
offset_h, offset_w = int(h - self.crop_size[0]), int(
(w - self.crop_size[1]) / 2
)
# crop the image
img = TF.crop(img, offset_h, offset_w, crop_size[0], crop_size[1])
results[key] = img
results["image_shape"] = tuple(img.shape)
for key in results.get("gt_fields", []):
gt = results[key]
results[key] = TF.crop(gt, offset_h, offset_w, crop_size[0], crop_size[1])
# crop semantic seg
for key in results.get("mask_fields", []):
mask = results[key]
results[key] = TF.crop(mask, offset_h, offset_w, crop_size[0], crop_size[1])
results["camera"] = results["camera"].crop(offset_w, offset_h)
return results
def __call__(self, results):
"""Call function to randomly crop images, bounding boxes, masks,
semantic segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Randomly cropped results, 'image_shape' key in result dict is
updated according to crop size.
"""
results = self._crop_data(results, self.crop_size)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f"(crop_size={self.crop_size}, "
return repr_str
class RandomMasking:
def __init__(
self,
mask_ratio,
mask_patch=16,
prob=0.5,
warmup_steps=50000,
sampling="random",
curriculum=False,
):
self.mask_patch = mask_patch
self.prob = prob
self.mask_ratio = mask_ratio
self.warmup_steps = max(1, warmup_steps)
self.hard_bound = 1
self.idx = 0
self.curriculum = curriculum
self.sampling = sampling
self.low_bound = 0.0
self.up_bound = 0.0
def __call__(self, results):
B, _, H, W = results["image"].shape
device = results["image"].device
down_size = H // self.mask_patch, W // self.mask_patch
if np.random.random() > self.prob: # fill with dummy
return self._nop(results, down_size, device)
validity_mask = results["validity_mask"].float().reshape(B, -1, H, W)
validity_mask = F.interpolate(validity_mask, size=down_size).bool()
validity_mask = validity_mask.reshape(B, 1, *down_size)
is_random = self.is_warmup or results.get("guidance") is None
if not is_random:
guidance = F.interpolate(results["guidance"], size=(H, W), mode="bilinear")
results["guidance"] = -F.max_pool2d(
-guidance, kernel_size=self.mask_patch, stride=self.mask_patch
)
if is_random and self.sampling == "inverse":
sampling = self.inverse_sampling
elif is_random and self.sampling == "random":
sampling = self.random_sampling
else:
sampling = self.guided_sampling
mask_ratio = np.random.uniform(self.low_bound, self.up_bound)
for key in results.get("image_fields", ["image"]):
mask = sampling(results, mask_ratio, down_size, validity_mask, device)
results[key + "_mask"] = mask
return results
def _nop(self, results, down_size, device):
B = results["image"].shape[0]
for key in results.get("image_fields", ["image"]):
mask_blocks = torch.zeros(size=(B, 1, *down_size), device=device)
results[key + "_mask"] = mask_blocks
return results
def random_sampling(self, results, mask_ratio, down_size, validity_mask, device):
B = results["image"].shape[0]
prob_blocks = torch.rand(size=(B, 1, *down_size), device=device)
mask_blocks = torch.logical_and(prob_blocks < mask_ratio, validity_mask)
return mask_blocks
def inverse_sampling(self, results, mask_ratio, down_size, validity_mask, device):
# from PIL import Image
# from unik3d.utils import colorize
def area_sample(depth, fx, fy):
dtype = depth.dtype
B = depth.shape[0]
H, W = down_size
depth = downsample(depth, depth.shape[-2] // H)
depth[depth > 200] = 50 # set sky as if depth 50 meters
pixel_area3d = depth / torch.sqrt(fx * fy)
# Set invalid as -1 (no div problem) -> then clip to 0.0
pixel_area3d[depth == 0.0] = -1
prob_density = (1 / pixel_area3d).clamp(min=0.0).square()
prob_density = prob_density / prob_density.sum(
dim=(-1, -2), keepdim=True
).clamp(min=1e-5)
# Image.fromarray((prob_density[0] * 255 * 100).clamp(min=0.0, max=255.0).squeeze().cpu().byte().numpy()).save("prob_density.png")
# Sample locations based on prob_density
prob_density_flat = prob_density.view(B, -1)
# Get the avgerage valid locations, of those we mask self.mask_ratio
valid_locations = (prob_density_flat > 0).to(dtype).sum(dim=1)
masks = []
for i in range(B):
num_samples = int(valid_locations[i] * mask_ratio)
mask = torch.zeros_like(prob_density_flat[i])
# Sample indices
if num_samples > 0:
sampled_indices_flat = torch.multinomial(
prob_density_flat[i], num_samples, replacement=False
)
mask.scatter_(0, sampled_indices_flat, 1)
masks.append(mask)
return torch.stack(masks).bool().view(B, 1, H, W)
def random_sample(validity_mask):
prob_blocks = torch.rand(
size=(validity_mask.shape[0], 1, *down_size), device=device
)
mask = torch.logical_and(prob_blocks < mask_ratio, validity_mask)
return mask
fx = results["K"][..., 0, 0].view(-1, 1, 1, 1) / self.mask_patch
fy = results["K"][..., 1, 1].view(-1, 1, 1, 1) / self.mask_patch
valid = ~results["ssi"] & ~results["si"] & results["valid_camera"]
mask_blocks = torch.zeros_like(validity_mask)
if valid.any():
out = area_sample(results["depth"][valid], fx[valid], fy[valid])
mask_blocks[valid] = out
if (~valid).any():
mask_blocks[~valid] = random_sample(validity_mask[~valid])
# mask_blocks_ = (mask_blocks.float() * 255).squeeze(1).byte().cpu().numpy()
# Image.fromarray(mask_blocks_[0]).save("mask1.png")
# Image.fromarray(mask_blocks_[-1]).save("mask2.png")
# dd = results["depth"]
# Image.fromarray(colorize(dd[0].squeeze().cpu().numpy())).save("depth1_p.png")
# Image.fromarray(colorize(dd[-1].squeeze().cpu().numpy())).save("depth2_p.png")
# dd = downsample(dd, dd.shape[-2] // down_size[0])
# Image.fromarray(colorize(dd[0].squeeze().cpu().numpy())).save("depth1.png")
# Image.fromarray(colorize(dd[-1].squeeze().cpu().numpy())).save("depth2.png")
# raise ValueError
return mask_blocks
def guided_sampling(self, results, mask_ratio, down_size, validity_mask, device):
# get the lowest (based on guidance) "mask_ratio" quantile of the patches that are in validity mask
B = results["image"].shape[0]
guidance = results["guidance"]
mask_blocks = torch.zeros(size=(B, 1, *down_size), device=device)
for b in range(B):
low_bound = torch.quantile(
guidance[b][validity_mask[b]], max(0.0, self.hard_bound - mask_ratio)
)
up_bound = torch.quantile(
guidance[b][validity_mask[b]], min(1.0, self.hard_bound)
)
mask_blocks[b] = torch.logical_and(
guidance[b] < up_bound, guidance[b] > low_bound
)
mask_blocks = torch.logical_and(mask_blocks, validity_mask)
return mask_blocks
def step(self):
self.idx += 1
# schedule hard from 1.0 to self.mask_ratio
if self.curriculum:
step = max(0, self.idx / self.warmup_steps / 2 - 0.5)
self.hard_bound = 1 - (1 - self.mask_ratio) * tanh(step)
self.up_bound = self.mask_ratio * tanh(step)
self.low_bound = 0.1 * tanh(step)
@property
def is_warmup(self):
return self.idx < self.warmup_steps
class Resize:
def __init__(self, image_scale=None, image_shape=None, keep_original=False):
assert (image_scale is None) ^ (image_shape is None)
if isinstance(image_scale, (float, int)):
image_scale = (image_scale, image_scale)
if isinstance(image_shape, (float, int)):
image_shape = (int(image_shape), int(image_shape))
self.image_scale = image_scale
self.image_shape = image_shape
self.keep_original = keep_original
def _resize_img(self, results):
for key in results.get("image_fields", ["image"]):
img = TF.resize(
results[key],
results["resized_shape"],
interpolation=TF.InterpolationMode.BILINEAR,
antialias=True,
)
results[key] = img
def _resize_masks(self, results):
for key in results.get("mask_fields", []):
mask = TF.resize(
results[key],
results["resized_shape"],
interpolation=TF.InterpolationMode.NEAREST_EXACT,
antialias=True,
)
results[key] = mask
def _resize_gt(self, results):
for key in results.get("gt_fields", []):
gt = TF.resize(
results[key],
results["resized_shape"],
interpolation=TF.InterpolationMode.NEAREST_EXACT,
antialias=True,
)
results[key] = gt
def __call__(self, results):
h, w = results["image"].shape[-2:]
results["K_original"] = results["K"].clone()
if self.image_scale:
image_shape = (
int(h * self.image_scale[0] + 0.5),
int(w * self.image_scale[1] + 0.5),
)
image_scale = self.image_scale
elif self.image_shape:
image_scale = (self.image_shape[0] / h, self.image_shape[1] / w)
image_shape = self.image_shape
else:
raise ValueError(
f"In {self.__class__.__name__}: image_scale of image_shape must be set"
)
results["resized_shape"] = tuple(image_shape)
results["resize_factor"] = tuple(image_scale)
results["K"][..., 0, 2] = (results["K"][..., 0, 2] - 0.5) * image_scale[1] + 0.5
results["K"][..., 1, 2] = (results["K"][..., 1, 2] - 0.5) * image_scale[0] + 0.5
results["K"][..., 0, 0] = results["K"][..., 0, 0] * image_scale[1]
results["K"][..., 1, 1] = results["K"][..., 1, 1] * image_scale[0]
self._resize_img(results)
if not self.keep_original:
self._resize_masks(results)
self._resize_gt(results)
return results
def __repr__(self):
repr_str = self.__class__.__name__
return repr_str
class Rotate:
def __init__(
self, angle, center=None, img_fill_val=(123.68, 116.28, 103.53), prob=0.5
):
if isinstance(img_fill_val, (float, int)):
img_fill_val = tuple([float(img_fill_val)] * 3)
elif isinstance(img_fill_val, tuple):
assert len(img_fill_val) == 3, (
"image_fill_val as tuple must "
f"have 3 elements. got {len(img_fill_val)}."
)
img_fill_val = tuple([float(val) for val in img_fill_val])
else:
raise ValueError("image_fill_val must be float or tuple with 3 elements.")
assert np.all(
[0 <= val <= 255 for val in img_fill_val]
), f"all elements of img_fill_val should between range [0,255] got {img_fill_val}."
assert 0 <= prob <= 1.0, f"The probability should be in range [0,1]bgot {prob}."
self.center = center
self.img_fill_val = img_fill_val
self.prob = prob
self.random = not isinstance(angle, (float, int))
self.angle = angle
def _rotate(self, results, angle, center=None, fill_val=0.0):
for key in results.get("image_fields", ["image"]):
img = results[key]
img_rotated = TF.rotate(
img,
angle,
center=center,
interpolation=TF.InterpolationMode.NEAREST_EXACT,
fill=self.img_fill_val,
)
results[key] = img_rotated.to(img.dtype)
results["image_shape"] = results[key].shape
for key in results.get("mask_fields", []):
results[key] = TF.rotate(
results[key],
angle,
center=center,
interpolation=TF.InterpolationMode.NEAREST_EXACT,
fill=fill_val,
)
for key in results.get("gt_fields", []):
results[key] = TF.rotate(
results[key],
angle,
center=center,
interpolation=TF.InterpolationMode.NEAREST_EXACT,
fill=fill_val,
)
def __call__(self, results):
"""Call function to rotate images, bounding boxes, masks and semantic
segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Rotated results.
"""
if np.random.random() > self.prob:
return results
angle = (
(self.angle[1] - self.angle[0]) * np.random.rand() + self.angle[0]
if self.random
else np.random.choice([-1, 1], size=1) * self.angle
)
self._rotate(results, angle, None, fill_val=0.0)
results["rotation"] = angle
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f"(angle={self.angle}, "
repr_str += f"center={self.center}, "
repr_str += f"image_fill_val={self.img_fill_val}, "
repr_str += f"prob={self.prob}, "
return repr_str
class RandomColor:
"""Apply Color transformation to image. The bboxes, masks, and
segmentations are not modified.
Args:
level (int | float): Should be in range [0,_MAX_LEVEL].
prob (float): The probability for performing Color transformation.
"""
def __init__(self, level, prob=0.5):
self.random = not isinstance(level, (float, int))
self.level = level
self.prob = prob
def _adjust_color_img(self, results, factor=1.0):
"""Apply Color transformation to image."""
for key in results.get("image_fields", ["image"]):
results[key] = TF.adjust_hue(results[key], factor) # .to(img.dtype)
def __call__(self, results):
"""Call function for Color transformation.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Colored results.
"""
if np.random.random() > self.prob:
return results
factor = (
((self.level[1] - self.level[0]) * np.random.rand() + self.level[0])
if self.random
else self.level
)
self._adjust_color_img(results, factor)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f"(level={self.level}, "
repr_str += f"prob={self.prob})"
return repr_str
class RandomSaturation:
"""Apply Color transformation to image. The bboxes, masks, and
segmentations are not modified.
Args:
level (int | float): Should be in range [0,_MAX_LEVEL].
prob (float): The probability for performing Color transformation.
"""
def __init__(self, level, prob=0.5):
self.random = not isinstance(level, (float, int))
self.level = level
self.prob = prob
def _adjust_saturation_img(self, results, factor=1.0):
"""Apply Color transformation to image."""
for key in results.get("image_fields", ["image"]):
# NOTE defaultly the image should be BGR format
results[key] = TF.adjust_saturation(results[key], factor) # .to(img.dtype)
def __call__(self, results):
"""Call function for Color transformation.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Colored results.
"""
if np.random.random() > self.prob:
return results
factor = (
2 ** ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0])
if self.random
else 2**self.level
)
self._adjust_saturation_img(results, factor)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f"(level={self.level}, "
repr_str += f"prob={self.prob})"
return repr_str
class RandomSharpness:
"""Apply Color transformation to image. The bboxes, masks, and
segmentations are not modified.
Args:
level (int | float): Should be in range [0,_MAX_LEVEL].
prob (float): The probability for performing Color transformation.
"""
def __init__(self, level, prob=0.5):
self.random = not isinstance(level, (float, int))
self.level = level
self.prob = prob
def _adjust_sharpeness_img(self, results, factor=1.0):
"""Apply Color transformation to image."""
for key in results.get("image_fields", ["image"]):
# NOTE defaultly the image should be BGR format
results[key] = TF.adjust_sharpness(results[key], factor) # .to(img.dtype)
def __call__(self, results):
"""Call function for Color transformation.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Colored results.
"""
if np.random.random() > self.prob:
return results
factor = (
2 ** ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0])
if self.random
else 2**self.level
)
self._adjust_sharpeness_img(results, factor)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f"(level={self.level}, "
repr_str += f"prob={self.prob})"
return repr_str
class RandomSolarize:
"""Apply Color transformation to image. The bboxes, masks, and
segmentations are not modified.
Args:
level (int | float): Should be in range [0,_MAX_LEVEL].
prob (float): The probability for performing Color transformation.
"""
def __init__(self, level, prob=0.5):
self.random = not isinstance(level, (float, int))
self.level = level
self.prob = prob
def _adjust_solarize_img(self, results, factor=255.0):
"""Apply Color transformation to image."""
for key in results.get("image_fields", ["image"]):
results[key] = TF.solarize(results[key], factor) # .to(img.dtype)
def __call__(self, results):
"""Call function for Color transformation.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Colored results.
"""
if np.random.random() > self.prob:
return results
factor = (
((self.level[1] - self.level[0]) * np.random.rand() + self.level[0])
if self.random
else self.level
)
self._adjust_solarize_img(results, factor)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f"(level={self.level}, "
repr_str += f"prob={self.prob})"
return repr_str
class RandomPosterize:
"""Apply Color transformation to image. The bboxes, masks, and
segmentations are not modified.
Args:
level (int | float): Should be in range [0,_MAX_LEVEL].
prob (float): The probability for performing Color transformation.
"""
def __init__(self, level, prob=0.5):
self.random = not isinstance(level, (float, int))
self.level = level
self.prob = prob
def _posterize_img(self, results, factor=1.0):
"""Apply Color transformation to image."""
for key in results.get("image_fields", ["image"]):
results[key] = TF.posterize(results[key], int(factor)) # .to(img.dtype)
def __call__(self, results):
"""Call function for Color transformation.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Colored results.
"""
if np.random.random() > self.prob:
return results
factor = (
((self.level[1] - self.level[0]) * np.random.rand() + self.level[0])
if self.random
else self.level
)
self._posterize_img(results, factor)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f"(level={self.level}, "
repr_str += f"prob={self.prob})"
return repr_str
class RandomEqualize:
"""Apply Equalize transformation to image. The bboxes, masks and
segmentations are not modified.
Args:
prob (float): The probability for performing Equalize transformation.
"""
def __init__(self, prob=0.5):
assert 0 <= prob <= 1.0, "The probability should be in range [0,1]."
self.prob = prob
def _imequalize(self, results):
"""Equalizes the histogram of one image."""
for key in results.get("image_fields", ["image"]):
results[key] = TF.equalize(results[key]) # .to(img.dtype)
def __call__(self, results):
"""Call function for Equalize transformation.
Args:
results (dict): Results dict from loading pipeline.
Returns:
dict: Results after the transformation.
"""
if np.random.random() > self.prob:
return results
self._imequalize(results)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f"(prob={self.prob})"
class RandomBrightness:
"""Apply Brightness transformation to image. The bboxes, masks and
segmentations are not modified.
Args:
level (int | float): Should be in range [0,_MAX_LEVEL].
prob (float): The probability for performing Brightness transformation.
"""
def __init__(self, level, prob=0.5):
self.random = not isinstance(level, (float, int))
self.level = level
self.prob = prob
def _adjust_brightness_img(self, results, factor=1.0):
"""Adjust the brightness of image."""
for key in results.get("image_fields", ["image"]):
results[key] = TF.adjust_brightness(results[key], factor) # .to(img.dtype)
def __call__(self, results, level=None):
"""Call function for Brightness transformation.
Args:
results (dict): Results dict from loading pipeline.
Returns:
dict: Results after the transformation.
"""
if np.random.random() > self.prob:
return results
factor = (
2 ** ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0])
if self.random
else 2**self.level
)
self._adjust_brightness_img(results, factor)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f"(level={self.level}, "
repr_str += f"prob={self.prob})"
return repr_str
class RandomContrast:
"""Apply Contrast transformation to image. The bboxes, masks and
segmentations are not modified.
Args:
level (int | float): Should be in range [0,_MAX_LEVEL].
prob (float): The probability for performing Contrast transformation.
"""
def __init__(self, level, prob=0.5):
self.random = not isinstance(level, (float, int))
self.level = level
self.prob = prob
def _adjust_contrast_img(self, results, factor=1.0):
"""Adjust the image contrast."""
for key in results.get("image_fields", ["image"]):
results[key] = TF.adjust_contrast(results[key], factor) # .to(img.dtype)
def __call__(self, results, level=None):
"""Call function for Contrast transformation.
Args:
results (dict): Results dict from loading pipeline.
Returns:
dict: Results after the transformation.
"""
if np.random.random() > self.prob:
return results
factor = (
2 ** ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0])
if self.random
else 2**self.level
)
self._adjust_contrast_img(results, factor)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f"(level={self.level}, "
repr_str += f"prob={self.prob})"
return repr_str
class RandomGamma:
def __init__(self, level, prob=0.5):
self.random = not isinstance(level, (float, int))
self.level = level
self.prob = prob
def __call__(self, results, level=None):
"""Call function for Contrast transformation.
Args:
results (dict): Results dict from loading pipeline.
Returns:
dict: Results after the transformation.
"""
if np.random.random() > self.prob:
return results
factor = (self.level[1] - self.level[0]) * np.random.rand() + self.level[0]
for key in results.get("image_fields", ["image"]):
if "original" not in key:
results[key] = TF.adjust_gamma(results[key], 1 + factor)
return results
class RandomInvert:
def __init__(self, prob=0.5):
self.prob = prob
def __call__(self, results):
if np.random.random() > self.prob:
return results
for key in results.get("image_fields", ["image"]):
if "original" not in key:
results[key] = TF.invert(results[key]) # .to(img.dtype)
return results
class RandomAutoContrast:
def __init__(self, prob=0.5):
self.prob = prob
def _autocontrast_img(self, results):
for key in results.get("image_fields", ["image"]):
img = results[key]
results[key] = TF.autocontrast(img) # .to(img.dtype)
def __call__(self, results):
if np.random.random() > self.prob:
return results
self._autocontrast_img(results)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f"(level={self.level}, "
repr_str += f"prob={self.prob})"
return repr_str
class Dilation:
def __init__(self, origin, kernel, border_value=-1.0, iterations=1) -> None:
self.structured_element = torch.ones(size=kernel)
self.origin = origin
self.border_value = border_value
self.iterations = iterations
def dilate(self, image):
image_pad = F.pad(
image,
[
self.origin[0],
self.structured_element.shape[0] - self.origin[0] - 1,
self.origin[1],
self.structured_element.shape[1] - self.origin[1] - 1,
],
mode="constant",
value=self.border_value,
)
if image_pad.ndim < 4:
image_pad = image_pad.unsqueeze(0)
# Unfold the image to be able to perform operation on neighborhoods
image_unfold = F.unfold(image_pad, kernel_size=self.structured_element.shape)
# Flatten the structural element since its two dimensions have been flatten when unfolding
# structured_element_flatten = torch.flatten(self.structured_element).unsqueeze(0).unsqueeze(-1)
# Perform the greyscale operation; sum would be replaced by rest if you want erosion
# sums = image_unfold + structured_element_flatten
# Take maximum over the neighborhood
# since we use depth, we need to take the cloest point (perspectivity)
# thus the min. But min is for "unknown" (0), so put it to a large number
# than take min
mask = image_unfold < 1e-3 # if == 0, some pixels are not involved, why?
# Replace the zero elements with a large value, so they don't affect the minimum operation
image_unfold = image_unfold.masked_fill(mask, 1000.0)
# Calculate the minimum along the neighborhood axis
dilate_image = torch.min(image_unfold, dim=1).values
# Fill the masked values with 0 to propagate zero if all pixels are zero
dilate_image[mask.all(dim=1)] = 0
return torch.reshape(dilate_image, image.shape)
def __call__(self, results):
for key in results.get("gt_fields", []):
gt = results[key]
for _ in range(self.iterations):
gt[gt < 1e-4] = self.dilate(gt)[gt < 1e-4]
results[key] = gt
return results
class RandomShear(object):
def __init__(
self,
level,
prob=0.5,
direction="horizontal",
):
self.random = not isinstance(level, (float, int))
self.level = level
self.prob = prob
self.direction = direction
def _shear_img(self, results, magnitude):
for key in results.get("image_fields", ["image"]):
img_sheared = TF.affine(
results[key],
angle=0.0,
translate=[0.0, 0.0],
scale=1.0,
shear=magnitude,
interpolation=TF.InterpolationMode.BILINEAR,
fill=0.0,
)
results[key] = img_sheared
def _shear_masks(self, results, magnitude):
for key in results.get("mask_fields", []):
mask_sheared = TF.affine(
results[key],
angle=0.0,
translate=[0.0, 0.0],
scale=1.0,
shear=magnitude,
interpolation=TF.InterpolationMode.NEAREST_EXACT,
fill=0.0,
)
results[key] = mask_sheared
def _shear_gt(
self,
results,
magnitude,
):
for key in results.get("gt_fields", []):
mask_sheared = TF.affine(
results[key],
angle=0.0,
translate=[0.0, 0.0],
scale=1.0,
shear=magnitude,
interpolation=TF.InterpolationMode.NEAREST_EXACT,
fill=0.0,
)
results[key] = mask_sheared
def __call__(self, results):
if np.random.random() > self.prob:
return results
magnitude = (
((self.level[1] - self.level[0]) * np.random.rand() + self.level[0])
if self.random
else np.random.choice([-1, 1], size=1) * self.level
)
if self.direction == "horizontal":
magnitude = [magnitude, 0.0]
else:
magnitude = [0.0, magnitude]
self._shear_img(results, magnitude)
self._shear_masks(results, magnitude)
self._shear_gt(results, magnitude)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f"(level={self.level}, "
repr_str += f"img_fill_val={self.img_fill_val}, "
repr_str += f"seg_ignore_label={self.seg_ignore_label}, "
repr_str += f"prob={self.prob}, "
repr_str += f"direction={self.direction}, "
repr_str += f"max_shear_magnitude={self.max_shear_magnitude}, "
repr_str += f"random_negative_prob={self.random_negative_prob}, "
repr_str += f"interpolation={self.interpolation})"
return repr_str
class RandomTranslate(object):
def __init__(
self,
range,
prob=0.5,
direction="horizontal",
):
self.range = range
self.prob = prob
self.direction = direction
def _translate_img(self, results, magnitude):
for key in results.get("image_fields", ["image"]):
img_sheared = TF.affine(
results[key],
angle=0.0,
translate=magnitude,
scale=1.0,
shear=[0.0, 0.0],
interpolation=TF.InterpolationMode.BILINEAR,
fill=(123.68, 116.28, 103.53),
)
results[key] = img_sheared
def _translate_mask(self, results, magnitude):
for key in results.get("mask_fields", []):
mask_sheared = TF.affine(
results[key],
angle=0.0,
translate=magnitude,
scale=1.0,
shear=[0.0, 0.0],
interpolation=TF.InterpolationMode.NEAREST_EXACT,
fill=0.0,
)
results[key] = mask_sheared
def _translate_gt(
self,
results,
magnitude,
):
for key in results.get("gt_fields", []):
mask_sheared = TF.affine(
results[key],
angle=0.0,
translate=magnitude,
scale=1.0,
shear=[0.0, 0.0],
interpolation=TF.InterpolationMode.NEAREST_EXACT,
fill=0.0,
)
results[key] = mask_sheared
def __call__(self, results):
if np.random.random() > self.prob:
return results
magnitude = (self.range[1] - self.range[0]) * np.random.rand() + self.range[0]
if self.direction == "horizontal":
magnitude = [magnitude * results["image"].shape[1], 0]
else:
magnitude = [0, magnitude * results["image"].shape[0]]
self._translate_img(results, magnitude)
self._translate_mask(results, magnitude)
self._translate_gt(results, magnitude)
results["K"][..., 0, 2] = results["K"][..., 0, 2] + magnitude[0]
results["K"][..., 1, 2] = results["K"][..., 1, 2] + magnitude[1]
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f"(range={self.range}, "
repr_str += f"prob={self.prob}, "
repr_str += f"direction={self.direction}, "
return repr_str
class RandomCut(object):
def __init__(self, prob=0.5, direction="all"):
self.direction = direction
self.prob = prob
def _cut_img(self, results, coord, dim):
for key in results.get("image_fields", ["image"]):
img_sheared = torch.roll(
results[key], int(coord * results[key].shape[dim]), dims=dim
)
results[key] = img_sheared
def _cut_mask(self, results, coord, dim):
for key in results.get("mask_fields", []):
mask_sheared = torch.roll(
results[key], int(coord * results[key].shape[dim]), dims=dim
)
results[key] = mask_sheared
def _cut_gt(self, results, coord, dim):
for key in results.get("gt_fields", []):
gt_sheared = torch.roll(
results[key], int(coord * results[key].shape[dim]), dims=dim
)
results[key] = gt_sheared
def __call__(self, results):
if np.random.random() > self.prob:
return results
coord = 0.8 * random.random() + 0.1
if self.direction == "horizontal":
dim = -1
elif self.direction == "vertical":
dim = -2
else:
dim = -1 if random.random() < 0.5 else -2
self._cut_img(results, coord, dim)
self._cut_mask(results, coord, dim)
self._cut_gt(results, coord, dim)
return results
class DownsamplerGT(object):
def __init__(self, downsample_factor: int, min_depth: float = 0.01):
assert downsample_factor == round(
downsample_factor, 0
), f"Downsample factor needs to be an integer, got {downsample_factor}"
self.downsample_factor = downsample_factor
self.min_depth = min_depth
def _downsample_gt(self, results):
for key in deepcopy(results.get("gt_fields", [])):
gt = results[key]
N, H, W = gt.shape
gt = gt.view(
N,
H // self.downsample_factor,
self.downsample_factor,
W // self.downsample_factor,
self.downsample_factor,
1,
)
gt = gt.permute(0, 1, 3, 5, 2, 4)
gt = gt.view(-1, self.downsample_factor * self.downsample_factor)
gt_tmp = torch.where(gt == 0.0, 1e5 * torch.ones_like(gt), gt)
gt = torch.min(gt_tmp, dim=-1).values
gt = gt.view(N, H // self.downsample_factor, W // self.downsample_factor)
gt = torch.where(gt > 1000, torch.zeros_like(gt), gt)
results[f"{key}_downsample"] = gt
results["gt_fields"].append(f"{key}_downsample")
results["downsampled"] = True
return results
def __call__(self, results):
results = self._downsample_gt(results)
return results
class RandomColorJitter:
def __init__(self, level, prob=0.9):
self.level = level
self.prob = prob
self.list_transform = [
self._adjust_brightness_img,
# self._adjust_sharpness_img,
self._adjust_contrast_img,
self._adjust_saturation_img,
self._adjust_color_img,
]
def _adjust_contrast_img(self, results, factor=1.0):
for key in results.get("image_fields", ["image"]):
if "original" not in key:
img = results[key]
results[key] = TF.adjust_contrast(img, factor)
def _adjust_sharpness_img(self, results, factor=1.0):
for key in results.get("image_fields", ["image"]):
if "original" not in key:
img = results[key]
results[key] = TF.adjust_sharpness(img, factor)
def _adjust_brightness_img(self, results, factor=1.0):
for key in results.get("image_fields", ["image"]):
if "original" not in key:
img = results[key]
results[key] = TF.adjust_brightness(img, factor)
def _adjust_saturation_img(self, results, factor=1.0):
for key in results.get("image_fields", ["image"]):
if "original" not in key:
img = results[key]
results[key] = TF.adjust_saturation(img, factor / 2.0)
def _adjust_color_img(self, results, factor=1.0):
for key in results.get("image_fields", ["image"]):
if "original" not in key:
img = results[key]
results[key] = TF.adjust_hue(img, (factor - 1.0) / 4.0)
def __call__(self, results):
random.shuffle(self.list_transform)
for op in self.list_transform:
if np.random.random() < self.prob:
factor = 1.0 + (
(self.level[1] - self.level[0]) * np.random.random() + self.level[0]
)
op(results, factor)
return results
class RandomGrayscale:
def __init__(self, prob=0.1, num_output_channels=3):
super().__init__()
self.prob = prob
self.num_output_channels = num_output_channels
def __call__(self, results):
if np.random.random() > self.prob:
return results
for key in results.get("image_fields", ["image"]):
if "original" not in key:
results[key] = TF.rgb_to_grayscale(
results[key], num_output_channels=self.num_output_channels
)
return results
class ContextCrop(Resize):
def __init__(
self,
image_shape,
keep_original=False,
test_min_ctx=1.0,
train_ctx_range=[0.5, 1.5],
shape_constraints={},
):
super().__init__(image_shape=image_shape, keep_original=keep_original)
self.test_min_ctx = test_min_ctx
self.train_ctx_range = train_ctx_range
self.shape_mult = shape_constraints["shape_mult"]
self.sample = shape_constraints["sample"]
self.ratio_bounds = shape_constraints["ratio_bounds"]
pixels_min = shape_constraints["pixels_min"] / (
self.shape_mult * self.shape_mult
)
pixels_max = shape_constraints["pixels_max"] / (
self.shape_mult * self.shape_mult
)
self.pixels_bounds = (pixels_min, pixels_max)
self.keepGT = int(os.environ.get("keepGT", 0))
self.ctx = None
def _transform_img(self, results, shapes):
for key in results.get("image_fields", ["image"]):
img = self.crop(results[key], **shapes)
img = TF.resize(
img,
results["resized_shape"],
interpolation=TF.InterpolationMode.BICUBIC,
antialias=True,
)
results[key] = img
def _transform_masks(self, results, shapes):
for key in results.get("mask_fields", []):
mask = self.crop(results[key].float(), **shapes).byte()
if "flow" in key: # take pad/crop into flow resize
mask = TF.resize(
mask,
results["resized_shape"],
interpolation=TF.InterpolationMode.NEAREST_EXACT,
antialias=False,
)
else:
mask = masked_nearest_interpolation(
mask, mask > 0, results["resized_shape"]
)
results[key] = mask
def _transform_gt(self, results, shapes):
for key in results.get("gt_fields", []):
gt = self.crop(results[key], **shapes)
if not self.keepGT:
if "flow" in key: # take pad/crop into flow resize
gt = self._rescale_flow(gt, results)
gt = TF.resize(
gt,
results["resized_shape"],
interpolation=TF.InterpolationMode.NEAREST_EXACT,
antialias=False,
)
else:
gt = masked_nearest_interpolation(
gt, gt > 0, results["resized_shape"]
)
results[key] = gt
def _rescale_flow(self, gt, results):
h_new, w_new = gt.shape[-2:]
h_old, w_old = results["image_ori_shape"]
gt[:, 0] = gt[:, 0] * (w_old - 1) / (w_new - 1)
gt[:, 1] = gt[:, 1] * (h_old - 1) / (h_new - 1)
return gt
@staticmethod
def crop(img, height, width, top, left) -> torch.Tensor:
h, w = img.shape[-2:]
right = left + width
bottom = top + height
padding_ltrb = [
max(-left + min(0, right), 0),
max(-top + min(0, bottom), 0),
max(right - max(w, left), 0),
max(bottom - max(h, top), 0),
]
image_cropped = img[..., max(top, 0) : bottom, max(left, 0) : right]
return TF.pad(image_cropped, padding_ltrb)
def test_closest_shape(self, image_shape):
h, w = image_shape
input_ratio = w / h
if self.sample:
input_pixels = int(ceil(h / self.shape_mult * w / self.shape_mult))
pixels = max(
min(input_pixels, self.pixels_bounds[1]), self.pixels_bounds[0]
)
ratio = min(max(input_ratio, self.ratio_bounds[0]), self.ratio_bounds[1])
h = round((pixels / ratio) ** 0.5)
w = h * ratio
self.image_shape[0] = int(h) * self.shape_mult
self.image_shape[1] = int(w) * self.shape_mult
def _get_crop_shapes(self, image_shape, ctx=None):
h, w = image_shape
input_ratio = w / h
if self.keep_original:
self.test_closest_shape(image_shape)
ctx = 1.0
elif ctx is None:
ctx = float(
torch.empty(1)
.uniform_(self.train_ctx_range[0], self.train_ctx_range[1])
.item()
)
output_ratio = self.image_shape[1] / self.image_shape[0]
if output_ratio <= input_ratio: # out like 4:3 in like kitti
if (
ctx >= 1
): # fully in -> use just max_length with sqrt(ctx), here max is width
new_w = w * ctx**0.5
# sporge un po in una sola dim
# we know that in_width will stick out before in_height, partial overshoot (sporge)
# new_h > old_h via area -> new_h ** 2 * ratio_new = old_h ** 2 * ratio_old * ctx
elif output_ratio / input_ratio * ctx > 1:
new_w = w * ctx
else: # fully contained -> use area
new_w = w * (ctx * output_ratio / input_ratio) ** 0.5
new_h = new_w / output_ratio
else:
if ctx >= 1:
new_h = h * ctx**0.5
elif input_ratio / output_ratio * ctx > 1:
new_h = h * ctx
else:
new_h = h * (ctx * input_ratio / output_ratio) ** 0.5
new_w = new_h * output_ratio
return (int(ceil(new_h - 0.5)), int(ceil(new_w - 0.5))), ctx
# def sample_view(self, results):
# original_K = results["K"]
# original_image = results["image"]
# original_depth = results["depth"]
# original_validity_mask = results["validity_mask"].float()
# # sample angles and translation
# # sample translation:
# # 10 max of z
# x = np.random.normal(0, 0.05 / 2) * original_depth.max()
# y = np.random.normal(0, 0.05)
# z = np.random.normal(0, 0.05) * original_depth.max()
# fov = 2 * np.arctan(original_image.shape[-2] / 2 / results["K"][0, 0, 0])
# phi = np.random.normal(0, fov / 10)
# theta = np.random.normal(0, fov / 10)
# psi = np.random.normal(0, np.pi / 60)
# translation = torch.tensor([x, y, z]).unsqueeze(0)
# angles = torch.tensor([phi, theta, psi])
# angles = euler_to_rotation_matrix(angles)
# translation = translation @ angles # translation before rotation
# cam2w = torch.eye(4).unsqueeze(0)
# cam2w[..., :3, :3] = angles
# cam2w[..., :3, 3] = translation
# cam2cam = torch.inverse(cam2w)
# image_warped, depth_warped = forward_warping(original_image, original_depth, original_K, original_K, cam2cam=cam2cam)
# depth_warped[depth_warped > 0] = depth_warped[depth_warped > 0] - z
# validity_mask_warped = image_warped.sum(dim=1, keepdim=True) > 0.0
# results["K"] = results["K"].repeat(2, 1, 1)
# results["cam2w"] = torch.cat([torch.eye(4).unsqueeze(0), cam2w])
# results["image"] = torch.cat([original_image, image_warped])
# results["depth"] = torch.cat([original_depth, depth_warped])
# results["validity_mask"] = torch.cat([original_validity_mask, validity_mask_warped], dim=0)
# # results["cam2w"] = torch.cat([torch.eye(4).unsqueeze(0), torch.eye(4).unsqueeze(0)])
# # results["image"] = torch.cat([original_image, original_image])
# # results["depth"] = torch.cat([original_depth, original_depth])
# # results["validity_mask"] = torch.cat([original_validity_mask, original_validity_mask], dim=0)
# return results
def __call__(self, results):
h, w = results["image"].shape[-2:]
results["image_ori_shape"] = (h, w)
results["camera_fields"].add("camera_original")
results["camera_original"] = results["camera"].clone()
results.get("mask_fields", set()).add("validity_mask")
if "validity_mask" not in results:
results["validity_mask"] = torch.ones(
(results["image"].shape[0], 1, h, w),
dtype=torch.uint8,
device=results["image"].device,
)
n_iter = 1 if self.keep_original or not self.sample else 100
min_valid_area = 0.5
max_hfov, max_vfov = results["camera"].max_fov[0] # it is a 1-dim list
ctx = None
for ii in range(n_iter):
(height, width), ctx = self._get_crop_shapes((h, w), ctx=self.ctx or ctx)
margin_h = h - height
margin_w = w - width
# keep it centered in y direction
top = margin_h // 2
left = margin_w // 2
if not self.keep_original:
left = left + np.random.randint(
-self.shape_mult // 2, self.shape_mult // 2 + 1
)
top = top + np.random.randint(
-self.shape_mult // 2, self.shape_mult // 2 + 1
)
right = left + width
bottom = top + height
x_zoom = self.image_shape[0] / height
paddings = [
max(-left + min(0, right), 0),
max(bottom - max(h, top), 0),
max(right - max(w, left), 0),
max(-top + min(0, bottom), 0),
]
valid_area = (
h
* w
/ (h + paddings[1] + paddings[3])
/ (w + paddings[0] + paddings[2])
)
new_hfov, new_vfov = results["camera_original"].get_new_fov(
new_shape=(height, width), original_shape=(h, w)
)[0]
# if valid_area >= min_valid_area or getattr(self, "ctx", None) is not None:
# break
if (
valid_area >= min_valid_area
and new_hfov < max_hfov
and new_vfov < max_vfov
):
break
ctx = (
ctx * 0.96
) # if not enough valid area, try again with less ctx (more zoom)
# save ctx for next iteration of sequences?
self.ctx = ctx
results["resized_shape"] = self.image_shape
results["paddings"] = paddings # left ,top ,right, bottom
results["image_rescale"] = x_zoom
results["scale_factor"] = results.get("scale_factor", 1.0) * x_zoom
results["camera"] = results["camera"].crop(
left, top, right=w - right, bottom=h - bottom
)
results["camera"] = results["camera"].resize(x_zoom)
# print("XAM", results["camera"].params.squeeze(), results["camera"][0].params.squeeze(), results["camera_original"].params.squeeze(), results["camera_original"][0].params.squeeze())
shapes = dict(height=height, width=width, top=top, left=left)
self._transform_img(results, shapes)
if not self.keep_original:
self._transform_gt(results, shapes)
self._transform_masks(results, shapes)
else:
# only validity_mask (rgb's masks follows rgb transform) #FIXME
mask = results["validity_mask"].float()
mask = self.crop(mask, **shapes).byte()
mask = TF.resize(
mask,
results["resized_shape"],
interpolation=TF.InterpolationMode.NEAREST,
)
results["validity_mask"] = mask
# # print(ii, ctx, results["camera"].hfov[0] * 180 / np.pi, original_hfov * 180 / np.pi, results["camera"].vfov[0] * 180 / np.pi, original_vfov * 180 / np.pi, valid_area)
# from PIL import Image
# from unik3d.utils.visualization import colorize
# img1 = results["image"][0].permute(1,2,0).clip(0, 255.0).cpu().numpy()
# # img2 = results["image"][1].permute(1,2,0).clip(0, 255.0).cpu().numpy()
# Image.fromarray(img1.astype(np.uint8)).save("test_col1.png")
# # Image.fromarray(img2.astype(np.uint8)).save("test_col2.png")
# Image.fromarray(colorize(results["depth"][0].cpu().numpy().squeeze(), 0.0, 10.0)).save("test_dep1.png")
# # Image.fromarray(colorize(results["depth"][1].cpu().numpy().squeeze(), 0.0, 10.0)).save("test_dep2.png")
# raise ValueError
# keep original images before photo-augment
results["image_original"] = results["image"].clone()
results["image_fields"].add(
*[
field.replace("image", "image_original")
for field in results["image_fields"]
]
)
# repeat for batch resized shape and paddings
results["paddings"] = [results["paddings"]] * results["image"].shape[0]
results["resized_shape"] = [results["resized_shape"]] * results["image"].shape[
0
]
return results
class RandomFiller:
def __init__(self, test_mode, *args, **kwargs):
super().__init__()
self.test_mode = test_mode
def _transform(self, results):
def fill_noise(size, device):
return torch.normal(0, 2.0, size=size, device=device)
def fill_black(size, device):
return -4 * torch.ones(size, device=device, dtype=torch.float32)
def fill_white(size, device):
return 4 * torch.ones(size, device=device, dtype=torch.float32)
def fill_zero(size, device):
return torch.zeros(size, device=device, dtype=torch.float32)
B, C = results["image"].shape[:2]
mismatch = B // results["validity_mask"].shape[0]
if mismatch:
results["validity_mask"] = results["validity_mask"].repeat(
mismatch, 1, 1, 1
)
validity_mask = results["validity_mask"].repeat(1, C, 1, 1).bool()
filler_fn = np.random.choice([fill_noise, fill_black, fill_white, fill_zero])
if self.test_mode:
filler_fn = fill_zero
for key in results.get("image_fields", ["image"]):
results[key][~validity_mask] = filler_fn(
size=results[key][~validity_mask].shape, device=results[key].device
)
def __call__(self, results):
# generate mask for filler
if "validity_mask" not in results:
paddings = results.get("padding_size", [0] * 4)
height, width = results["image"].shape[-2:]
results.get("mask_fields", set()).add("validity_mask")
results["validity_mask"] = torch.zeros_like(results["image"][:, :1])
results["validity_mask"][
...,
paddings[1] : height - paddings[3],
paddings[0] : width - paddings[2],
] = 1.0
self._transform(results)
return results
class GaussianBlur:
def __init__(self, kernel_size, sigma=(0.1, 2.0), prob=0.9):
super().__init__()
self.kernel_size = kernel_size
self.sigma = sigma
self.prob = prob
self.padding = kernel_size // 2
def apply(self, x, kernel):
# Pad the input tensor
x = F.pad(
x, (self.padding, self.padding, self.padding, self.padding), mode="reflect"
)
# Apply the convolution with the Gaussian kernel
return F.conv2d(x, kernel, stride=1, padding=0, groups=x.size(1))
def _create_kernel(self, sigma):
# Create a 1D Gaussian kernel
kernel_1d = torch.exp(
-torch.arange(-self.padding, self.padding + 1) ** 2 / (2 * sigma**2)
)
kernel_1d = kernel_1d / kernel_1d.sum()
# Expand the kernel to 2D and match size of the input
kernel_2d = kernel_1d.unsqueeze(0) * kernel_1d.unsqueeze(1)
kernel_2d = kernel_2d.view(1, 1, self.kernel_size, self.kernel_size).expand(
3, 1, -1, -1
)
return kernel_2d
def __call__(self, results):
if np.random.random() > self.prob:
return results
sigma = (self.sigma[1] - self.sigma[0]) * np.random.rand() + self.sigma[0]
kernel = self._create_kernel(sigma)
for key in results.get("image_fields", ["image"]):
if "original" not in key:
results[key] = self.apply(results[key], kernel)
return results
class MotionBlur:
def __init__(self, kernel_size=(9, 9), angles=(-180, 180), prob=0.1):
super().__init__()
self.kernel_size = kernel_size
self.angles = angles
self.prob = prob
self.padding = kernel_size // 2
def _create_kernel(self, angle):
# Generate a 2D grid of coordinates
grid = torch.meshgrid(
torch.arange(self.kernel_size), torch.arange(self.kernel_size)
)
grid = torch.stack(grid).float() # Shape: (2, kernel_size, kernel_size)
# Calculate relative coordinates from the center
center = (self.kernel_size - 1) / 2.0
x_offset = grid[1] - center
y_offset = grid[0] - center
# Compute motion blur kernel
cos_theta = torch.cos(angle * torch.pi / 180.0)
sin_theta = torch.sin(angle * torch.pi / 180.0)
kernel = (1.0 / self.kernel_size) * (
1.0 - torch.abs(x_offset * cos_theta + y_offset * sin_theta)
)
# Expand kernel dimensions to match input image channels
kernel = kernel.unsqueeze(0).unsqueeze(0).expand(3, 1, -1, -1)
return kernel
def apply(self, image, kernel):
x = F.pad(
x, (self.padding, self.padding, self.padding, self.padding), mode="reflect"
)
# Apply convolution with the motion blur kernel
blurred_image = F.conv2d(image, kernel, stride=1, padding=0, groups=x.size(1))
return blurred_image
def __call__(self, results):
if np.random.random() > self.prob:
return results
angle = np.random.uniform(self.angles[0], self.angles[1])
kernel = self._create_kernel(angle)
for key in results.get("image_fields", ["image"]):
if "original" in key:
continue
results[key] = self.apply(results[key], kernel)
return results
class JPEGCompression:
def __init__(self, level=(10, 70), prob=0.1):
super().__init__()
self.level = level
self.prob = prob
def __call__(self, results):
if np.random.random() > self.prob:
return results
level = np.random.uniform(self.level[0], self.level[1])
for key in results.get("image_fields", ["image"]):
if "original" in key:
continue
results[key] = TF.jpeg(results[key], level)
return results
class Compose:
def __init__(self, transforms):
self.transforms = deepcopy(transforms)
def __call__(self, results):
for t in self.transforms:
results = t(results)
return results
def __setattr__(self, name: str, value) -> None:
super().__setattr__(name, value)
for t in self.transforms:
setattr(t, name, value)
def __repr__(self):
format_string = self.__class__.__name__ + "("
for t in self.transforms:
format_string += f"\n {t}"
format_string += "\n)"
return format_string
class DummyCrop(Resize):
def __init__(
self,
*args,
**kwargs,
):
# dummy image shape, not really used
super().__init__(image_shape=(512, 512))
def __call__(self, results):
h, w = results["image"].shape[-2:]
results["image_ori_shape"] = (h, w)
results["camera_fields"].add("camera_original")
results["camera_original"] = results["camera"].clone()
results.get("mask_fields", set()).add("validity_mask")
if "validity_mask" not in results:
results["validity_mask"] = torch.ones(
(results["image"].shape[0], 1, h, w),
dtype=torch.uint8,
device=results["image"].device,
)
self.ctx = 1.0
results["resized_shape"] = self.image_shape
results["paddings"] = [0, 0, 0, 0]
results["image_rescale"] = 1.0
results["scale_factor"] = results.get("scale_factor", 1.0) * 1.0
results["camera"] = results["camera"].crop(0, 0, right=w, bottom=h)
results["camera"] = results["camera"].resize(1)
# keep original images before photo-augment
results["image_original"] = results["image"].clone()
results["image_fields"].add(
*[
field.replace("image", "image_original")
for field in results["image_fields"]
]
)
# repeat for batch resized shape and paddings
results["paddings"] = [results["paddings"]] * results["image"].shape[0]
results["resized_shape"] = [results["resized_shape"]] * results["image"].shape[
0
]
return results