m3g4p0p commited on
Commit
1ea41a7
·
1 Parent(s): eecf3d6

add array to hex

Browse files
Files changed (2) hide show
  1. myapp/palette.py +10 -3
  2. myapp/pall_app.py +8 -9
myapp/palette.py CHANGED
@@ -22,9 +22,16 @@ def generate_palette_image(model: KMeans, size=40):
22
  return image
23
 
24
 
25
- def extract_color_clusters(image_array: np.ndarray):
 
 
 
26
  w, h, d = image_array.shape
27
  pixels = image_array.reshape(w * h, d)
28
- model = KMeans(n_clusters=4).fit(pixels)
29
 
30
- return model
 
 
 
 
 
 
22
  return image
23
 
24
 
25
+ def extract_color_clusters(image_array: np.ndarray | Image.Image, n_clusters=2):
26
+ if not isinstance(image_array, np.ndarray):
27
+ image_array = np.array(image_array)
28
+
29
  w, h, d = image_array.shape
30
  pixels = image_array.reshape(w * h, d)
 
31
 
32
+ return KMeans(n_clusters=n_clusters).fit(pixels)
33
+
34
+
35
+ def array_to_hex(values: np.ndarray):
36
+ values = np.round(values).astype(int)
37
+ return "#" + ("{:02X}" * len(values)).format(*values)
myapp/pall_app.py CHANGED
@@ -1,20 +1,19 @@
1
  import gradio as gr
2
  import numpy as np
3
- from sklearn.cluster import KMeans
 
4
 
5
  with gr.Blocks() as demo:
6
  image = gr.Image("vulture.webp")
7
- n_colors = gr.Slider(1, 16, 4)
8
  button = gr.Button()
9
- plot = gr.Plot()
10
 
11
- @button.click(inputs=image, outputs=gr.JSON())
12
- def get_palette(image_array: np.ndarray):
13
- w, h, d = image_array.shape
14
- pixels = image_array.reshape(w * h, d)
15
- model = KMeans(n_clusters=4).fit(pixels)
16
 
17
- return model.cluster_centers_
 
18
 
19
 
20
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import numpy as np
3
+
4
+ from myapp.palette import array_to_hex, extract_color_clusters
5
 
6
  with gr.Blocks() as demo:
7
  image = gr.Image("vulture.webp")
8
+ n_colors = gr.Slider(1, 16, 4, step=1)
9
  button = gr.Button()
 
10
 
11
+ @gr.render(inputs=[image, n_colors])
12
+ def render_palette(image_array: np.ndarray, n_clusers: int):
13
+ model = extract_color_clusters(image_array, n_clusers)
 
 
14
 
15
+ for cluster in model.cluster_centers_:
16
+ gr.ColorPicker(array_to_hex(cluster))
17
 
18
 
19
  if __name__ == "__main__":