hz2475's picture
init
72f684c
from starvector.metrics.base_metric import BaseMetric
from skimage.metrics import structural_similarity as ssim
import numpy as np
class SSIMDistanceCalculator(BaseMetric):
def __init__(self, config=None):
super().__init__()
self.class_name = self.__class__.__name__
self.config = config
self.metric = self.compute_SSIM
def compute_SSIM(self, **kwargs):
image1 = kwargs.get('gt_im')
image2 = kwargs.get('gen_im')
win_size = kwargs.get('win_size', 11) # Increase win_size for more accuracy
channel_axis = kwargs.get('channel_axis', -1) # Default channel_axis to -1
sigma = kwargs.get('sigma', 1.5) # Add sigma parameter for Gaussian filter
# Convert images to numpy arrays if they aren't already
img1_np = np.array(image1)
img2_np = np.array(image2)
# Check if images are grayscale or RGB
if len(img1_np.shape) == 3 and img1_np.shape[2] == 3:
# Compute SSIM for RGB images
score, _ = ssim(img1_np, img2_np, win_size=win_size, channel_axis=channel_axis, sigma=sigma, full=True)
else:
# Convert to grayscale if not already
if len(img1_np.shape) == 3:
img1_np = np.mean(img1_np, axis=2)
img2_np = np.mean(img2_np, axis=2)
score, _ = ssim(img1_np, img2_np, win_size=win_size, sigma=sigma, full=True)
return score