Sharp-It / src /utils /infer_util.py
YiftachEde's picture
add src
a1d8bef
import os
from einops import rearrange
import imageio
import rembg
import torch
import numpy as np
import PIL.Image
from PIL import Image
from typing import Any
import torch
import random
from kiui.op import safe_normalize
import torch.nn.functional as F
from kiui.cam import orbit_camera
def get_rays(pose, h, w, fovy, opengl=True):
x, y = torch.meshgrid(
torch.arange(w, device=pose.device),
torch.arange(h, device=pose.device),
indexing="xy",
)
x = x.flatten()
y = y.flatten()
cx = w * 0.5
cy = h * 0.5
focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy))
camera_dirs = F.pad(
torch.stack(
[
(x - cx + 0.5) / focal,
(y - cy + 0.5) / focal * (-1.0 if opengl else 1.0),
],
dim=-1,
),
(0, 1),
value=(-1.0 if opengl else 1.0),
) # [hw, 3]
rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3]
rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3]
rays_o = rays_o.view(h, w, 3)
rays_d = safe_normalize(rays_d).view(h, w, 3)
return rays_o, rays_d
def prepare_default_rays(device, azimuths=0, B=1):
fov = 30
radius = (0.5 / np.tan(np.radians(fov/2)))
cam_poses = []
for azimuth in azimuths:
if isinstance(azimuth, torch.Tensor):
azimuth = azimuth.item()
# azimuth = azimuth.item()
cam_poses.append([
orbit_camera(20, 30+azimuth, radius=radius),
orbit_camera(-10, 90+azimuth, radius=radius),
orbit_camera(20, 150+azimuth, radius=radius),
orbit_camera(-10, 210+azimuth, radius=radius),
orbit_camera(20, 270+azimuth, radius=radius),
orbit_camera(-10, 330+azimuth, radius=radius),
])
cam_poses = torch.from_numpy(np.stack(cam_poses)).reshape(-1, 4, 4).float().to(device)
rays_embeddings = []
for i in range(cam_poses.shape[0]):
rays_o, rays_d = get_rays(cam_poses[i], 320//8, 320//8, 30) # [h, w, 3]
rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6]
rays_embeddings.append(rays_plucker)
## visualize rays for plotting figure
# kiui.vis.plot_image(rays_d * 0.5 + 0.5, save=True)
rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous().to(device) # [V, 6, h, w]
rays_embeddings = rearrange(rays_embeddings, '(B x y) c h w -> B c (x h) (y w)', B=len(azimuths), x=3, y=2) # (B, C, 3H, 2W)
return rays_embeddings
class RandomCutout(object):
"""Randomly apply cutout to a batch of tensor images. The cutout is 3 times smaller than the image."""
def __init__(self, height, width, cutout_prob=0.4):
self.height = height
self.width = width
self.cutout_prob = cutout_prob
self.mask_height = height // 3 # Cutout size 3 times smaller
self.mask_width = width // 3
def __call__(self, img, is_train=True):
if not is_train:
return img
batch, channels, height, width = img.shape
# Generate a random number for each image in the batch to decide cutout application
random_seeds = torch.rand(batch, device=img.device)
# Masks should start as ones (no cutout effect)
masks = torch.ones((batch, height, width), device=img.device, dtype=img.dtype)
# Random coordinates for cutout
top = torch.randint(low=height // 10, high=height - self.mask_height - height // 10, size=(batch,), device=img.device)
left = torch.randint(low=width // 10, high=width - self.mask_width - width // 10, size=(batch,), device=img.device)
# Create a range grid for height and width
hh, ww = torch.meshgrid(torch.arange(0, height, device=img.device), torch.arange(0, width, device=img.device), indexing='ij')
# Expand top and left to match batch dimensions for broadcasting
top = top[:, None, None]
left = left[:, None, None]
# Create the cutout mask where cutouts should happen
cutout_masks = (hh >= top) & (hh < top + self.mask_height) & (ww >= left) & (ww < left + self.mask_width)
# Apply cutout_masks where random seeds fall below the probability
masks[cutout_masks & (random_seeds[:, None, None] < self.cutout_prob)] = 0
return img * masks[:,None]
def remove_background(image: PIL.Image.Image,
rembg_session: Any = None,
force: bool = False,
**rembg_kwargs,
) -> PIL.Image.Image:
do_remove = True
if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
do_remove = False
do_remove = do_remove or force
if do_remove:
image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
return image
def resize_foreground(
image: PIL.Image.Image,
ratio: float,
) -> PIL.Image.Image:
image = np.array(image)
assert image.shape[-1] == 4
alpha = np.where(image[..., 3] > 0)
y1, y2, x1, x2 = (
alpha[0].min(),
alpha[0].max(),
alpha[1].min(),
alpha[1].max(),
)
# crop the foreground
fg = image[y1:y2, x1:x2]
# pad to square
size = max(fg.shape[0], fg.shape[1])
ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
new_image = np.pad(
fg,
((ph0, ph1), (pw0, pw1), (0, 0)),
mode="constant",
constant_values=((0, 0), (0, 0), (0, 0)),
)
# compute padding according to the ratio
new_size = int(new_image.shape[0] / ratio)
# pad to size, double side
ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
ph1, pw1 = new_size - size - ph0, new_size - size - pw0
new_image = np.pad(
new_image,
((ph0, ph1), (pw0, pw1), (0, 0)),
mode="constant",
constant_values=((0, 0), (0, 0), (0, 0)),
)
new_image = PIL.Image.fromarray(new_image)
return new_image
def images_to_video(
images: torch.Tensor,
output_path: str,
fps: int = 30,
) -> None:
# images: (N, C, H, W)
video_dir = os.path.dirname(output_path)
video_name = os.path.basename(output_path)
os.makedirs(video_dir, exist_ok=True)
frames = []
for i in range(len(images)):
frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \
f"Frame shape mismatch: {frame.shape} vs {images.shape}"
assert frame.min() >= 0 and frame.max() <= 255, \
f"Frame value out of range: {frame.min()} ~ {frame.max()}"
frames.append(frame)
imageio.mimwrite(output_path, np.stack(frames), fps=fps, quality=10)
def save_video(
frames: torch.Tensor,
output_path: str,
fps: int = 30,
) -> None:
# images: (N, C, H, W)
frames = [(frame.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) for frame in frames]
writer = imageio.get_writer(output_path, fps=fps)
for frame in frames:
writer.append_data(frame)
writer.close()