Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
import math | |
import os | |
from functools import partial | |
import torch | |
import torch.distributed as dist | |
from torch import nn | |
from torch.nn import functional as F | |
from torchvision.models import inception_v3 | |
from cleanfid.features import feature_extractor | |
from cleanfid.resize import build_resizer | |
from imaginaire.evaluation.lpips import get_lpips_model | |
from imaginaire.evaluation.segmentation import get_segmentation_hist_model, get_miou | |
from imaginaire.evaluation.caption import get_image_encoder, get_r_precision | |
from imaginaire.evaluation.pretrained import TFInceptionV3, InceptionV3, Vgg16, SwAV | |
from imaginaire.utils.distributed import (dist_all_gather_tensor, get_rank, | |
get_world_size, is_master, | |
is_local_master) | |
from imaginaire.utils.distributed import master_only_print | |
from imaginaire.utils.misc import apply_imagenet_normalization, to_cuda | |
def compute_all_metrics(act_dir, | |
data_loader, | |
net_G, | |
key_real='images', | |
key_fake='fake_images', | |
sample_size=None, | |
preprocess=None, | |
is_video=False, | |
few_shot_video=False, | |
kid_num_subsets=1, | |
kid_subset_size=None, | |
key_prefix='', | |
prdc_k=5, | |
metrics=None, | |
dataset_name='', | |
aws_credentials=None, | |
**kwargs): | |
r""" | |
Args: | |
act_dir (string): Path to a directory to temporarily save feature activations. | |
data_loader (obj): PyTorch dataloader object. | |
net_G (obj): The generator module. | |
key_real (str): Dictionary key value for the real data. | |
key_fake (str): Dictionary key value for the fake data. | |
sample_size (int or None): How many samples to use for FID. | |
preprocess (func or None): Pre-processing function to use. | |
is_video (bool): Whether we are handling video sequences. | |
few_shot_video (bool): If ``True``, uses few-shot video synthesis. | |
kid_num_subsets (int): Number of subsets for KID evaluation. | |
kid_subset_size (int or None): The number of samples in each subset for KID evaluation. | |
key_prefix (string): Add this string before all keys of the output dictionary. | |
prdc_k (int): The K used for computing K-NN when evaluating precision/recall/density/coverage. | |
metrics (list of strings): Which metrics we want to evaluate. | |
dataset_name (string): The name of the dataset, currently only used to determine which segmentation network to | |
use for segmentation evaluation. | |
Returns: | |
batch_y (tensor): Inception features of the current batch. Note that | |
only the master gpu will get it. | |
""" | |
from imaginaire.evaluation.fid import _calculate_frechet_distance | |
from imaginaire.evaluation.kid import _polynomial_mmd_averages | |
from imaginaire.evaluation.prdc import _get_prdc | |
from imaginaire.evaluation.msid import _get_msid | |
from imaginaire.evaluation.knn import _get_1nn_acc | |
if metrics is None: | |
metrics = [] | |
act_path = os.path.join(act_dir, 'activations_real.pt') | |
# Get feature activations and other outputs computed from fake images. | |
output_module_dict = nn.ModuleDict() | |
if "seg_mIOU" in metrics: | |
output_module_dict["seg_mIOU"] = get_segmentation_hist_model(dataset_name, aws_credentials) | |
if "caption_rprec" in metrics: | |
output_module_dict["caption_rprec"] = get_image_encoder(aws_credentials) | |
if "LPIPS" in metrics: | |
output_module_dict["LPIPS"] = get_lpips_model() | |
fake_outputs = get_outputs( | |
data_loader, key_real, key_fake, net_G, sample_size, preprocess, | |
output_module_dict=output_module_dict, **kwargs | |
) | |
fake_act = fake_outputs["activations"] | |
# Get feature activations computed from real images. | |
real_act = load_or_compute_activations( | |
act_path, data_loader, key_real, key_fake, None, | |
sample_size, preprocess, is_video=is_video, | |
few_shot_video=few_shot_video, **kwargs | |
) | |
metrics_from_activations = { | |
"1NN": _get_1nn_acc, | |
"MSID": _get_msid, | |
"FID": _calculate_frechet_distance, | |
"KID": partial(_polynomial_mmd_averages, | |
n_subsets=kid_num_subsets, | |
subset_size=kid_subset_size, | |
ret_var=True), | |
"PRDC": partial(_get_prdc, nearest_k=prdc_k) | |
} | |
other_metrics = { | |
"seg_mIOU": get_miou, | |
"caption_rprec": get_r_precision, | |
"LPIPS": lambda x: {"LPIPS": torch.mean(x).item()} | |
} | |
all_metrics = {} | |
if is_master(): | |
for metric in metrics: | |
if metric in metrics_from_activations: | |
metric_function = metrics_from_activations[metric] | |
metric_dict = metric_function(real_act, fake_act) | |
elif metric in other_metrics: | |
metric_function = other_metrics[metric] | |
if fake_outputs[metric] is not None: | |
metric_dict = metric_function(fake_outputs[metric]) | |
else: | |
print(f"{metric} is not implemented!") | |
raise NotImplementedError | |
for k, v in metric_dict.items(): | |
all_metrics.update({key_prefix + k: v}) | |
if dist.is_initialized(): | |
dist.barrier() | |
return all_metrics | |
def compute_all_metrics_data(data_loader_a, | |
data_loader_b, | |
key_a='images', | |
key_b='images', | |
sample_size=None, | |
preprocess=None, | |
kid_num_subsets=1, | |
kid_subset_size=None, | |
key_prefix='', | |
prdc_k=5, | |
metrics=None, | |
dataset_name='', | |
aws_credentials=None, | |
**kwargs): | |
r""" | |
Args: | |
act_dir (string): Path to a directory to temporarily save feature activations. | |
data_loader (obj): PyTorch dataloader object. | |
net_G (obj): The generator module. | |
key_a (str): Dictionary key value for the real data. | |
key_b (str): Dictionary key value for the fake data. | |
sample_size (int or None): How many samples to use for FID. | |
preprocess (func or None): Pre-processing function to use. | |
is_video (bool): Whether we are handling video sequences. | |
few_shot_video (bool): If ``True``, uses few-shot video synthesis. | |
kid_num_subsets (int): Number of subsets for KID evaluation. | |
kid_subset_size (int or None): The number of samples in each subset for KID evaluation. | |
key_prefix (string): Add this string before all keys of the output dictionary. | |
prdc_k (int): The K used for computing K-NN when evaluating precision/recall/density/coverage. | |
metrics (list of strings): Which metrics we want to evaluate. | |
dataset_name (string): The name of the dataset, currently only used to determine which segmentation network to | |
use for segmentation evaluation. | |
Returns: | |
batch_y (tensor): Inception features of the current batch. Note that | |
only the master gpu will get it. | |
""" | |
from imaginaire.evaluation.fid import _calculate_frechet_distance | |
from imaginaire.evaluation.kid import _polynomial_mmd_averages | |
from imaginaire.evaluation.prdc import _get_prdc | |
from imaginaire.evaluation.msid import _get_msid | |
from imaginaire.evaluation.knn import _get_1nn_acc | |
if metrics is None: | |
metrics = [] | |
min_data_size = min(len(data_loader_a.dataset), | |
len(data_loader_b.dataset)) | |
if sample_size is None: | |
sample_size = min_data_size | |
else: | |
sample_size = min(sample_size, min_data_size) | |
# Get feature activations and other outputs computed from fake images. | |
output_module_dict = nn.ModuleDict() | |
if "seg_mIOU" in metrics: | |
output_module_dict["seg_mIOU"] = get_segmentation_hist_model(dataset_name, aws_credentials) | |
if "caption_rprec" in metrics: | |
output_module_dict["caption_rprec"] = get_image_encoder(aws_credentials) | |
if "LPIPS" in metrics: | |
output_module_dict["LPIPS"] = get_lpips_model() | |
fake_outputs = get_outputs( | |
data_loader_b, key_a, key_b, None, sample_size, preprocess, | |
output_module_dict=output_module_dict, **kwargs | |
) | |
act_b = fake_outputs["activations"] | |
act_a = load_or_compute_activations( | |
None, data_loader_a, key_a, key_b, None, sample_size, preprocess, | |
output_module_dict=output_module_dict, **kwargs | |
) | |
# act_b = load_or_compute_activations( | |
# None, data_loader_b, key_a, key_b, None, sample_size, preprocess, | |
# output_module_dict=output_module_dict, generate_twice=generate_twice, **kwargs | |
# ) | |
metrics_from_activations = { | |
"1NN": _get_1nn_acc, | |
"MSID": _get_msid, | |
"FID": _calculate_frechet_distance, | |
"KID": partial(_polynomial_mmd_averages, | |
n_subsets=kid_num_subsets, | |
subset_size=kid_subset_size, | |
ret_var=True), | |
"PRDC": partial(_get_prdc, nearest_k=prdc_k) | |
} | |
other_metrics = { | |
"seg_mIOU": get_miou, | |
"caption_rprec": get_r_precision, | |
"LPIPS": lambda x: {"LPIPS": torch.mean(x).item()} | |
} | |
all_metrics = {} | |
if is_master(): | |
for metric in metrics: | |
if metric in metrics_from_activations: | |
metric_function = metrics_from_activations[metric] | |
metric_dict = metric_function(act_a, act_b) | |
elif metric in other_metrics: | |
metric_function = other_metrics[metric] | |
if fake_outputs[metric] is not None: | |
metric_dict = metric_function(fake_outputs[metric]) | |
else: | |
print(f"{metric} is not implemented!") | |
raise NotImplementedError | |
for k, v in metric_dict.items(): | |
all_metrics.update({key_prefix + k: v}) | |
if dist.is_initialized(): | |
dist.barrier() | |
return all_metrics | |
def get_activations(data_loader, key_real, key_fake, | |
generator=None, sample_size=None, preprocess=None, | |
align_corners=True, network='inception', **kwargs): | |
r"""Compute activation values and pack them in a list. | |
Args: | |
data_loader (obj): PyTorch dataloader object. | |
key_real (str): Dictionary key value for the real data. | |
key_fake (str): Dictionary key value for the fake data. | |
generator (obj): PyTorch trainer network. | |
sample_size (int): How many samples to use for FID. | |
preprocess (func): Pre-processing function to use. | |
align_corners (bool): The ``'align_corners'`` parameter to be used for | |
`torch.nn.functional.interpolate`. | |
Returns: | |
batch_y (tensor): Inception features of the current batch. Note that | |
only the master gpu will get it. | |
""" | |
if dist.is_initialized() and not is_local_master(): | |
# Make sure only the first process in distributed training downloads | |
# the model, and the others will use the cache | |
# noinspection PyUnresolvedReferences | |
torch.distributed.barrier() | |
if network == 'tf_inception': | |
model = TFInceptionV3() | |
elif network == 'inception': | |
model = InceptionV3() | |
elif network == 'vgg16': | |
model = Vgg16() | |
elif network == 'swav': | |
model = SwAV() | |
elif network == 'clean_inception': | |
model = CleanInceptionV3() | |
else: | |
raise NotImplementedError(f'Network "{network}" is not supported!') | |
if dist.is_initialized() and is_local_master(): | |
# Make sure only the first process in distributed training downloads | |
# the model, and the others will use the cache | |
# noinspection PyUnresolvedReferences | |
dist.barrier() | |
model = model.to('cuda').eval() | |
world_size = get_world_size() | |
batch_y = [] | |
# Iterate through the dataset to compute the activation. | |
for it, data in enumerate(data_loader): | |
data = to_cuda(data) | |
# Preprocess the data. | |
if preprocess is not None: | |
data = preprocess(data) | |
# Load real data if the generator is not specified. | |
if generator is None: | |
images = data[key_real] | |
else: | |
# Compute the generated image. | |
net_G_output = generator(data, **kwargs) | |
images = net_G_output[key_fake] | |
# Clamp the image for models that do not set the output to between | |
# -1, 1. For models that employ tanh, this has no effect. | |
images.clamp_(-1, 1) | |
y = model(images, align_corners=align_corners) | |
batch_y.append(y) | |
if sample_size is not None and \ | |
data_loader.batch_size * world_size * (it + 1) >= sample_size: | |
# Reach the number of samples we need. | |
break | |
batch_y = torch.cat(dist_all_gather_tensor(torch.cat(batch_y))) | |
if sample_size is not None: | |
batch_y = batch_y[:sample_size] | |
print(f"Computed feature activations of size {batch_y.shape}") | |
return batch_y | |
class CleanInceptionV3(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.model = feature_extractor(name="torchscript_inception", resize_inside=False) | |
def forward(self, img_batch, transform=True, **_kwargs): | |
if transform: | |
# Assume the input is (-1, 1). We transform it to (0, 255) and round it to the closest integer. | |
img_batch = torch.round(255 * (0.5 * img_batch + 0.5)) | |
resized_batch = clean_resize(img_batch) | |
return self.model(resized_batch) | |
def clean_resize(img_batch): | |
# Resize images from arbitrary resolutions to 299x299. | |
batch_size = img_batch.size(0) | |
img_batch = img_batch.cpu().numpy() | |
fn_resize = build_resizer('clean') | |
resized_batch = torch.zeros(batch_size, 3, 299, 299, device='cuda') | |
for idx in range(batch_size): | |
curr_img = img_batch[idx] | |
img_np = curr_img.transpose((1, 2, 0)) | |
img_resize = fn_resize(img_np) | |
resized_batch[idx] = torch.tensor(img_resize.transpose((2, 0, 1)), device='cuda') | |
resized_batch = resized_batch.cuda() | |
return resized_batch | |
def get_outputs(data_loader, key_real, key_fake, | |
generator=None, sample_size=None, preprocess=None, | |
align_corners=True, network='inception', | |
output_module_dict=None, **kwargs): | |
r"""Compute activation values and pack them in a list. | |
Args: | |
data_loader (obj): PyTorch dataloader object. | |
key_real (str): Dictionary key value for the real data. | |
key_fake (str): Dictionary key value for the fake data. | |
generator (obj): PyTorch trainer network. | |
sample_size (int): How many samples to use for FID. | |
preprocess (func): Pre-processing function to use. | |
align_corners (bool): The ``'align_corners'`` parameter to be used for `torch.nn.functional.interpolate`. | |
Returns: | |
batch_y (tensor): Inception features of the current batch. Note that | |
only the master gpu will get it. | |
""" | |
if output_module_dict is None: | |
output_module_dict = nn.ModuleDict() | |
if dist.is_initialized() and not is_local_master(): | |
# Make sure only the first process in distributed training downloads | |
# the model, and the others will use the cache | |
# noinspection PyUnresolvedReferences | |
torch.distributed.barrier() | |
if network == 'tf_inception': | |
model = TFInceptionV3() | |
elif network == 'inception': | |
model = InceptionV3() | |
elif network == 'vgg16': | |
model = Vgg16() | |
elif network == 'swav': | |
model = SwAV() | |
elif network == 'clean_inception': | |
model = CleanInceptionV3() | |
else: | |
raise NotImplementedError(f'Network "{network}" is not supported!') | |
if dist.is_initialized() and is_local_master(): | |
# Make sure only the first process in distributed training downloads | |
# the model, and the others will use the cache | |
# noinspection PyUnresolvedReferences | |
dist.barrier() | |
model = model.to('cuda').eval() | |
world_size = get_world_size() | |
output = {} | |
for k in output_module_dict.keys(): | |
output[k] = [] | |
output["activations"] = [] | |
# Iterate through the dataset to compute the activation. | |
for it, data in enumerate(data_loader): | |
data = to_cuda(data) | |
# Preprocess the data. | |
if preprocess is not None: | |
data = preprocess(data) | |
# Load real data if the generator is not specified. | |
if generator is None: | |
images = data[key_real] | |
else: | |
# Compute the generated image. | |
net_G_output = generator(data, **kwargs) | |
images = net_G_output[key_fake] | |
for metric_name, metric_module in output_module_dict.items(): | |
if metric_module is not None: | |
if metric_name == 'LPIPS': | |
assert generator is not None | |
net_G_output_another = generator(data, **kwargs) | |
images_another = net_G_output_another[key_fake] | |
output[metric_name].append(metric_module(images, images_another)) | |
else: | |
output[metric_name].append(metric_module(data, images, align_corners=align_corners)) | |
# Clamp the image for models that do not set the output to between | |
# -1, 1. For models that employ tanh, this has no effect. | |
images.clamp_(-1, 1) | |
y = model(images, align_corners=align_corners) | |
output["activations"].append(y) | |
if sample_size is not None and data_loader.batch_size * world_size * (it + 1) >= sample_size: | |
# Reach the number of samples we need. | |
break | |
for k, v in output.items(): | |
if len(v) > 0: | |
output[k] = torch.cat(dist_all_gather_tensor(torch.cat(v)))[:sample_size] | |
else: | |
output[k] = None | |
return output | |
def get_video_activations(data_loader, key_real, key_fake, trainer=None, | |
sample_size=None, preprocess=None, few_shot=False): | |
r"""Compute activation values and pack them in a list. We do not do all | |
reduce here. | |
Args: | |
data_loader (obj): PyTorch dataloader object. | |
key_real (str): Dictionary key value for the real data. | |
key_fake (str): Dictionary key value for the fake data. | |
trainer (obj): Trainer. Video generation is more involved, we rely on | |
the "reset" and "test" function to conduct the evaluation. | |
sample_size (int): For computing video activation, we will use . | |
preprocess (func): The preprocess function to be applied to the data. | |
few_shot (bool): If ``True``, uses the few-shot setting. | |
Returns: | |
batch_y (tensor): Inception features of the current batch. Note that | |
only the master gpu will get it. | |
""" | |
inception = inception_init() | |
batch_y = [] | |
# We divide video sequences to different GPUs for testing. | |
num_sequences = data_loader.dataset.num_inference_sequences() | |
if sample_size is None: | |
num_videos_to_test = 10 | |
num_frames_per_video = 5 | |
else: | |
num_videos_to_test, num_frames_per_video = sample_size | |
if num_videos_to_test == -1: | |
num_videos_to_test = num_sequences | |
else: | |
num_videos_to_test = min(num_videos_to_test, num_sequences) | |
master_only_print('Number of videos used for evaluation: {}'.format(num_videos_to_test)) | |
master_only_print('Number of frames per video used for evaluation: {}'.format(num_frames_per_video)) | |
world_size = get_world_size() | |
if num_videos_to_test < world_size: | |
seq_to_run = [get_rank() % num_videos_to_test] | |
else: | |
num_videos_to_test = num_videos_to_test // world_size * world_size | |
seq_to_run = range(get_rank(), num_videos_to_test, world_size) | |
for sequence_idx in seq_to_run: | |
data_loader = set_sequence_idx(few_shot, data_loader, sequence_idx) | |
if trainer is not None: | |
trainer.reset() | |
for it, data in enumerate(data_loader): | |
if few_shot and it == 0: | |
continue | |
if it >= num_frames_per_video: | |
break | |
# preprocess the data is preprocess is not none. | |
if trainer is not None: | |
data = trainer.pre_process(data) | |
elif preprocess is not None: | |
data = preprocess(data) | |
data = to_cuda(data) | |
if trainer is None: | |
images = data[key_real][:, -1] | |
else: | |
net_G_output = trainer.test_single(data) | |
images = net_G_output[key_fake] | |
y = inception_forward(inception, images) | |
batch_y += [y] | |
batch_y = torch.cat(batch_y) | |
batch_y = dist_all_gather_tensor(batch_y) | |
if is_local_master(): | |
batch_y = torch.cat(batch_y) | |
return batch_y | |
def inception_init(): | |
inception = inception_v3(pretrained=True, transform_input=False) | |
inception = inception.to('cuda') | |
inception.eval() | |
inception.fc = torch.nn.Sequential() | |
return inception | |
def inception_forward(inception, images): | |
images.clamp_(-1, 1) | |
images = apply_imagenet_normalization(images) | |
images = F.interpolate(images, size=(299, 299), | |
mode='bicubic', align_corners=True) | |
return inception(images) | |
def gather_tensors(batch_y): | |
batch_y = torch.cat(batch_y) | |
batch_y = dist_all_gather_tensor(batch_y) | |
if is_local_master(): | |
batch_y = torch.cat(batch_y) | |
return batch_y | |
def set_sequence_idx(few_shot, data_loader, sequence_idx): | |
r"""Get sequence index | |
Args: | |
few_shot (bool): If ``True``, uses the few-shot setting. | |
data_loader: dataloader object | |
sequence_idx (int): which sequence to use. | |
""" | |
if few_shot: | |
data_loader.dataset.set_inference_sequence_idx(sequence_idx, | |
sequence_idx, | |
0) | |
else: | |
data_loader.dataset.set_inference_sequence_idx(sequence_idx) | |
return data_loader | |
def load_or_compute_activations(act_path, data_loader, key_real, key_fake, | |
generator=None, sample_size=None, | |
preprocess=None, | |
is_video=False, few_shot_video=False, | |
**kwargs): | |
r"""Load mean and covariance from saved npy file if exists. Otherwise, | |
compute the mean and covariance. | |
Args: | |
act_path (str or None): Location for the numpy file to store or to load | |
the activations. | |
data_loader (obj): PyTorch dataloader object. | |
key_real (str): Dictionary key value for the real data. | |
key_fake (str): Dictionary key value for the fake data. | |
generator (obj): PyTorch trainer network. | |
sample_size (int): How many samples to be used for computing the KID. | |
preprocess (func): The preprocess function to be applied to the data. | |
is_video (bool): Whether we are handling video sequences. | |
few_shot_video (bool): If ``True``, uses few-shot video synthesis. | |
Returns: | |
(torch.Tensor) Feature activations. | |
""" | |
if act_path is not None and os.path.exists(act_path): | |
# Loading precomputed activations. | |
print('Load activations from {}'.format(act_path)) | |
act = torch.load(act_path, map_location='cpu').cuda() | |
else: | |
# Compute activations. | |
if is_video: | |
act = get_video_activations( | |
data_loader, key_real, key_fake, generator, | |
sample_size, preprocess, few_shot_video, **kwargs | |
) | |
else: | |
act = get_activations( | |
data_loader, key_real, key_fake, generator, | |
sample_size, preprocess, **kwargs | |
) | |
if act_path is not None and is_local_master(): | |
print('Save activations to {}'.format(act_path)) | |
if not os.path.exists(os.path.dirname(act_path)): | |
os.makedirs(os.path.dirname(act_path), exist_ok=True) | |
torch.save(act, act_path) | |
return act | |
def compute_pairwise_distance(data_x, data_y=None, num_splits=10): | |
r""" | |
Args: | |
data_x: numpy.ndarray([N, feature_dim], dtype=np.float32) | |
data_y: numpy.ndarray([N, feature_dim], dtype=np.float32) | |
Returns: | |
numpy.ndarray([N, N], dtype=np.float32) of pairwise distances. | |
""" | |
if data_y is None: | |
data_y = data_x | |
num_samples = data_x.shape[0] | |
assert data_x.shape[0] == data_y.shape[0] | |
dists = [] | |
for i in range(num_splits): | |
batch_size = math.ceil(num_samples / num_splits) | |
start_idx = i * batch_size | |
end_idx = min((i + 1) * batch_size, num_samples) | |
dists.append(torch.cdist(data_x[start_idx:end_idx], | |
data_y).cpu()) | |
dists = torch.cat(dists, dim=0) | |
return dists | |
def compute_nn(input_features, k, num_splits=50): | |
num_samples = input_features.shape[0] | |
all_indices = [] | |
all_values = [] | |
for i in range(num_splits): | |
batch_size = math.ceil(num_samples / num_splits) | |
start_idx = i * batch_size | |
end_idx = min((i + 1) * batch_size, num_samples) | |
dist = torch.cdist(input_features[start_idx:end_idx], | |
input_features) | |
dist[:, start_idx:end_idx] += torch.diag( | |
float('inf') * torch.ones(dist.size(0), device=dist.device) | |
) | |
k_smallests, indices = torch.topk(dist, k, dim=-1, largest=False) | |
all_indices.append(indices) | |
all_values.append(k_smallests) | |
return torch.cat(all_values, dim=0), torch.cat(all_indices, dim=0) | |