File size: 1,329 Bytes
dbd3377
 
eecf3d6
 
 
 
ec7ee9c
eecf3d6
 
1ea41a7
 
 
 
eecf3d6
 
 
1ea41a7
 
 
6c17cac
 
 
 
dbd3377
6c17cac
dbd3377
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import itertools

import numpy as np
from PIL import Image
from sklearn.cluster import KMeans

from myapp.colorutils import add_hsv_saturation, get_hsv_value


def extract_color_clusters(image_array: np.ndarray | Image.Image, n_clusters=2):
    if not isinstance(image_array, np.ndarray):
        image_array = np.array(image_array)

    w, h, d = image_array.shape
    pixels = image_array.reshape(w * h, d)

    return KMeans(n_clusters=n_clusters).fit(pixels)


def sort_color_clusters(k_means: KMeans):
    return sorted(k_means.cluster_centers_, key=get_hsv_value)


def iter_color_shades(k_means: KMeans, shades: tuple[float, ...]):
    cluster_centers = sort_color_clusters(k_means)

    for delta, cluster_center in itertools.product(shades, cluster_centers):
        yield add_hsv_saturation(cluster_center, delta)


def generate_palette_image(k_means: KMeans, size=40, shades=(0.0,)):
    num_cluster_centers = len(k_means.cluster_centers_)
    image = Image.new("RGB", (num_cluster_centers * size, len(shades) * size))

    for i, color in enumerate(iter_color_shades(k_means, shades)):
        color = tuple(map(int, color))
        part = Image.new("RGB", (size, size), color)
        position = (i % num_cluster_centers * size, i // num_cluster_centers * size)
        image.paste(part, position)

    return image