Feat2GS / utils /graphics_utils.py
faneggg's picture
init
123719b
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact [email protected]
#
import torch
import math
import numpy as np
from typing import NamedTuple
import torch.nn.functional as F
from torch import Tensor
class BasicPointCloud(NamedTuple):
points : np.array
colors : np.array
normals : np.array
features: np.array
def geom_transform_points(points, transf_matrix):
P, _ = points.shape
ones = torch.ones(P, 1, dtype=points.dtype, device=points.device)
points_hom = torch.cat([points, ones], dim=1)
points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0))
denom = points_out[..., 3:] + 0.0000001
return (points_out[..., :3] / denom).squeeze(dim=0)
def getWorld2View(R, t):
Rt = np.zeros((4, 4))
Rt[:3, :3] = R.transpose()
Rt[:3, 3] = t
Rt[3, 3] = 1.0
return np.float32(Rt)
def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
Rt = np.zeros((4, 4))
Rt[:3, :3] = R.transpose()
Rt[:3, 3] = t
Rt[3, 3] = 1.0
C2W = np.linalg.inv(Rt)
cam_center = C2W[:3, 3]
cam_center = (cam_center + translate) * scale
C2W[:3, 3] = cam_center
Rt = np.linalg.inv(C2W)
return np.float32(Rt)
def getWorld2View2_torch(R, t, translate=torch.tensor([0.0, 0.0, 0.0]), scale=1.0):
translate = torch.tensor(translate, dtype=torch.float32)
# Initialize the transformation matrix
Rt = torch.zeros((4, 4), dtype=torch.float32)
Rt[:3, :3] = R.t() # Transpose of R
Rt[:3, 3] = t
Rt[3, 3] = 1.0
# Compute the inverse to get the camera-to-world transformation
C2W = torch.linalg.inv(Rt)
cam_center = C2W[:3, 3]
cam_center = (cam_center + translate) * scale
C2W[:3, 3] = cam_center
# Invert again to get the world-to-view transformation
Rt = torch.linalg.inv(C2W)
return Rt
def getProjectionMatrix(znear, zfar, fovX, fovY):
tanHalfFovY = math.tan((fovY / 2))
tanHalfFovX = math.tan((fovX / 2))
top = tanHalfFovY * znear
bottom = -top
right = tanHalfFovX * znear
left = -right
P = torch.zeros(4, 4)
z_sign = 1.0
P[0, 0] = 2.0 * znear / (right - left)
P[1, 1] = 2.0 * znear / (top - bottom)
P[0, 2] = (right + left) / (right - left)
P[1, 2] = (top + bottom) / (top - bottom)
P[3, 2] = z_sign
P[2, 2] = z_sign * zfar / (zfar - znear)
P[2, 3] = -(zfar * znear) / (zfar - znear)
return P
def fov2focal(fov, pixels):
return pixels / (2 * math.tan(fov / 2))
def focal2fov(focal, pixels):
return 2*math.atan(pixels/(2*focal))
def resize_render(view, size=None):
image_size = size if size is not None else max(view.image_width, view.image_height)
view.original_image = torch.zeros((3, image_size, image_size), device=view.original_image.device)
focal_length_x = fov2focal(view.FoVx, view.image_width)
focal_length_y = fov2focal(view.FoVy, view.image_height)
view.image_width = image_size
view.image_height = image_size
view.FoVx = focal2fov(focal_length_x, image_size)
view.FoVy = focal2fov(focal_length_y, image_size)
return view
def make_video_divisble(
video: torch.Tensor | np.ndarray, block_size=16
) -> torch.Tensor | np.ndarray:
H, W = video.shape[1:3]
H_new = H - H % block_size
W_new = W - W % block_size
return video[:, :H_new, :W_new]
def depth_to_points(
depths: Tensor, camtoworlds: Tensor, Ks: Tensor, z_depth: bool = True
) -> Tensor:
"""Convert depth maps to 3D points
Args:
depths: Depth maps [..., H, W, 1]
camtoworlds: Camera-to-world transformation matrices [..., 4, 4]
Ks: Camera intrinsics [..., 3, 3]
z_depth: Whether the depth is in z-depth (True) or ray depth (False)
Returns:
points: 3D points in the world coordinate system [..., H, W, 3]
"""
assert depths.shape[-1] == 1, f"Invalid depth shape: {depths.shape}"
assert camtoworlds.shape[-2:] == (
4,
4,
), f"Invalid viewmats shape: {camtoworlds.shape}"
assert Ks.shape[-2:] == (3, 3), f"Invalid Ks shape: {Ks.shape}"
assert (
depths.shape[:-3] == camtoworlds.shape[:-2] == Ks.shape[:-2]
), f"Shape mismatch! depths: {depths.shape}, viewmats: {camtoworlds.shape}, Ks: {Ks.shape}"
device = depths.device
height, width = depths.shape[-3:-1]
x, y = torch.meshgrid(
torch.arange(width, device=device),
torch.arange(height, device=device),
indexing="xy",
) # [H, W]
fx = Ks[..., 0, 0] # [...]
fy = Ks[..., 1, 1] # [...]
cx = Ks[..., 0, 2] # [...]
cy = Ks[..., 1, 2] # [...]
# camera directions in camera coordinates
camera_dirs = F.pad(
torch.stack(
[
(x - cx[..., None, None] + 0.5) / fx[..., None, None],
(y - cy[..., None, None] + 0.5) / fy[..., None, None],
],
dim=-1,
),
(0, 1),
value=1.0,
) # [..., H, W, 3]
# ray directions in world coordinates
directions = torch.einsum(
"...ij,...hwj->...hwi", camtoworlds[..., :3, :3], camera_dirs
) # [..., H, W, 3]
origins = camtoworlds[..., :3, -1] # [..., 3]
if not z_depth:
directions = F.normalize(directions, dim=-1)
points = origins[..., None, None, :] + depths * directions
return points
def depth_to_normal(
depths: Tensor, camtoworlds: Tensor, Ks: Tensor, z_depth: bool = True
) -> Tensor:
"""Convert depth maps to surface normals
Args:
depths: Depth maps [..., H, W, 1]
camtoworlds: Camera-to-world transformation matrices [..., 4, 4]
Ks: Camera intrinsics [..., 3, 3]
z_depth: Whether the depth is in z-depth (True) or ray depth (False)
Returns:
normals: Surface normals in the world coordinate system [..., H, W, 3]
"""
points = depth_to_points(depths, camtoworlds, Ks, z_depth=z_depth) # [..., H, W, 3]
dx = torch.cat(
[points[..., 2:, 1:-1, :] - points[..., :-2, 1:-1, :]], dim=-3
) # [..., H-2, W-2, 3]
dy = torch.cat(
[points[..., 1:-1, 2:, :] - points[..., 1:-1, :-2, :]], dim=-2
) # [..., H-2, W-2, 3]
normals = F.normalize(torch.cross(dx, dy, dim=-1), dim=-1) # [..., H-2, W-2, 3]
normals = F.pad(normals, (0, 0, 1, 1, 1, 1), value=0.0) # [..., H, W, 3]
return normals