|
import torch
|
|
import itertools
|
|
from torchmetrics.image import LearnedPerceptualImagePatchSimilarity as LPIPS
|
|
from utils.utils import denorm
|
|
|
|
def compute_lpips_variability(samples: torch.Tensor,
|
|
net: str = 'alex',
|
|
device: str = 'cuda'
|
|
) -> float:
|
|
loss_fn = LPIPS(net_type=net).to(device)
|
|
loss_fn.eval()
|
|
|
|
if samples.min() >= 0.0:
|
|
samples = samples * 2 - 1
|
|
|
|
N = samples.size(0)
|
|
scores = []
|
|
for i, j in itertools.combinations(range(N), 2):
|
|
x = samples[i:i+1].to(device)
|
|
y = samples[j:j+1].to(device)
|
|
dist = loss_fn(denorm(x.clamp(-1, 1)), denorm(y.clamp(-1, 1)))
|
|
scores.append(dist.item())
|
|
|
|
return sum(scores) / len(scores)
|
|
|
|
def compute_pixelwise_correlation(samples: torch.Tensor) -> float:
|
|
N, C, H, W = samples.shape
|
|
samples_flat = samples.view(N, C, -1)
|
|
|
|
corrs = []
|
|
for i, j in itertools.combinations(range(N), 2):
|
|
x = samples_flat[i]
|
|
y = samples_flat[j]
|
|
mean_x = x.mean(dim=1, keepdim=True)
|
|
mean_y = y.mean(dim=1, keepdim=True)
|
|
x_centered = x - mean_x
|
|
y_centered = y - mean_y
|
|
numerator = (x_centered * y_centered).sum(dim=1)
|
|
denominator = (x_centered.norm(dim=1) * y_centered.norm(dim=1)) + 1e-8
|
|
corr = numerator / denominator
|
|
corrs.append(corr.mean().item())
|
|
return sum(corrs) / len(corrs)
|
|
|
|
def compute_dynamic_range(samples: torch.Tensor) -> float:
|
|
max_vals, _ = samples.max(dim=0)
|
|
min_vals, _ = samples.min(dim=0)
|
|
|
|
dynamic_range = max_vals - min_vals
|
|
return dynamic_range.mean().item() |