This checkkpoint is compiled by ByteDance/SDXL-Lightning for AWS Inf2.

Compilation

Download the unet checkpoint from ByteDance/SDXL-Lightning and replace the unet checkpoint under the original sdxl checkpoint:

from huggingface_hub import hf_hub_download

repo = "ByteDance/SDXL-Lightning"
ckpt = "sdxl_lightning_4step_unet.safetensors"
hf_hub_download(repo, ckpt)

Replace the unet:

cp /home/ubuntu/.cache/huggingface/hub/models--ByteDance--SDXL-Lightning/snapshots/xxxxxx/sdxl_lightning_4step_unet.safetensors stable-diffusion-xl-lightning/unet/diffusion_pytorch_model.safetensors

Compile the whole pipeline:

from optimum.neuron import NeuronStableDiffusionXLPipeline

model_id = "stable-diffusion-xl-lightning"
num_images_per_prompt = 1
input_shapes = {"batch_size": 1, "height": 1024, "width": 1024, "num_images_per_prompt": num_images_per_prompt}
compiler_args = {"auto_cast": "matmul", "auto_cast_type": "bf16"}

stable_diffusion = NeuronStableDiffusionXLPipeline.from_pretrained(
    model_id, export=True, **compiler_args, **input_shapes
)
save_directory = "sdxl_lightning_4_steps_neuronx/"
stable_diffusion.save_pretrained(save_directory)
# push to hub

Inference

from optimum.neuron import NeuronStableDiffusionXLPipeline
from diffusers import EulerDiscreteScheduler

pipe = NeuronStableDiffusionXLPipeline.from_pretrained("aws-neuron/SDXL-Lightning-4steps-neuronx")
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe("A girl smiling", num_inference_steps=4, guidance_scale=0).images[0].save("output.png")
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.