import os import re from functools import partial from io import BytesIO from pathlib import Path from typing import Any, cast import gradio as gr import segno from gradio.components import Component from huggingface_hub import InferenceClient from PIL import Image from qrcode_artistic import write_artistic from segno.consts import ERROR_MAPPING from myapp.colorutils import array_to_hex from myapp.palette import extract_color_clusters, sort_color_clusters from myapp.palette_demo import demo as palette_demo try: import dotenv dotenv.load_dotenv() except ImportError: pass debug = bool(os.getenv("DEBUG")) client = InferenceClient(model="black-forest-labs/FLUX.1-schnell") static_path = Path(__file__).parent.relative_to(Path.cwd()) / "static" MODELS = [ "stabilityai/stable-diffusion-3.5-large", "black-forest-labs/FLUX.1-schnell", ] gr.set_static_paths("static") with gr.Blocks() as demo: with gr.Row(): with gr.Column(): text = gr.Textbox( "https://wheelingvultures.bandcamp.com/album/ep", label="Text" ) prompt = gr.TextArea("A psychedelic vulture", label="Prompt") model = gr.Radio(MODELS, value=MODELS[0], label="Model") generate_button = gr.Button("Generate") upload_button = gr.UploadButton( "Upload", file_types=["image"], type="filepath" ) with gr.Column(): output = gr.Image() background = gr.Image( str(static_path / "example.webp"), visible=False, type="filepath" ) scale = gr.Slider(3, 15, 9, step=1, label="Scale") error = gr.Radio(list(ERROR_MAPPING), value="H", label="Error") with gr.Row(): color_dark = gr.ColorPicker("#000000", label="Dark") color_light = gr.ColorPicker("#FFFFFF", label="Light") with gr.Row(): extract_colors = gr.Button("Extract") gr.ClearButton([color_dark, color_light], value="Reset") share_link = gr.DeepLinkButton() def generate_background(data: dict[Component, Any]): if not data.get(prompt): return gr.skip(), gr.skip() return client.text_to_image(data[prompt], model=data[model]), None def upload_background(data: dict[Component, Any]): return Image.open(data[upload_button]) def generate_output(data: dict[Component, Any]): if data.get(background) is None: return None def to_hex_format(value: str): if value is None: return None if value.startswith("#"): return value matches = re.findall(r"\d+(?:\.\d+)?", value) r, g, b = map(int, map(float, matches[:3])) return f"#{r:02X}{g:02X}{b:02X}" image = Image.open(data[background]) qr_code = segno.make(data[text], error=data[error]) with BytesIO() as buffer: write_artistic( qr_code, target=buffer, background=image.filename, kind=image.format, scale=data[scale], light=to_hex_format(data[color_light]), dark=to_hex_format(data[color_dark]), quiet_zone=cast(Any, "#FFFFFF"), ) return Image.open(buffer) def generate_palette(data: dict[Component, Any]): if data[background] is None: return None, None image = Image.open(data[background]) k_means = extract_color_clusters(image, n_clusters=2) primary, secondary = map(array_to_hex, sort_color_clusters(k_means)) return primary, secondary gr.on( [upload_button.upload], upload_background, inputs={upload_button}, outputs=background, show_progress_on=output, ) gr.on( [generate_button.click, prompt.submit], partial(gr.update, interactive=False), outputs=generate_button, ).then( generate_background, inputs={prompt, model}, outputs=[background, output], ).then( partial(gr.update, interactive=True), outputs=generate_button, ) gr.on( [ demo.load, text.submit, background.change, scale.change, error.change, color_light.change, color_dark.change, ], generate_output, inputs={ text, background, scale, error, color_light, color_dark, }, outputs=output, ) gr.on( [extract_colors.click], generate_palette, inputs={background}, outputs=[color_dark, color_light], ) with demo.route("Counter", "/counter"): number = gr.Number(0, label="Number") increment = gr.Button("Increment") decrement = gr.Button("Decrement") button = gr.DeepLinkButton() def increment_number(data): return data + 1 def decrement_number(data): return data - 1 gr.on(increment.click, increment_number, inputs=[number], outputs=[number]) gr.on(decrement.click, decrement_number, inputs=[number], outputs=[number]) with demo.route("Chat", "/chat"): def slow_echo(message, history): for i in range(len(message['text'])): yield f"You typed: " + message['text'][: i + 1] chat = gr.ChatInterface(slow_echo, multimodal=True, type="messages") deep_link = gr.DeepLinkButton() with demo.route("Render", "/render"): input_text = gr.Textbox(label="input") @gr.render(inputs=input_text) def show_split(text): if len(text) == 0: gr.Markdown("## No Input Provided") else: for letter in text: gr.Textbox(letter) deep_link = gr.DeepLinkButton() if __name__ == "__main__": demo.launch(debug=debug)