Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import random | |
from copy import deepcopy | |
from math import ceil, exp, log, log2, log10, tanh | |
from typing import Dict, List, Tuple | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
import torchvision.transforms.v2.functional as TF | |
from unik3d.utils.geometric import downsample | |
def euler_to_rotation_matrix(angles): | |
""" | |
Convert Euler angles to a 3x3 rotation matrix. | |
Args: | |
angles (torch.Tensor): Euler angles [roll, pitch, yaw]. | |
Returns: | |
torch.Tensor: 3x3 rotation matrix. | |
""" | |
phi, theta, psi = angles | |
cos_phi, sin_phi = torch.cos(phi), torch.sin(phi) | |
cos_theta, sin_theta = torch.cos(theta), torch.sin(theta) | |
cos_psi, sin_psi = torch.cos(psi), torch.sin(psi) | |
# Rotation matrices | |
Rx = torch.tensor([[1, 0, 0], [0, cos_phi, -sin_phi], [0, sin_phi, cos_phi]]) | |
Ry = torch.tensor( | |
[[cos_theta, 0, sin_theta], [0, 1, 0], [-sin_theta, 0, cos_theta]] | |
) | |
Rz = torch.tensor([[cos_psi, -sin_psi, 0], [sin_psi, cos_psi, 0], [0, 0, 1]]) | |
return Rz @ Ry @ Rx | |
def compute_grid(H, W): | |
meshgrid = torch.meshgrid(torch.arange(W), torch.arange(H), indexing="xy") | |
id_coords = torch.stack(meshgrid, axis=0).to(torch.float32) | |
id_coords = id_coords.reshape(2, -1) | |
id_coords = torch.cat( | |
[id_coords, torch.ones(1, id_coords.shape[-1])], dim=0 | |
) # 3 HW | |
id_coords = id_coords.unsqueeze(0) | |
return id_coords | |
def lexsort(keys): | |
sorted_indices = torch.arange(keys[0].size(0)) | |
for key in reversed(keys): | |
_, sorted_indices = key[sorted_indices].sort() | |
return sorted_indices | |
def masked_bilinear_interpolation(input, mask, target_size): | |
B, C, H, W = input.shape | |
target_H, target_W = target_size | |
mask = mask.float() | |
# Generate a grid of coordinates in the target space | |
grid_y, grid_x = torch.meshgrid( | |
torch.linspace(0, H - 1, target_H), torch.linspace(0, W - 1, target_W) | |
) | |
grid_y = grid_y.to(input.device) | |
grid_x = grid_x.to(input.device) | |
# Calculate the floor and ceil of the grid coordinates to get the bounding box | |
x0 = torch.floor(grid_x).long().clamp(0, W - 1) | |
x1 = (x0 + 1).clamp(0, W - 1) | |
y0 = torch.floor(grid_y).long().clamp(0, H - 1) | |
y1 = (y0 + 1).clamp(0, H - 1) | |
# Gather depth values at the four corners | |
Ia = input[..., y0, x0] | |
Ib = input[..., y1, x0] | |
Ic = input[..., y0, x1] | |
Id = input[..., y1, x1] | |
# Gather corresponding mask values | |
ma = mask[..., y0, x0] | |
mb = mask[..., y1, x0] | |
mc = mask[..., y0, x1] | |
md = mask[..., y1, x1] | |
# Calculate the areas (weights) for bilinear interpolation | |
wa = (x1.float() - grid_x) * (y1.float() - grid_y) | |
wb = (x1.float() - grid_x) * (grid_y - y0.float()) | |
wc = (grid_x - x0.float()) * (y1.float() - grid_y) | |
wd = (grid_x - x0.float()) * (grid_y - y0.float()) | |
wa = wa.reshape(1, 1, target_H, target_W).repeat(B, C, 1, 1) | |
wb = wb.reshape(1, 1, target_H, target_W).repeat(B, C, 1, 1) | |
wc = wc.reshape(1, 1, target_H, target_W).repeat(B, C, 1, 1) | |
wd = wd.reshape(1, 1, target_H, target_W).repeat(B, C, 1, 1) | |
# Only consider valid points for interpolation | |
weights_sum = (wa * ma) + (wb * mb) + (wc * mc) + (wd * md) | |
weights_sum = torch.clamp(weights_sum, min=1e-5) | |
# Perform the interpolation | |
interpolated_depth = ( | |
wa * Ia * ma + wb * Ib * mb + wc * Ic * mc + wd * Id * md | |
) / weights_sum | |
return interpolated_depth, (ma + mb + mc + md) > 0 | |
def masked_nearest_interpolation(input, mask, target_size): | |
B, C, H, W = input.shape | |
target_H, target_W = target_size | |
mask = mask.float() | |
# Generate a grid of coordinates in the target space | |
grid_y, grid_x = torch.meshgrid( | |
torch.linspace(0, H - 1, target_H), | |
torch.linspace(0, W - 1, target_W), | |
indexing="ij", | |
) | |
grid_y = grid_y.to(input.device) | |
grid_x = grid_x.to(input.device) | |
# Calculate the floor and ceil of the grid coordinates to get the bounding box | |
x0 = torch.floor(grid_x).long().clamp(0, W - 1) | |
x1 = (x0 + 1).clamp(0, W - 1) | |
y0 = torch.floor(grid_y).long().clamp(0, H - 1) | |
y1 = (y0 + 1).clamp(0, H - 1) | |
# Gather depth values at the four corners | |
Ia = input[..., y0, x0] | |
Ib = input[..., y1, x0] | |
Ic = input[..., y0, x1] | |
Id = input[..., y1, x1] | |
# Gather corresponding mask values | |
ma = mask[..., y0, x0] | |
mb = mask[..., y1, x0] | |
mc = mask[..., y0, x1] | |
md = mask[..., y1, x1] | |
# Calculate distances to each neighbor | |
# The distances are calculated from the center (grid_x, grid_y) to each corner | |
dist_a = (grid_x - x0.float()) ** 2 + (grid_y - y0.float()) ** 2 # Top-left | |
dist_b = (grid_x - x0.float()) ** 2 + (grid_y - y1.float()) ** 2 # Bottom-left | |
dist_c = (grid_x - x1.float()) ** 2 + (grid_y - y0.float()) ** 2 # Top-right | |
dist_d = (grid_x - x1.float()) ** 2 + (grid_y - y1.float()) ** 2 # Bottom-right | |
# Stack the neighbors, their masks, and distances | |
stacked_values = torch.stack( | |
[Ia, Ib, Ic, Id], dim=-1 | |
) # Shape: (B, C, target_H, target_W, 4) | |
stacked_masks = torch.stack( | |
[ma, mb, mc, md], dim=-1 | |
) # Shape: (B, 1, target_H, target_W, 4) | |
stacked_distances = torch.stack( | |
[dist_a, dist_b, dist_c, dist_d], dim=-1 | |
) # Shape: (target_H, target_W, 4) | |
stacked_distances = ( | |
stacked_distances.unsqueeze(0).unsqueeze(1).repeat(B, 1, 1, 1, 1) | |
) # Shape: (B, 1, target_H, target_W, 4) | |
# Set distances to infinity for invalid neighbors (so that invalid neighbors are never chosen) | |
stacked_distances[stacked_masks == 0] = float("inf") | |
# Find the index of the nearest valid neighbor (the one with the smallest distance) | |
nearest_indices = stacked_distances.argmin(dim=-1, keepdim=True)[ | |
..., :1 | |
] # Shape: (B, 1, target_H, target_W, 1) | |
# Select the corresponding depth value using the nearest valid neighbor index | |
interpolated_depth = torch.gather( | |
stacked_values, dim=-1, index=nearest_indices.repeat(1, C, 1, 1, 1) | |
).squeeze(-1) | |
# Set depth to zero where no valid neighbors were found | |
interpolated_depth = interpolated_depth * stacked_masks.sum(dim=-1).clip( | |
min=0.0, max=1.0 | |
) | |
return interpolated_depth | |
def masked_nxn_interpolation(input, mask, target_size, N=2): | |
B, C, H, W = input.shape | |
target_H, target_W = target_size | |
# Generate a grid of coordinates in the target space | |
grid_y, grid_x = torch.meshgrid( | |
torch.linspace(0, H - 1, target_H), | |
torch.linspace(0, W - 1, target_W), | |
indexing="ij", | |
) | |
grid_y = grid_y.to(input.device) | |
grid_x = grid_x.to(input.device) | |
# Calculate the top-left corner of the NxN grid | |
half_N = (N - 1) // 2 | |
y0 = torch.floor(grid_y - half_N).long().clamp(0, H - 1) | |
x0 = torch.floor(grid_x - half_N).long().clamp(0, W - 1) | |
# Prepare to gather NxN neighborhoods | |
input_patches = [] | |
mask_patches = [] | |
weights = [] | |
for i in range(N): | |
for j in range(N): | |
yi = (y0 + i).clamp(0, H - 1) | |
xi = (x0 + j).clamp(0, W - 1) | |
# Gather depth and mask values | |
input_patches.append(input[..., yi, xi]) | |
mask_patches.append(mask[..., yi, xi]) | |
# Compute bilinear weights | |
weight_y = 1 - torch.abs(grid_y - yi.float()) / N | |
weight_x = 1 - torch.abs(grid_x - xi.float()) / N | |
weight = ( | |
(weight_y * weight_x) | |
.reshape(1, 1, target_H, target_W) | |
.repeat(B, C, 1, 1) | |
) | |
weights.append(weight) | |
input_patches = torch.stack(input_patches) | |
mask_patches = torch.stack(mask_patches) | |
weights = torch.stack(weights) | |
# Calculate weighted sum and normalize by the sum of weights | |
weighted_sum = (input_patches * mask_patches * weights).sum(dim=0) | |
weights_sum = (mask_patches * weights).sum(dim=0) | |
interpolated_tensor = weighted_sum / torch.clamp(weights_sum, min=1e-8) | |
if N != 2: | |
interpolated_tensor_2x2, mask_sum_2x2 = masked_bilinear_interpolation( | |
input, mask, target_size | |
) | |
interpolated_tensor = torch.where( | |
mask_sum_2x2, interpolated_tensor_2x2, interpolated_tensor | |
) | |
return interpolated_tensor | |
class PanoCrop: | |
def __init__(self, crop_v=0.15): | |
self.crop_v = crop_v | |
def _crop_data(self, results, crop_size): | |
offset_w, offset_h = crop_size | |
left, top, right, bottom = offset_w[0], offset_h[0], offset_w[1], offset_h[1] | |
H, W = results["image"].shape[-2:] | |
for key in results.get("image_fields", ["image"]): | |
img = results[key][..., top : H - bottom, left : W - right] | |
results[key] = img | |
results["image_shape"] = tuple(img.shape) | |
for key in results.get("gt_fields", []): | |
results[key] = results[key][..., top : H - bottom, left : W - right] | |
for key in results.get("mask_fields", []): | |
results[key] = results[key][..., top : H - bottom, left : W - right] | |
results["camera"] = results["camera"].crop(left, top, right, bottom) | |
return results | |
def __call__(self, results): | |
H, W = results["image"].shape[-2:] | |
crop_w = (0, 0) | |
crop_h = (int(H * self.crop_v), int(H * self.crop_v)) | |
results = self._crop_data(results, (crop_w, crop_h)) | |
return results | |
class PanoRoll: | |
def __init__(self, test_mode, roll=[-0.5, 0.5]): | |
self.roll = roll | |
self.test_mode = test_mode | |
def __call__(self, results): | |
if self.test_mode: | |
return results | |
W = results["image"].shape[-1] | |
roll = random.randint(int(W * self.roll[0]), int(W * self.roll[1])) | |
for key in results.get("image_fields", ["image"]): | |
img = results[key] | |
img = torch.roll(img, roll, dims=-1) | |
results[key] = img | |
for key in results.get("gt_fields", []): | |
results[key] = torch.roll(results[key], roll, dims=-1) | |
for key in results.get("mask_fields", []): | |
results[key] = torch.roll(results[key], roll, dims=-1) | |
return results | |
class RandomFlip: | |
def __init__(self, direction="horizontal", prob=0.5, consistent=False, **kwargs): | |
self.flip_ratio = prob | |
valid_directions = ["horizontal", "vertical", "diagonal"] | |
if isinstance(direction, str): | |
assert direction in valid_directions | |
elif isinstance(direction, list): | |
assert set(direction).issubset(set(valid_directions)) | |
else: | |
raise ValueError("direction must be either str or list of str") | |
self.direction = direction | |
self.consistent = consistent | |
def __call__(self, results): | |
if "flip" not in results: | |
# None means non-flip | |
if isinstance(self.direction, list): | |
direction_list = self.direction + [None] | |
else: | |
direction_list = [self.direction, None] | |
if isinstance(self.flip_ratio, list): | |
non_flip_ratio = 1 - sum(self.flip_ratio) | |
flip_ratio_list = self.flip_ratio + [non_flip_ratio] | |
else: | |
non_flip_ratio = 1 - self.flip_ratio | |
# exclude non-flip | |
single_ratio = self.flip_ratio / (len(direction_list) - 1) | |
flip_ratio_list = [single_ratio] * (len(direction_list) - 1) + [ | |
non_flip_ratio | |
] | |
cur_dir = np.random.choice(direction_list, p=flip_ratio_list) | |
results["flip"] = cur_dir is not None | |
if "flip_direction" not in results: | |
results["flip_direction"] = cur_dir | |
if results["flip"]: | |
# flip image | |
if results["flip_direction"] != "vertical": | |
for key in results.get("image_fields", ["image"]): | |
results[key] = TF.hflip(results[key]) | |
for key in results.get("mask_fields", []): | |
results[key] = TF.hflip(results[key]) | |
for key in results.get("gt_fields", []): | |
results[key] = TF.hflip(results[key]) | |
if "flow" in key: # flip u direction | |
results[key][:, 0] = -results[key][:, 0] | |
H, W = results["image"].shape[-2:] | |
results["camera"] = results["camera"].flip( | |
H=H, W=W, direction="horizontal" | |
) | |
flip_transform = torch.tensor( | |
[[-1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], | |
dtype=torch.float32, | |
).unsqueeze(0) | |
repeats = (results["cam2w"].shape[0],) + (1,) * ( | |
results["cam2w"].ndim - 1 | |
) | |
results["cam2w"] = flip_transform.repeat(*repeats) @ results["cam2w"] | |
if results["flip_direction"] != "horizontal": | |
for key in results.get("image_fields", ["image"]): | |
results[key] = TF.vflip(results[key]) | |
for key in results.get("mask_fields", []): | |
results[key] = TF.vflip(results[key]) | |
for key in results.get("gt_fields", []): | |
results[key] = TF.vflip(results[key]) | |
results["K"][..., 1, 2] = ( | |
results["image"].shape[-2] - results["K"][..., 1, 2] | |
) | |
results["flip"] = [results["flip"]] * len(results["image"]) | |
return results | |
class Crop: | |
def __init__( | |
self, | |
crop_size, | |
crop_type="absolute", | |
crop_offset=(0, 0), | |
): | |
if crop_type not in [ | |
"relative_range", | |
"relative", | |
"absolute", | |
"absolute_range", | |
]: | |
raise ValueError(f"Invalid crop_type {crop_type}.") | |
if crop_type in ["absolute", "absolute_range"]: | |
assert crop_size[0] > 0 and crop_size[1] > 0 | |
assert isinstance(crop_size[0], int) and isinstance(crop_size[1], int) | |
else: | |
assert 0 < crop_size[0] <= 1 and 0 < crop_size[1] <= 1 | |
self.crop_size = crop_size | |
self.crop_type = crop_type | |
self.offset_h, self.offset_w = ( | |
crop_offset[: len(crop_offset) // 2], | |
crop_offset[len(crop_offset) // 2 :], | |
) | |
def _get_crop_size(self, image_shape): | |
h, w = image_shape | |
if self.crop_type == "absolute": | |
return (min(self.crop_size[0], h), min(self.crop_size[1], w)) | |
elif self.crop_type == "absolute_range": | |
assert self.crop_size[0] <= self.crop_size[1] | |
crop_h = np.random.randint( | |
min(h, self.crop_size[0]), min(h, self.crop_size[1]) + 1 | |
) | |
crop_w = np.random.randint( | |
min(w, self.crop_size[0]), min(w, self.crop_size[1]) + 1 | |
) | |
return crop_h, crop_w | |
elif self.crop_type == "relative": | |
crop_h, crop_w = self.crop_size | |
return int(h * crop_h + 0.5), int(w * crop_w + 0.5) | |
elif self.crop_type == "relative_range": | |
crop_size = np.asarray(self.crop_size, dtype=np.float32) | |
crop_h, crop_w = crop_size + np.random.rand(2) * (1 - crop_size) | |
return int(h * crop_h + 0.5), int(w * crop_w + 0.5) | |
def _crop_data(self, results, crop_size): | |
assert crop_size[0] > 0 and crop_size[1] > 0 | |
for key in results.get("image_fields", ["image"]): | |
img = results[key] | |
img = TF.crop( | |
img, self.offset_h[0], self.offset_w[0], crop_size[0], crop_size[1] | |
) | |
results[key] = img | |
results["image_shape"] = tuple(img.shape) | |
for key in results.get("gt_fields", []): | |
gt = results[key] | |
results[key] = TF.crop( | |
gt, self.offset_h[0], self.offset_w[0], crop_size[0], crop_size[1] | |
) | |
# crop semantic seg | |
for key in results.get("mask_fields", []): | |
mask = results[key] | |
results[key] = TF.crop( | |
mask, self.offset_h[0], self.offset_w[0], crop_size[0], crop_size[1] | |
) | |
results["K"][..., 0, 2] = results["K"][..., 0, 2] - self.offset_w[0] | |
results["K"][..., 1, 2] = results["K"][..., 1, 2] - self.offset_h[0] | |
return results | |
def __call__(self, results): | |
image_shape = results["image"].shape[-2:] | |
crop_size = self._get_crop_size(image_shape) | |
results = self._crop_data(results, crop_size) | |
return results | |
def __repr__(self): | |
repr_str = self.__class__.__name__ | |
repr_str += f"(crop_size={self.crop_size}, " | |
repr_str += f"crop_type={self.crop_type}, " | |
return repr_str | |
class KittiCrop: | |
def __init__(self, crop_size): | |
self.crop_size = crop_size | |
def _crop_data(self, results, crop_size): | |
"""Function to randomly crop images, bounding boxes, masks, semantic | |
segmentation maps. | |
Args: | |
results (dict): Result dict from loading pipeline. | |
crop_size (tuple): Expected absolute size after cropping, (h, w). | |
allow_negative_crop (bool): Whether to allow a crop that does not | |
contain any bbox area. Default to False. | |
Returns: | |
dict: Randomly cropped results, 'image_shape' key in result dict is | |
updated according to crop size. | |
""" | |
assert crop_size[0] > 0 and crop_size[1] > 0 | |
for key in results.get("image_fields", ["image"]): | |
img = results[key] | |
h, w = img.shape[-2:] | |
offset_h, offset_w = int(h - self.crop_size[0]), int( | |
(w - self.crop_size[1]) / 2 | |
) | |
# crop the image | |
img = TF.crop(img, offset_h, offset_w, crop_size[0], crop_size[1]) | |
results[key] = img | |
results["image_shape"] = tuple(img.shape) | |
for key in results.get("gt_fields", []): | |
gt = results[key] | |
results[key] = TF.crop(gt, offset_h, offset_w, crop_size[0], crop_size[1]) | |
# crop semantic seg | |
for key in results.get("mask_fields", []): | |
mask = results[key] | |
results[key] = TF.crop(mask, offset_h, offset_w, crop_size[0], crop_size[1]) | |
results["camera"] = results["camera"].crop(offset_w, offset_h) | |
return results | |
def __call__(self, results): | |
"""Call function to randomly crop images, bounding boxes, masks, | |
semantic segmentation maps. | |
Args: | |
results (dict): Result dict from loading pipeline. | |
Returns: | |
dict: Randomly cropped results, 'image_shape' key in result dict is | |
updated according to crop size. | |
""" | |
results = self._crop_data(results, self.crop_size) | |
return results | |
def __repr__(self): | |
repr_str = self.__class__.__name__ | |
repr_str += f"(crop_size={self.crop_size}, " | |
return repr_str | |
class RandomMasking: | |
def __init__( | |
self, | |
mask_ratio, | |
mask_patch=16, | |
prob=0.5, | |
warmup_steps=50000, | |
sampling="random", | |
curriculum=False, | |
): | |
self.mask_patch = mask_patch | |
self.prob = prob | |
self.mask_ratio = mask_ratio | |
self.warmup_steps = max(1, warmup_steps) | |
self.hard_bound = 1 | |
self.idx = 0 | |
self.curriculum = curriculum | |
self.sampling = sampling | |
self.low_bound = 0.0 | |
self.up_bound = 0.0 | |
def __call__(self, results): | |
B, _, H, W = results["image"].shape | |
device = results["image"].device | |
down_size = H // self.mask_patch, W // self.mask_patch | |
if np.random.random() > self.prob: # fill with dummy | |
return self._nop(results, down_size, device) | |
validity_mask = results["validity_mask"].float().reshape(B, -1, H, W) | |
validity_mask = F.interpolate(validity_mask, size=down_size).bool() | |
validity_mask = validity_mask.reshape(B, 1, *down_size) | |
is_random = self.is_warmup or results.get("guidance") is None | |
if not is_random: | |
guidance = F.interpolate(results["guidance"], size=(H, W), mode="bilinear") | |
results["guidance"] = -F.max_pool2d( | |
-guidance, kernel_size=self.mask_patch, stride=self.mask_patch | |
) | |
if is_random and self.sampling == "inverse": | |
sampling = self.inverse_sampling | |
elif is_random and self.sampling == "random": | |
sampling = self.random_sampling | |
else: | |
sampling = self.guided_sampling | |
mask_ratio = np.random.uniform(self.low_bound, self.up_bound) | |
for key in results.get("image_fields", ["image"]): | |
mask = sampling(results, mask_ratio, down_size, validity_mask, device) | |
results[key + "_mask"] = mask | |
return results | |
def _nop(self, results, down_size, device): | |
B = results["image"].shape[0] | |
for key in results.get("image_fields", ["image"]): | |
mask_blocks = torch.zeros(size=(B, 1, *down_size), device=device) | |
results[key + "_mask"] = mask_blocks | |
return results | |
def random_sampling(self, results, mask_ratio, down_size, validity_mask, device): | |
B = results["image"].shape[0] | |
prob_blocks = torch.rand(size=(B, 1, *down_size), device=device) | |
mask_blocks = torch.logical_and(prob_blocks < mask_ratio, validity_mask) | |
return mask_blocks | |
def inverse_sampling(self, results, mask_ratio, down_size, validity_mask, device): | |
# from PIL import Image | |
# from unik3d.utils import colorize | |
def area_sample(depth, fx, fy): | |
dtype = depth.dtype | |
B = depth.shape[0] | |
H, W = down_size | |
depth = downsample(depth, depth.shape[-2] // H) | |
depth[depth > 200] = 50 # set sky as if depth 50 meters | |
pixel_area3d = depth / torch.sqrt(fx * fy) | |
# Set invalid as -1 (no div problem) -> then clip to 0.0 | |
pixel_area3d[depth == 0.0] = -1 | |
prob_density = (1 / pixel_area3d).clamp(min=0.0).square() | |
prob_density = prob_density / prob_density.sum( | |
dim=(-1, -2), keepdim=True | |
).clamp(min=1e-5) | |
# Image.fromarray((prob_density[0] * 255 * 100).clamp(min=0.0, max=255.0).squeeze().cpu().byte().numpy()).save("prob_density.png") | |
# Sample locations based on prob_density | |
prob_density_flat = prob_density.view(B, -1) | |
# Get the avgerage valid locations, of those we mask self.mask_ratio | |
valid_locations = (prob_density_flat > 0).to(dtype).sum(dim=1) | |
masks = [] | |
for i in range(B): | |
num_samples = int(valid_locations[i] * mask_ratio) | |
mask = torch.zeros_like(prob_density_flat[i]) | |
# Sample indices | |
if num_samples > 0: | |
sampled_indices_flat = torch.multinomial( | |
prob_density_flat[i], num_samples, replacement=False | |
) | |
mask.scatter_(0, sampled_indices_flat, 1) | |
masks.append(mask) | |
return torch.stack(masks).bool().view(B, 1, H, W) | |
def random_sample(validity_mask): | |
prob_blocks = torch.rand( | |
size=(validity_mask.shape[0], 1, *down_size), device=device | |
) | |
mask = torch.logical_and(prob_blocks < mask_ratio, validity_mask) | |
return mask | |
fx = results["K"][..., 0, 0].view(-1, 1, 1, 1) / self.mask_patch | |
fy = results["K"][..., 1, 1].view(-1, 1, 1, 1) / self.mask_patch | |
valid = ~results["ssi"] & ~results["si"] & results["valid_camera"] | |
mask_blocks = torch.zeros_like(validity_mask) | |
if valid.any(): | |
out = area_sample(results["depth"][valid], fx[valid], fy[valid]) | |
mask_blocks[valid] = out | |
if (~valid).any(): | |
mask_blocks[~valid] = random_sample(validity_mask[~valid]) | |
# mask_blocks_ = (mask_blocks.float() * 255).squeeze(1).byte().cpu().numpy() | |
# Image.fromarray(mask_blocks_[0]).save("mask1.png") | |
# Image.fromarray(mask_blocks_[-1]).save("mask2.png") | |
# dd = results["depth"] | |
# Image.fromarray(colorize(dd[0].squeeze().cpu().numpy())).save("depth1_p.png") | |
# Image.fromarray(colorize(dd[-1].squeeze().cpu().numpy())).save("depth2_p.png") | |
# dd = downsample(dd, dd.shape[-2] // down_size[0]) | |
# Image.fromarray(colorize(dd[0].squeeze().cpu().numpy())).save("depth1.png") | |
# Image.fromarray(colorize(dd[-1].squeeze().cpu().numpy())).save("depth2.png") | |
# raise ValueError | |
return mask_blocks | |
def guided_sampling(self, results, mask_ratio, down_size, validity_mask, device): | |
# get the lowest (based on guidance) "mask_ratio" quantile of the patches that are in validity mask | |
B = results["image"].shape[0] | |
guidance = results["guidance"] | |
mask_blocks = torch.zeros(size=(B, 1, *down_size), device=device) | |
for b in range(B): | |
low_bound = torch.quantile( | |
guidance[b][validity_mask[b]], max(0.0, self.hard_bound - mask_ratio) | |
) | |
up_bound = torch.quantile( | |
guidance[b][validity_mask[b]], min(1.0, self.hard_bound) | |
) | |
mask_blocks[b] = torch.logical_and( | |
guidance[b] < up_bound, guidance[b] > low_bound | |
) | |
mask_blocks = torch.logical_and(mask_blocks, validity_mask) | |
return mask_blocks | |
def step(self): | |
self.idx += 1 | |
# schedule hard from 1.0 to self.mask_ratio | |
if self.curriculum: | |
step = max(0, self.idx / self.warmup_steps / 2 - 0.5) | |
self.hard_bound = 1 - (1 - self.mask_ratio) * tanh(step) | |
self.up_bound = self.mask_ratio * tanh(step) | |
self.low_bound = 0.1 * tanh(step) | |
def is_warmup(self): | |
return self.idx < self.warmup_steps | |
class Resize: | |
def __init__(self, image_scale=None, image_shape=None, keep_original=False): | |
assert (image_scale is None) ^ (image_shape is None) | |
if isinstance(image_scale, (float, int)): | |
image_scale = (image_scale, image_scale) | |
if isinstance(image_shape, (float, int)): | |
image_shape = (int(image_shape), int(image_shape)) | |
self.image_scale = image_scale | |
self.image_shape = image_shape | |
self.keep_original = keep_original | |
def _resize_img(self, results): | |
for key in results.get("image_fields", ["image"]): | |
img = TF.resize( | |
results[key], | |
results["resized_shape"], | |
interpolation=TF.InterpolationMode.BILINEAR, | |
antialias=True, | |
) | |
results[key] = img | |
def _resize_masks(self, results): | |
for key in results.get("mask_fields", []): | |
mask = TF.resize( | |
results[key], | |
results["resized_shape"], | |
interpolation=TF.InterpolationMode.NEAREST_EXACT, | |
antialias=True, | |
) | |
results[key] = mask | |
def _resize_gt(self, results): | |
for key in results.get("gt_fields", []): | |
gt = TF.resize( | |
results[key], | |
results["resized_shape"], | |
interpolation=TF.InterpolationMode.NEAREST_EXACT, | |
antialias=True, | |
) | |
results[key] = gt | |
def __call__(self, results): | |
h, w = results["image"].shape[-2:] | |
results["K_original"] = results["K"].clone() | |
if self.image_scale: | |
image_shape = ( | |
int(h * self.image_scale[0] + 0.5), | |
int(w * self.image_scale[1] + 0.5), | |
) | |
image_scale = self.image_scale | |
elif self.image_shape: | |
image_scale = (self.image_shape[0] / h, self.image_shape[1] / w) | |
image_shape = self.image_shape | |
else: | |
raise ValueError( | |
f"In {self.__class__.__name__}: image_scale of image_shape must be set" | |
) | |
results["resized_shape"] = tuple(image_shape) | |
results["resize_factor"] = tuple(image_scale) | |
results["K"][..., 0, 2] = (results["K"][..., 0, 2] - 0.5) * image_scale[1] + 0.5 | |
results["K"][..., 1, 2] = (results["K"][..., 1, 2] - 0.5) * image_scale[0] + 0.5 | |
results["K"][..., 0, 0] = results["K"][..., 0, 0] * image_scale[1] | |
results["K"][..., 1, 1] = results["K"][..., 1, 1] * image_scale[0] | |
self._resize_img(results) | |
if not self.keep_original: | |
self._resize_masks(results) | |
self._resize_gt(results) | |
return results | |
def __repr__(self): | |
repr_str = self.__class__.__name__ | |
return repr_str | |
class Rotate: | |
def __init__( | |
self, angle, center=None, img_fill_val=(123.68, 116.28, 103.53), prob=0.5 | |
): | |
if isinstance(img_fill_val, (float, int)): | |
img_fill_val = tuple([float(img_fill_val)] * 3) | |
elif isinstance(img_fill_val, tuple): | |
assert len(img_fill_val) == 3, ( | |
"image_fill_val as tuple must " | |
f"have 3 elements. got {len(img_fill_val)}." | |
) | |
img_fill_val = tuple([float(val) for val in img_fill_val]) | |
else: | |
raise ValueError("image_fill_val must be float or tuple with 3 elements.") | |
assert np.all( | |
[0 <= val <= 255 for val in img_fill_val] | |
), f"all elements of img_fill_val should between range [0,255] got {img_fill_val}." | |
assert 0 <= prob <= 1.0, f"The probability should be in range [0,1]bgot {prob}." | |
self.center = center | |
self.img_fill_val = img_fill_val | |
self.prob = prob | |
self.random = not isinstance(angle, (float, int)) | |
self.angle = angle | |
def _rotate(self, results, angle, center=None, fill_val=0.0): | |
for key in results.get("image_fields", ["image"]): | |
img = results[key] | |
img_rotated = TF.rotate( | |
img, | |
angle, | |
center=center, | |
interpolation=TF.InterpolationMode.NEAREST_EXACT, | |
fill=self.img_fill_val, | |
) | |
results[key] = img_rotated.to(img.dtype) | |
results["image_shape"] = results[key].shape | |
for key in results.get("mask_fields", []): | |
results[key] = TF.rotate( | |
results[key], | |
angle, | |
center=center, | |
interpolation=TF.InterpolationMode.NEAREST_EXACT, | |
fill=fill_val, | |
) | |
for key in results.get("gt_fields", []): | |
results[key] = TF.rotate( | |
results[key], | |
angle, | |
center=center, | |
interpolation=TF.InterpolationMode.NEAREST_EXACT, | |
fill=fill_val, | |
) | |
def __call__(self, results): | |
"""Call function to rotate images, bounding boxes, masks and semantic | |
segmentation maps. | |
Args: | |
results (dict): Result dict from loading pipeline. | |
Returns: | |
dict: Rotated results. | |
""" | |
if np.random.random() > self.prob: | |
return results | |
angle = ( | |
(self.angle[1] - self.angle[0]) * np.random.rand() + self.angle[0] | |
if self.random | |
else np.random.choice([-1, 1], size=1) * self.angle | |
) | |
self._rotate(results, angle, None, fill_val=0.0) | |
results["rotation"] = angle | |
return results | |
def __repr__(self): | |
repr_str = self.__class__.__name__ | |
repr_str += f"(angle={self.angle}, " | |
repr_str += f"center={self.center}, " | |
repr_str += f"image_fill_val={self.img_fill_val}, " | |
repr_str += f"prob={self.prob}, " | |
return repr_str | |
class RandomColor: | |
"""Apply Color transformation to image. The bboxes, masks, and | |
segmentations are not modified. | |
Args: | |
level (int | float): Should be in range [0,_MAX_LEVEL]. | |
prob (float): The probability for performing Color transformation. | |
""" | |
def __init__(self, level, prob=0.5): | |
self.random = not isinstance(level, (float, int)) | |
self.level = level | |
self.prob = prob | |
def _adjust_color_img(self, results, factor=1.0): | |
"""Apply Color transformation to image.""" | |
for key in results.get("image_fields", ["image"]): | |
results[key] = TF.adjust_hue(results[key], factor) # .to(img.dtype) | |
def __call__(self, results): | |
"""Call function for Color transformation. | |
Args: | |
results (dict): Result dict from loading pipeline. | |
Returns: | |
dict: Colored results. | |
""" | |
if np.random.random() > self.prob: | |
return results | |
factor = ( | |
((self.level[1] - self.level[0]) * np.random.rand() + self.level[0]) | |
if self.random | |
else self.level | |
) | |
self._adjust_color_img(results, factor) | |
return results | |
def __repr__(self): | |
repr_str = self.__class__.__name__ | |
repr_str += f"(level={self.level}, " | |
repr_str += f"prob={self.prob})" | |
return repr_str | |
class RandomSaturation: | |
"""Apply Color transformation to image. The bboxes, masks, and | |
segmentations are not modified. | |
Args: | |
level (int | float): Should be in range [0,_MAX_LEVEL]. | |
prob (float): The probability for performing Color transformation. | |
""" | |
def __init__(self, level, prob=0.5): | |
self.random = not isinstance(level, (float, int)) | |
self.level = level | |
self.prob = prob | |
def _adjust_saturation_img(self, results, factor=1.0): | |
"""Apply Color transformation to image.""" | |
for key in results.get("image_fields", ["image"]): | |
# NOTE defaultly the image should be BGR format | |
results[key] = TF.adjust_saturation(results[key], factor) # .to(img.dtype) | |
def __call__(self, results): | |
"""Call function for Color transformation. | |
Args: | |
results (dict): Result dict from loading pipeline. | |
Returns: | |
dict: Colored results. | |
""" | |
if np.random.random() > self.prob: | |
return results | |
factor = ( | |
2 ** ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0]) | |
if self.random | |
else 2**self.level | |
) | |
self._adjust_saturation_img(results, factor) | |
return results | |
def __repr__(self): | |
repr_str = self.__class__.__name__ | |
repr_str += f"(level={self.level}, " | |
repr_str += f"prob={self.prob})" | |
return repr_str | |
class RandomSharpness: | |
"""Apply Color transformation to image. The bboxes, masks, and | |
segmentations are not modified. | |
Args: | |
level (int | float): Should be in range [0,_MAX_LEVEL]. | |
prob (float): The probability for performing Color transformation. | |
""" | |
def __init__(self, level, prob=0.5): | |
self.random = not isinstance(level, (float, int)) | |
self.level = level | |
self.prob = prob | |
def _adjust_sharpeness_img(self, results, factor=1.0): | |
"""Apply Color transformation to image.""" | |
for key in results.get("image_fields", ["image"]): | |
# NOTE defaultly the image should be BGR format | |
results[key] = TF.adjust_sharpness(results[key], factor) # .to(img.dtype) | |
def __call__(self, results): | |
"""Call function for Color transformation. | |
Args: | |
results (dict): Result dict from loading pipeline. | |
Returns: | |
dict: Colored results. | |
""" | |
if np.random.random() > self.prob: | |
return results | |
factor = ( | |
2 ** ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0]) | |
if self.random | |
else 2**self.level | |
) | |
self._adjust_sharpeness_img(results, factor) | |
return results | |
def __repr__(self): | |
repr_str = self.__class__.__name__ | |
repr_str += f"(level={self.level}, " | |
repr_str += f"prob={self.prob})" | |
return repr_str | |
class RandomSolarize: | |
"""Apply Color transformation to image. The bboxes, masks, and | |
segmentations are not modified. | |
Args: | |
level (int | float): Should be in range [0,_MAX_LEVEL]. | |
prob (float): The probability for performing Color transformation. | |
""" | |
def __init__(self, level, prob=0.5): | |
self.random = not isinstance(level, (float, int)) | |
self.level = level | |
self.prob = prob | |
def _adjust_solarize_img(self, results, factor=255.0): | |
"""Apply Color transformation to image.""" | |
for key in results.get("image_fields", ["image"]): | |
results[key] = TF.solarize(results[key], factor) # .to(img.dtype) | |
def __call__(self, results): | |
"""Call function for Color transformation. | |
Args: | |
results (dict): Result dict from loading pipeline. | |
Returns: | |
dict: Colored results. | |
""" | |
if np.random.random() > self.prob: | |
return results | |
factor = ( | |
((self.level[1] - self.level[0]) * np.random.rand() + self.level[0]) | |
if self.random | |
else self.level | |
) | |
self._adjust_solarize_img(results, factor) | |
return results | |
def __repr__(self): | |
repr_str = self.__class__.__name__ | |
repr_str += f"(level={self.level}, " | |
repr_str += f"prob={self.prob})" | |
return repr_str | |
class RandomPosterize: | |
"""Apply Color transformation to image. The bboxes, masks, and | |
segmentations are not modified. | |
Args: | |
level (int | float): Should be in range [0,_MAX_LEVEL]. | |
prob (float): The probability for performing Color transformation. | |
""" | |
def __init__(self, level, prob=0.5): | |
self.random = not isinstance(level, (float, int)) | |
self.level = level | |
self.prob = prob | |
def _posterize_img(self, results, factor=1.0): | |
"""Apply Color transformation to image.""" | |
for key in results.get("image_fields", ["image"]): | |
results[key] = TF.posterize(results[key], int(factor)) # .to(img.dtype) | |
def __call__(self, results): | |
"""Call function for Color transformation. | |
Args: | |
results (dict): Result dict from loading pipeline. | |
Returns: | |
dict: Colored results. | |
""" | |
if np.random.random() > self.prob: | |
return results | |
factor = ( | |
((self.level[1] - self.level[0]) * np.random.rand() + self.level[0]) | |
if self.random | |
else self.level | |
) | |
self._posterize_img(results, factor) | |
return results | |
def __repr__(self): | |
repr_str = self.__class__.__name__ | |
repr_str += f"(level={self.level}, " | |
repr_str += f"prob={self.prob})" | |
return repr_str | |
class RandomEqualize: | |
"""Apply Equalize transformation to image. The bboxes, masks and | |
segmentations are not modified. | |
Args: | |
prob (float): The probability for performing Equalize transformation. | |
""" | |
def __init__(self, prob=0.5): | |
assert 0 <= prob <= 1.0, "The probability should be in range [0,1]." | |
self.prob = prob | |
def _imequalize(self, results): | |
"""Equalizes the histogram of one image.""" | |
for key in results.get("image_fields", ["image"]): | |
results[key] = TF.equalize(results[key]) # .to(img.dtype) | |
def __call__(self, results): | |
"""Call function for Equalize transformation. | |
Args: | |
results (dict): Results dict from loading pipeline. | |
Returns: | |
dict: Results after the transformation. | |
""" | |
if np.random.random() > self.prob: | |
return results | |
self._imequalize(results) | |
return results | |
def __repr__(self): | |
repr_str = self.__class__.__name__ | |
repr_str += f"(prob={self.prob})" | |
class RandomBrightness: | |
"""Apply Brightness transformation to image. The bboxes, masks and | |
segmentations are not modified. | |
Args: | |
level (int | float): Should be in range [0,_MAX_LEVEL]. | |
prob (float): The probability for performing Brightness transformation. | |
""" | |
def __init__(self, level, prob=0.5): | |
self.random = not isinstance(level, (float, int)) | |
self.level = level | |
self.prob = prob | |
def _adjust_brightness_img(self, results, factor=1.0): | |
"""Adjust the brightness of image.""" | |
for key in results.get("image_fields", ["image"]): | |
results[key] = TF.adjust_brightness(results[key], factor) # .to(img.dtype) | |
def __call__(self, results, level=None): | |
"""Call function for Brightness transformation. | |
Args: | |
results (dict): Results dict from loading pipeline. | |
Returns: | |
dict: Results after the transformation. | |
""" | |
if np.random.random() > self.prob: | |
return results | |
factor = ( | |
2 ** ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0]) | |
if self.random | |
else 2**self.level | |
) | |
self._adjust_brightness_img(results, factor) | |
return results | |
def __repr__(self): | |
repr_str = self.__class__.__name__ | |
repr_str += f"(level={self.level}, " | |
repr_str += f"prob={self.prob})" | |
return repr_str | |
class RandomContrast: | |
"""Apply Contrast transformation to image. The bboxes, masks and | |
segmentations are not modified. | |
Args: | |
level (int | float): Should be in range [0,_MAX_LEVEL]. | |
prob (float): The probability for performing Contrast transformation. | |
""" | |
def __init__(self, level, prob=0.5): | |
self.random = not isinstance(level, (float, int)) | |
self.level = level | |
self.prob = prob | |
def _adjust_contrast_img(self, results, factor=1.0): | |
"""Adjust the image contrast.""" | |
for key in results.get("image_fields", ["image"]): | |
results[key] = TF.adjust_contrast(results[key], factor) # .to(img.dtype) | |
def __call__(self, results, level=None): | |
"""Call function for Contrast transformation. | |
Args: | |
results (dict): Results dict from loading pipeline. | |
Returns: | |
dict: Results after the transformation. | |
""" | |
if np.random.random() > self.prob: | |
return results | |
factor = ( | |
2 ** ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0]) | |
if self.random | |
else 2**self.level | |
) | |
self._adjust_contrast_img(results, factor) | |
return results | |
def __repr__(self): | |
repr_str = self.__class__.__name__ | |
repr_str += f"(level={self.level}, " | |
repr_str += f"prob={self.prob})" | |
return repr_str | |
class RandomGamma: | |
def __init__(self, level, prob=0.5): | |
self.random = not isinstance(level, (float, int)) | |
self.level = level | |
self.prob = prob | |
def __call__(self, results, level=None): | |
"""Call function for Contrast transformation. | |
Args: | |
results (dict): Results dict from loading pipeline. | |
Returns: | |
dict: Results after the transformation. | |
""" | |
if np.random.random() > self.prob: | |
return results | |
factor = (self.level[1] - self.level[0]) * np.random.rand() + self.level[0] | |
for key in results.get("image_fields", ["image"]): | |
if "original" not in key: | |
results[key] = TF.adjust_gamma(results[key], 1 + factor) | |
return results | |
class RandomInvert: | |
def __init__(self, prob=0.5): | |
self.prob = prob | |
def __call__(self, results): | |
if np.random.random() > self.prob: | |
return results | |
for key in results.get("image_fields", ["image"]): | |
if "original" not in key: | |
results[key] = TF.invert(results[key]) # .to(img.dtype) | |
return results | |
class RandomAutoContrast: | |
def __init__(self, prob=0.5): | |
self.prob = prob | |
def _autocontrast_img(self, results): | |
for key in results.get("image_fields", ["image"]): | |
img = results[key] | |
results[key] = TF.autocontrast(img) # .to(img.dtype) | |
def __call__(self, results): | |
if np.random.random() > self.prob: | |
return results | |
self._autocontrast_img(results) | |
return results | |
def __repr__(self): | |
repr_str = self.__class__.__name__ | |
repr_str += f"(level={self.level}, " | |
repr_str += f"prob={self.prob})" | |
return repr_str | |
class Dilation: | |
def __init__(self, origin, kernel, border_value=-1.0, iterations=1) -> None: | |
self.structured_element = torch.ones(size=kernel) | |
self.origin = origin | |
self.border_value = border_value | |
self.iterations = iterations | |
def dilate(self, image): | |
image_pad = F.pad( | |
image, | |
[ | |
self.origin[0], | |
self.structured_element.shape[0] - self.origin[0] - 1, | |
self.origin[1], | |
self.structured_element.shape[1] - self.origin[1] - 1, | |
], | |
mode="constant", | |
value=self.border_value, | |
) | |
if image_pad.ndim < 4: | |
image_pad = image_pad.unsqueeze(0) | |
# Unfold the image to be able to perform operation on neighborhoods | |
image_unfold = F.unfold(image_pad, kernel_size=self.structured_element.shape) | |
# Flatten the structural element since its two dimensions have been flatten when unfolding | |
# structured_element_flatten = torch.flatten(self.structured_element).unsqueeze(0).unsqueeze(-1) | |
# Perform the greyscale operation; sum would be replaced by rest if you want erosion | |
# sums = image_unfold + structured_element_flatten | |
# Take maximum over the neighborhood | |
# since we use depth, we need to take the cloest point (perspectivity) | |
# thus the min. But min is for "unknown" (0), so put it to a large number | |
# than take min | |
mask = image_unfold < 1e-3 # if == 0, some pixels are not involved, why? | |
# Replace the zero elements with a large value, so they don't affect the minimum operation | |
image_unfold = image_unfold.masked_fill(mask, 1000.0) | |
# Calculate the minimum along the neighborhood axis | |
dilate_image = torch.min(image_unfold, dim=1).values | |
# Fill the masked values with 0 to propagate zero if all pixels are zero | |
dilate_image[mask.all(dim=1)] = 0 | |
return torch.reshape(dilate_image, image.shape) | |
def __call__(self, results): | |
for key in results.get("gt_fields", []): | |
gt = results[key] | |
for _ in range(self.iterations): | |
gt[gt < 1e-4] = self.dilate(gt)[gt < 1e-4] | |
results[key] = gt | |
return results | |
class RandomShear(object): | |
def __init__( | |
self, | |
level, | |
prob=0.5, | |
direction="horizontal", | |
): | |
self.random = not isinstance(level, (float, int)) | |
self.level = level | |
self.prob = prob | |
self.direction = direction | |
def _shear_img(self, results, magnitude): | |
for key in results.get("image_fields", ["image"]): | |
img_sheared = TF.affine( | |
results[key], | |
angle=0.0, | |
translate=[0.0, 0.0], | |
scale=1.0, | |
shear=magnitude, | |
interpolation=TF.InterpolationMode.BILINEAR, | |
fill=0.0, | |
) | |
results[key] = img_sheared | |
def _shear_masks(self, results, magnitude): | |
for key in results.get("mask_fields", []): | |
mask_sheared = TF.affine( | |
results[key], | |
angle=0.0, | |
translate=[0.0, 0.0], | |
scale=1.0, | |
shear=magnitude, | |
interpolation=TF.InterpolationMode.NEAREST_EXACT, | |
fill=0.0, | |
) | |
results[key] = mask_sheared | |
def _shear_gt( | |
self, | |
results, | |
magnitude, | |
): | |
for key in results.get("gt_fields", []): | |
mask_sheared = TF.affine( | |
results[key], | |
angle=0.0, | |
translate=[0.0, 0.0], | |
scale=1.0, | |
shear=magnitude, | |
interpolation=TF.InterpolationMode.NEAREST_EXACT, | |
fill=0.0, | |
) | |
results[key] = mask_sheared | |
def __call__(self, results): | |
if np.random.random() > self.prob: | |
return results | |
magnitude = ( | |
((self.level[1] - self.level[0]) * np.random.rand() + self.level[0]) | |
if self.random | |
else np.random.choice([-1, 1], size=1) * self.level | |
) | |
if self.direction == "horizontal": | |
magnitude = [magnitude, 0.0] | |
else: | |
magnitude = [0.0, magnitude] | |
self._shear_img(results, magnitude) | |
self._shear_masks(results, magnitude) | |
self._shear_gt(results, magnitude) | |
return results | |
def __repr__(self): | |
repr_str = self.__class__.__name__ | |
repr_str += f"(level={self.level}, " | |
repr_str += f"img_fill_val={self.img_fill_val}, " | |
repr_str += f"seg_ignore_label={self.seg_ignore_label}, " | |
repr_str += f"prob={self.prob}, " | |
repr_str += f"direction={self.direction}, " | |
repr_str += f"max_shear_magnitude={self.max_shear_magnitude}, " | |
repr_str += f"random_negative_prob={self.random_negative_prob}, " | |
repr_str += f"interpolation={self.interpolation})" | |
return repr_str | |
class RandomTranslate(object): | |
def __init__( | |
self, | |
range, | |
prob=0.5, | |
direction="horizontal", | |
): | |
self.range = range | |
self.prob = prob | |
self.direction = direction | |
def _translate_img(self, results, magnitude): | |
for key in results.get("image_fields", ["image"]): | |
img_sheared = TF.affine( | |
results[key], | |
angle=0.0, | |
translate=magnitude, | |
scale=1.0, | |
shear=[0.0, 0.0], | |
interpolation=TF.InterpolationMode.BILINEAR, | |
fill=(123.68, 116.28, 103.53), | |
) | |
results[key] = img_sheared | |
def _translate_mask(self, results, magnitude): | |
for key in results.get("mask_fields", []): | |
mask_sheared = TF.affine( | |
results[key], | |
angle=0.0, | |
translate=magnitude, | |
scale=1.0, | |
shear=[0.0, 0.0], | |
interpolation=TF.InterpolationMode.NEAREST_EXACT, | |
fill=0.0, | |
) | |
results[key] = mask_sheared | |
def _translate_gt( | |
self, | |
results, | |
magnitude, | |
): | |
for key in results.get("gt_fields", []): | |
mask_sheared = TF.affine( | |
results[key], | |
angle=0.0, | |
translate=magnitude, | |
scale=1.0, | |
shear=[0.0, 0.0], | |
interpolation=TF.InterpolationMode.NEAREST_EXACT, | |
fill=0.0, | |
) | |
results[key] = mask_sheared | |
def __call__(self, results): | |
if np.random.random() > self.prob: | |
return results | |
magnitude = (self.range[1] - self.range[0]) * np.random.rand() + self.range[0] | |
if self.direction == "horizontal": | |
magnitude = [magnitude * results["image"].shape[1], 0] | |
else: | |
magnitude = [0, magnitude * results["image"].shape[0]] | |
self._translate_img(results, magnitude) | |
self._translate_mask(results, magnitude) | |
self._translate_gt(results, magnitude) | |
results["K"][..., 0, 2] = results["K"][..., 0, 2] + magnitude[0] | |
results["K"][..., 1, 2] = results["K"][..., 1, 2] + magnitude[1] | |
return results | |
def __repr__(self): | |
repr_str = self.__class__.__name__ | |
repr_str += f"(range={self.range}, " | |
repr_str += f"prob={self.prob}, " | |
repr_str += f"direction={self.direction}, " | |
return repr_str | |
class RandomCut(object): | |
def __init__(self, prob=0.5, direction="all"): | |
self.direction = direction | |
self.prob = prob | |
def _cut_img(self, results, coord, dim): | |
for key in results.get("image_fields", ["image"]): | |
img_sheared = torch.roll( | |
results[key], int(coord * results[key].shape[dim]), dims=dim | |
) | |
results[key] = img_sheared | |
def _cut_mask(self, results, coord, dim): | |
for key in results.get("mask_fields", []): | |
mask_sheared = torch.roll( | |
results[key], int(coord * results[key].shape[dim]), dims=dim | |
) | |
results[key] = mask_sheared | |
def _cut_gt(self, results, coord, dim): | |
for key in results.get("gt_fields", []): | |
gt_sheared = torch.roll( | |
results[key], int(coord * results[key].shape[dim]), dims=dim | |
) | |
results[key] = gt_sheared | |
def __call__(self, results): | |
if np.random.random() > self.prob: | |
return results | |
coord = 0.8 * random.random() + 0.1 | |
if self.direction == "horizontal": | |
dim = -1 | |
elif self.direction == "vertical": | |
dim = -2 | |
else: | |
dim = -1 if random.random() < 0.5 else -2 | |
self._cut_img(results, coord, dim) | |
self._cut_mask(results, coord, dim) | |
self._cut_gt(results, coord, dim) | |
return results | |
class DownsamplerGT(object): | |
def __init__(self, downsample_factor: int, min_depth: float = 0.01): | |
assert downsample_factor == round( | |
downsample_factor, 0 | |
), f"Downsample factor needs to be an integer, got {downsample_factor}" | |
self.downsample_factor = downsample_factor | |
self.min_depth = min_depth | |
def _downsample_gt(self, results): | |
for key in deepcopy(results.get("gt_fields", [])): | |
gt = results[key] | |
N, H, W = gt.shape | |
gt = gt.view( | |
N, | |
H // self.downsample_factor, | |
self.downsample_factor, | |
W // self.downsample_factor, | |
self.downsample_factor, | |
1, | |
) | |
gt = gt.permute(0, 1, 3, 5, 2, 4) | |
gt = gt.view(-1, self.downsample_factor * self.downsample_factor) | |
gt_tmp = torch.where(gt == 0.0, 1e5 * torch.ones_like(gt), gt) | |
gt = torch.min(gt_tmp, dim=-1).values | |
gt = gt.view(N, H // self.downsample_factor, W // self.downsample_factor) | |
gt = torch.where(gt > 1000, torch.zeros_like(gt), gt) | |
results[f"{key}_downsample"] = gt | |
results["gt_fields"].append(f"{key}_downsample") | |
results["downsampled"] = True | |
return results | |
def __call__(self, results): | |
results = self._downsample_gt(results) | |
return results | |
class RandomColorJitter: | |
def __init__(self, level, prob=0.9): | |
self.level = level | |
self.prob = prob | |
self.list_transform = [ | |
self._adjust_brightness_img, | |
# self._adjust_sharpness_img, | |
self._adjust_contrast_img, | |
self._adjust_saturation_img, | |
self._adjust_color_img, | |
] | |
def _adjust_contrast_img(self, results, factor=1.0): | |
for key in results.get("image_fields", ["image"]): | |
if "original" not in key: | |
img = results[key] | |
results[key] = TF.adjust_contrast(img, factor) | |
def _adjust_sharpness_img(self, results, factor=1.0): | |
for key in results.get("image_fields", ["image"]): | |
if "original" not in key: | |
img = results[key] | |
results[key] = TF.adjust_sharpness(img, factor) | |
def _adjust_brightness_img(self, results, factor=1.0): | |
for key in results.get("image_fields", ["image"]): | |
if "original" not in key: | |
img = results[key] | |
results[key] = TF.adjust_brightness(img, factor) | |
def _adjust_saturation_img(self, results, factor=1.0): | |
for key in results.get("image_fields", ["image"]): | |
if "original" not in key: | |
img = results[key] | |
results[key] = TF.adjust_saturation(img, factor / 2.0) | |
def _adjust_color_img(self, results, factor=1.0): | |
for key in results.get("image_fields", ["image"]): | |
if "original" not in key: | |
img = results[key] | |
results[key] = TF.adjust_hue(img, (factor - 1.0) / 4.0) | |
def __call__(self, results): | |
random.shuffle(self.list_transform) | |
for op in self.list_transform: | |
if np.random.random() < self.prob: | |
factor = 1.0 + ( | |
(self.level[1] - self.level[0]) * np.random.random() + self.level[0] | |
) | |
op(results, factor) | |
return results | |
class RandomGrayscale: | |
def __init__(self, prob=0.1, num_output_channels=3): | |
super().__init__() | |
self.prob = prob | |
self.num_output_channels = num_output_channels | |
def __call__(self, results): | |
if np.random.random() > self.prob: | |
return results | |
for key in results.get("image_fields", ["image"]): | |
if "original" not in key: | |
results[key] = TF.rgb_to_grayscale( | |
results[key], num_output_channels=self.num_output_channels | |
) | |
return results | |
class ContextCrop(Resize): | |
def __init__( | |
self, | |
image_shape, | |
keep_original=False, | |
test_min_ctx=1.0, | |
train_ctx_range=[0.5, 1.5], | |
shape_constraints={}, | |
): | |
super().__init__(image_shape=image_shape, keep_original=keep_original) | |
self.test_min_ctx = test_min_ctx | |
self.train_ctx_range = train_ctx_range | |
self.shape_mult = shape_constraints["shape_mult"] | |
self.sample = shape_constraints["sample"] | |
self.ratio_bounds = shape_constraints["ratio_bounds"] | |
pixels_min = shape_constraints["pixels_min"] / ( | |
self.shape_mult * self.shape_mult | |
) | |
pixels_max = shape_constraints["pixels_max"] / ( | |
self.shape_mult * self.shape_mult | |
) | |
self.pixels_bounds = (pixels_min, pixels_max) | |
self.keepGT = int(os.environ.get("keepGT", 0)) | |
self.ctx = None | |
def _transform_img(self, results, shapes): | |
for key in results.get("image_fields", ["image"]): | |
img = self.crop(results[key], **shapes) | |
img = TF.resize( | |
img, | |
results["resized_shape"], | |
interpolation=TF.InterpolationMode.BICUBIC, | |
antialias=True, | |
) | |
results[key] = img | |
def _transform_masks(self, results, shapes): | |
for key in results.get("mask_fields", []): | |
mask = self.crop(results[key].float(), **shapes).byte() | |
if "flow" in key: # take pad/crop into flow resize | |
mask = TF.resize( | |
mask, | |
results["resized_shape"], | |
interpolation=TF.InterpolationMode.NEAREST_EXACT, | |
antialias=False, | |
) | |
else: | |
mask = masked_nearest_interpolation( | |
mask, mask > 0, results["resized_shape"] | |
) | |
results[key] = mask | |
def _transform_gt(self, results, shapes): | |
for key in results.get("gt_fields", []): | |
gt = self.crop(results[key], **shapes) | |
if not self.keepGT: | |
if "flow" in key: # take pad/crop into flow resize | |
gt = self._rescale_flow(gt, results) | |
gt = TF.resize( | |
gt, | |
results["resized_shape"], | |
interpolation=TF.InterpolationMode.NEAREST_EXACT, | |
antialias=False, | |
) | |
else: | |
gt = masked_nearest_interpolation( | |
gt, gt > 0, results["resized_shape"] | |
) | |
results[key] = gt | |
def _rescale_flow(self, gt, results): | |
h_new, w_new = gt.shape[-2:] | |
h_old, w_old = results["image_ori_shape"] | |
gt[:, 0] = gt[:, 0] * (w_old - 1) / (w_new - 1) | |
gt[:, 1] = gt[:, 1] * (h_old - 1) / (h_new - 1) | |
return gt | |
def crop(img, height, width, top, left) -> torch.Tensor: | |
h, w = img.shape[-2:] | |
right = left + width | |
bottom = top + height | |
padding_ltrb = [ | |
max(-left + min(0, right), 0), | |
max(-top + min(0, bottom), 0), | |
max(right - max(w, left), 0), | |
max(bottom - max(h, top), 0), | |
] | |
image_cropped = img[..., max(top, 0) : bottom, max(left, 0) : right] | |
return TF.pad(image_cropped, padding_ltrb) | |
def test_closest_shape(self, image_shape): | |
h, w = image_shape | |
input_ratio = w / h | |
if self.sample: | |
input_pixels = int(ceil(h / self.shape_mult * w / self.shape_mult)) | |
pixels = max( | |
min(input_pixels, self.pixels_bounds[1]), self.pixels_bounds[0] | |
) | |
ratio = min(max(input_ratio, self.ratio_bounds[0]), self.ratio_bounds[1]) | |
h = round((pixels / ratio) ** 0.5) | |
w = h * ratio | |
self.image_shape[0] = int(h) * self.shape_mult | |
self.image_shape[1] = int(w) * self.shape_mult | |
def _get_crop_shapes(self, image_shape, ctx=None): | |
h, w = image_shape | |
input_ratio = w / h | |
if self.keep_original: | |
self.test_closest_shape(image_shape) | |
ctx = 1.0 | |
elif ctx is None: | |
ctx = float( | |
torch.empty(1) | |
.uniform_(self.train_ctx_range[0], self.train_ctx_range[1]) | |
.item() | |
) | |
output_ratio = self.image_shape[1] / self.image_shape[0] | |
if output_ratio <= input_ratio: # out like 4:3 in like kitti | |
if ( | |
ctx >= 1 | |
): # fully in -> use just max_length with sqrt(ctx), here max is width | |
new_w = w * ctx**0.5 | |
# sporge un po in una sola dim | |
# we know that in_width will stick out before in_height, partial overshoot (sporge) | |
# new_h > old_h via area -> new_h ** 2 * ratio_new = old_h ** 2 * ratio_old * ctx | |
elif output_ratio / input_ratio * ctx > 1: | |
new_w = w * ctx | |
else: # fully contained -> use area | |
new_w = w * (ctx * output_ratio / input_ratio) ** 0.5 | |
new_h = new_w / output_ratio | |
else: | |
if ctx >= 1: | |
new_h = h * ctx**0.5 | |
elif input_ratio / output_ratio * ctx > 1: | |
new_h = h * ctx | |
else: | |
new_h = h * (ctx * input_ratio / output_ratio) ** 0.5 | |
new_w = new_h * output_ratio | |
return (int(ceil(new_h - 0.5)), int(ceil(new_w - 0.5))), ctx | |
# def sample_view(self, results): | |
# original_K = results["K"] | |
# original_image = results["image"] | |
# original_depth = results["depth"] | |
# original_validity_mask = results["validity_mask"].float() | |
# # sample angles and translation | |
# # sample translation: | |
# # 10 max of z | |
# x = np.random.normal(0, 0.05 / 2) * original_depth.max() | |
# y = np.random.normal(0, 0.05) | |
# z = np.random.normal(0, 0.05) * original_depth.max() | |
# fov = 2 * np.arctan(original_image.shape[-2] / 2 / results["K"][0, 0, 0]) | |
# phi = np.random.normal(0, fov / 10) | |
# theta = np.random.normal(0, fov / 10) | |
# psi = np.random.normal(0, np.pi / 60) | |
# translation = torch.tensor([x, y, z]).unsqueeze(0) | |
# angles = torch.tensor([phi, theta, psi]) | |
# angles = euler_to_rotation_matrix(angles) | |
# translation = translation @ angles # translation before rotation | |
# cam2w = torch.eye(4).unsqueeze(0) | |
# cam2w[..., :3, :3] = angles | |
# cam2w[..., :3, 3] = translation | |
# cam2cam = torch.inverse(cam2w) | |
# image_warped, depth_warped = forward_warping(original_image, original_depth, original_K, original_K, cam2cam=cam2cam) | |
# depth_warped[depth_warped > 0] = depth_warped[depth_warped > 0] - z | |
# validity_mask_warped = image_warped.sum(dim=1, keepdim=True) > 0.0 | |
# results["K"] = results["K"].repeat(2, 1, 1) | |
# results["cam2w"] = torch.cat([torch.eye(4).unsqueeze(0), cam2w]) | |
# results["image"] = torch.cat([original_image, image_warped]) | |
# results["depth"] = torch.cat([original_depth, depth_warped]) | |
# results["validity_mask"] = torch.cat([original_validity_mask, validity_mask_warped], dim=0) | |
# # results["cam2w"] = torch.cat([torch.eye(4).unsqueeze(0), torch.eye(4).unsqueeze(0)]) | |
# # results["image"] = torch.cat([original_image, original_image]) | |
# # results["depth"] = torch.cat([original_depth, original_depth]) | |
# # results["validity_mask"] = torch.cat([original_validity_mask, original_validity_mask], dim=0) | |
# return results | |
def __call__(self, results): | |
h, w = results["image"].shape[-2:] | |
results["image_ori_shape"] = (h, w) | |
results["camera_fields"].add("camera_original") | |
results["camera_original"] = results["camera"].clone() | |
results.get("mask_fields", set()).add("validity_mask") | |
if "validity_mask" not in results: | |
results["validity_mask"] = torch.ones( | |
(results["image"].shape[0], 1, h, w), | |
dtype=torch.uint8, | |
device=results["image"].device, | |
) | |
n_iter = 1 if self.keep_original or not self.sample else 100 | |
min_valid_area = 0.5 | |
max_hfov, max_vfov = results["camera"].max_fov[0] # it is a 1-dim list | |
ctx = None | |
for ii in range(n_iter): | |
(height, width), ctx = self._get_crop_shapes((h, w), ctx=self.ctx or ctx) | |
margin_h = h - height | |
margin_w = w - width | |
# keep it centered in y direction | |
top = margin_h // 2 | |
left = margin_w // 2 | |
if not self.keep_original: | |
left = left + np.random.randint( | |
-self.shape_mult // 2, self.shape_mult // 2 + 1 | |
) | |
top = top + np.random.randint( | |
-self.shape_mult // 2, self.shape_mult // 2 + 1 | |
) | |
right = left + width | |
bottom = top + height | |
x_zoom = self.image_shape[0] / height | |
paddings = [ | |
max(-left + min(0, right), 0), | |
max(bottom - max(h, top), 0), | |
max(right - max(w, left), 0), | |
max(-top + min(0, bottom), 0), | |
] | |
valid_area = ( | |
h | |
* w | |
/ (h + paddings[1] + paddings[3]) | |
/ (w + paddings[0] + paddings[2]) | |
) | |
new_hfov, new_vfov = results["camera_original"].get_new_fov( | |
new_shape=(height, width), original_shape=(h, w) | |
)[0] | |
# if valid_area >= min_valid_area or getattr(self, "ctx", None) is not None: | |
# break | |
if ( | |
valid_area >= min_valid_area | |
and new_hfov < max_hfov | |
and new_vfov < max_vfov | |
): | |
break | |
ctx = ( | |
ctx * 0.96 | |
) # if not enough valid area, try again with less ctx (more zoom) | |
# save ctx for next iteration of sequences? | |
self.ctx = ctx | |
results["resized_shape"] = self.image_shape | |
results["paddings"] = paddings # left ,top ,right, bottom | |
results["image_rescale"] = x_zoom | |
results["scale_factor"] = results.get("scale_factor", 1.0) * x_zoom | |
results["camera"] = results["camera"].crop( | |
left, top, right=w - right, bottom=h - bottom | |
) | |
results["camera"] = results["camera"].resize(x_zoom) | |
# print("XAM", results["camera"].params.squeeze(), results["camera"][0].params.squeeze(), results["camera_original"].params.squeeze(), results["camera_original"][0].params.squeeze()) | |
shapes = dict(height=height, width=width, top=top, left=left) | |
self._transform_img(results, shapes) | |
if not self.keep_original: | |
self._transform_gt(results, shapes) | |
self._transform_masks(results, shapes) | |
else: | |
# only validity_mask (rgb's masks follows rgb transform) #FIXME | |
mask = results["validity_mask"].float() | |
mask = self.crop(mask, **shapes).byte() | |
mask = TF.resize( | |
mask, | |
results["resized_shape"], | |
interpolation=TF.InterpolationMode.NEAREST, | |
) | |
results["validity_mask"] = mask | |
# # print(ii, ctx, results["camera"].hfov[0] * 180 / np.pi, original_hfov * 180 / np.pi, results["camera"].vfov[0] * 180 / np.pi, original_vfov * 180 / np.pi, valid_area) | |
# from PIL import Image | |
# from unik3d.utils.visualization import colorize | |
# img1 = results["image"][0].permute(1,2,0).clip(0, 255.0).cpu().numpy() | |
# # img2 = results["image"][1].permute(1,2,0).clip(0, 255.0).cpu().numpy() | |
# Image.fromarray(img1.astype(np.uint8)).save("test_col1.png") | |
# # Image.fromarray(img2.astype(np.uint8)).save("test_col2.png") | |
# Image.fromarray(colorize(results["depth"][0].cpu().numpy().squeeze(), 0.0, 10.0)).save("test_dep1.png") | |
# # Image.fromarray(colorize(results["depth"][1].cpu().numpy().squeeze(), 0.0, 10.0)).save("test_dep2.png") | |
# raise ValueError | |
# keep original images before photo-augment | |
results["image_original"] = results["image"].clone() | |
results["image_fields"].add( | |
*[ | |
field.replace("image", "image_original") | |
for field in results["image_fields"] | |
] | |
) | |
# repeat for batch resized shape and paddings | |
results["paddings"] = [results["paddings"]] * results["image"].shape[0] | |
results["resized_shape"] = [results["resized_shape"]] * results["image"].shape[ | |
0 | |
] | |
return results | |
class RandomFiller: | |
def __init__(self, test_mode, *args, **kwargs): | |
super().__init__() | |
self.test_mode = test_mode | |
def _transform(self, results): | |
def fill_noise(size, device): | |
return torch.normal(0, 2.0, size=size, device=device) | |
def fill_black(size, device): | |
return -4 * torch.ones(size, device=device, dtype=torch.float32) | |
def fill_white(size, device): | |
return 4 * torch.ones(size, device=device, dtype=torch.float32) | |
def fill_zero(size, device): | |
return torch.zeros(size, device=device, dtype=torch.float32) | |
B, C = results["image"].shape[:2] | |
mismatch = B // results["validity_mask"].shape[0] | |
if mismatch: | |
results["validity_mask"] = results["validity_mask"].repeat( | |
mismatch, 1, 1, 1 | |
) | |
validity_mask = results["validity_mask"].repeat(1, C, 1, 1).bool() | |
filler_fn = np.random.choice([fill_noise, fill_black, fill_white, fill_zero]) | |
if self.test_mode: | |
filler_fn = fill_zero | |
for key in results.get("image_fields", ["image"]): | |
results[key][~validity_mask] = filler_fn( | |
size=results[key][~validity_mask].shape, device=results[key].device | |
) | |
def __call__(self, results): | |
# generate mask for filler | |
if "validity_mask" not in results: | |
paddings = results.get("padding_size", [0] * 4) | |
height, width = results["image"].shape[-2:] | |
results.get("mask_fields", set()).add("validity_mask") | |
results["validity_mask"] = torch.zeros_like(results["image"][:, :1]) | |
results["validity_mask"][ | |
..., | |
paddings[1] : height - paddings[3], | |
paddings[0] : width - paddings[2], | |
] = 1.0 | |
self._transform(results) | |
return results | |
class GaussianBlur: | |
def __init__(self, kernel_size, sigma=(0.1, 2.0), prob=0.9): | |
super().__init__() | |
self.kernel_size = kernel_size | |
self.sigma = sigma | |
self.prob = prob | |
self.padding = kernel_size // 2 | |
def apply(self, x, kernel): | |
# Pad the input tensor | |
x = F.pad( | |
x, (self.padding, self.padding, self.padding, self.padding), mode="reflect" | |
) | |
# Apply the convolution with the Gaussian kernel | |
return F.conv2d(x, kernel, stride=1, padding=0, groups=x.size(1)) | |
def _create_kernel(self, sigma): | |
# Create a 1D Gaussian kernel | |
kernel_1d = torch.exp( | |
-torch.arange(-self.padding, self.padding + 1) ** 2 / (2 * sigma**2) | |
) | |
kernel_1d = kernel_1d / kernel_1d.sum() | |
# Expand the kernel to 2D and match size of the input | |
kernel_2d = kernel_1d.unsqueeze(0) * kernel_1d.unsqueeze(1) | |
kernel_2d = kernel_2d.view(1, 1, self.kernel_size, self.kernel_size).expand( | |
3, 1, -1, -1 | |
) | |
return kernel_2d | |
def __call__(self, results): | |
if np.random.random() > self.prob: | |
return results | |
sigma = (self.sigma[1] - self.sigma[0]) * np.random.rand() + self.sigma[0] | |
kernel = self._create_kernel(sigma) | |
for key in results.get("image_fields", ["image"]): | |
if "original" not in key: | |
results[key] = self.apply(results[key], kernel) | |
return results | |
class MotionBlur: | |
def __init__(self, kernel_size=(9, 9), angles=(-180, 180), prob=0.1): | |
super().__init__() | |
self.kernel_size = kernel_size | |
self.angles = angles | |
self.prob = prob | |
self.padding = kernel_size // 2 | |
def _create_kernel(self, angle): | |
# Generate a 2D grid of coordinates | |
grid = torch.meshgrid( | |
torch.arange(self.kernel_size), torch.arange(self.kernel_size) | |
) | |
grid = torch.stack(grid).float() # Shape: (2, kernel_size, kernel_size) | |
# Calculate relative coordinates from the center | |
center = (self.kernel_size - 1) / 2.0 | |
x_offset = grid[1] - center | |
y_offset = grid[0] - center | |
# Compute motion blur kernel | |
cos_theta = torch.cos(angle * torch.pi / 180.0) | |
sin_theta = torch.sin(angle * torch.pi / 180.0) | |
kernel = (1.0 / self.kernel_size) * ( | |
1.0 - torch.abs(x_offset * cos_theta + y_offset * sin_theta) | |
) | |
# Expand kernel dimensions to match input image channels | |
kernel = kernel.unsqueeze(0).unsqueeze(0).expand(3, 1, -1, -1) | |
return kernel | |
def apply(self, image, kernel): | |
x = F.pad( | |
x, (self.padding, self.padding, self.padding, self.padding), mode="reflect" | |
) | |
# Apply convolution with the motion blur kernel | |
blurred_image = F.conv2d(image, kernel, stride=1, padding=0, groups=x.size(1)) | |
return blurred_image | |
def __call__(self, results): | |
if np.random.random() > self.prob: | |
return results | |
angle = np.random.uniform(self.angles[0], self.angles[1]) | |
kernel = self._create_kernel(angle) | |
for key in results.get("image_fields", ["image"]): | |
if "original" in key: | |
continue | |
results[key] = self.apply(results[key], kernel) | |
return results | |
class JPEGCompression: | |
def __init__(self, level=(10, 70), prob=0.1): | |
super().__init__() | |
self.level = level | |
self.prob = prob | |
def __call__(self, results): | |
if np.random.random() > self.prob: | |
return results | |
level = np.random.uniform(self.level[0], self.level[1]) | |
for key in results.get("image_fields", ["image"]): | |
if "original" in key: | |
continue | |
results[key] = TF.jpeg(results[key], level) | |
return results | |
class Compose: | |
def __init__(self, transforms): | |
self.transforms = deepcopy(transforms) | |
def __call__(self, results): | |
for t in self.transforms: | |
results = t(results) | |
return results | |
def __setattr__(self, name: str, value) -> None: | |
super().__setattr__(name, value) | |
for t in self.transforms: | |
setattr(t, name, value) | |
def __repr__(self): | |
format_string = self.__class__.__name__ + "(" | |
for t in self.transforms: | |
format_string += f"\n {t}" | |
format_string += "\n)" | |
return format_string | |
class DummyCrop(Resize): | |
def __init__( | |
self, | |
*args, | |
**kwargs, | |
): | |
# dummy image shape, not really used | |
super().__init__(image_shape=(512, 512)) | |
def __call__(self, results): | |
h, w = results["image"].shape[-2:] | |
results["image_ori_shape"] = (h, w) | |
results["camera_fields"].add("camera_original") | |
results["camera_original"] = results["camera"].clone() | |
results.get("mask_fields", set()).add("validity_mask") | |
if "validity_mask" not in results: | |
results["validity_mask"] = torch.ones( | |
(results["image"].shape[0], 1, h, w), | |
dtype=torch.uint8, | |
device=results["image"].device, | |
) | |
self.ctx = 1.0 | |
results["resized_shape"] = self.image_shape | |
results["paddings"] = [0, 0, 0, 0] | |
results["image_rescale"] = 1.0 | |
results["scale_factor"] = results.get("scale_factor", 1.0) * 1.0 | |
results["camera"] = results["camera"].crop(0, 0, right=w, bottom=h) | |
results["camera"] = results["camera"].resize(1) | |
# keep original images before photo-augment | |
results["image_original"] = results["image"].clone() | |
results["image_fields"].add( | |
*[ | |
field.replace("image", "image_original") | |
for field in results["image_fields"] | |
] | |
) | |
# repeat for batch resized shape and paddings | |
results["paddings"] = [results["paddings"]] * results["image"].shape[0] | |
results["resized_shape"] = [results["resized_shape"]] * results["image"].shape[ | |
0 | |
] | |
return results | |