from starvector.metrics.base_metric import BaseMetric from skimage.metrics import structural_similarity as ssim import numpy as np class SSIMDistanceCalculator(BaseMetric): def __init__(self, config=None): super().__init__() self.class_name = self.__class__.__name__ self.config = config self.metric = self.compute_SSIM def compute_SSIM(self, **kwargs): image1 = kwargs.get('gt_im') image2 = kwargs.get('gen_im') win_size = kwargs.get('win_size', 11) # Increase win_size for more accuracy channel_axis = kwargs.get('channel_axis', -1) # Default channel_axis to -1 sigma = kwargs.get('sigma', 1.5) # Add sigma parameter for Gaussian filter # Convert images to numpy arrays if they aren't already img1_np = np.array(image1) img2_np = np.array(image2) # Check if images are grayscale or RGB if len(img1_np.shape) == 3 and img1_np.shape[2] == 3: # Compute SSIM for RGB images score, _ = ssim(img1_np, img2_np, win_size=win_size, channel_axis=channel_axis, sigma=sigma, full=True) else: # Convert to grayscale if not already if len(img1_np.shape) == 3: img1_np = np.mean(img1_np, axis=2) img2_np = np.mean(img2_np, axis=2) score, _ = ssim(img1_np, img2_np, win_size=win_size, sigma=sigma, full=True) return score