import os import re from functools import partial from io import BytesIO from pathlib import Path from typing import Any, cast import gradio as gr import segno from gradio.components import Component from huggingface_hub import InferenceClient from PIL import Image from qrcode_artistic import write_artistic from segno.consts import ERROR_MAPPING from myapp.colorutils import array_to_hex from myapp.palette import extract_color_clusters, sort_color_clusters from myapp.palette_demo import demo as palette_demo try: import dotenv dotenv.load_dotenv() except ImportError: pass debug = bool(os.getenv("DEBUG")) client = InferenceClient(model="black-forest-labs/FLUX.1-schnell") static_path = Path(__file__).parent.relative_to(Path.cwd()) / "static" MODELS = [ "stabilityai/stable-diffusion-3.5-large", "black-forest-labs/FLUX.1-schnell", ] gr.set_static_paths("static") with gr.Blocks() as demo: with gr.Row(): with gr.Column(): text = gr.Textbox( "https://wheelingvultures.bandcamp.com/album/ep", label="Text" ) prompt = gr.TextArea("A psychedelic vulture", label="Prompt") model = gr.Radio(MODELS, value=MODELS[0], label="Model") button = gr.Button("Generate") with gr.Column(): output = gr.Image() background = gr.Image( str(static_path / "example.webp"), visible=False, type="filepath" ) scale = gr.Slider(3, 15, 9, step=1, label="Scale") error = gr.Radio(list(ERROR_MAPPING), value="H", label="Error") with gr.Row(): color_dark = gr.ColorPicker("#000000", label="Dark") color_light = gr.ColorPicker("#FFFFFF", label="Light") with gr.Row(): extract_colors = gr.Button("Extract") gr.ClearButton([color_dark, color_light], value="Reset") def generate_background(data: dict[Component, Any]): if not data.get(prompt): return gr.skip(), gr.skip() return client.text_to_image(data[prompt], model=data[model]), None def generate_output(data: dict[Component, Any]): if data.get(background) is None: return None def to_hex_format(value: str): if value is None: return None if value.startswith("#"): return value matches = re.findall(r"\d+(?:\.\d+)?", value) r, g, b = map(int, map(float, matches[:3])) return f"#{r:02X}{g:02X}{b:02X}" image = Image.open(data[background]) qr_code = segno.make(data[text], error=data[error]) with BytesIO() as buffer: write_artistic( qr_code, target=buffer, background=image.filename, kind=image.format, scale=data[scale], light=to_hex_format(data[color_light]), dark=to_hex_format(data[color_dark]), quiet_zone=cast(Any, "#FFFFFF"), ) return Image.open(buffer) def generate_palette(data: dict[Component, Any]): if data[background] is None: return None, None image = Image.open(data[background]) k_means = extract_color_clusters(image, n_clusters=2) primary, secondary = map(array_to_hex, sort_color_clusters(k_means)) return primary, secondary gr.on( [button.click, prompt.submit], partial(gr.update, interactive=False), outputs=button, ).then( generate_background, inputs={prompt, model}, outputs=[background, output], ).then( partial(gr.update, interactive=True), outputs=button, ) gr.on( [ demo.load, text.submit, background.change, scale.change, error.change, color_light.change, color_dark.change, ], generate_output, inputs={ text, background, scale, error, color_light, color_dark, }, outputs=output, ) gr.on( [extract_colors.click], generate_palette, inputs={background}, outputs=[color_dark, color_light], ) if debug: with demo.route("Palette", "/palette"): palette_demo.render() if __name__ == "__main__": demo.launch(debug=debug)