import gradio as gr import numpy as np import random import spaces import torch from diffusers import FluxPipeline dtype = torch.bfloat16 device = "cuda" if torch.cuda.is_available() else "cpu" pipe = FluxPipeline.from_pretrained("sayakpaul/FLUX.1-merged", torch_dtype=torch.bfloat16).to(device) MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 2048 def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=8, output_format="png"): if randomize_seed: seed = random.randint(0, MAX_SEED) if width*height*num_inference_steps <= 1024*1024*8: return infer_in_1min(prompt=prompt, seed=seed, randomize_seed=randomize_seed, width=width, height=height, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, output_format=output_format) else: return infer_in_5min(prompt=prompt, seed=seed, randomize_seed=randomize_seed, width=width, height=height, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, output_format=output_format) @spaces.GPU(duration=60) def infer_in_1min(prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, output_format): return infer_on_gpu(prompt=prompt, seed=seed, randomize_seed=randomize_seed, width=width, height=height, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, output_format=output_format) @spaces.GPU(duration=300) def infer_in_5min(prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, output_format): return infer_on_gpu(prompt=prompt, seed=seed, randomize_seed=randomize_seed, width=width, height=height, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, output_format=output_format) def infer_on_gpu(prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, output_format, progress=gr.Progress(track_tqdm=True)): generator = torch.Generator().manual_seed(seed) image = pipe( prompt = prompt, width = width, height = height, num_inference_steps = num_inference_steps, generator = generator, guidance_scale=guidance_scale ).images[0] return gr.update(format = output_format, value = image), seed examples = [ "a tiny astronaut hatching from an egg on the moon", "a cat holding a sign that says hello world", "an anime illustration of a wiener schnitzel", ] with gr.Blocks(delete_cache=(4000, 4000)) as demo: with gr.Column(elem_id="col-container"): gr.Markdown(f"""# [FLUX.1 [merged]](https://huggingface.co/sayakpaul/FLUX.1-merged) Merge by [Sayak Paul](https://huggingface.co/sayakpaul) of 2 of the 12B param rectified flow transformers [FLUX.1 [dev]](https://huggingface.co/black-forest-labs/FLUX.1-dev) and [FLUX.1 [schnell]](https://huggingface.co/black-forest-labs/FLUX.1-schnell) by [Black Forest Labs](https://blackforestlabs.ai/) """) prompt = gr.Text( label = "Prompt", show_label = False, lines = 2, autofocus = True, placeholder = "Enter your prompt", container = False ) output_format = gr.Radio([["*.png", "png"], ["*.webp", "webp"], ["*.jpeg", "jpeg"], ["*.gif", "gif"], ["*.bmp", "bmp"]], label="Image format for result", info="File extention", value="png", interactive=True) with gr.Accordion("Advanced Settings", open=False): with gr.Row(): width = gr.Slider( label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024, ) height = gr.Slider( label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024, ) num_inference_steps = gr.Slider( label="Number of inference steps", minimum=1, maximum=50, step=1, value=4, ) guidance_scale = gr.Slider( label="Guidance Scale", minimum=1, maximum=15, step=0.1, value=3.5, ) seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) run_button = gr.Button(value = "🚀 Generate", variant="primary") result = gr.Image(label="Result", show_label=False, format="png") gr.Examples( examples = examples, fn = infer, inputs = [prompt], outputs = [result, seed], cache_examples="lazy" ) gr.on( triggers=[run_button.click, prompt.submit], fn = infer, inputs = [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, output_format], outputs = [result, seed] ) demo.queue(default_concurrency_limit=2).launch(show_error=True)