Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,393 Bytes
3324de2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import torch
import numpy as np
from einops import rearrange
def sample_img_rays(x, img_fov=45):
"""
Samples a unit ray for each pixel in image
Args:
x: images (...,h,w)
img_fov: assumed image fov for ray calculation; int or tuple(h,w)
Returns:
img_rays (h,w,3) 3:<x,y,z>
"""
h, w, dtype, device = *x.shape[-2:], x.dtype, x.device
hf_rad = 2*torch.pi*torch.tensor(img_fov)/2/360
axis_mag = (1/hf_rad.cos()).expand(2) # [y,x]
axis_max_coord = (axis_mag**2-1)**.5 # [y,x]
y_coords = torch.linspace(-axis_max_coord[0],axis_max_coord[0],h, dtype=dtype, device=device)
x_coords = torch.linspace(-axis_max_coord[1],axis_max_coord[1],w, dtype=dtype, device=device)
y, x = torch.meshgrid(y_coords, x_coords, indexing = 'ij')
xyz = torch.stack([x, y, torch.ones_like(x)], dim=-1) # (h,w,<x,y,z>)
img_rays = xyz / xyz.norm(dim=-1).unsqueeze(-1)
return img_rays
def gen_rotation_matrix(angles):
"""
Generate rotation matrix from angles
Args:
angles: axis-wise rotations in [0,360] (...,3)
Returns:
rot_mat (...,3,3)
"""
dims = angles.shape[:-1]
angles = 2*torch.pi*angles/360 # [0,1] -> [0,2pi]
angles = rearrange(angles, '... a -> a ...') # (3,...)
cos = angles.cos()
sin = angles.sin()
rot_mat = torch.stack([
cos[1]*cos[2], sin[0]*sin[1]*cos[2]-cos[0]*sin[2], cos[0]*sin[1]*cos[2]+sin[0]*sin[2],
cos[1]*sin[2], sin[0]*sin[1]*sin[2]+cos[0]*cos[2], cos[0]*sin[1]*sin[2]-sin[0]*cos[2],
-sin[1], sin[0]*cos[1], cos[0]*cos[1]
], dim=-1).reshape(*dims,3,3) # (...,9) -> (...,3,3)
return rot_mat
def cart_2_spherical(pts):
"""
Convert Cartesian to spherical coordinates
Args:
pts: input pts (...,<x,y,z>)
Returns:
ret (...,<theta,phi,r>) (<azimuth,inclination,radius>) (radians)
"""
x,y,z = pts.moveaxis(-1,0)
r = pts.norm(dim=-1)
phi = torch.arcsin(y/r)
theta = x.sign()*torch.arccos(z/(x**2+z**2)**.5)
ret = torch.stack([theta,phi,r],dim=-1)
return ret
def sample_pano_img(img, pts, h_fov_ratio=1, w_fov_ratio=1):
"""
Sample points from panoramic image
Args:
img: pano-image (...,3:<rgb>,h,w)
pts: spherical points to sample from img (...,h,w,3:<azimuth,inclination,radius>)
*_fov_ratio: ratio of full fov for pano
Returns:
sampled_img (...,3:<rgb>,h,w)
"""
h, w = img.shape[-2:]
sh, sw = pts.shape[-3:-1]
h_conv, w_conv = h/h_fov_ratio, w/w_fov_ratio
img = rearrange(img, '... c h w -> ... (h w) c') # (...,n,3)
pts = rearrange(pts, '... h w c -> ... (h w) c') # (...,m,3)
# convert (pts) radians to indices
h_inds = (((pts[...,1] + torch.pi/2) / torch.pi) % 1) * h_conv # azimuth (-pi/2,+pi/2)
w_inds = (((pts[...,0] + torch.pi) / (2*torch.pi)) % 1) * w_conv # azimuth (-pi,+pi)
# get inds for bilin interp
h_l, w_l = h_inds.to(torch.int).clamp(0,h-1), w_inds.to(torch.int).clamp(0,w-1)
h_r, w_r = (h_l+1).clamp(0,h-1), (w_l+1).clamp(0,w-1)
# get weights
h_p_r, w_p_r = h_inds-h_l, w_inds-w_l
h_p_l, w_p_l = 1-h_p_r, 1-w_p_r
# linearize inds,weights
inds = (torch.stack([w*h_l, w*h_r],dim=-1)[...,:,None] + torch.stack([w_l, w_r],dim=-1)[...,None,:]).flatten(-2).moveaxis(-1,0).to(torch.long) # (4,...)
weights = (torch.stack([h_p_l, h_p_r],dim=-1)[...,:,None] * torch.stack([w_p_l, w_p_r],dim=-1)[...,None,:]).flatten(-2).moveaxis(-1,0) # (4,...)
# do bilin interp
img_extract = img[None,:].expand(4,*(len(img.shape)*[-1])).gather(-2, inds[...,None].expand(*(len(inds.shape)*[-1]),3))
sampled_img = (weights[...,None]*img_extract).sum(0) # (4,...,m,3) -> (...,m,3)
sampled_img = rearrange(sampled_img, '... (h w) c -> ... c h w', h=sh, w=sw)
return sampled_img
def sample_perspective_img(pano_img, output_shape, fov=None, rot=None):
"""
Sample perspective image from panoramic
Args:
pano_img: pano-image numpy.array (h,w,3:<rgb>)
output_shape: output image dimensions tuple(h,w)
fov: desired perspective image fov; int or tuple(vertical,horizontal) in degrees [0,180)
rot: axis-wise rotations; tuple(pitch,yaw,roll) in degrees [0,360]
Returns:
sampled_img numpy.array (h,w,3:<rgb>), fov, rot
"""
if fov is None:
fov = torch.tensor([30,30]) + torch.tensor([60,60])*torch.rand(2) # (v-fov,h-fov)
fov = (fov[0].item(), fov[1].item())
if rot is None:
rot = (-torch.tensor([10,135,20]) + torch.tensor([20,225,40])*torch.rand(3)) # rot w.r.t (x,y,z) aka (pitch,yaw,roll)
else:
rot = torch.tensor(rot)
pano_img = torch.tensor(pano_img, dtype=torch.uint8).moveaxis(-1,0)
out_dtype = pano_img.dtype
pano_img = pano_img.to(torch.float)
img_rays = sample_img_rays(torch.empty(output_shape, dtype=pano_img.dtype, device=pano_img.device), img_fov=fov)
rot_mat = gen_rotation_matrix(rot.to(pano_img.dtype))[None,None,:] # (3,3) -> (1,1,3,3)
rot_img_rays = torch.matmul(rot_mat, img_rays.unsqueeze(-1)).squeeze(-1)
spher_rot_img_rays = cart_2_spherical(rot_img_rays) # (h,w,3)
# sample img
pano_img = sample_pano_img(pano_img, spher_rot_img_rays)
return pano_img.moveaxis(0,-1).to(out_dtype).numpy(), fov, rot.numpy() |