import numpy as np import roma import torch import torch.nn.functional as F def rt_to_mat4( R: torch.Tensor, t: torch.Tensor, s: torch.Tensor | None = None ) -> torch.Tensor: """ Args: R (torch.Tensor): (..., 3, 3). t (torch.Tensor): (..., 3). s (torch.Tensor): (...,). Returns: torch.Tensor: (..., 4, 4) """ mat34 = torch.cat([R, t[..., None]], dim=-1) if s is None: bottom = ( mat34.new_tensor([[0.0, 0.0, 0.0, 1.0]]) .reshape((1,) * (mat34.dim() - 2) + (1, 4)) .expand(mat34.shape[:-2] + (1, 4)) ) else: bottom = F.pad(1.0 / s[..., None, None], (3, 0), value=0.0) mat4 = torch.cat([mat34, bottom], dim=-2) return mat4 def get_avg_w2c(w2cs: torch.Tensor): c2ws = torch.linalg.inv(w2cs) # 1. Compute the center center = c2ws[:, :3, -1].mean(0) # 2. Compute the z axis z = F.normalize(c2ws[:, :3, 2].mean(0), dim=-1) # 3. Compute axis y' (no need to normalize as it's not the final output) y_ = c2ws[:, :3, 1].mean(0) # (3) # 4. Compute the x axis x = F.normalize(torch.cross(y_, z, dim=-1), dim=-1) # (3) # 5. Compute the y axis (as z and x are normalized, y is already of norm 1) y = torch.cross(z, x, dim=-1) # (3) avg_c2w = rt_to_mat4(torch.stack([x, y, z], 1), center) avg_w2c = torch.linalg.inv(avg_c2w) return avg_w2c # def get_lookat(origins: torch.Tensor, viewdirs: torch.Tensor) -> torch.Tensor: # """Calculate the intersection point of multiple camera rays as the lookat point. # Use the center of camera positions as a reference point for the lookat, # then move forward along the average view direction by a certain distance. # """ # # Calculate the center of camera positions # center = origins.mean(dim=0) # # Calculate average view direction # mean_dir = F.normalize(viewdirs.mean(dim=0), dim=-1) # # Calculate average distance to the center point # avg_dist = torch.norm(origins - center, dim=-1).mean() # # Move forward along the average view direction # lookat = center + mean_dir * avg_dist # return lookat def get_lookat(origins: torch.Tensor, viewdirs: torch.Tensor) -> torch.Tensor: """Triangulate a set of rays to find a single lookat point. Args: origins (torch.Tensor): A (N, 3) array of ray origins. viewdirs (torch.Tensor): A (N, 3) array of ray view directions. Returns: torch.Tensor: A (3,) lookat point. """ viewdirs = torch.nn.functional.normalize(viewdirs, dim=-1) eye = torch.eye(3, device=origins.device, dtype=origins.dtype)[None] # Calculate projection matrix I - rr^T I_min_cov = eye - (viewdirs[..., None] * viewdirs[..., None, :]) # Compute sum of projections sum_proj = I_min_cov.matmul(origins[..., None]).sum(dim=-3) # Solve for the intersection point using least squares lookat = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0] # Check NaNs. assert not torch.any(torch.isnan(lookat)) return lookat def get_lookat_w2cs(positions: torch.Tensor, lookat: torch.Tensor, up: torch.Tensor): """ Args: positions: (N, 3) tensor of camera positions lookat: (3,) tensor of lookat point up: (3,) tensor of up vector Returns: w2cs: (N, 3, 3) tensor of world to camera rotation matrices """ forward_vectors = F.normalize(lookat - positions, dim=-1) right_vectors = F.normalize(torch.cross(forward_vectors, up[None], dim=-1), dim=-1) down_vectors = F.normalize( torch.cross(forward_vectors, right_vectors, dim=-1), dim=-1 ) Rs = torch.stack([right_vectors, down_vectors, forward_vectors], dim=-1) w2cs = torch.linalg.inv(rt_to_mat4(Rs, positions)) return w2cs def get_arc_w2cs( ref_w2c: torch.Tensor, lookat: torch.Tensor, up: torch.Tensor, num_frames: int, degree: float, **_, ) -> torch.Tensor: ref_position = torch.linalg.inv(ref_w2c)[:3, 3] thetas = ( torch.sin( torch.linspace(0.0, torch.pi * 2.0, num_frames + 1, device=ref_w2c.device)[ :-1 ] ) * (degree / 2.0) / 180.0 * torch.pi ) positions = torch.einsum( "nij,j->ni", roma.rotvec_to_rotmat(thetas[:, None] * up[None]), ref_position - lookat, ) return get_lookat_w2cs(positions, lookat, up) def get_lemniscate_w2cs( ref_w2c: torch.Tensor, lookat: torch.Tensor, up: torch.Tensor, num_frames: int, degree: float, **_, ) -> torch.Tensor: ref_c2w = torch.linalg.inv(ref_w2c) a = torch.linalg.norm(ref_c2w[:3, 3] - lookat) * np.tan(degree / 360 * np.pi) # Lemniscate curve in camera space. Starting at the origin. thetas = ( torch.linspace(0, 2 * torch.pi, num_frames + 1, device=ref_w2c.device)[:-1] + torch.pi / 2 ) positions = torch.stack( [ a * torch.cos(thetas) / (1 + torch.sin(thetas) ** 2), a * torch.cos(thetas) * torch.sin(thetas) / (1 + torch.sin(thetas) ** 2), torch.zeros(num_frames, device=ref_w2c.device), ], dim=-1, ) # Transform to world space. positions = torch.einsum( "ij,nj->ni", ref_c2w[:3], F.pad(positions, (0, 1), value=1.0) ) return get_lookat_w2cs(positions, lookat, up) def get_spiral_w2cs( ref_w2c: torch.Tensor, lookat: torch.Tensor, up: torch.Tensor, num_frames: int, rads: float | torch.Tensor, zrate: float, rots: int, **_, ) -> torch.Tensor: ref_c2w = torch.linalg.inv(ref_w2c) thetas = torch.linspace( 0, 2 * torch.pi * rots, num_frames + 1, device=ref_w2c.device )[:-1] # Spiral curve in camera space. Starting at the origin. if isinstance(rads, torch.Tensor): rads = rads.reshape(-1, 3).to(ref_w2c.device) positions = ( torch.stack( [ torch.cos(thetas), -torch.sin(thetas), -torch.sin(thetas * zrate), ], dim=-1, ) * rads ) # Transform to world space. positions = torch.einsum( "ij,nj->ni", ref_c2w[:3], F.pad(positions, (0, 1), value=1.0) ) return get_lookat_w2cs(positions, lookat, up) def get_wander_w2cs(ref_w2c, focal_length, num_frames, max_disp, **_): device = ref_w2c.device c2w = np.linalg.inv(ref_w2c.detach().cpu().numpy()) max_disp = max_disp max_trans = max_disp / focal_length output_poses = [] for i in range(num_frames): x_trans = max_trans * np.sin(2.0 * np.pi * float(i) / float(num_frames)) y_trans = 0.0 z_trans = max_trans * np.cos(2.0 * np.pi * float(i) / float(num_frames)) / 2.0 i_pose = np.concatenate( [ np.concatenate( [ np.eye(3), np.array([x_trans, y_trans, z_trans])[:, np.newaxis], ], axis=1, ), np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :], ], axis=0, ) i_pose = np.linalg.inv(i_pose) ref_pose = np.concatenate( [c2w[:3, :4], np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]], axis=0 ) render_pose = np.dot(ref_pose, i_pose) output_poses.append(render_pose) output_poses = torch.from_numpy(np.array(output_poses, dtype=np.float32)).to(device) w2cs = torch.linalg.inv(output_poses) return w2cs