vfontech's picture
Uploading the app
587665f verified
raw
history blame
1.83 kB
import torch
import itertools
from torchmetrics.image import LearnedPerceptualImagePatchSimilarity as LPIPS
from utils.utils import denorm
def compute_lpips_variability(samples: torch.Tensor,
net: str = 'alex',
device: str = 'cuda'
) -> float:
loss_fn = LPIPS(net_type=net).to(device)
loss_fn.eval()
if samples.min() >= 0.0:
samples = samples * 2 - 1 # Convertir [0, 1] β†’ [-1, 1]
N = samples.size(0)
scores = []
for i, j in itertools.combinations(range(N), 2):
x = samples[i:i+1].to(device)
y = samples[j:j+1].to(device)
dist = loss_fn(denorm(x.clamp(-1, 1)), denorm(y.clamp(-1, 1)))
scores.append(dist.item())
return sum(scores) / len(scores)
def compute_pixelwise_correlation(samples: torch.Tensor) -> float:
N, C, H, W = samples.shape
samples_flat = samples.view(N, C, -1) # (N, C, H*W)
corrs = []
for i, j in itertools.combinations(range(N), 2):
x = samples_flat[i] # (C, HW)
y = samples_flat[j] # (C, HW)
mean_x = x.mean(dim=1, keepdim=True)
mean_y = y.mean(dim=1, keepdim=True)
x_centered = x - mean_x
y_centered = y - mean_y
numerator = (x_centered * y_centered).sum(dim=1)
denominator = (x_centered.norm(dim=1) * y_centered.norm(dim=1)) + 1e-8
corr = numerator / denominator # (C,)
corrs.append(corr.mean().item())
return sum(corrs) / len(corrs)
def compute_dynamic_range(samples: torch.Tensor) -> float:
max_vals, _ = samples.max(dim=0) # (C, H, W)
min_vals, _ = samples.min(dim=0) # (C, H, W)
dynamic_range = max_vals - min_vals # (C, H, W)
return dynamic_range.mean().item()