cbensimon's picture
cbensimon HF Staff
gr.Gallery()
48a26b4
raw
history blame
867 Bytes
from datetime import datetime
import gradio as gr
import spaces
import torch
from diffusers import FluxPipeline
from optimization import optimize_pipeline_
pipeline = FluxPipeline.from_pretrained('black-forest-labs/FLUX.1-dev', torch_dtype=torch.bfloat16).to('cuda')
optimize_pipeline_(pipeline, "prompt")
@spaces.GPU
def generate_image(prompt: str, progress=gr.Progress(track_tqdm=True)):
generator = torch.Generator(device='cuda').manual_seed(42)
t0 = datetime.now()
output = pipeline(
prompt=prompt,
num_inference_steps=28,
generator=generator,
)
return [(output.images[0], f'{(datetime.now() - t0).total_seconds():.2f}s')]
gr.Interface(
fn=generate_image,
inputs=gr.Text(label="Prompt"),
outputs=gr.Gallery(),
examples=["A cat playing with a ball of yarn"],
cache_examples=False,
).launch()