qr-code / myapp /palette.py
m3g4p0p's picture
add colorutils module
ec7ee9c
raw
history blame
1.33 kB
import itertools
import numpy as np
from PIL import Image
from sklearn.cluster import KMeans
from myapp.colorutils import add_hsv_saturation, get_hsv_value
def extract_color_clusters(image_array: np.ndarray | Image.Image, n_clusters=2):
if not isinstance(image_array, np.ndarray):
image_array = np.array(image_array)
w, h, d = image_array.shape
pixels = image_array.reshape(w * h, d)
return KMeans(n_clusters=n_clusters).fit(pixels)
def sort_color_clusters(k_means: KMeans):
return sorted(k_means.cluster_centers_, key=get_hsv_value)
def iter_color_shades(k_means: KMeans, shades: tuple[float, ...]):
cluster_centers = sort_color_clusters(k_means)
for delta, cluster_center in itertools.product(shades, cluster_centers):
yield add_hsv_saturation(cluster_center, delta)
def generate_palette_image(k_means: KMeans, size=40, shades=(0.0,)):
num_cluster_centers = len(k_means.cluster_centers_)
image = Image.new("RGB", (num_cluster_centers * size, len(shades) * size))
for i, color in enumerate(iter_color_shades(k_means, shades)):
color = tuple(map(int, color))
part = Image.new("RGB", (size, size), color)
position = (i % num_cluster_centers * size, i // num_cluster_centers * size)
image.paste(part, position)
return image