File size: 5,744 Bytes
72f684c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# Refer https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html
# from torchmetrics.image.fid import FrechetInceptionDistance
from PIL import Image
from starvector.metrics.base_metric import BaseMetric
import torch
from torchvision import transforms
import clip
from torch.nn.functional import adaptive_avg_pool2d
from starvector.metrics.inception import InceptionV3
import numpy as np
from tqdm import tqdm
from scipy import linalg
import torchvision.transforms as TF

class FIDCalculator(BaseMetric): 
    def __init__(self, model_name = 'InceptionV3',):
        self.class_name = self.__class__.__name__
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model_name = model_name
        if self.model_name == 'ViT-B/32':
            self.dims = 512
            model, preprocess = clip.load('ViT-B/32')
            
        elif self.model_name == 'InceptionV3':
            self.dims = 2048
            block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[self.dims]
            model = InceptionV3([block_idx]).to(self.device)
            preprocess = TF.Compose([TF.ToTensor()])
        
        self.model = model.cuda()
        self.preprocess = preprocess

    def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6):
        """Numpy implementation of the Frechet Distance.
        The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
        and X_2 ~ N(mu_2, C_2) is
                d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).

        Stable version by Dougal J. Sutherland.

        Params:
        -- mu1   : Numpy array containing the activations of a layer of the
                inception net (like returned by the function 'get_predictions')
                for generated samples.
        -- mu2   : The sample mean over activations, precalculated on an
                representative data set.
        -- sigma1: The covariance matrix over activations for generated samples.
        -- sigma2: The covariance matrix over activations, precalculated on an
                representative data set.

        Returns:
        --   : The Frechet Distance.
        """

        mu1 = np.atleast_1d(mu1)
        mu2 = np.atleast_1d(mu2)

        sigma1 = np.atleast_2d(sigma1)
        sigma2 = np.atleast_2d(sigma2)

        assert mu1.shape == mu2.shape, \
            'Training and test mean vectors have different lengths'
        assert sigma1.shape == sigma2.shape, \
            'Training and test covariances have different dimensions'

        diff = mu1 - mu2

        # Product might be almost singular
        covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
        if not np.isfinite(covmean).all():
            msg = ('fid calculation produces singular product; '
                'adding %s to diagonal of cov estimates') % eps
            print(msg)
            offset = np.eye(sigma1.shape[0]) * eps
            covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

        # Numerical error might give slight imaginary component
        if np.iscomplexobj(covmean):
            if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
                m = np.max(np.abs(covmean.imag))
                raise ValueError('Imaginary component {}'.format(m))
            covmean = covmean.real

        tr_covmean = np.trace(covmean)

        return (diff.dot(diff) + np.trace(sigma1)
                + np.trace(sigma2) - 2 * tr_covmean)

    def get_activations(self, images):
        dataset = ImageDataset(images, self.preprocess)
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=50, shuffle=False, num_workers=4)
        pred_arr = np.empty((len(images), self.dims))
        start_idx = 0
        for batch in tqdm(dataloader):
            batch = batch.to(self.device)
            
            with torch.no_grad():
                if self.model_name == 'ViT-B/32':
                    pred = self.model.encode_image(batch).cpu().numpy()
                elif self.model_name == 'InceptionV3':
                    pred = self.model(batch)[0]

                    # If model output is not scalar, apply global spatial average pooling.
                    # This happens if you choose a dimensionality not equal 2048.
                    if pred.size(2) != 1 or pred.size(3) != 1:
                        pred = adaptive_avg_pool2d(pred, output_size=(1, 1))

                    pred = pred.squeeze(3).squeeze(2).cpu().numpy()
                pred_arr[start_idx:start_idx + pred.shape[0]] = pred
                start_idx = start_idx + pred.shape[0]

        return pred_arr

    def calculate_activation_statistics(self, images):
            act = self.get_activations(images)
            mu = np.mean(act, axis=0)
            sigma = np.cov(act, rowvar=False)
            return mu, sigma

    def pil_images_to_tensor(self, images_list):
        """Convert a list of PIL Images to a torch.Tensor."""
        tensors_list = [self.preprocess(img) for img in images_list]
        return torch.stack(tensors_list).cuda()  # BxCxHxW format

    def calculate_score(self, batch):
        m1, s1 = self.calculate_activation_statistics(batch['gt_im'])
        m2, s2 = self.calculate_activation_statistics(batch['gen_im'])
        fid_value = self.calculate_frechet_distance(m1, s1, m2, s2)
        return fid_value

    def reset(self):
        pass

class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, images, processor=None):
        self.images = images
        self.processor = processor

    def __len__(self):
        return len(self.images)

    def __getitem__(self, i):
        img = self.images[i]
        img = self.processor(img)
        return img