m3g4p0p commited on
Commit
dbd3377
·
1 Parent(s): 16c06bf

add generate palette command

Browse files
Files changed (2) hide show
  1. myapp/cli.py +8 -2
  2. myapp/palette.py +29 -14
myapp/cli.py CHANGED
@@ -2,6 +2,7 @@ import click
2
  import dotenv
3
  from huggingface_hub import InferenceClient
4
 
 
5
  from myapp.types import ImageParamType
6
 
7
  dotenv.load_dotenv()
@@ -32,5 +33,10 @@ def generate_image(prompt, target, model, width, height):
32
 
33
  @cli.command()
34
  @click.option("--image", type=ImageParamType(), required=True)
35
- def generate_palette(image):
36
- pass
 
 
 
 
 
 
2
  import dotenv
3
  from huggingface_hub import InferenceClient
4
 
5
+ from myapp.palette import extract_color_clusters, generate_palette_image
6
  from myapp.types import ImageParamType
7
 
8
  dotenv.load_dotenv()
 
33
 
34
  @cli.command()
35
  @click.option("--image", type=ImageParamType(), required=True)
36
+ @click.option("--target", type=click.Path(dir_okay=False), required=True)
37
+ @click.option("--n-colors", default=4)
38
+ @click.option("--shade", "shades", default=(0.0,), multiple=True)
39
+ def generate_palette(image, target, n_colors, shades):
40
+ k_means = extract_color_clusters(image, n_colors)
41
+ palette = generate_palette_image(k_means, shades=shades)
42
+ palette.save(target)
myapp/palette.py CHANGED
@@ -1,25 +1,20 @@
 
 
 
1
  import numpy as np
2
  from PIL import Image
3
  from sklearn.cluster import KMeans
4
 
5
 
6
- def join_images(a: Image.Image, b: Image.Image):
7
- result = Image.new(a.mode, (a.width + b.width, max(a.height, b.height)))
8
- result.paste(a)
9
- result.paste(b, (a.width, 0))
10
-
11
- return result
12
-
13
 
14
- def generate_palette_image(model: KMeans, size=40):
15
- image = Image.new("RGB", (0, size))
16
 
17
- for cluster_center in model.cluster_centers_:
18
- color = tuple(map(int, cluster_center))
19
- part = Image.new("RGB", (40, 40), color)
20
- image = join_images(image, part)
21
 
22
- return image
23
 
24
 
25
  def extract_color_clusters(image_array: np.ndarray | Image.Image, n_clusters=2):
@@ -32,6 +27,26 @@ def extract_color_clusters(image_array: np.ndarray | Image.Image, n_clusters=2):
32
  return KMeans(n_clusters=n_clusters).fit(pixels)
33
 
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def array_to_hex(values: np.ndarray):
36
  values = np.round(values).astype(int)
37
  return "#" + ("{:02X}" * len(values)).format(*values)
 
1
+ import colorsys
2
+ import itertools
3
+
4
  import numpy as np
5
  from PIL import Image
6
  from sklearn.cluster import KMeans
7
 
8
 
9
+ def get_hsv_value(cluster: np.ndarray):
10
+ return colorsys.rgb_to_hsv(*cluster / 255)[2]
 
 
 
 
 
11
 
 
 
12
 
13
+ def add_hsv_saturation(cluster: np.ndarray, delta: float):
14
+ h, s, v = colorsys.rgb_to_hsv(*cluster / 255)
15
+ s = max(0, min(1, s + delta))
 
16
 
17
+ return np.array(colorsys.hsv_to_rgb(h, s, v)) * 255
18
 
19
 
20
  def extract_color_clusters(image_array: np.ndarray | Image.Image, n_clusters=2):
 
27
  return KMeans(n_clusters=n_clusters).fit(pixels)
28
 
29
 
30
+ def iter_color_shades(k_means: KMeans, shades: tuple[float, ...]):
31
+ cluster_centers = sorted(k_means.cluster_centers_, key=get_hsv_value)
32
+
33
+ for delta, cluster_center in itertools.product(shades, cluster_centers):
34
+ yield add_hsv_saturation(cluster_center, delta)
35
+
36
+
37
+ def generate_palette_image(k_means: KMeans, size=40, shades=(0.0,)):
38
+ num_cluster_centers = len(k_means.cluster_centers_)
39
+ image = Image.new("RGB", (num_cluster_centers * size, len(shades) * size))
40
+
41
+ for i, color in enumerate(iter_color_shades(k_means, shades)):
42
+ color = tuple(map(int, color))
43
+ part = Image.new("RGB", (size, size), color)
44
+ position = (i % num_cluster_centers * size, i // num_cluster_centers * size)
45
+ image.paste(part, position)
46
+
47
+ return image
48
+
49
+
50
  def array_to_hex(values: np.ndarray):
51
  values = np.round(values).astype(int)
52
  return "#" + ("{:02X}" * len(values)).format(*values)