from typing import Optional import torch from .utils import * ''' All MPxE-like metrics will be implements here. - Local Metrics: the inputs motion's translation should be removed (or may be automatically removed). - MPxE: call `eval_MPxE()` - PA-MPxE: cal `eval_PA_MPxE()` - Global Metrics: the inputs motion's translation should be kept. - G-MPxE: call `eval_MPxE()` - W2-MPxE: call `eval_Wk_MPxE()`, and set k = 2 - WA-MPxE: call `eval_WA_MPxE()` ''' def eval_MPxE( pred : torch.Tensor, gt : torch.Tensor, scale : float = m2mm, ): ''' Calculate the Mean Per Error. might be joints position (MPJPE), or vertices (MPVE). The results will be the sequence of MPxE of each multi-dim batch. ### Args - `pred`: torch.Tensor - shape = (...B, N, 3), where B is the multi-dim batch size, N is points count in one batch - the predicted joints/vertices position data - `gt`: torch.Tensor - shape = (...B, N, 3), where B is the multi-dim batch size, N is points count in one batch - the ground truth joints/vertices position data - `scale`: float, default = `m2mm` ### Returns - torch.Tensor - shape = (...B) - shape = () ''' # Calculate the MPxE. ret = L2_error(pred, gt).mean(dim=-1) * scale # (...B,) return ret def eval_PA_MPxE( pred : torch.Tensor, gt : torch.Tensor, scale : float = m2mm, ): ''' Calculate the Procrustes-Aligned Mean Per Error. might be joints position (PA-MPJPE), or vertices (PA-MPVE). Targets will be Procrustes-aligned and then calculate the per frame MPxE. The results will be the sequence of MPxE of each batch. ### Args - `pred`: torch.Tensor - shape = (...B, N, 3), where B is the multi-dim batch size, N is points count in one batch - the predicted joints/vertices position data - `gt`: torch.Tensor - shape = (...B, N, 3), where B is the multi-dim batch size, N is points count in one batch - the ground truth joints/vertices position data - `scale`: float, default = `m2mm` ### Returns - torch.Tensor - shape = (...B) - shape = () ''' # Perform Procrustes alignment. pred_aligned = similarity_align_to(pred, gt) # (...B, N, 3) # Calculate the PA-MPxE return eval_MPxE(pred_aligned, gt, scale) # (...B,) def eval_Wk_MPxE( pred : torch.Tensor, gt : torch.Tensor, scale : float = m2mm, k_f : int = 2, ): ''' Calculate the first k frames aligned (World aligned) Mean Per Error. might be joints position (PA-MPJPE), or vertices (PA-MPVE). Targets will be aligned using the first k frames and then calculate the per frame MPxE. The results will be the sequence of MPxE of each batch. ### Args - `pred`: torch.Tensor - shape = (..., L, N, 3), where L is the length of the sequence, N is points count in one batch - the predicted joints/vertices position data - `gt`: torch.Tensor - shape = (..., L, N, 3), where L is the length of the sequence, N is points count in one batch - the ground truth joints/vertices position data - `scale`: float, default = `m2mm` - `k_f`: int, default = 2 - the number of frames to use for alignment ### Returns - torch.Tensor - shape = (..., L) - shape = () ''' L = max(pred.shape[-3], gt.shape[-3]) assert L >= 2, f'Length of the sequence should be at least 2, but got {L}.' # Perform first two alignment. pred_aligned = first_k_frames_align_to(pred, gt, k_f) # (..., L, N, 3) # Calculate the PA-MPxE return eval_MPxE(pred_aligned, gt, scale) # (..., L) def eval_WA_MPxE( pred : torch.Tensor, gt : torch.Tensor, scale : float = m2mm, ): ''' Calculate the all frames aligned (World All aligned) Mean Per Error. might be joints position (PA-MPJPE), or vertices (PA-MPVE). Targets will be aligned using the first k frames and then calculate the per frame MPxE. The results will be the sequence of MPxE of each batch. ### Args - `pred`: torch.Tensor - shape = (..., L, N, 3), where L is the length of the sequence, N is points count in one batch - the predicted joints/vertices position data - `gt`: torch.Tensor - shape = (..., L, N, 3), where L is the length of the sequence, N is points count in one batch - the ground truth joints/vertices position data - `scale`: float, default = `m2mm` ### Returns - torch.Tensor - shape = (..., L) - shape = () ''' L_pred = pred.shape[-3] L_gt = gt.shape[-3] assert (L_pred == L_gt), f'Length of the sequence should be the same, but got {L_pred} and {L_gt}.' # WA_MPxE is just Wk_MPxE when k = L. return eval_Wk_MPxE(pred, gt, scale, L_gt)