interfacegan_pp / models /pggan_tf_official /metrics /frechet_inception_distance.py
ybelkada's picture
commit files
4d6b877
#!/usr/bin/env python3
#
# Copyright 2017 Martin Heusel
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from the original implementation by Martin Heusel.
# Source https://github.com/bioinf-jku/TTUR/blob/master/fid.py
''' Calculates the Frechet Inception Distance (FID) to evalulate GANs.
The FID metric calculates the distance between two distributions of images.
Typically, we have summary statistics (mean & covariance matrix) of one
of these distributions, while the 2nd distribution is given by a GAN.
When run as a stand-alone program, it compares the distribution of
images that are stored as PNG/JPEG at a specified location with a
distribution given by summary statistics (in pickle format).
The FID is calculated by assuming that X_1 and X_2 are the activations of
the pool_3 layer of the inception net for generated samples and real world
samples respectivly.
See --help to see further details.
'''
from __future__ import absolute_import, division, print_function
import numpy as np
import scipy as sp
import os
import gzip, pickle
import tensorflow as tf
from scipy.misc import imread
import pathlib
import urllib
class InvalidFIDException(Exception):
pass
def create_inception_graph(pth):
"""Creates a graph from saved GraphDef file."""
# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile( pth, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString( f.read())
_ = tf.import_graph_def( graph_def, name='FID_Inception_Net')
#-------------------------------------------------------------------------------
# code for handling inception net derived from
# https://github.com/openai/improved-gan/blob/master/inception_score/model.py
def _get_inception_layer(sess):
"""Prepares inception net for batched usage and returns pool_3 layer. """
layername = 'FID_Inception_Net/pool_3:0'
pool3 = sess.graph.get_tensor_by_name(layername)
ops = pool3.graph.get_operations()
for op_idx, op in enumerate(ops):
for o in op.outputs:
shape = o.get_shape()
if shape._dims is not None:
shape = [s.value for s in shape]
new_shape = []
for j, s in enumerate(shape):
if s == 1 and j == 0:
new_shape.append(None)
else:
new_shape.append(s)
try:
o._shape = tf.TensorShape(new_shape)
except ValueError:
o._shape_val = tf.TensorShape(new_shape) # EDIT: added for compatibility with tensorflow 1.6.0
return pool3
#-------------------------------------------------------------------------------
def get_activations(images, sess, batch_size=50, verbose=False):
"""Calculates the activations of the pool_3 layer for all images.
Params:
-- images : Numpy array of dimension (n_images, hi, wi, 3). The values
must lie between 0 and 256.
-- sess : current session
-- batch_size : the images numpy array is split into batches with batch size
batch_size. A reasonable batch size depends on the disposable hardware.
-- verbose : If set to True and parameter out_step is given, the number of calculated
batches is reported.
Returns:
-- A numpy array of dimension (num images, 2048) that contains the
activations of the given tensor when feeding inception with the query tensor.
"""
inception_layer = _get_inception_layer(sess)
d0 = images.shape[0]
if batch_size > d0:
print("warning: batch size is bigger than the data size. setting batch size to data size")
batch_size = d0
n_batches = d0//batch_size
n_used_imgs = n_batches*batch_size
pred_arr = np.empty((n_used_imgs,2048))
for i in range(n_batches):
if verbose:
print("\rPropagating batch %d/%d" % (i+1, n_batches), end="", flush=True)
start = i*batch_size
end = start + batch_size
batch = images[start:end]
pred = sess.run(inception_layer, {'FID_Inception_Net/ExpandDims:0': batch})
pred_arr[start:end] = pred.reshape(batch_size,-1)
if verbose:
print(" done")
return pred_arr
#-------------------------------------------------------------------------------
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2):
"""Numpy implementation of the Frechet Distance.
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
and X_2 ~ N(mu_2, C_2) is
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
Params:
-- mu1 : Numpy array containing the activations of the pool_3 layer of the
inception net ( like returned by the function 'get_predictions')
-- mu2 : The sample mean over activations of the pool_3 layer, precalcualted
on an representive data set.
-- sigma2: The covariance matrix over activations of the pool_3 layer,
precalcualted on an representive data set.
Returns:
-- dist : The Frechet Distance.
Raises:
-- InvalidFIDException if nan occures.
"""
m = np.square(mu1 - mu2).sum()
#s = sp.linalg.sqrtm(np.dot(sigma1, sigma2)) # EDIT: commented out
s, _ = sp.linalg.sqrtm(np.dot(sigma1, sigma2), disp=False) # EDIT: added
dist = m + np.trace(sigma1+sigma2 - 2*s)
#if np.isnan(dist): # EDIT: commented out
# raise InvalidFIDException("nan occured in distance calculation.") # EDIT: commented out
#return dist # EDIT: commented out
return np.real(dist) # EDIT: added
#-------------------------------------------------------------------------------
def calculate_activation_statistics(images, sess, batch_size=50, verbose=False):
"""Calculation of the statistics used by the FID.
Params:
-- images : Numpy array of dimension (n_images, hi, wi, 3). The values
must lie between 0 and 255.
-- sess : current session
-- batch_size : the images numpy array is split into batches with batch size
batch_size. A reasonable batch size depends on the available hardware.
-- verbose : If set to True and parameter out_step is given, the number of calculated
batches is reported.
Returns:
-- mu : The mean over samples of the activations of the pool_3 layer of
the incption model.
-- sigma : The covariance matrix of the activations of the pool_3 layer of
the incption model.
"""
act = get_activations(images, sess, batch_size, verbose)
mu = np.mean(act, axis=0)
sigma = np.cov(act, rowvar=False)
return mu, sigma
#-------------------------------------------------------------------------------
#-------------------------------------------------------------------------------
# The following functions aren't needed for calculating the FID
# they're just here to make this module work as a stand-alone script
# for calculating FID scores
#-------------------------------------------------------------------------------
def check_or_download_inception(inception_path):
''' Checks if the path to the inception file is valid, or downloads
the file if it is not present. '''
INCEPTION_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
if inception_path is None:
inception_path = '/tmp'
inception_path = pathlib.Path(inception_path)
model_file = inception_path / 'classify_image_graph_def.pb'
if not model_file.exists():
print("Downloading Inception model")
from urllib import request
import tarfile
fn, _ = request.urlretrieve(INCEPTION_URL)
with tarfile.open(fn, mode='r') as f:
f.extract('classify_image_graph_def.pb', str(model_file.parent))
return str(model_file)
def _handle_path(path, sess):
if path.endswith('.npz'):
f = np.load(path)
m, s = f['mu'][:], f['sigma'][:]
f.close()
else:
path = pathlib.Path(path)
files = list(path.glob('*.jpg')) + list(path.glob('*.png'))
x = np.array([imread(str(fn)).astype(np.float32) for fn in files])
m, s = calculate_activation_statistics(x, sess)
return m, s
def calculate_fid_given_paths(paths, inception_path):
''' Calculates the FID of two paths. '''
inception_path = check_or_download_inception(inception_path)
for p in paths:
if not os.path.exists(p):
raise RuntimeError("Invalid path: %s" % p)
create_inception_graph(str(inception_path))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
m1, s1 = _handle_path(paths[0], sess)
m2, s2 = _handle_path(paths[1], sess)
fid_value = calculate_frechet_distance(m1, s1, m2, s2)
return fid_value
if __name__ == "__main__":
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument("path", type=str, nargs=2,
help='Path to the generated images or to .npz statistic files')
parser.add_argument("-i", "--inception", type=str, default=None,
help='Path to Inception model (will be downloaded if not provided)')
parser.add_argument("--gpu", default="", type=str,
help='GPU to use (leave blank for CPU only)')
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
fid_value = calculate_fid_given_paths(args.path, args.inception)
print("FID: ", fid_value)
#----------------------------------------------------------------------------
# EDIT: added
class API:
def __init__(self, num_images, image_shape, image_dtype, minibatch_size):
import config
self.network_dir = os.path.join(config.result_dir, '_inception_fid')
self.network_file = check_or_download_inception(self.network_dir)
self.sess = tf.get_default_session()
create_inception_graph(self.network_file)
def get_metric_names(self):
return ['FID']
def get_metric_formatting(self):
return ['%-10.4f']
def begin(self, mode):
assert mode in ['warmup', 'reals', 'fakes']
self.activations = []
def feed(self, mode, minibatch):
act = get_activations(minibatch.transpose(0,2,3,1), self.sess, batch_size=minibatch.shape[0])
self.activations.append(act)
def end(self, mode):
act = np.concatenate(self.activations)
mu = np.mean(act, axis=0)
sigma = np.cov(act, rowvar=False)
if mode in ['warmup', 'reals']:
self.mu_real = mu
self.sigma_real = sigma
fid = calculate_frechet_distance(mu, sigma, self.mu_real, self.sigma_real)
return [fid]
#----------------------------------------------------------------------------