Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from kornia.core import Tensor, concatenate | |
import torch | |
import math | |
import numpy as np | |
from torch import nn | |
from kiui.cam import orbit_camera | |
# gaussian splatting utils.graphics_utils | |
def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): | |
Rt = np.zeros((4, 4)) | |
Rt[:3, :3] = R.transpose() | |
Rt[:3, 3] = t | |
Rt[3, 3] = 1.0 | |
C2W = np.linalg.inv(Rt) | |
cam_center = C2W[:3, 3] | |
cam_center = (cam_center + translate) * scale | |
C2W[:3, 3] = cam_center | |
Rt = np.linalg.inv(C2W) | |
return np.float32(Rt) | |
def getProjectionMatrix(znear, zfar, fovX, fovY): | |
tanHalfFovY = math.tan((fovY / 2)) | |
tanHalfFovX = math.tan((fovX / 2)) | |
top = tanHalfFovY * znear | |
bottom = -top | |
right = tanHalfFovX * znear | |
left = -right | |
P = torch.zeros(4, 4) | |
z_sign = 1.0 | |
P[0, 0] = 2.0 * znear / (right - left) | |
P[1, 1] = 2.0 * znear / (top - bottom) | |
P[0, 2] = (right + left) / (right - left) | |
P[1, 2] = (top + bottom) / (top - bottom) | |
P[3, 2] = z_sign | |
P[2, 2] = z_sign * zfar / (zfar - znear) | |
P[2, 3] = -(zfar * znear) / (zfar - znear) | |
return P | |
def fov2focal(fov, pixels): | |
return pixels / (2 * math.tan(fov / 2)) | |
def focal2fov(focal, pixels): | |
return 2*math.atan(pixels/(2*focal)) | |
# gaussian splatting scene.camera | |
class Camera(nn.Module): | |
def __init__(self, R, T, FoVx, FoVy, | |
trans=np.array([0.0, 0.0, 0.0]), scale=1.0 | |
): | |
super(Camera, self).__init__() | |
self.R = R | |
self.T = T | |
self.FoVx = FoVx | |
self.FoVy = FoVy | |
self.zfar = 100.0 | |
self.znear = 0.01 | |
self.trans = trans | |
self.scale = scale | |
self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1) | |
self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1) | |
self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) | |
self.camera_center = self.world_view_transform.inverse()[3, :3] | |
# gaussian splatting utils.camera_utils | |
def loadCam(c2w, fovx, image_height=512, image_width=512): | |
# load_camera | |
w2c = np.linalg.inv(c2w) | |
R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code | |
T = w2c[:3, 3] | |
fovy = focal2fov(fov2focal(fovx, image_width), image_height) | |
FovY = fovy | |
FovX = fovx | |
return Camera(R=R, T=T, | |
FoVx=FovX, FoVy=FovY) | |
# epipolar calculation related | |
def fundamental_from_projections(P1: Tensor, P2: Tensor) -> Tensor: | |
r"""Get the Fundamental matrix from Projection matrices. | |
Args: | |
P1: The projection matrix from first camera with shape :math:`(*, 3, 4)`. | |
P2: The projection matrix from second camera with shape :math:`(*, 3, 4)`. | |
Returns: | |
The fundamental matrix with shape :math:`(*, 3, 3)`. | |
""" | |
if not (len(P1.shape) >= 2 and P1.shape[-2:] == (3, 4)): | |
raise AssertionError(P1.shape) | |
if not (len(P2.shape) >= 2 and P2.shape[-2:] == (3, 4)): | |
raise AssertionError(P2.shape) | |
if P1.shape[:-2] != P2.shape[:-2]: | |
raise AssertionError | |
def vstack(x: Tensor, y: Tensor) -> Tensor: | |
return concatenate([x, y], dim=-2) | |
X1 = P1[..., 1:, :] | |
X2 = vstack(P1[..., 2:3, :], P1[..., 0:1, :]) | |
X3 = P1[..., :2, :] | |
Y1 = P2[..., 1:, :] | |
Y2 = vstack(P2[..., 2:3, :], P2[..., 0:1, :]) | |
Y3 = P2[..., :2, :] | |
X1Y1, X2Y1, X3Y1 = vstack(X1, Y1), vstack(X2, Y1), vstack(X3, Y1) | |
X1Y2, X2Y2, X3Y2 = vstack(X1, Y2), vstack(X2, Y2), vstack(X3, Y2) | |
X1Y3, X2Y3, X3Y3 = vstack(X1, Y3), vstack(X2, Y3), vstack(X3, Y3) | |
F_vec = torch.cat( | |
[ | |
X1Y1.det().reshape(-1, 1), | |
X2Y1.det().reshape(-1, 1), | |
X3Y1.det().reshape(-1, 1), | |
X1Y2.det().reshape(-1, 1), | |
X2Y2.det().reshape(-1, 1), | |
X3Y2.det().reshape(-1, 1), | |
X1Y3.det().reshape(-1, 1), | |
X2Y3.det().reshape(-1, 1), | |
X3Y3.det().reshape(-1, 1), | |
], | |
dim=1, | |
) | |
return F_vec.view(*P1.shape[:-2], 3, 3) | |
def get_fundamental_matrix_with_H(cam1, cam2, current_H, current_W): | |
NDC_2_pixel = torch.tensor([[current_W / 2, 0, current_W / 2], [0, current_H / 2, current_H / 2], [0, 0, 1]]) | |
# NDC_2_pixel_inversed = torch.inverse(NDC_2_pixel) | |
NDC_2_pixel = NDC_2_pixel.float() | |
cam_1_tranformation = cam1.full_proj_transform[:, [0,1,3]].T.float() | |
cam_2_tranformation = cam2.full_proj_transform[:, [0,1,3]].T.float() | |
cam_1_pixel = NDC_2_pixel@cam_1_tranformation | |
cam_2_pixel = NDC_2_pixel@cam_2_tranformation | |
# print(NDC_2_pixel.dtype, cam_1_tranformation.dtype, cam_2_tranformation.dtype, cam_1_pixel.dtype, cam_2_pixel.dtype) | |
cam_1_pixel = cam_1_pixel.float() | |
cam_2_pixel = cam_2_pixel.float() | |
# print("cam_1", cam_1_pixel.dtype, cam_1_pixel.shape) | |
# print("cam_2", cam_2_pixel.dtype, cam_2_pixel.shape) | |
# print(NDC_2_pixel@cam_1_tranformation, NDC_2_pixel@cam_2_tranformation) | |
return fundamental_from_projections(cam_1_pixel, cam_2_pixel) | |
def point_to_line_dist(points, lines): | |
""" | |
Calculate the distance from points to lines in 2D. | |
points: Nx3 | |
lines: Mx3 | |
return distance: NxM | |
""" | |
numerator = torch.abs(lines @ points.T) | |
denominator = torch.linalg.norm(lines[:,:2], dim=1, keepdim=True) | |
return numerator / denominator | |
def compute_epipolar_constrains(cam1, cam2, current_H=64, current_W=64): | |
n_frames = 1 | |
# sequence_length = current_W * current_H | |
fundamental_matrix_1 = [] | |
fundamental_matrix_1.append(get_fundamental_matrix_with_H(cam1, cam2, current_H, current_W)) | |
fundamental_matrix_1 = torch.stack(fundamental_matrix_1, dim=0) | |
x = torch.arange(current_W) | |
y = torch.arange(current_H) | |
x, y = torch.meshgrid(x, y, indexing='xy') | |
x = x.reshape(-1) | |
y = y.reshape(-1) | |
heto_cam2 = torch.stack([x, y, torch.ones(size=(len(x),))], dim=1).view(-1, 3) | |
heto_cam1 = torch.stack([x, y, torch.ones(size=(len(x),))], dim=1).view(-1, 3) | |
# epipolar_line: n_frames X seq_len, 3 | |
line1 = (heto_cam2.unsqueeze(0).repeat(n_frames, 1, 1) @ fundamental_matrix_1).view(-1, 3) | |
distance1 = point_to_line_dist(heto_cam1, line1) | |
idx1_epipolar = distance1 > 1 # sequence_length x sequence_lengths | |
return idx1_epipolar | |
def compute_camera_distance(cams, key_cams): | |
cam_centers = [cam.camera_center for cam in cams] | |
key_cam_centers = [cam.camera_center for cam in key_cams] | |
cam_centers = torch.stack(cam_centers) | |
key_cam_centers = torch.stack(key_cam_centers) | |
cam_distance = torch.cdist(cam_centers, key_cam_centers) | |
return cam_distance | |
def get_intri(target_im=None, h=None, w=None, normalize=False): | |
if target_im is None: | |
assert (h is not None and w is not None) | |
else: | |
h, w = target_im.shape[:2] | |
fx = fy = 1422.222 | |
res_raw = 1024 | |
f_x = f_y = fx * h / res_raw | |
K = np.array([f_x, 0, w / 2, 0, f_y, h / 2, 0, 0, 1]).reshape(3, 3) | |
if normalize: # center is [0.5, 0.5], eg3d renderer tradition | |
K[:2] /= h | |
return K | |
def normalize_camera(c, c_frame0): | |
B = c.shape[0] | |
camera_poses = c[:, :16].reshape(B, 4, 4) # 3x4 | |
canonical_camera_poses = c_frame0[:, :16].reshape(1, 4, 4) | |
inverse_canonical_pose = np.linalg.inv(canonical_camera_poses) | |
inverse_canonical_pose = np.repeat(inverse_canonical_pose, B, 0) | |
cam_radius = np.linalg.norm( | |
c_frame0[:, :16].reshape(1, 4, 4)[:, :3, 3], | |
axis=-1, | |
keepdims=False) # since g-buffer adopts dynamic radius here. | |
frame1_fixed_pos = np.repeat(np.eye(4)[None], 1, axis=0) | |
frame1_fixed_pos[:, 2, -1] = -cam_radius | |
transform = frame1_fixed_pos @ inverse_canonical_pose | |
new_camera_poses = np.repeat( | |
transform, 1, axis=0 | |
) @ camera_poses # [v, 4, 4]. np.repeat() is th.repeat_interleave() | |
c = np.concatenate([new_camera_poses.reshape(B, 16), c[:, 16:]], | |
axis=-1) | |
return c | |
def gen_rays(c2w, intrinsics, h, w): | |
# Generate rays | |
yy, xx = torch.meshgrid( | |
torch.arange(h, dtype=torch.float32) + 0.5, | |
torch.arange(w, dtype=torch.float32) + 0.5, | |
indexing='ij') | |
# normalize to 0-1 pixel range | |
yy = yy / h | |
xx = xx / w | |
cx, cy, fx, fy = intrinsics[2], intrinsics[ | |
5], intrinsics[0], intrinsics[4] | |
xx = (xx - cx) / fx | |
yy = (yy - cy) / fy | |
zz = torch.ones_like(xx) | |
dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention | |
dirs /= torch.norm(dirs, dim=-1, keepdim=True) | |
dirs = dirs.reshape(-1, 3, 1) | |
del xx, yy, zz | |
dirs = (c2w[None, :3, :3] @ dirs)[..., 0] | |
origins = c2w[None, :3, 3].expand(h * w, -1).contiguous() | |
origins = origins.view(h, w, 3) | |
dirs = dirs.view(h, w, 3) | |
return origins, dirs | |
def get_c2ws(elevations, amuziths, camera_radius=1.5): | |
c2ws = np.stack([ | |
orbit_camera(elevation, amuzith, radius=camera_radius) for elevation, amuzith in zip(elevations, amuziths) | |
], axis=0) | |
# change kiui opengl camera system to our camera system | |
c2ws[:, :3, 1:3] *= -1 | |
c2ws[:, [0, 1, 2], :] = c2ws[:, [2, 0, 1], :] | |
c2ws = c2ws.reshape(-1, 16) | |
return c2ws | |
def get_camera_poses(c2ws, fov, h, w, intrinsics=None): | |
if intrinsics is None: | |
intrinsics = get_intri(h=64, w=64, normalize=True).reshape(9) | |
c2ws = normalize_camera(c2ws, c2ws[0:1]) | |
rays_pluckers = [] | |
c2ws = c2ws.reshape((-1, 4, 4)) | |
c2ws = torch.from_numpy(c2ws).float() | |
gs_cams = [] | |
for i, c2w in enumerate(c2ws): | |
gs_cams.append(loadCam(c2w.numpy(), fov, h, w)) | |
rays_o, rays_d = gen_rays(c2w, intrinsics, h, w) | |
rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], | |
dim=-1) # [h, w, 6] | |
rays_pluckers.append(rays_plucker.permute(2, 0, 1)) # [6, h, w] | |
n_views = len(gs_cams) | |
epipolar_constrains = [] | |
cam_distances = [] | |
for i in range(n_views): | |
cur_epipolar_constrains = [] | |
kv_idxs = [(i-1)%n_views, (i+1)%n_views] | |
for kv_idx in kv_idxs: | |
# False means that the position is on the epipolar line | |
cam_epipolar_constrain = compute_epipolar_constrains(gs_cams[kv_idx], gs_cams[i], current_H=h//16, current_W=w//16) | |
cur_epipolar_constrains.append(cam_epipolar_constrain) | |
cam_distances.append(compute_camera_distance([gs_cams[i]], [gs_cams[kv_idxs[0]], gs_cams[kv_idxs[1]]])) # 1, 2 | |
epipolar_constrains.append(torch.stack(cur_epipolar_constrains, dim=0)) | |
rays_pluckers = torch.stack(rays_pluckers) # [v, 6, h, w] | |
cam_distances = torch.cat(cam_distances, dim=0) # [v, 2] | |
epipolar_constrains = torch.stack(epipolar_constrains, dim=0) # [v, 2, 1024, 1024] | |
return rays_pluckers, epipolar_constrains, cam_distances |