import torch from kornia.core import Tensor, concatenate import torch import math import numpy as np from torch import nn from kiui.cam import orbit_camera # gaussian splatting utils.graphics_utils def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): Rt = np.zeros((4, 4)) Rt[:3, :3] = R.transpose() Rt[:3, 3] = t Rt[3, 3] = 1.0 C2W = np.linalg.inv(Rt) cam_center = C2W[:3, 3] cam_center = (cam_center + translate) * scale C2W[:3, 3] = cam_center Rt = np.linalg.inv(C2W) return np.float32(Rt) def getProjectionMatrix(znear, zfar, fovX, fovY): tanHalfFovY = math.tan((fovY / 2)) tanHalfFovX = math.tan((fovX / 2)) top = tanHalfFovY * znear bottom = -top right = tanHalfFovX * znear left = -right P = torch.zeros(4, 4) z_sign = 1.0 P[0, 0] = 2.0 * znear / (right - left) P[1, 1] = 2.0 * znear / (top - bottom) P[0, 2] = (right + left) / (right - left) P[1, 2] = (top + bottom) / (top - bottom) P[3, 2] = z_sign P[2, 2] = z_sign * zfar / (zfar - znear) P[2, 3] = -(zfar * znear) / (zfar - znear) return P def fov2focal(fov, pixels): return pixels / (2 * math.tan(fov / 2)) def focal2fov(focal, pixels): return 2*math.atan(pixels/(2*focal)) # gaussian splatting scene.camera class Camera(nn.Module): def __init__(self, R, T, FoVx, FoVy, trans=np.array([0.0, 0.0, 0.0]), scale=1.0 ): super(Camera, self).__init__() self.R = R self.T = T self.FoVx = FoVx self.FoVy = FoVy self.zfar = 100.0 self.znear = 0.01 self.trans = trans self.scale = scale self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1) self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1) self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) self.camera_center = self.world_view_transform.inverse()[3, :3] # gaussian splatting utils.camera_utils def loadCam(c2w, fovx, image_height=512, image_width=512): # load_camera w2c = np.linalg.inv(c2w) R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code T = w2c[:3, 3] fovy = focal2fov(fov2focal(fovx, image_width), image_height) FovY = fovy FovX = fovx return Camera(R=R, T=T, FoVx=FovX, FoVy=FovY) # epipolar calculation related @torch.no_grad() def fundamental_from_projections(P1: Tensor, P2: Tensor) -> Tensor: r"""Get the Fundamental matrix from Projection matrices. Args: P1: The projection matrix from first camera with shape :math:`(*, 3, 4)`. P2: The projection matrix from second camera with shape :math:`(*, 3, 4)`. Returns: The fundamental matrix with shape :math:`(*, 3, 3)`. """ if not (len(P1.shape) >= 2 and P1.shape[-2:] == (3, 4)): raise AssertionError(P1.shape) if not (len(P2.shape) >= 2 and P2.shape[-2:] == (3, 4)): raise AssertionError(P2.shape) if P1.shape[:-2] != P2.shape[:-2]: raise AssertionError def vstack(x: Tensor, y: Tensor) -> Tensor: return concatenate([x, y], dim=-2) X1 = P1[..., 1:, :] X2 = vstack(P1[..., 2:3, :], P1[..., 0:1, :]) X3 = P1[..., :2, :] Y1 = P2[..., 1:, :] Y2 = vstack(P2[..., 2:3, :], P2[..., 0:1, :]) Y3 = P2[..., :2, :] X1Y1, X2Y1, X3Y1 = vstack(X1, Y1), vstack(X2, Y1), vstack(X3, Y1) X1Y2, X2Y2, X3Y2 = vstack(X1, Y2), vstack(X2, Y2), vstack(X3, Y2) X1Y3, X2Y3, X3Y3 = vstack(X1, Y3), vstack(X2, Y3), vstack(X3, Y3) F_vec = torch.cat( [ X1Y1.det().reshape(-1, 1), X2Y1.det().reshape(-1, 1), X3Y1.det().reshape(-1, 1), X1Y2.det().reshape(-1, 1), X2Y2.det().reshape(-1, 1), X3Y2.det().reshape(-1, 1), X1Y3.det().reshape(-1, 1), X2Y3.det().reshape(-1, 1), X3Y3.det().reshape(-1, 1), ], dim=1, ) return F_vec.view(*P1.shape[:-2], 3, 3) def get_fundamental_matrix_with_H(cam1, cam2, current_H, current_W): NDC_2_pixel = torch.tensor([[current_W / 2, 0, current_W / 2], [0, current_H / 2, current_H / 2], [0, 0, 1]]) # NDC_2_pixel_inversed = torch.inverse(NDC_2_pixel) NDC_2_pixel = NDC_2_pixel.float() cam_1_tranformation = cam1.full_proj_transform[:, [0,1,3]].T.float() cam_2_tranformation = cam2.full_proj_transform[:, [0,1,3]].T.float() cam_1_pixel = NDC_2_pixel@cam_1_tranformation cam_2_pixel = NDC_2_pixel@cam_2_tranformation # print(NDC_2_pixel.dtype, cam_1_tranformation.dtype, cam_2_tranformation.dtype, cam_1_pixel.dtype, cam_2_pixel.dtype) cam_1_pixel = cam_1_pixel.float() cam_2_pixel = cam_2_pixel.float() # print("cam_1", cam_1_pixel.dtype, cam_1_pixel.shape) # print("cam_2", cam_2_pixel.dtype, cam_2_pixel.shape) # print(NDC_2_pixel@cam_1_tranformation, NDC_2_pixel@cam_2_tranformation) return fundamental_from_projections(cam_1_pixel, cam_2_pixel) def point_to_line_dist(points, lines): """ Calculate the distance from points to lines in 2D. points: Nx3 lines: Mx3 return distance: NxM """ numerator = torch.abs(lines @ points.T) denominator = torch.linalg.norm(lines[:,:2], dim=1, keepdim=True) return numerator / denominator def compute_epipolar_constrains(cam1, cam2, current_H=64, current_W=64): n_frames = 1 # sequence_length = current_W * current_H fundamental_matrix_1 = [] fundamental_matrix_1.append(get_fundamental_matrix_with_H(cam1, cam2, current_H, current_W)) fundamental_matrix_1 = torch.stack(fundamental_matrix_1, dim=0) x = torch.arange(current_W) y = torch.arange(current_H) x, y = torch.meshgrid(x, y, indexing='xy') x = x.reshape(-1) y = y.reshape(-1) heto_cam2 = torch.stack([x, y, torch.ones(size=(len(x),))], dim=1).view(-1, 3) heto_cam1 = torch.stack([x, y, torch.ones(size=(len(x),))], dim=1).view(-1, 3) # epipolar_line: n_frames X seq_len, 3 line1 = (heto_cam2.unsqueeze(0).repeat(n_frames, 1, 1) @ fundamental_matrix_1).view(-1, 3) distance1 = point_to_line_dist(heto_cam1, line1) idx1_epipolar = distance1 > 1 # sequence_length x sequence_lengths return idx1_epipolar def compute_camera_distance(cams, key_cams): cam_centers = [cam.camera_center for cam in cams] key_cam_centers = [cam.camera_center for cam in key_cams] cam_centers = torch.stack(cam_centers) key_cam_centers = torch.stack(key_cam_centers) cam_distance = torch.cdist(cam_centers, key_cam_centers) return cam_distance def get_intri(target_im=None, h=None, w=None, normalize=False): if target_im is None: assert (h is not None and w is not None) else: h, w = target_im.shape[:2] fx = fy = 1422.222 res_raw = 1024 f_x = f_y = fx * h / res_raw K = np.array([f_x, 0, w / 2, 0, f_y, h / 2, 0, 0, 1]).reshape(3, 3) if normalize: # center is [0.5, 0.5], eg3d renderer tradition K[:2] /= h return K def normalize_camera(c, c_frame0): B = c.shape[0] camera_poses = c[:, :16].reshape(B, 4, 4) # 3x4 canonical_camera_poses = c_frame0[:, :16].reshape(1, 4, 4) inverse_canonical_pose = np.linalg.inv(canonical_camera_poses) inverse_canonical_pose = np.repeat(inverse_canonical_pose, B, 0) cam_radius = np.linalg.norm( c_frame0[:, :16].reshape(1, 4, 4)[:, :3, 3], axis=-1, keepdims=False) # since g-buffer adopts dynamic radius here. frame1_fixed_pos = np.repeat(np.eye(4)[None], 1, axis=0) frame1_fixed_pos[:, 2, -1] = -cam_radius transform = frame1_fixed_pos @ inverse_canonical_pose new_camera_poses = np.repeat( transform, 1, axis=0 ) @ camera_poses # [v, 4, 4]. np.repeat() is th.repeat_interleave() c = np.concatenate([new_camera_poses.reshape(B, 16), c[:, 16:]], axis=-1) return c def gen_rays(c2w, intrinsics, h, w): # Generate rays yy, xx = torch.meshgrid( torch.arange(h, dtype=torch.float32) + 0.5, torch.arange(w, dtype=torch.float32) + 0.5, indexing='ij') # normalize to 0-1 pixel range yy = yy / h xx = xx / w cx, cy, fx, fy = intrinsics[2], intrinsics[ 5], intrinsics[0], intrinsics[4] xx = (xx - cx) / fx yy = (yy - cy) / fy zz = torch.ones_like(xx) dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention dirs /= torch.norm(dirs, dim=-1, keepdim=True) dirs = dirs.reshape(-1, 3, 1) del xx, yy, zz dirs = (c2w[None, :3, :3] @ dirs)[..., 0] origins = c2w[None, :3, 3].expand(h * w, -1).contiguous() origins = origins.view(h, w, 3) dirs = dirs.view(h, w, 3) return origins, dirs def get_c2ws(elevations, amuziths, camera_radius=1.5): c2ws = np.stack([ orbit_camera(elevation, amuzith, radius=camera_radius) for elevation, amuzith in zip(elevations, amuziths) ], axis=0) # change kiui opengl camera system to our camera system c2ws[:, :3, 1:3] *= -1 c2ws[:, [0, 1, 2], :] = c2ws[:, [2, 0, 1], :] c2ws = c2ws.reshape(-1, 16) return c2ws def get_camera_poses(c2ws, fov, h, w, intrinsics=None): if intrinsics is None: intrinsics = get_intri(h=64, w=64, normalize=True).reshape(9) c2ws = normalize_camera(c2ws, c2ws[0:1]) rays_pluckers = [] c2ws = c2ws.reshape((-1, 4, 4)) c2ws = torch.from_numpy(c2ws).float() gs_cams = [] for i, c2w in enumerate(c2ws): gs_cams.append(loadCam(c2w.numpy(), fov, h, w)) rays_o, rays_d = gen_rays(c2w, intrinsics, h, w) rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6] rays_pluckers.append(rays_plucker.permute(2, 0, 1)) # [6, h, w] n_views = len(gs_cams) epipolar_constrains = [] cam_distances = [] for i in range(n_views): cur_epipolar_constrains = [] kv_idxs = [(i-1)%n_views, (i+1)%n_views] for kv_idx in kv_idxs: # False means that the position is on the epipolar line cam_epipolar_constrain = compute_epipolar_constrains(gs_cams[kv_idx], gs_cams[i], current_H=h//16, current_W=w//16) cur_epipolar_constrains.append(cam_epipolar_constrain) cam_distances.append(compute_camera_distance([gs_cams[i]], [gs_cams[kv_idxs[0]], gs_cams[kv_idxs[1]]])) # 1, 2 epipolar_constrains.append(torch.stack(cur_epipolar_constrains, dim=0)) rays_pluckers = torch.stack(rays_pluckers) # [v, 6, h, w] cam_distances = torch.cat(cam_distances, dim=0) # [v, 2] epipolar_constrains = torch.stack(epipolar_constrains, dim=0) # [v, 2, 1024, 1024] return rays_pluckers, epipolar_constrains, cam_distances