Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,620 Bytes
17cd746 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
import torch
import numpy as np
class NME:
def __init__(self, nme_left_index, nme_right_index):
self.nme_left_index = nme_left_index
self.nme_right_index = nme_right_index
def __repr__(self):
return "NME()"
def get_norm_distance(self, landmarks):
assert isinstance(self.nme_right_index, list), 'the nme_right_index is not list.'
assert isinstance(self.nme_left_index, list), 'the nme_left, index is not list.'
right_pupil = landmarks[self.nme_right_index, :].mean(0)
left_pupil = landmarks[self.nme_left_index, :].mean(0)
norm_distance = np.linalg.norm(right_pupil - left_pupil)
return norm_distance
def test(self, label_pd, label_gt):
nme_list = []
label_pd = label_pd.data.cpu().numpy()
label_gt = label_gt.data.cpu().numpy()
for i in range(label_gt.shape[0]):
landmarks_gt = label_gt[i]
landmarks_pv = label_pd[i]
if isinstance(self.nme_right_index, list):
norm_distance = self.get_norm_distance(landmarks_gt)
elif isinstance(self.nme_right_index, int):
norm_distance = np.linalg.norm(landmarks_gt[self.nme_left_index] - landmarks_gt[self.nme_right_index])
else:
raise NotImplementedError
landmarks_delta = landmarks_pv - landmarks_gt
nme = (np.linalg.norm(landmarks_delta, axis=1) / norm_distance).mean()
nme_list.append(nme)
# sum_nme += nme
# total_cnt += 1
return nme_list
|