import cv2 import numpy as np from ..image_utils import MAX_VALUES_BY_DTYPE, as_3d def _as_float32(image: np.ndarray) -> np.ndarray: if image.dtype == np.float32: return image max_value = MAX_VALUES_BY_DTYPE[image.dtype.name] return image.astype(np.float32) / max_value def distinct_colors_palette(image: np.ndarray) -> np.ndarray: image = as_3d(image) return np.unique(image.reshape((-1, image.shape[2])), axis=0).reshape( (1, -1, image.shape[2]) ) def kmeans_palette(image: np.ndarray, num_colors: int) -> np.ndarray: image = as_3d(image) flat_image = _as_float32(image.reshape((-1, image.shape[2]))) max_iter = 10 epsilon = 1.0 criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, max_iter, epsilon) attempts = 10 cv2.setRNGSeed(0) _, _, center = cv2.kmeans( flat_image, num_colors, None, criteria, attempts, cv2.KMEANS_PP_CENTERS # type: ignore ) return center.reshape((1, -1, image.shape[2])) class MedianCutBucket: def __init__(self, data: np.ndarray): self.data = data self.n_pixels, self.n_channels = data.shape self.min_values = np.min(data, axis=0) self.max_values = np.max(data, axis=0) self.channel_ranges = self.max_values - self.min_values self.biggest_range = np.max(self.channel_ranges) def split(self): widest_channel = np.argmax(self.channel_ranges) median = np.median(self.data[:, widest_channel]) mask = self.data[:, widest_channel] > median if mask.sum() == 0: mean = np.mean(self.data[:, widest_channel]) mask = self.data[:, widest_channel] > mean return MedianCutBucket(self.data[mask is True]), MedianCutBucket( self.data[mask is False] ) def average(self): return np.mean(self.data, axis=0) def median_cut_palette(image: np.ndarray, num_colors: int) -> np.ndarray: image = as_3d(image) flat_image = _as_float32(image.reshape((-1, image.shape[2]))) buckets = [MedianCutBucket(flat_image)] while len(buckets) < num_colors: bucket_idx, bucket = max(enumerate(buckets), key=lambda x: x[1].biggest_range) if bucket.biggest_range == 0: break buckets.pop(bucket_idx) buckets.extend(bucket.split()) return np.stack([bucket.average() for bucket in buckets], axis=0).reshape( (1, -1, image.shape[2]) )