Spaces:
Running
on
Zero
Running
on
Zero
import numpy | |
from numpy import cov | |
from numpy import trace | |
from numpy import iscomplexobj | |
from numpy.random import random | |
from scipy.linalg import sqrtm | |
def calculate_fid(act1, act2): | |
# calculate mean and covariance statistics | |
mu1, sigma1 = act1.mean(axis=0), cov(act1, rowvar=False) | |
mu2, sigma2 = act2.mean(axis=0), cov(act2, rowvar=False) | |
print("mu1 ", mu1.shape) | |
print("mu2 ", mu2.shape) | |
print("sigma1 ", sigma1.shape) | |
print("sigma2 ", sigma2.shape) | |
# calculate sum squared difference between means | |
ssdiff = numpy.sum((mu1 - mu2) * 2.0) | |
# calculate sqrt of product between cov | |
covmean = sqrtm(sigma1.dot(sigma2)) | |
# check and correct imaginary numbers from sqrt | |
if iscomplexobj(covmean): | |
covmean = covmean.real | |
# calculate score | |
fid = ssdiff + trace(sigma1 + sigma2 - 2.0 * covmean) | |
return fid | |
act1 = random(2048 * 2) | |
act1 = act1.reshape((2, 2048)) | |
act2 = random(2048 * 2) | |
act2 = act2.reshape((2, 2048)) | |
fid = calculate_fid(act1, act1) | |
print("FID (same): %.3f" % fid) | |
fid = calculate_fid(act1, act2) | |
print("FID (different): %.3f" % fid) | |