Spaces:
Running
on
Zero
Running
on
Zero
import os | |
from einops import rearrange | |
import imageio | |
import rembg | |
import torch | |
import numpy as np | |
import PIL.Image | |
from PIL import Image | |
from typing import Any | |
import torch | |
import random | |
from kiui.op import safe_normalize | |
import torch.nn.functional as F | |
from kiui.cam import orbit_camera | |
def get_rays(pose, h, w, fovy, opengl=True): | |
x, y = torch.meshgrid( | |
torch.arange(w, device=pose.device), | |
torch.arange(h, device=pose.device), | |
indexing="xy", | |
) | |
x = x.flatten() | |
y = y.flatten() | |
cx = w * 0.5 | |
cy = h * 0.5 | |
focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy)) | |
camera_dirs = F.pad( | |
torch.stack( | |
[ | |
(x - cx + 0.5) / focal, | |
(y - cy + 0.5) / focal * (-1.0 if opengl else 1.0), | |
], | |
dim=-1, | |
), | |
(0, 1), | |
value=(-1.0 if opengl else 1.0), | |
) # [hw, 3] | |
rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3] | |
rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3] | |
rays_o = rays_o.view(h, w, 3) | |
rays_d = safe_normalize(rays_d).view(h, w, 3) | |
return rays_o, rays_d | |
def prepare_default_rays(device, azimuths=0, B=1): | |
fov = 30 | |
radius = (0.5 / np.tan(np.radians(fov/2))) | |
cam_poses = [] | |
for azimuth in azimuths: | |
if isinstance(azimuth, torch.Tensor): | |
azimuth = azimuth.item() | |
# azimuth = azimuth.item() | |
cam_poses.append([ | |
orbit_camera(20, 30+azimuth, radius=radius), | |
orbit_camera(-10, 90+azimuth, radius=radius), | |
orbit_camera(20, 150+azimuth, radius=radius), | |
orbit_camera(-10, 210+azimuth, radius=radius), | |
orbit_camera(20, 270+azimuth, radius=radius), | |
orbit_camera(-10, 330+azimuth, radius=radius), | |
]) | |
cam_poses = torch.from_numpy(np.stack(cam_poses)).reshape(-1, 4, 4).float().to(device) | |
rays_embeddings = [] | |
for i in range(cam_poses.shape[0]): | |
rays_o, rays_d = get_rays(cam_poses[i], 320//8, 320//8, 30) # [h, w, 3] | |
rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6] | |
rays_embeddings.append(rays_plucker) | |
## visualize rays for plotting figure | |
# kiui.vis.plot_image(rays_d * 0.5 + 0.5, save=True) | |
rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous().to(device) # [V, 6, h, w] | |
rays_embeddings = rearrange(rays_embeddings, '(B x y) c h w -> B c (x h) (y w)', B=len(azimuths), x=3, y=2) # (B, C, 3H, 2W) | |
return rays_embeddings | |
class RandomCutout(object): | |
"""Randomly apply cutout to a batch of tensor images. The cutout is 3 times smaller than the image.""" | |
def __init__(self, height, width, cutout_prob=0.4): | |
self.height = height | |
self.width = width | |
self.cutout_prob = cutout_prob | |
self.mask_height = height // 3 # Cutout size 3 times smaller | |
self.mask_width = width // 3 | |
def __call__(self, img, is_train=True): | |
if not is_train: | |
return img | |
batch, channels, height, width = img.shape | |
# Generate a random number for each image in the batch to decide cutout application | |
random_seeds = torch.rand(batch, device=img.device) | |
# Masks should start as ones (no cutout effect) | |
masks = torch.ones((batch, height, width), device=img.device, dtype=img.dtype) | |
# Random coordinates for cutout | |
top = torch.randint(low=height // 10, high=height - self.mask_height - height // 10, size=(batch,), device=img.device) | |
left = torch.randint(low=width // 10, high=width - self.mask_width - width // 10, size=(batch,), device=img.device) | |
# Create a range grid for height and width | |
hh, ww = torch.meshgrid(torch.arange(0, height, device=img.device), torch.arange(0, width, device=img.device), indexing='ij') | |
# Expand top and left to match batch dimensions for broadcasting | |
top = top[:, None, None] | |
left = left[:, None, None] | |
# Create the cutout mask where cutouts should happen | |
cutout_masks = (hh >= top) & (hh < top + self.mask_height) & (ww >= left) & (ww < left + self.mask_width) | |
# Apply cutout_masks where random seeds fall below the probability | |
masks[cutout_masks & (random_seeds[:, None, None] < self.cutout_prob)] = 0 | |
return img * masks[:,None] | |
def remove_background(image: PIL.Image.Image, | |
rembg_session: Any = None, | |
force: bool = False, | |
**rembg_kwargs, | |
) -> PIL.Image.Image: | |
do_remove = True | |
if image.mode == "RGBA" and image.getextrema()[3][0] < 255: | |
do_remove = False | |
do_remove = do_remove or force | |
if do_remove: | |
image = rembg.remove(image, session=rembg_session, **rembg_kwargs) | |
return image | |
def resize_foreground( | |
image: PIL.Image.Image, | |
ratio: float, | |
) -> PIL.Image.Image: | |
image = np.array(image) | |
assert image.shape[-1] == 4 | |
alpha = np.where(image[..., 3] > 0) | |
y1, y2, x1, x2 = ( | |
alpha[0].min(), | |
alpha[0].max(), | |
alpha[1].min(), | |
alpha[1].max(), | |
) | |
# crop the foreground | |
fg = image[y1:y2, x1:x2] | |
# pad to square | |
size = max(fg.shape[0], fg.shape[1]) | |
ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2 | |
ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0 | |
new_image = np.pad( | |
fg, | |
((ph0, ph1), (pw0, pw1), (0, 0)), | |
mode="constant", | |
constant_values=((0, 0), (0, 0), (0, 0)), | |
) | |
# compute padding according to the ratio | |
new_size = int(new_image.shape[0] / ratio) | |
# pad to size, double side | |
ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2 | |
ph1, pw1 = new_size - size - ph0, new_size - size - pw0 | |
new_image = np.pad( | |
new_image, | |
((ph0, ph1), (pw0, pw1), (0, 0)), | |
mode="constant", | |
constant_values=((0, 0), (0, 0), (0, 0)), | |
) | |
new_image = PIL.Image.fromarray(new_image) | |
return new_image | |
def images_to_video( | |
images: torch.Tensor, | |
output_path: str, | |
fps: int = 30, | |
) -> None: | |
# images: (N, C, H, W) | |
video_dir = os.path.dirname(output_path) | |
video_name = os.path.basename(output_path) | |
os.makedirs(video_dir, exist_ok=True) | |
frames = [] | |
for i in range(len(images)): | |
frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) | |
assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \ | |
f"Frame shape mismatch: {frame.shape} vs {images.shape}" | |
assert frame.min() >= 0 and frame.max() <= 255, \ | |
f"Frame value out of range: {frame.min()} ~ {frame.max()}" | |
frames.append(frame) | |
imageio.mimwrite(output_path, np.stack(frames), fps=fps, quality=10) | |
def save_video( | |
frames: torch.Tensor, | |
output_path: str, | |
fps: int = 30, | |
) -> None: | |
# images: (N, C, H, W) | |
frames = [(frame.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) for frame in frames] | |
writer = imageio.get_writer(output_path, fps=fps) | |
for frame in frames: | |
writer.append_data(frame) | |
writer.close() |