HSMR / lib /evaluation /metrics /mpxe_like.py
IsshikiHugh's picture
feat: CPU demo
5ac1897
from typing import Optional
import torch
from .utils import *
'''
All MPxE-like metrics will be implements here.
- Local Metrics: the inputs motion's translation should be removed (or may be automatically removed).
- MPxE: call `eval_MPxE()`
- PA-MPxE: cal `eval_PA_MPxE()`
- Global Metrics: the inputs motion's translation should be kept.
- G-MPxE: call `eval_MPxE()`
- W2-MPxE: call `eval_Wk_MPxE()`, and set k = 2
- WA-MPxE: call `eval_WA_MPxE()`
'''
def eval_MPxE(
pred : torch.Tensor,
gt : torch.Tensor,
scale : float = m2mm,
):
'''
Calculate the Mean Per <X> Error. <X> might be joints position (MPJPE), or vertices (MPVE).
The results will be the sequence of MPxE of each multi-dim batch.
### Args
- `pred`: torch.Tensor
- shape = (...B, N, 3), where B is the multi-dim batch size, N is points count in one batch
- the predicted joints/vertices position data
- `gt`: torch.Tensor
- shape = (...B, N, 3), where B is the multi-dim batch size, N is points count in one batch
- the ground truth joints/vertices position data
- `scale`: float, default = `m2mm`
### Returns
- torch.Tensor
- shape = (...B)
- shape = ()
'''
# Calculate the MPxE.
ret = L2_error(pred, gt).mean(dim=-1) * scale # (...B,)
return ret
def eval_PA_MPxE(
pred : torch.Tensor,
gt : torch.Tensor,
scale : float = m2mm,
):
'''
Calculate the Procrustes-Aligned Mean Per <X> Error. <X> might be joints position (PA-MPJPE), or
vertices (PA-MPVE). Targets will be Procrustes-aligned and then calculate the per frame MPxE.
The results will be the sequence of MPxE of each batch.
### Args
- `pred`: torch.Tensor
- shape = (...B, N, 3), where B is the multi-dim batch size, N is points count in one batch
- the predicted joints/vertices position data
- `gt`: torch.Tensor
- shape = (...B, N, 3), where B is the multi-dim batch size, N is points count in one batch
- the ground truth joints/vertices position data
- `scale`: float, default = `m2mm`
### Returns
- torch.Tensor
- shape = (...B)
- shape = ()
'''
# Perform Procrustes alignment.
pred_aligned = similarity_align_to(pred, gt) # (...B, N, 3)
# Calculate the PA-MPxE
return eval_MPxE(pred_aligned, gt, scale) # (...B,)
def eval_Wk_MPxE(
pred : torch.Tensor,
gt : torch.Tensor,
scale : float = m2mm,
k_f : int = 2,
):
'''
Calculate the first k frames aligned (World aligned) Mean Per <X> Error. <X> might be joints
position (PA-MPJPE), or vertices (PA-MPVE). Targets will be aligned using the first k frames
and then calculate the per frame MPxE.
The results will be the sequence of MPxE of each batch.
### Args
- `pred`: torch.Tensor
- shape = (..., L, N, 3), where L is the length of the sequence, N is points count in one batch
- the predicted joints/vertices position data
- `gt`: torch.Tensor
- shape = (..., L, N, 3), where L is the length of the sequence, N is points count in one batch
- the ground truth joints/vertices position data
- `scale`: float, default = `m2mm`
- `k_f`: int, default = 2
- the number of frames to use for alignment
### Returns
- torch.Tensor
- shape = (..., L)
- shape = ()
'''
L = max(pred.shape[-3], gt.shape[-3])
assert L >= 2, f'Length of the sequence should be at least 2, but got {L}.'
# Perform first two alignment.
pred_aligned = first_k_frames_align_to(pred, gt, k_f) # (..., L, N, 3)
# Calculate the PA-MPxE
return eval_MPxE(pred_aligned, gt, scale) # (..., L)
def eval_WA_MPxE(
pred : torch.Tensor,
gt : torch.Tensor,
scale : float = m2mm,
):
'''
Calculate the all frames aligned (World All aligned) Mean Per <X> Error. <X> might be joints
position (PA-MPJPE), or vertices (PA-MPVE). Targets will be aligned using the first k frames
and then calculate the per frame MPxE.
The results will be the sequence of MPxE of each batch.
### Args
- `pred`: torch.Tensor
- shape = (..., L, N, 3), where L is the length of the sequence, N is points count in one batch
- the predicted joints/vertices position data
- `gt`: torch.Tensor
- shape = (..., L, N, 3), where L is the length of the sequence, N is points count in one batch
- the ground truth joints/vertices position data
- `scale`: float, default = `m2mm`
### Returns
- torch.Tensor
- shape = (..., L)
- shape = ()
'''
L_pred = pred.shape[-3]
L_gt = gt.shape[-3]
assert (L_pred == L_gt), f'Length of the sequence should be the same, but got {L_pred} and {L_gt}.'
# WA_MPxE is just Wk_MPxE when k = L.
return eval_Wk_MPxE(pred, gt, scale, L_gt)