import itertools import numpy as np from PIL import Image from sklearn.cluster import KMeans from myapp.colorutils import add_hsv_saturation, get_hsv_value def extract_color_clusters(image_array: np.ndarray | Image.Image, n_clusters=2): if not isinstance(image_array, np.ndarray): image_array = np.array(image_array) w, h, d = image_array.shape pixels = image_array.reshape(w * h, d) return KMeans(n_clusters=n_clusters).fit(pixels) def sort_color_clusters(k_means: KMeans): return sorted(k_means.cluster_centers_, key=get_hsv_value) def iter_color_shades(k_means: KMeans, shades: tuple[float, ...]): cluster_centers = sort_color_clusters(k_means) for delta, cluster_center in itertools.product(shades, cluster_centers): yield add_hsv_saturation(cluster_center, delta) def generate_palette_image(k_means: KMeans, size=40, shades=(0.0,)): num_cluster_centers = len(k_means.cluster_centers_) image = Image.new("RGB", (num_cluster_centers * size, len(shades) * size)) for i, color in enumerate(iter_color_shades(k_means, shades)): color = tuple(map(int, color)) part = Image.new("RGB", (size, size), color) position = (i % num_cluster_centers * size, i // num_cluster_centers * size) image.paste(part, position) return image