from datetime import datetime import gradio as gr import spaces import torch from diffusers import FluxPipeline from optimization import optimize_pipeline_ pipeline = FluxPipeline.from_pretrained('black-forest-labs/FLUX.1-schnell', torch_dtype=torch.bfloat16).to('cuda') optimize_pipeline_(pipeline, "prompt") @spaces.GPU def generate_image(prompt: str): generator = torch.Generator(device='cuda').manual_seed(42) t0 = datetime.now() images = [] for _ in range(9): image = pipeline(prompt, num_inference_steps=4, generator=generator).images[0] elapsed = -(t0 - (t0 := datetime.now())) images += [(image, f'{elapsed.total_seconds():.2f}s')] yield images gr.Interface( fn=generate_image, inputs=gr.Text(label="Prompt"), outputs=gr.Gallery(rows=3, columns=3, height='60vh'), examples=["A cat playing with a ball of yarn"], cache_examples=False, ).launch()