Spaces:
Build error
Build error
r""" Evaluates CHMNet with PCK """ | |
import torch | |
class Evaluator: | |
r""" Computes evaluation metrics of PCK """ | |
def initialize(cls, alpha): | |
cls.alpha = torch.tensor(alpha).unsqueeze(1) | |
def evaluate(cls, prd_kps, batch): | |
r""" Compute percentage of correct key-points (PCK) with multiple alpha {0.05, 0.1, 0.15 }""" | |
pcks = [] | |
for idx, (pk, tk) in enumerate(zip(prd_kps, batch['trg_kps'])): | |
pckthres = batch['pckthres'][idx] | |
npt = batch['n_pts'][idx] | |
prd_kps = pk[:, :npt] | |
trg_kps = tk[:, :npt] | |
l2dist = (prd_kps - trg_kps).pow(2).sum(dim=0).pow(0.5).unsqueeze(0).repeat(len(cls.alpha), 1) | |
thres = pckthres.expand_as(l2dist).float() * cls.alpha | |
pck = torch.le(l2dist, thres).sum(dim=1) / float(npt) | |
if len(pck) == 1: pck = pck[0] | |
pcks.append(pck) | |
eval_result = {'pck': pcks} | |
return eval_result | |