Spaces:
Sleeping
Sleeping
from typing import Optional | |
import torch | |
from .utils import * | |
''' | |
All MPxE-like metrics will be implements here. | |
- Local Metrics: the inputs motion's translation should be removed (or may be automatically removed). | |
- MPxE: call `eval_MPxE()` | |
- PA-MPxE: cal `eval_PA_MPxE()` | |
- Global Metrics: the inputs motion's translation should be kept. | |
- G-MPxE: call `eval_MPxE()` | |
- W2-MPxE: call `eval_Wk_MPxE()`, and set k = 2 | |
- WA-MPxE: call `eval_WA_MPxE()` | |
''' | |
def eval_MPxE( | |
pred : torch.Tensor, | |
gt : torch.Tensor, | |
scale : float = m2mm, | |
): | |
''' | |
Calculate the Mean Per <X> Error. <X> might be joints position (MPJPE), or vertices (MPVE). | |
The results will be the sequence of MPxE of each multi-dim batch. | |
### Args | |
- `pred`: torch.Tensor | |
- shape = (...B, N, 3), where B is the multi-dim batch size, N is points count in one batch | |
- the predicted joints/vertices position data | |
- `gt`: torch.Tensor | |
- shape = (...B, N, 3), where B is the multi-dim batch size, N is points count in one batch | |
- the ground truth joints/vertices position data | |
- `scale`: float, default = `m2mm` | |
### Returns | |
- torch.Tensor | |
- shape = (...B) | |
- shape = () | |
''' | |
# Calculate the MPxE. | |
ret = L2_error(pred, gt).mean(dim=-1) * scale # (...B,) | |
return ret | |
def eval_PA_MPxE( | |
pred : torch.Tensor, | |
gt : torch.Tensor, | |
scale : float = m2mm, | |
): | |
''' | |
Calculate the Procrustes-Aligned Mean Per <X> Error. <X> might be joints position (PA-MPJPE), or | |
vertices (PA-MPVE). Targets will be Procrustes-aligned and then calculate the per frame MPxE. | |
The results will be the sequence of MPxE of each batch. | |
### Args | |
- `pred`: torch.Tensor | |
- shape = (...B, N, 3), where B is the multi-dim batch size, N is points count in one batch | |
- the predicted joints/vertices position data | |
- `gt`: torch.Tensor | |
- shape = (...B, N, 3), where B is the multi-dim batch size, N is points count in one batch | |
- the ground truth joints/vertices position data | |
- `scale`: float, default = `m2mm` | |
### Returns | |
- torch.Tensor | |
- shape = (...B) | |
- shape = () | |
''' | |
# Perform Procrustes alignment. | |
pred_aligned = similarity_align_to(pred, gt) # (...B, N, 3) | |
# Calculate the PA-MPxE | |
return eval_MPxE(pred_aligned, gt, scale) # (...B,) | |
def eval_Wk_MPxE( | |
pred : torch.Tensor, | |
gt : torch.Tensor, | |
scale : float = m2mm, | |
k_f : int = 2, | |
): | |
''' | |
Calculate the first k frames aligned (World aligned) Mean Per <X> Error. <X> might be joints | |
position (PA-MPJPE), or vertices (PA-MPVE). Targets will be aligned using the first k frames | |
and then calculate the per frame MPxE. | |
The results will be the sequence of MPxE of each batch. | |
### Args | |
- `pred`: torch.Tensor | |
- shape = (..., L, N, 3), where L is the length of the sequence, N is points count in one batch | |
- the predicted joints/vertices position data | |
- `gt`: torch.Tensor | |
- shape = (..., L, N, 3), where L is the length of the sequence, N is points count in one batch | |
- the ground truth joints/vertices position data | |
- `scale`: float, default = `m2mm` | |
- `k_f`: int, default = 2 | |
- the number of frames to use for alignment | |
### Returns | |
- torch.Tensor | |
- shape = (..., L) | |
- shape = () | |
''' | |
L = max(pred.shape[-3], gt.shape[-3]) | |
assert L >= 2, f'Length of the sequence should be at least 2, but got {L}.' | |
# Perform first two alignment. | |
pred_aligned = first_k_frames_align_to(pred, gt, k_f) # (..., L, N, 3) | |
# Calculate the PA-MPxE | |
return eval_MPxE(pred_aligned, gt, scale) # (..., L) | |
def eval_WA_MPxE( | |
pred : torch.Tensor, | |
gt : torch.Tensor, | |
scale : float = m2mm, | |
): | |
''' | |
Calculate the all frames aligned (World All aligned) Mean Per <X> Error. <X> might be joints | |
position (PA-MPJPE), or vertices (PA-MPVE). Targets will be aligned using the first k frames | |
and then calculate the per frame MPxE. | |
The results will be the sequence of MPxE of each batch. | |
### Args | |
- `pred`: torch.Tensor | |
- shape = (..., L, N, 3), where L is the length of the sequence, N is points count in one batch | |
- the predicted joints/vertices position data | |
- `gt`: torch.Tensor | |
- shape = (..., L, N, 3), where L is the length of the sequence, N is points count in one batch | |
- the ground truth joints/vertices position data | |
- `scale`: float, default = `m2mm` | |
### Returns | |
- torch.Tensor | |
- shape = (..., L) | |
- shape = () | |
''' | |
L_pred = pred.shape[-3] | |
L_gt = gt.shape[-3] | |
assert (L_pred == L_gt), f'Length of the sequence should be the same, but got {L_pred} and {L_gt}.' | |
# WA_MPxE is just Wk_MPxE when k = L. | |
return eval_Wk_MPxE(pred, gt, scale, L_gt) |