Spaces:
Sleeping
Sleeping
import torch | |
m2mm = 1000.0 | |
def L2_error(x:torch.Tensor, y:torch.Tensor): | |
''' | |
Calculate the L2 error across the last dim of the input tensors. | |
### Args | |
- `x`: torch.Tensor, shape (..., D) | |
- `y`: torch.Tensor, shape (..., D) | |
### Returns | |
- torch.Tensor, shape (...) | |
''' | |
return (x - y).norm(dim=-1) | |
def similarity_align_to( | |
S1 : torch.Tensor, | |
S2 : torch.Tensor, | |
): | |
''' | |
Computes a similarity transform (sR, t) that takes a set of 3D points S1 (3 x N) | |
closest to a set of 3D points S2, where R is an 3x3 rotation matrix, | |
t 3x1 translation, s scales. That is to solves the orthogonal Procrutes problem. | |
The code was modified from [WHAM](https://github.com/yohanshin/WHAM/blob/d1ade93ae83a91855902fdb8246c129c4b3b8a40/lib/eval/eval_utils.py#L201-L252). | |
### Args | |
- `S1`: torch.Tensor, shape (...B, N, 3) | |
- `S2`: torch.Tensor, shape (...B, N, 3) | |
### Returns | |
- torch.Tensor, shape (...B, N, 3) | |
''' | |
assert (S1.shape[-1] == 3 and S2.shape[-1] == 3), 'The last dimension of `S1` and `S2` must be 3.' | |
assert (S1.shape[:-2] == S2.shape[:-2]), 'The batch size of `S1` and `S2` must be the same.' | |
original_BN3 = S1.shape | |
N = original_BN3[-2] | |
S1 = S1.reshape(-1, N, 3) # (B', N, 3) <- (...B, N, 3) | |
S2 = S2.reshape(-1, N, 3) # (B', N, 3) <- (...B, N, 3) | |
B = S1.shape[0] | |
S1 = S1.transpose(-1, -2) # (B', 3, N) <- (B', N, 3) | |
S2 = S2.transpose(-1, -2) # (B', 3, N) <- (B', N, 3) | |
_device = S2.device | |
S1 = S1.to(_device) | |
# 1. Remove mean. | |
mu1 = S1.mean(axis=-1, keepdims=True) # (B', 3, 1) | |
mu2 = S2.mean(axis=-1, keepdims=True) # (B', 3, 1) | |
X1 = S1 - mu1 # (B', 3, N) | |
X2 = S2 - mu2 # (B', 3, N) | |
# 2. Compute variance of X1 used for scales. | |
var1 = torch.einsum('...BDN->...B', X1**2) # (B',) | |
# 3. The outer product of X1 and X2. | |
K = X1 @ X2.transpose(-1, -2) # (B', 3, 3) | |
# 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are singular vectors of K. | |
U, s, V = torch.svd(K) # (B', 3, 3), (B', 3), (B', 3, 3) | |
# Construct Z that fixes the orientation of R to get det(R)=1. | |
Z = torch.eye(3, device=_device)[None].repeat(B, 1, 1) # (B', 3, 3) | |
Z[:, -1, -1] *= (U @ V.transpose(-1, -2)).det().sign() | |
# Construct R. | |
R = V @ (Z @ U.transpose(-1, -2)) # (B', 3, 3) | |
# 5. Recover scales. | |
traces = [torch.trace(x)[None] for x in (R @ K)] | |
scales = torch.cat(traces) / var1 # (B',) | |
scales = scales[..., None, None] # (B', 1, 1) | |
# 6. Recover translation. | |
t = mu2 - (scales * (R @ mu1)) # (B', 3, 1) | |
# 7. Error: | |
S1_aligned = scales * (R @ S1) + t # (B', 3, N) | |
S1_aligned = S1_aligned.transpose(-1, -2) # (B', N, 3) <- (B', 3, N) | |
S1_aligned = S1_aligned.reshape(original_BN3) # (...B, N, 3) | |
return S1_aligned # (...B, N, 3) | |
def align_pcl(Y: torch.Tensor, X: torch.Tensor, weight=None, fixed_scale=False): | |
''' | |
Align similarity transform to align X with Y using umeyama method. X' = s * R * X + t is aligned with Y. | |
The code was copied from [SLAHMR](https://github.com/vye16/slahmr/blob/58518fec991877bc4911e260776589185b828fe9/slahmr/geometry/pcl.py#L10-L60). | |
### Args | |
- `Y`: torch.Tensor, shape (*, N, 3) first trajectory | |
- `X`: torch.Tensor, shape (*, N, 3) second trajectory | |
- `weight`: torch.Tensor, shape (*, N, 1) optional weight of valid correspondences | |
- `fixed_scale`: bool, default = False | |
### Returns | |
- `s` (*, 1) | |
- `R` (*, 3, 3) | |
- `t` (*, 3) | |
''' | |
*dims, N, _ = Y.shape | |
N = torch.ones(*dims, 1, 1) * N | |
if weight is not None: | |
Y = Y * weight | |
X = X * weight | |
N = weight.sum(dim=-2, keepdim=True) # (*, 1, 1) | |
# subtract mean | |
my = Y.sum(dim=-2) / N[..., 0] # (*, 3) | |
mx = X.sum(dim=-2) / N[..., 0] | |
y0 = Y - my[..., None, :] # (*, N, 3) | |
x0 = X - mx[..., None, :] | |
if weight is not None: | |
y0 = y0 * weight | |
x0 = x0 * weight | |
# correlation | |
C = torch.matmul(y0.transpose(-1, -2), x0) / N # (*, 3, 3) | |
U, D, Vh = torch.linalg.svd(C) # (*, 3, 3), (*, 3), (*, 3, 3) | |
S = torch.eye(3).reshape(*(1,) * (len(dims)), 3, 3).repeat(*dims, 1, 1) | |
neg = torch.det(U) * torch.det(Vh.transpose(-1, -2)) < 0 | |
S[neg, 2, 2] = -1 | |
R = torch.matmul(U, torch.matmul(S, Vh)) # (*, 3, 3) | |
D = torch.diag_embed(D) # (*, 3, 3) | |
if fixed_scale: | |
s = torch.ones(*dims, 1, device=Y.device, dtype=torch.float32) | |
else: | |
var = torch.sum(torch.square(x0), dim=(-1, -2), keepdim=True) / N # (*, 1, 1) | |
s = ( | |
torch.diagonal(torch.matmul(D, S), dim1=-2, dim2=-1).sum( | |
dim=-1, keepdim=True | |
) | |
/ var[..., 0] | |
) # (*, 1) | |
t = my - s * torch.matmul(R, mx[..., None])[..., 0] # (*, 3) | |
return s, R, t | |
def first_k_frames_align_to( | |
S1 : torch.Tensor, | |
S2 : torch.Tensor, | |
k_f : int, | |
): | |
''' | |
Compute the transformation between the first trajectory segment of S1 and S2, and use | |
the transformation to align S1 to S2. | |
The code was modified from [SLAHMR](https://github.com/vye16/slahmr/blob/58518fec991877bc4911e260776589185b828fe9/slahmr/eval/tools.py#L68-L81). | |
### Args | |
- `S1`: torch.Tensor, shape (..., L, N, 3) | |
- `S2`: torch.Tensor, shape (..., L, N, 3) | |
- `k_f`: int | |
- The number of frames to use for alignment. | |
### Returns | |
- `S1_aligned`: torch.Tensor, shape (..., L, N, 3) | |
- The aligned S1. | |
''' | |
assert (len(S1.shape) >= 3 and len(S2.shape) >= 3), 'The input tensors must have at least 3 dimensions.' | |
original_shape = S1.shape # (..., L, N, 3) | |
L, N, _ = original_shape[-3:] | |
S1 = S1.reshape(-1, L, N, 3) # (B, L, N, 3) | |
S2 = S2.reshape(-1, L, N, 3) # (B, L, N, 3) | |
B = S1.shape[0] | |
# 1. Prepare the clouds to be aligned. | |
S1_first = S1[:, :k_f, :, :].reshape(B, -1, 3) # (B, 1, k_f * N, 3) | |
S2_first = S2[:, :k_f, :, :].reshape(B, -1, 3) # (B, 1, k_f * N, 3) | |
# 2. Get the transformation to perform the alignment. | |
s_first, R_first, t_first = align_pcl( | |
X = S1_first, | |
Y = S2_first, | |
) # (B, 1), (B, 3, 3), (B, 3) | |
s_first = s_first.reshape(B, 1, 1, 1) # (B, 1, 1, 1) | |
t_first = t_first.reshape(B, 1, 1, 3) # (B, 1, 1, 3) | |
# 3. Perform the alignment on the whole sequence. | |
S1_aligned = s_first * torch.einsum('Bij,BLNj->BLNi', R_first, S1) + t_first # (B, L, N, 3) | |
S1_aligned = S1_aligned.reshape(original_shape) # (..., L, N, 3) | |
return S1_aligned # (..., L, N, 3) |