# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # This work is made available under the Nvidia Source Code License-NC. # To view a copy of this license, check out LICENSE.md import torch from imaginaire.evaluation.common import compute_nn def _get_1nn_acc(data_x, data_y, k=1): device = data_x.device n0 = data_x.size(0) n1 = data_y.size(0) data_all = torch.cat((data_x, data_y), dim=0) val, idx = compute_nn(data_all, k) label = torch.cat((torch.ones(n0, device=device), torch.zeros(n1, device=device))) count = torch.zeros(n0 + n1, device=device) for i in range(0, k): count = count + label.index_select(0, idx[:, i]) pred = torch.ge(count, (float(k) / 2) * torch.ones(n0 + n1, device=device)).float() tp = (pred * label).sum() fp = (pred * (1 - label)).sum() fn = ((1 - pred) * label).sum() tn = ((1 - pred) * (1 - label)).sum() acc_r = (tp / (tp + fn)).item() acc_f = (tn / (tn + fp)).item() acc = torch.eq(label, pred).float().mean().item() return {'1NN_acc': acc, '1NN_acc_real': acc_r, '1NN_acc_fake': acc_f}