File size: 1,461 Bytes
72f684c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from starvector.metrics.base_metric import BaseMetric
from skimage.metrics import structural_similarity as ssim
import numpy as np

class SSIMDistanceCalculator(BaseMetric): 
    def __init__(self, config=None):
        super().__init__()
        self.class_name = self.__class__.__name__
        self.config = config
        self.metric = self.compute_SSIM
    
    def compute_SSIM(self, **kwargs):
        image1 = kwargs.get('gt_im')
        image2 = kwargs.get('gen_im')
        win_size = kwargs.get('win_size', 11)  # Increase win_size for more accuracy
        channel_axis = kwargs.get('channel_axis', -1)  # Default channel_axis to -1
        sigma = kwargs.get('sigma', 1.5)  # Add sigma parameter for Gaussian filter

        # Convert images to numpy arrays if they aren't already
        img1_np = np.array(image1)
        img2_np = np.array(image2)

        # Check if images are grayscale or RGB
        if len(img1_np.shape) == 3 and img1_np.shape[2] == 3:
            # Compute SSIM for RGB images
            score, _ = ssim(img1_np, img2_np, win_size=win_size, channel_axis=channel_axis, sigma=sigma, full=True)
        else:
            # Convert to grayscale if not already
            if len(img1_np.shape) == 3:
                img1_np = np.mean(img1_np, axis=2)
                img2_np = np.mean(img2_np, axis=2)
            
            score, _ = ssim(img1_np, img2_np, win_size=win_size, sigma=sigma, full=True)

        return score