cbensimon's picture
cbensimon HF Staff
Update app.py
7359b93 verified
raw
history blame
974 Bytes
import os
os.system('pip install --upgrade spaces')
from datetime import datetime
import gradio as gr
import spaces
import torch
from diffusers import FluxPipeline
import fa3
from aoti import aoti_load
pipeline = FluxPipeline.from_pretrained('black-forest-labs/FLUX.1-dev', torch_dtype=torch.bfloat16).to('cuda')
pipeline.transformer.fuse_qkv_projections()
aoti_load(pipeline.transformer, 'zerogpu-aoti/FLUX.1')
@spaces.GPU
def generate_image(prompt: str, progress=gr.Progress(track_tqdm=True)):
generator = torch.Generator(device='cuda').manual_seed(42)
t0 = datetime.now()
output = pipeline(
prompt=prompt,
num_inference_steps=28,
generator=generator,
)
return [(output.images[0], f'{(datetime.now() - t0).total_seconds():.2f}s')]
gr.Interface(
fn=generate_image,
inputs=gr.Text(label="Prompt"),
outputs=gr.Gallery(),
examples=["A cat playing with a ball of yarn"],
cache_examples=False,
).launch()