import os from einops import rearrange import imageio import rembg import torch import numpy as np import PIL.Image from PIL import Image from typing import Any import torch import random from kiui.op import safe_normalize import torch.nn.functional as F from kiui.cam import orbit_camera def get_rays(pose, h, w, fovy, opengl=True): x, y = torch.meshgrid( torch.arange(w, device=pose.device), torch.arange(h, device=pose.device), indexing="xy", ) x = x.flatten() y = y.flatten() cx = w * 0.5 cy = h * 0.5 focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy)) camera_dirs = F.pad( torch.stack( [ (x - cx + 0.5) / focal, (y - cy + 0.5) / focal * (-1.0 if opengl else 1.0), ], dim=-1, ), (0, 1), value=(-1.0 if opengl else 1.0), ) # [hw, 3] rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3] rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3] rays_o = rays_o.view(h, w, 3) rays_d = safe_normalize(rays_d).view(h, w, 3) return rays_o, rays_d def prepare_default_rays(device, azimuths=0, B=1): fov = 30 radius = (0.5 / np.tan(np.radians(fov/2))) cam_poses = [] for azimuth in azimuths: if isinstance(azimuth, torch.Tensor): azimuth = azimuth.item() # azimuth = azimuth.item() cam_poses.append([ orbit_camera(20, 30+azimuth, radius=radius), orbit_camera(-10, 90+azimuth, radius=radius), orbit_camera(20, 150+azimuth, radius=radius), orbit_camera(-10, 210+azimuth, radius=radius), orbit_camera(20, 270+azimuth, radius=radius), orbit_camera(-10, 330+azimuth, radius=radius), ]) cam_poses = torch.from_numpy(np.stack(cam_poses)).reshape(-1, 4, 4).float().to(device) rays_embeddings = [] for i in range(cam_poses.shape[0]): rays_o, rays_d = get_rays(cam_poses[i], 320//8, 320//8, 30) # [h, w, 3] rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6] rays_embeddings.append(rays_plucker) ## visualize rays for plotting figure # kiui.vis.plot_image(rays_d * 0.5 + 0.5, save=True) rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous().to(device) # [V, 6, h, w] rays_embeddings = rearrange(rays_embeddings, '(B x y) c h w -> B c (x h) (y w)', B=len(azimuths), x=3, y=2) # (B, C, 3H, 2W) return rays_embeddings class RandomCutout(object): """Randomly apply cutout to a batch of tensor images. The cutout is 3 times smaller than the image.""" def __init__(self, height, width, cutout_prob=0.4): self.height = height self.width = width self.cutout_prob = cutout_prob self.mask_height = height // 3 # Cutout size 3 times smaller self.mask_width = width // 3 def __call__(self, img, is_train=True): if not is_train: return img batch, channels, height, width = img.shape # Generate a random number for each image in the batch to decide cutout application random_seeds = torch.rand(batch, device=img.device) # Masks should start as ones (no cutout effect) masks = torch.ones((batch, height, width), device=img.device, dtype=img.dtype) # Random coordinates for cutout top = torch.randint(low=height // 10, high=height - self.mask_height - height // 10, size=(batch,), device=img.device) left = torch.randint(low=width // 10, high=width - self.mask_width - width // 10, size=(batch,), device=img.device) # Create a range grid for height and width hh, ww = torch.meshgrid(torch.arange(0, height, device=img.device), torch.arange(0, width, device=img.device), indexing='ij') # Expand top and left to match batch dimensions for broadcasting top = top[:, None, None] left = left[:, None, None] # Create the cutout mask where cutouts should happen cutout_masks = (hh >= top) & (hh < top + self.mask_height) & (ww >= left) & (ww < left + self.mask_width) # Apply cutout_masks where random seeds fall below the probability masks[cutout_masks & (random_seeds[:, None, None] < self.cutout_prob)] = 0 return img * masks[:,None] def remove_background(image: PIL.Image.Image, rembg_session: Any = None, force: bool = False, **rembg_kwargs, ) -> PIL.Image.Image: do_remove = True if image.mode == "RGBA" and image.getextrema()[3][0] < 255: do_remove = False do_remove = do_remove or force if do_remove: image = rembg.remove(image, session=rembg_session, **rembg_kwargs) return image def resize_foreground( image: PIL.Image.Image, ratio: float, ) -> PIL.Image.Image: image = np.array(image) assert image.shape[-1] == 4 alpha = np.where(image[..., 3] > 0) y1, y2, x1, x2 = ( alpha[0].min(), alpha[0].max(), alpha[1].min(), alpha[1].max(), ) # crop the foreground fg = image[y1:y2, x1:x2] # pad to square size = max(fg.shape[0], fg.shape[1]) ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2 ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0 new_image = np.pad( fg, ((ph0, ph1), (pw0, pw1), (0, 0)), mode="constant", constant_values=((0, 0), (0, 0), (0, 0)), ) # compute padding according to the ratio new_size = int(new_image.shape[0] / ratio) # pad to size, double side ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2 ph1, pw1 = new_size - size - ph0, new_size - size - pw0 new_image = np.pad( new_image, ((ph0, ph1), (pw0, pw1), (0, 0)), mode="constant", constant_values=((0, 0), (0, 0), (0, 0)), ) new_image = PIL.Image.fromarray(new_image) return new_image def images_to_video( images: torch.Tensor, output_path: str, fps: int = 30, ) -> None: # images: (N, C, H, W) video_dir = os.path.dirname(output_path) video_name = os.path.basename(output_path) os.makedirs(video_dir, exist_ok=True) frames = [] for i in range(len(images)): frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \ f"Frame shape mismatch: {frame.shape} vs {images.shape}" assert frame.min() >= 0 and frame.max() <= 255, \ f"Frame value out of range: {frame.min()} ~ {frame.max()}" frames.append(frame) imageio.mimwrite(output_path, np.stack(frames), fps=fps, quality=10) def save_video( frames: torch.Tensor, output_path: str, fps: int = 30, ) -> None: # images: (N, C, H, W) frames = [(frame.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) for frame in frames] writer = imageio.get_writer(output_path, fps=fps) for frame in frames: writer.append_data(frame) writer.close()