File size: 1,195 Bytes
f670afc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import torch

from imaginaire.evaluation.common import compute_nn


def _get_1nn_acc(data_x, data_y, k=1):
    device = data_x.device
    n0 = data_x.size(0)
    n1 = data_y.size(0)
    data_all = torch.cat((data_x, data_y), dim=0)
    val, idx = compute_nn(data_all, k)
    label = torch.cat((torch.ones(n0, device=device),
                       torch.zeros(n1, device=device)))

    count = torch.zeros(n0 + n1, device=device)
    for i in range(0, k):
        count = count + label.index_select(0, idx[:, i])
    pred = torch.ge(count, (float(k) / 2) *
                    torch.ones(n0 + n1, device=device)).float()

    tp = (pred * label).sum()
    fp = (pred * (1 - label)).sum()
    fn = ((1 - pred) * label).sum()
    tn = ((1 - pred) * (1 - label)).sum()
    acc_r = (tp / (tp + fn)).item()
    acc_f = (tn / (tn + fp)).item()
    acc = torch.eq(label, pred).float().mean().item()

    return {'1NN_acc': acc,
            '1NN_acc_real': acc_r,
            '1NN_acc_fake': acc_f}