Spaces:
Running
Running
import itertools | |
import numpy as np | |
from PIL import Image | |
from sklearn.cluster import KMeans | |
from myapp.colorutils import add_hsv_saturation, get_hsv_value | |
def extract_color_clusters(image_array: np.ndarray | Image.Image, n_clusters=2): | |
if not isinstance(image_array, np.ndarray): | |
image_array = np.array(image_array) | |
w, h, d = image_array.shape | |
pixels = image_array.reshape(w * h, d) | |
return KMeans(n_clusters=n_clusters).fit(pixels) | |
def sort_color_clusters(k_means: KMeans): | |
return sorted(k_means.cluster_centers_, key=get_hsv_value) | |
def iter_color_shades(k_means: KMeans, shades: tuple[float, ...]): | |
cluster_centers = sort_color_clusters(k_means) | |
for delta, cluster_center in itertools.product(shades, cluster_centers): | |
yield add_hsv_saturation(cluster_center, delta) | |
def generate_palette_image(k_means: KMeans, size=40, shades=(0.0,)): | |
num_cluster_centers = len(k_means.cluster_centers_) | |
image = Image.new("RGB", (num_cluster_centers * size, len(shades) * size)) | |
for i, color in enumerate(iter_color_shades(k_means, shades)): | |
color = tuple(map(int, color)) | |
part = Image.new("RGB", (size, size), color) | |
position = (i % num_cluster_centers * size, i // num_cluster_centers * size) | |
image.paste(part, position) | |
return image | |