import itertools

import numpy as np
from PIL import Image
from sklearn.cluster import KMeans

from myapp.colorutils import add_hsv_saturation, get_hsv_value


def extract_color_clusters(image_array: np.ndarray | Image.Image, n_clusters=2):
    if not isinstance(image_array, np.ndarray):
        image_array = np.array(image_array)

    w, h, d = image_array.shape
    pixels = image_array.reshape(w * h, d)

    return KMeans(n_clusters=n_clusters).fit(pixels)


def sort_color_clusters(k_means: KMeans):
    return sorted(k_means.cluster_centers_, key=get_hsv_value)


def iter_color_shades(k_means: KMeans, shades: tuple[float, ...]):
    cluster_centers = sort_color_clusters(k_means)

    for delta, cluster_center in itertools.product(shades, cluster_centers):
        yield add_hsv_saturation(cluster_center, delta)


def generate_palette_image(k_means: KMeans, size=40, shades=(0.0,)):
    num_cluster_centers = len(k_means.cluster_centers_)
    image = Image.new("RGB", (num_cluster_centers * size, len(shades) * size))

    for i, color in enumerate(iter_color_shades(k_means, shades)):
        color = tuple(map(int, color))
        part = Image.new("RGB", (size, size), color)
        position = (i % num_cluster_centers * size, i // num_cluster_centers * size)
        image.paste(part, position)

    return image