Spaces:
Running
Running
from starvector.metrics.base_metric import BaseMetric | |
from skimage.metrics import structural_similarity as ssim | |
import numpy as np | |
class SSIMDistanceCalculator(BaseMetric): | |
def __init__(self, config=None): | |
super().__init__() | |
self.class_name = self.__class__.__name__ | |
self.config = config | |
self.metric = self.compute_SSIM | |
def compute_SSIM(self, **kwargs): | |
image1 = kwargs.get('gt_im') | |
image2 = kwargs.get('gen_im') | |
win_size = kwargs.get('win_size', 11) # Increase win_size for more accuracy | |
channel_axis = kwargs.get('channel_axis', -1) # Default channel_axis to -1 | |
sigma = kwargs.get('sigma', 1.5) # Add sigma parameter for Gaussian filter | |
# Convert images to numpy arrays if they aren't already | |
img1_np = np.array(image1) | |
img2_np = np.array(image2) | |
# Check if images are grayscale or RGB | |
if len(img1_np.shape) == 3 and img1_np.shape[2] == 3: | |
# Compute SSIM for RGB images | |
score, _ = ssim(img1_np, img2_np, win_size=win_size, channel_axis=channel_axis, sigma=sigma, full=True) | |
else: | |
# Convert to grayscale if not already | |
if len(img1_np.shape) == 3: | |
img1_np = np.mean(img1_np, axis=2) | |
img2_np = np.mean(img2_np, axis=2) | |
score, _ = ssim(img1_np, img2_np, win_size=win_size, sigma=sigma, full=True) | |
return score |