File size: 5,393 Bytes
3324de2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import torch
import numpy as np
from einops import rearrange

def sample_img_rays(x, img_fov=45):
    """
    Samples a unit ray for each pixel in image
    
    Args:
        x: images (...,h,w)
        img_fov: assumed image fov for ray calculation; int or tuple(h,w)
        
    Returns:
        img_rays (h,w,3) 3:<x,y,z>
    """
    h, w, dtype, device = *x.shape[-2:], x.dtype, x.device
    hf_rad = 2*torch.pi*torch.tensor(img_fov)/2/360
    axis_mag = (1/hf_rad.cos()).expand(2) # [y,x]
    axis_max_coord = (axis_mag**2-1)**.5 # [y,x]
    y_coords = torch.linspace(-axis_max_coord[0],axis_max_coord[0],h, dtype=dtype, device=device)
    x_coords = torch.linspace(-axis_max_coord[1],axis_max_coord[1],w, dtype=dtype, device=device)
    y, x = torch.meshgrid(y_coords, x_coords, indexing = 'ij')
    xyz = torch.stack([x, y, torch.ones_like(x)], dim=-1) # (h,w,<x,y,z>)
    img_rays = xyz / xyz.norm(dim=-1).unsqueeze(-1)
    return img_rays

def gen_rotation_matrix(angles):
    """
    Generate rotation matrix from angles
    
    Args:
        angles: axis-wise rotations in [0,360] (...,3)
        
    Returns:
        rot_mat (...,3,3)
    """
    dims = angles.shape[:-1]
    angles = 2*torch.pi*angles/360 # [0,1] -> [0,2pi]
    angles = rearrange(angles, '... a -> a ...') # (3,...)
    cos = angles.cos()
    sin = angles.sin()
    rot_mat = torch.stack([
        cos[1]*cos[2], sin[0]*sin[1]*cos[2]-cos[0]*sin[2], cos[0]*sin[1]*cos[2]+sin[0]*sin[2],
        cos[1]*sin[2], sin[0]*sin[1]*sin[2]+cos[0]*cos[2], cos[0]*sin[1]*sin[2]-sin[0]*cos[2],
        -sin[1], sin[0]*cos[1], cos[0]*cos[1]
    ], dim=-1).reshape(*dims,3,3) # (...,9) -> (...,3,3)
    return rot_mat

def cart_2_spherical(pts):
    """
    Convert Cartesian to spherical coordinates
    
    Args:
        pts: input pts (...,<x,y,z>)
        
    Returns:
        ret (...,<theta,phi,r>) (<azimuth,inclination,radius>) (radians)
    """
    x,y,z = pts.moveaxis(-1,0)
    r = pts.norm(dim=-1)
    phi = torch.arcsin(y/r)
    theta = x.sign()*torch.arccos(z/(x**2+z**2)**.5)
    ret = torch.stack([theta,phi,r],dim=-1)
    return ret

def sample_pano_img(img, pts, h_fov_ratio=1, w_fov_ratio=1):
    """
    Sample points from panoramic image
    
    Args:
        img: pano-image (...,3:<rgb>,h,w)
        pts: spherical points to sample from img (...,h,w,3:<azimuth,inclination,radius>)
        *_fov_ratio: ratio of full fov for pano
        
    Returns:
        sampled_img (...,3:<rgb>,h,w)
    """
    h, w = img.shape[-2:]
    sh, sw = pts.shape[-3:-1]
    h_conv, w_conv = h/h_fov_ratio, w/w_fov_ratio
    img = rearrange(img, '... c h w -> ... (h w) c') # (...,n,3)
    pts = rearrange(pts, '... h w c -> ... (h w) c') # (...,m,3)
    # convert (pts) radians to indices
    h_inds = (((pts[...,1] + torch.pi/2) / torch.pi) % 1) * h_conv # azimuth (-pi/2,+pi/2)
    w_inds = (((pts[...,0] + torch.pi) / (2*torch.pi)) % 1) * w_conv # azimuth (-pi,+pi)
    # get inds for bilin interp
    h_l, w_l = h_inds.to(torch.int).clamp(0,h-1), w_inds.to(torch.int).clamp(0,w-1)
    h_r, w_r = (h_l+1).clamp(0,h-1), (w_l+1).clamp(0,w-1)
    # get weights
    h_p_r, w_p_r = h_inds-h_l, w_inds-w_l
    h_p_l, w_p_l = 1-h_p_r, 1-w_p_r
    # linearize inds,weights
    inds = (torch.stack([w*h_l, w*h_r],dim=-1)[...,:,None] + torch.stack([w_l, w_r],dim=-1)[...,None,:]).flatten(-2).moveaxis(-1,0).to(torch.long) # (4,...)
    weights = (torch.stack([h_p_l, h_p_r],dim=-1)[...,:,None] * torch.stack([w_p_l, w_p_r],dim=-1)[...,None,:]).flatten(-2).moveaxis(-1,0) # (4,...)
    # do bilin interp
    img_extract = img[None,:].expand(4,*(len(img.shape)*[-1])).gather(-2, inds[...,None].expand(*(len(inds.shape)*[-1]),3))
    sampled_img = (weights[...,None]*img_extract).sum(0) # (4,...,m,3) -> (...,m,3)
    sampled_img = rearrange(sampled_img, '... (h w) c -> ... c h w', h=sh, w=sw)
    return sampled_img

def sample_perspective_img(pano_img, output_shape, fov=None, rot=None):
    """
    Sample perspective image from panoramic
    
    Args:
        pano_img: pano-image numpy.array (h,w,3:<rgb>)
        output_shape: output image dimensions tuple(h,w)
        fov: desired perspective image fov; int or tuple(vertical,horizontal) in degrees [0,180)
        rot: axis-wise rotations; tuple(pitch,yaw,roll) in degrees [0,360]
        
    Returns:
        sampled_img numpy.array (h,w,3:<rgb>), fov, rot
    """
    if fov is None:
        fov = torch.tensor([30,30]) + torch.tensor([60,60])*torch.rand(2) # (v-fov,h-fov)
        fov = (fov[0].item(), fov[1].item())
    if rot is None:
        rot = (-torch.tensor([10,135,20]) + torch.tensor([20,225,40])*torch.rand(3)) # rot w.r.t (x,y,z) aka (pitch,yaw,roll)
    else:
        rot = torch.tensor(rot)
    pano_img = torch.tensor(pano_img, dtype=torch.uint8).moveaxis(-1,0)
    out_dtype = pano_img.dtype
    pano_img = pano_img.to(torch.float)
    
    img_rays = sample_img_rays(torch.empty(output_shape, dtype=pano_img.dtype, device=pano_img.device), img_fov=fov)
    rot_mat = gen_rotation_matrix(rot.to(pano_img.dtype))[None,None,:] # (3,3) -> (1,1,3,3)
    rot_img_rays = torch.matmul(rot_mat, img_rays.unsqueeze(-1)).squeeze(-1)
    spher_rot_img_rays = cart_2_spherical(rot_img_rays) # (h,w,3)
    # sample img
    pano_img = sample_pano_img(pano_img, spher_rot_img_rays)

    return pano_img.moveaxis(0,-1).to(out_dtype).numpy(), fov, rot.numpy()