File size: 2,464 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import cv2
import numpy as np

from ..image_utils import MAX_VALUES_BY_DTYPE, as_3d


def _as_float32(image: np.ndarray) -> np.ndarray:
    if image.dtype == np.float32:
        return image
    max_value = MAX_VALUES_BY_DTYPE[image.dtype.name]
    return image.astype(np.float32) / max_value


def distinct_colors_palette(image: np.ndarray) -> np.ndarray:
    image = as_3d(image)
    return np.unique(image.reshape((-1, image.shape[2])), axis=0).reshape(
        (1, -1, image.shape[2])
    )


def kmeans_palette(image: np.ndarray, num_colors: int) -> np.ndarray:
    image = as_3d(image)
    flat_image = _as_float32(image.reshape((-1, image.shape[2])))

    max_iter = 10
    epsilon = 1.0
    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, max_iter, epsilon)

    attempts = 10
    cv2.setRNGSeed(0)
    _, _, center = cv2.kmeans(
        flat_image, num_colors, None, criteria, attempts, cv2.KMEANS_PP_CENTERS  # type: ignore
    )

    return center.reshape((1, -1, image.shape[2]))


class MedianCutBucket:
    def __init__(self, data: np.ndarray):
        self.data = data
        self.n_pixels, self.n_channels = data.shape
        self.min_values = np.min(data, axis=0)
        self.max_values = np.max(data, axis=0)
        self.channel_ranges = self.max_values - self.min_values
        self.biggest_range = np.max(self.channel_ranges)

    def split(self):
        widest_channel = np.argmax(self.channel_ranges)
        median = np.median(self.data[:, widest_channel])
        mask = self.data[:, widest_channel] > median
        if mask.sum() == 0:
            mean = np.mean(self.data[:, widest_channel])
            mask = self.data[:, widest_channel] > mean
        return MedianCutBucket(self.data[mask is True]), MedianCutBucket(
            self.data[mask is False]
        )

    def average(self):
        return np.mean(self.data, axis=0)


def median_cut_palette(image: np.ndarray, num_colors: int) -> np.ndarray:
    image = as_3d(image)
    flat_image = _as_float32(image.reshape((-1, image.shape[2])))

    buckets = [MedianCutBucket(flat_image)]
    while len(buckets) < num_colors:
        bucket_idx, bucket = max(enumerate(buckets), key=lambda x: x[1].biggest_range)
        if bucket.biggest_range == 0:
            break
        buckets.pop(bucket_idx)

        buckets.extend(bucket.split())

    return np.stack([bucket.average() for bucket in buckets], axis=0).reshape(
        (1, -1, image.shape[2])
    )