Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
import torch | |
from imaginaire.evaluation.common import compute_nn | |
def _get_1nn_acc(data_x, data_y, k=1): | |
device = data_x.device | |
n0 = data_x.size(0) | |
n1 = data_y.size(0) | |
data_all = torch.cat((data_x, data_y), dim=0) | |
val, idx = compute_nn(data_all, k) | |
label = torch.cat((torch.ones(n0, device=device), | |
torch.zeros(n1, device=device))) | |
count = torch.zeros(n0 + n1, device=device) | |
for i in range(0, k): | |
count = count + label.index_select(0, idx[:, i]) | |
pred = torch.ge(count, (float(k) / 2) * | |
torch.ones(n0 + n1, device=device)).float() | |
tp = (pred * label).sum() | |
fp = (pred * (1 - label)).sum() | |
fn = ((1 - pred) * label).sum() | |
tn = ((1 - pred) * (1 - label)).sum() | |
acc_r = (tp / (tp + fn)).item() | |
acc_f = (tn / (tn + fp)).item() | |
acc = torch.eq(label, pred).float().mean().item() | |
return {'1NN_acc': acc, | |
'1NN_acc_real': acc_r, | |
'1NN_acc_fake': acc_f} | |