SoundingStreet / utils.py
FQiao's picture
Upload 70 files
3324de2 verified
import torch
import numpy as np
from einops import rearrange
def sample_img_rays(x, img_fov=45):
"""
Samples a unit ray for each pixel in image
Args:
x: images (...,h,w)
img_fov: assumed image fov for ray calculation; int or tuple(h,w)
Returns:
img_rays (h,w,3) 3:<x,y,z>
"""
h, w, dtype, device = *x.shape[-2:], x.dtype, x.device
hf_rad = 2*torch.pi*torch.tensor(img_fov)/2/360
axis_mag = (1/hf_rad.cos()).expand(2) # [y,x]
axis_max_coord = (axis_mag**2-1)**.5 # [y,x]
y_coords = torch.linspace(-axis_max_coord[0],axis_max_coord[0],h, dtype=dtype, device=device)
x_coords = torch.linspace(-axis_max_coord[1],axis_max_coord[1],w, dtype=dtype, device=device)
y, x = torch.meshgrid(y_coords, x_coords, indexing = 'ij')
xyz = torch.stack([x, y, torch.ones_like(x)], dim=-1) # (h,w,<x,y,z>)
img_rays = xyz / xyz.norm(dim=-1).unsqueeze(-1)
return img_rays
def gen_rotation_matrix(angles):
"""
Generate rotation matrix from angles
Args:
angles: axis-wise rotations in [0,360] (...,3)
Returns:
rot_mat (...,3,3)
"""
dims = angles.shape[:-1]
angles = 2*torch.pi*angles/360 # [0,1] -> [0,2pi]
angles = rearrange(angles, '... a -> a ...') # (3,...)
cos = angles.cos()
sin = angles.sin()
rot_mat = torch.stack([
cos[1]*cos[2], sin[0]*sin[1]*cos[2]-cos[0]*sin[2], cos[0]*sin[1]*cos[2]+sin[0]*sin[2],
cos[1]*sin[2], sin[0]*sin[1]*sin[2]+cos[0]*cos[2], cos[0]*sin[1]*sin[2]-sin[0]*cos[2],
-sin[1], sin[0]*cos[1], cos[0]*cos[1]
], dim=-1).reshape(*dims,3,3) # (...,9) -> (...,3,3)
return rot_mat
def cart_2_spherical(pts):
"""
Convert Cartesian to spherical coordinates
Args:
pts: input pts (...,<x,y,z>)
Returns:
ret (...,<theta,phi,r>) (<azimuth,inclination,radius>) (radians)
"""
x,y,z = pts.moveaxis(-1,0)
r = pts.norm(dim=-1)
phi = torch.arcsin(y/r)
theta = x.sign()*torch.arccos(z/(x**2+z**2)**.5)
ret = torch.stack([theta,phi,r],dim=-1)
return ret
def sample_pano_img(img, pts, h_fov_ratio=1, w_fov_ratio=1):
"""
Sample points from panoramic image
Args:
img: pano-image (...,3:<rgb>,h,w)
pts: spherical points to sample from img (...,h,w,3:<azimuth,inclination,radius>)
*_fov_ratio: ratio of full fov for pano
Returns:
sampled_img (...,3:<rgb>,h,w)
"""
h, w = img.shape[-2:]
sh, sw = pts.shape[-3:-1]
h_conv, w_conv = h/h_fov_ratio, w/w_fov_ratio
img = rearrange(img, '... c h w -> ... (h w) c') # (...,n,3)
pts = rearrange(pts, '... h w c -> ... (h w) c') # (...,m,3)
# convert (pts) radians to indices
h_inds = (((pts[...,1] + torch.pi/2) / torch.pi) % 1) * h_conv # azimuth (-pi/2,+pi/2)
w_inds = (((pts[...,0] + torch.pi) / (2*torch.pi)) % 1) * w_conv # azimuth (-pi,+pi)
# get inds for bilin interp
h_l, w_l = h_inds.to(torch.int).clamp(0,h-1), w_inds.to(torch.int).clamp(0,w-1)
h_r, w_r = (h_l+1).clamp(0,h-1), (w_l+1).clamp(0,w-1)
# get weights
h_p_r, w_p_r = h_inds-h_l, w_inds-w_l
h_p_l, w_p_l = 1-h_p_r, 1-w_p_r
# linearize inds,weights
inds = (torch.stack([w*h_l, w*h_r],dim=-1)[...,:,None] + torch.stack([w_l, w_r],dim=-1)[...,None,:]).flatten(-2).moveaxis(-1,0).to(torch.long) # (4,...)
weights = (torch.stack([h_p_l, h_p_r],dim=-1)[...,:,None] * torch.stack([w_p_l, w_p_r],dim=-1)[...,None,:]).flatten(-2).moveaxis(-1,0) # (4,...)
# do bilin interp
img_extract = img[None,:].expand(4,*(len(img.shape)*[-1])).gather(-2, inds[...,None].expand(*(len(inds.shape)*[-1]),3))
sampled_img = (weights[...,None]*img_extract).sum(0) # (4,...,m,3) -> (...,m,3)
sampled_img = rearrange(sampled_img, '... (h w) c -> ... c h w', h=sh, w=sw)
return sampled_img
def sample_perspective_img(pano_img, output_shape, fov=None, rot=None):
"""
Sample perspective image from panoramic
Args:
pano_img: pano-image numpy.array (h,w,3:<rgb>)
output_shape: output image dimensions tuple(h,w)
fov: desired perspective image fov; int or tuple(vertical,horizontal) in degrees [0,180)
rot: axis-wise rotations; tuple(pitch,yaw,roll) in degrees [0,360]
Returns:
sampled_img numpy.array (h,w,3:<rgb>), fov, rot
"""
if fov is None:
fov = torch.tensor([30,30]) + torch.tensor([60,60])*torch.rand(2) # (v-fov,h-fov)
fov = (fov[0].item(), fov[1].item())
if rot is None:
rot = (-torch.tensor([10,135,20]) + torch.tensor([20,225,40])*torch.rand(3)) # rot w.r.t (x,y,z) aka (pitch,yaw,roll)
else:
rot = torch.tensor(rot)
pano_img = torch.tensor(pano_img, dtype=torch.uint8).moveaxis(-1,0)
out_dtype = pano_img.dtype
pano_img = pano_img.to(torch.float)
img_rays = sample_img_rays(torch.empty(output_shape, dtype=pano_img.dtype, device=pano_img.device), img_fov=fov)
rot_mat = gen_rotation_matrix(rot.to(pano_img.dtype))[None,None,:] # (3,3) -> (1,1,3,3)
rot_img_rays = torch.matmul(rot_mat, img_rays.unsqueeze(-1)).squeeze(-1)
spher_rot_img_rays = cart_2_spherical(rot_img_rays) # (h,w,3)
# sample img
pano_img = sample_pano_img(pano_img, spher_rot_img_rays)
return pano_img.moveaxis(0,-1).to(out_dtype).numpy(), fov, rot.numpy()