IsshikiHugh's picture
feat: CPU demo
5ac1897
import torch
m2mm = 1000.0
def L2_error(x:torch.Tensor, y:torch.Tensor):
'''
Calculate the L2 error across the last dim of the input tensors.
### Args
- `x`: torch.Tensor, shape (..., D)
- `y`: torch.Tensor, shape (..., D)
### Returns
- torch.Tensor, shape (...)
'''
return (x - y).norm(dim=-1)
def similarity_align_to(
S1 : torch.Tensor,
S2 : torch.Tensor,
):
'''
Computes a similarity transform (sR, t) that takes a set of 3D points S1 (3 x N)
closest to a set of 3D points S2, where R is an 3x3 rotation matrix,
t 3x1 translation, s scales. That is to solves the orthogonal Procrutes problem.
The code was modified from [WHAM](https://github.com/yohanshin/WHAM/blob/d1ade93ae83a91855902fdb8246c129c4b3b8a40/lib/eval/eval_utils.py#L201-L252).
### Args
- `S1`: torch.Tensor, shape (...B, N, 3)
- `S2`: torch.Tensor, shape (...B, N, 3)
### Returns
- torch.Tensor, shape (...B, N, 3)
'''
assert (S1.shape[-1] == 3 and S2.shape[-1] == 3), 'The last dimension of `S1` and `S2` must be 3.'
assert (S1.shape[:-2] == S2.shape[:-2]), 'The batch size of `S1` and `S2` must be the same.'
original_BN3 = S1.shape
N = original_BN3[-2]
S1 = S1.reshape(-1, N, 3) # (B', N, 3) <- (...B, N, 3)
S2 = S2.reshape(-1, N, 3) # (B', N, 3) <- (...B, N, 3)
B = S1.shape[0]
S1 = S1.transpose(-1, -2) # (B', 3, N) <- (B', N, 3)
S2 = S2.transpose(-1, -2) # (B', 3, N) <- (B', N, 3)
_device = S2.device
S1 = S1.to(_device)
# 1. Remove mean.
mu1 = S1.mean(axis=-1, keepdims=True) # (B', 3, 1)
mu2 = S2.mean(axis=-1, keepdims=True) # (B', 3, 1)
X1 = S1 - mu1 # (B', 3, N)
X2 = S2 - mu2 # (B', 3, N)
# 2. Compute variance of X1 used for scales.
var1 = torch.einsum('...BDN->...B', X1**2) # (B',)
# 3. The outer product of X1 and X2.
K = X1 @ X2.transpose(-1, -2) # (B', 3, 3)
# 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are singular vectors of K.
U, s, V = torch.svd(K) # (B', 3, 3), (B', 3), (B', 3, 3)
# Construct Z that fixes the orientation of R to get det(R)=1.
Z = torch.eye(3, device=_device)[None].repeat(B, 1, 1) # (B', 3, 3)
Z[:, -1, -1] *= (U @ V.transpose(-1, -2)).det().sign()
# Construct R.
R = V @ (Z @ U.transpose(-1, -2)) # (B', 3, 3)
# 5. Recover scales.
traces = [torch.trace(x)[None] for x in (R @ K)]
scales = torch.cat(traces) / var1 # (B',)
scales = scales[..., None, None] # (B', 1, 1)
# 6. Recover translation.
t = mu2 - (scales * (R @ mu1)) # (B', 3, 1)
# 7. Error:
S1_aligned = scales * (R @ S1) + t # (B', 3, N)
S1_aligned = S1_aligned.transpose(-1, -2) # (B', N, 3) <- (B', 3, N)
S1_aligned = S1_aligned.reshape(original_BN3) # (...B, N, 3)
return S1_aligned # (...B, N, 3)
def align_pcl(Y: torch.Tensor, X: torch.Tensor, weight=None, fixed_scale=False):
'''
Align similarity transform to align X with Y using umeyama method. X' = s * R * X + t is aligned with Y.
The code was copied from [SLAHMR](https://github.com/vye16/slahmr/blob/58518fec991877bc4911e260776589185b828fe9/slahmr/geometry/pcl.py#L10-L60).
### Args
- `Y`: torch.Tensor, shape (*, N, 3) first trajectory
- `X`: torch.Tensor, shape (*, N, 3) second trajectory
- `weight`: torch.Tensor, shape (*, N, 1) optional weight of valid correspondences
- `fixed_scale`: bool, default = False
### Returns
- `s` (*, 1)
- `R` (*, 3, 3)
- `t` (*, 3)
'''
*dims, N, _ = Y.shape
N = torch.ones(*dims, 1, 1) * N
if weight is not None:
Y = Y * weight
X = X * weight
N = weight.sum(dim=-2, keepdim=True) # (*, 1, 1)
# subtract mean
my = Y.sum(dim=-2) / N[..., 0] # (*, 3)
mx = X.sum(dim=-2) / N[..., 0]
y0 = Y - my[..., None, :] # (*, N, 3)
x0 = X - mx[..., None, :]
if weight is not None:
y0 = y0 * weight
x0 = x0 * weight
# correlation
C = torch.matmul(y0.transpose(-1, -2), x0) / N # (*, 3, 3)
U, D, Vh = torch.linalg.svd(C) # (*, 3, 3), (*, 3), (*, 3, 3)
S = torch.eye(3).reshape(*(1,) * (len(dims)), 3, 3).repeat(*dims, 1, 1)
neg = torch.det(U) * torch.det(Vh.transpose(-1, -2)) < 0
S[neg, 2, 2] = -1
R = torch.matmul(U, torch.matmul(S, Vh)) # (*, 3, 3)
D = torch.diag_embed(D) # (*, 3, 3)
if fixed_scale:
s = torch.ones(*dims, 1, device=Y.device, dtype=torch.float32)
else:
var = torch.sum(torch.square(x0), dim=(-1, -2), keepdim=True) / N # (*, 1, 1)
s = (
torch.diagonal(torch.matmul(D, S), dim1=-2, dim2=-1).sum(
dim=-1, keepdim=True
)
/ var[..., 0]
) # (*, 1)
t = my - s * torch.matmul(R, mx[..., None])[..., 0] # (*, 3)
return s, R, t
def first_k_frames_align_to(
S1 : torch.Tensor,
S2 : torch.Tensor,
k_f : int,
):
'''
Compute the transformation between the first trajectory segment of S1 and S2, and use
the transformation to align S1 to S2.
The code was modified from [SLAHMR](https://github.com/vye16/slahmr/blob/58518fec991877bc4911e260776589185b828fe9/slahmr/eval/tools.py#L68-L81).
### Args
- `S1`: torch.Tensor, shape (..., L, N, 3)
- `S2`: torch.Tensor, shape (..., L, N, 3)
- `k_f`: int
- The number of frames to use for alignment.
### Returns
- `S1_aligned`: torch.Tensor, shape (..., L, N, 3)
- The aligned S1.
'''
assert (len(S1.shape) >= 3 and len(S2.shape) >= 3), 'The input tensors must have at least 3 dimensions.'
original_shape = S1.shape # (..., L, N, 3)
L, N, _ = original_shape[-3:]
S1 = S1.reshape(-1, L, N, 3) # (B, L, N, 3)
S2 = S2.reshape(-1, L, N, 3) # (B, L, N, 3)
B = S1.shape[0]
# 1. Prepare the clouds to be aligned.
S1_first = S1[:, :k_f, :, :].reshape(B, -1, 3) # (B, 1, k_f * N, 3)
S2_first = S2[:, :k_f, :, :].reshape(B, -1, 3) # (B, 1, k_f * N, 3)
# 2. Get the transformation to perform the alignment.
s_first, R_first, t_first = align_pcl(
X = S1_first,
Y = S2_first,
) # (B, 1), (B, 3, 3), (B, 3)
s_first = s_first.reshape(B, 1, 1, 1) # (B, 1, 1, 1)
t_first = t_first.reshape(B, 1, 1, 3) # (B, 1, 1, 3)
# 3. Perform the alignment on the whole sequence.
S1_aligned = s_first * torch.einsum('Bij,BLNj->BLNi', R_first, S1) + t_first # (B, L, N, 3)
S1_aligned = S1_aligned.reshape(original_shape) # (..., L, N, 3)
return S1_aligned # (..., L, N, 3)