File size: 1,275 Bytes
72f684c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from torchvision.transforms import ToTensor
import torch.nn.functional as F
from starvector.metrics.base_metric import BaseMetric
import torch

class L2DistanceCalculator(BaseMetric): 
    def __init__(self, config=None, masked_l2=False):
        super().__init__()
        self.class_name = self.__class__.__name__
        self.config = config
        self.metric = self.l2_distance
        self.masked_l2 = masked_l2

    def l2_distance(self, **kwargs):
        image1 = kwargs.get('gt_im')
        image2 = kwargs.get('gen_im')
        image1_tensor = ToTensor()(image1)
        image2_tensor = ToTensor()(image2)

        if self.masked_l2:
            # Create binary masks: 0 for white pixels, 1 for non-white pixels
            mask1 = (image1_tensor != 1).any(dim=0).float()
            mask2 = (image2_tensor != 1).any(dim=0).float()

            # Create a combined mask for overlapping non-white pixels
            combined_mask = mask1 * mask2

            # Apply the combined mask to both images
            image1_tensor = image1_tensor * combined_mask.unsqueeze(0)
            image2_tensor = image2_tensor * combined_mask.unsqueeze(0)

        # Compute mean squared error
        mse = F.mse_loss(image1_tensor, image2_tensor)
        return mse.item()