Spaces:
Running
Running
File size: 1,275 Bytes
72f684c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
from torchvision.transforms import ToTensor
import torch.nn.functional as F
from starvector.metrics.base_metric import BaseMetric
import torch
class L2DistanceCalculator(BaseMetric):
def __init__(self, config=None, masked_l2=False):
super().__init__()
self.class_name = self.__class__.__name__
self.config = config
self.metric = self.l2_distance
self.masked_l2 = masked_l2
def l2_distance(self, **kwargs):
image1 = kwargs.get('gt_im')
image2 = kwargs.get('gen_im')
image1_tensor = ToTensor()(image1)
image2_tensor = ToTensor()(image2)
if self.masked_l2:
# Create binary masks: 0 for white pixels, 1 for non-white pixels
mask1 = (image1_tensor != 1).any(dim=0).float()
mask2 = (image2_tensor != 1).any(dim=0).float()
# Create a combined mask for overlapping non-white pixels
combined_mask = mask1 * mask2
# Apply the combined mask to both images
image1_tensor = image1_tensor * combined_mask.unsqueeze(0)
image2_tensor = image2_tensor * combined_mask.unsqueeze(0)
# Compute mean squared error
mse = F.mse_loss(image1_tensor, image2_tensor)
return mse.item()
|