File size: 4,526 Bytes
796128b
86fc953
ced082c
9c9e714
d604745
7a8f70c
cfd74ee
3d78d34
cfd74ee
b5594eb
cfd74ee
 
 
23b7e6c
cfd74ee
ec7ee9c
 
796128b
cfd74ee
0894424
 
cfd74ee
0894424
 
 
cfd74ee
796128b
20654e3
4e87ee2
cfd74ee
fb1a543
 
 
 
 
d604745
 
cfd74ee
 
 
 
 
 
ced082c
fb1a543
20654e3
86fc953
cfd74ee
86fc953
4e87ee2
 
 
20654e3
23b7e6c
86fc953
 
 
 
cfd74ee
29a8662
 
23b7e6c
29a8662
fb1a543
86fc953
ff5b825
 
fb1a543
cfd74ee
fb1a543
ced082c
cfd74ee
3d78d34
86fc953
7a8f70c
 
 
86fc953
 
 
 
 
 
2ca6236
86fc953
9c9e714
23b7e6c
3d78d34
9c9e714
cfd74ee
 
9c9e714
2ca6236
9c9e714
b5594eb
86fc953
 
7a8f70c
cfd74ee
3d78d34
9c9e714
3d78d34
7a8f70c
 
 
 
 
 
6c17cac
7a8f70c
 
 
fb1a543
30b7b7a
fb1a543
 
 
 
 
 
 
 
 
 
 
 
86fc953
2ca6236
86fc953
 
 
23b7e6c
86fc953
 
 
fb1a543
86fc953
 
 
 
23b7e6c
86fc953
 
 
fb1a543
 
 
7a8f70c
29a8662
 
7a8f70c
 
 
 
796128b
 
 
 
3d78d34
 
796128b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import os
import re
from functools import partial
from io import BytesIO
from pathlib import Path
from typing import Any, cast

import gradio as gr
import segno
from gradio.components import Component
from huggingface_hub import InferenceClient
from PIL import Image
from qrcode_artistic import write_artistic
from segno.consts import ERROR_MAPPING

from myapp.colorutils import array_to_hex
from myapp.palette import extract_color_clusters, sort_color_clusters
from myapp.palette_demo import demo as palette_demo

try:
    import dotenv

    dotenv.load_dotenv()
except ImportError:
    pass

debug = bool(os.getenv("DEBUG"))
client = InferenceClient(model="black-forest-labs/FLUX.1-schnell")
static_path = Path(__file__).parent.relative_to(Path.cwd()) / "static"

MODELS = [
    "stabilityai/stable-diffusion-3.5-large",
    "black-forest-labs/FLUX.1-schnell",
]

gr.set_static_paths("static")

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            text = gr.Textbox(
                "https://wheelingvultures.bandcamp.com/album/ep", label="Text"
            )
            prompt = gr.TextArea("A psychedelic vulture", label="Prompt")
            model = gr.Radio(MODELS, value=MODELS[0], label="Model")
            button = gr.Button("Generate")

        with gr.Column():
            output = gr.Image()
            background = gr.Image(
                str(static_path / "example.webp"), visible=False, type="filepath"
            )
            scale = gr.Slider(3, 15, 9, step=1, label="Scale")
            error = gr.Radio(list(ERROR_MAPPING), value="H", label="Error")

            with gr.Row():
                color_dark = gr.ColorPicker("#000000", label="Dark")
                color_light = gr.ColorPicker("#FFFFFF", label="Light")

            with gr.Row():
                extract_colors = gr.Button("Extract")
                gr.ClearButton([color_dark, color_light], value="Reset")

    def generate_background(data: dict[Component, Any]):
        if not data.get(prompt):
            return gr.skip(), gr.skip()

        return client.text_to_image(data[prompt], model=data[model]), None

    def generate_output(data: dict[Component, Any]):
        if data.get(background) is None:
            return None

        def to_hex_format(value: str):
            if value is None:
                return None

            if value.startswith("#"):
                return value

            matches = re.findall(r"\d+(?:\.\d+)?", value)
            r, g, b = map(int, map(float, matches[:3]))

            return f"#{r:02X}{g:02X}{b:02X}"

        image = Image.open(data[background])
        qr_code = segno.make(data[text], error=data[error])

        with BytesIO() as buffer:
            write_artistic(
                qr_code,
                target=buffer,
                background=image.filename,
                kind=image.format,
                scale=data[scale],
                light=to_hex_format(data[color_light]),
                dark=to_hex_format(data[color_dark]),
                quiet_zone=cast(Any, "#FFFFFF"),
            )

            return Image.open(buffer)

    def generate_palette(data: dict[Component, Any]):
        if data[background] is None:
            return None, None

        image = Image.open(data[background])
        k_means = extract_color_clusters(image, n_clusters=2)
        primary, secondary = map(array_to_hex, sort_color_clusters(k_means))

        return primary, secondary

    gr.on(
        [button.click, prompt.submit],
        partial(gr.update, interactive=False),
        outputs=button,
    ).then(
        generate_background,
        inputs={prompt, model},
        outputs=[background, output],
    ).then(
        partial(gr.update, interactive=True),
        outputs=button,
    )

    gr.on(
        [
            demo.load,
            text.submit,
            background.change,
            scale.change,
            error.change,
            color_light.change,
            color_dark.change,
        ],
        generate_output,
        inputs={
            text,
            background,
            scale,
            error,
            color_light,
            color_dark,
        },
        outputs=output,
    )

    gr.on(
        [extract_colors.click],
        generate_palette,
        inputs={background},
        outputs=[color_dark, color_light],
    )

if debug:
    with demo.route("Palette", "/palette"):
        palette_demo.render()


if __name__ == "__main__":
    demo.launch(debug=debug)