Learning to Discretize Denoising Diffusion ODEs

πŸ† ICLR2025 Oral

Paper on OpenReview

Implementation of LD3, a lightweight framework designed to learn the optimal time discretization for sampling from pre-trained Diffusion Probabilistic Models (DPMs). LD3 can be combined with various samplers and consistently improves generation quality without having to retrain resource-intensive neural networks. LD3 offers an efficient approach to sampling from pre-trained diffusion models.

Alt Text

πŸ”₯ Latest News

  • March 2025: We have successfully applied LD3 to the Flux-dev model and observed promising results.
  • We are releasing the trained time steps for the Flux model soon! Stay tuned for updates.

Setup Environment

We will set up the environment using Anaconda.

conda env create -f requirements.yml
conda activate ld3
pip install -e ./src/clip/
pip install -e ./src/taming-transformers/
pip install omegaconf
pip install PyYAML
pip install requests
pip install scipy
pip install torchmetrics

Download Pretrained Models and FID Reference Sets

All necessary data will be automatically downloaded by the script. Note that this process may take some time. If you wish to skip certain downloads, you can comment out the corresponding lines in the script.

bash scripts/download_model.sh
wget https://raw.githubusercontent.com/tylin/coco-caption/master/annotations/captions_val2014.json

πŸš€ Generating Training Data for LD3

Before training LD3, we first need to generate training data using the teacher solver. The script gen_data.py handles this process. Below is an example of generating training data with 20 sampling steps for CIFAR-10, using the uni_pc solver and time-edm discretization.

πŸ“Œ Example: Generating CIFAR-10 Training Data

CUDA_VISIBLE_DEVICES=0 python3 gen_data.py \
                    --all_config configs/cifar10.yml \
                    --total_samples 100 \
                    --sampling_batch_size 10 \
                    --steps 20 \
                    --solver_name uni_pc \
                    --skip_type edm \
                    --save_pt --save_png --data_dir train_data/train_data_cifar10 \
                    --low_gpu

πŸ“Œ Key Arguments:

  • all_config: Path to the default configuration file (mandatory). If other arguments are not specified, their values will be taken from this file.
  • solver_name: Solver to use. Options include uni_pc, dpm_solver++, euler, and ipndm.
  • skip_type: Discretization method. Options include edm, time_uniform, and time_quadratic.
  • low_gpu: Enables the use of PyTorch's checkpoint feature to reduce GPU memory usage.
  • data_dir: Root directory for saving the generated data. The script will create a subdirectory within this path using the naming format ${solver_name}_NFE${steps}_${skip_type}.

πŸ“Œ Example: Generating Stable Diffusion Training Data

For Stable Diffusion, you must additionally specify the prompt file and the number of prompts. Below is an example:

CUDA_VISIBLE_DEVICES=0 python3 gen_data.py \
                    --all_config configs/stable_diff_v1-4.yml \
                    --total_samples 100 \
                    --sampling_batch_size 2 \
                    --steps 6 \
                    --solver_name uni_pc \
                    --skip_type time_uniform \
                    --save_pt --save_png --data_dir train_data/train_data_stable_diff_v1-4 \
                    --low_gpu \
                    --num_prompts 5 --prompt_path captions_val2014.json

Training LD3

After generating the training data, you can train LD3 using the main.py script. Below is an example of training LD3 on CIFAR-10 with the following configurations:

  • Teacher: 20 sampling steps, uni_pc solver, and time-edm discretization.
  • Student: 10 sampling steps, dpm_solver++ solver.
CUDA_VISIBLE_DEVICES=0 python3 main.py \
                    --all_config configs/cifar10.yml \
                    --data_dir train_data/train_data_cifar10/uni_pc_NFE20_edm \
                    --num_train 50 --num_valid 50 \
                    --main_train_batch_size 1 \
                    --main_valid_batch_size 10 \
                    --solver_name dpm_solver++ \
                    --training_rounds_v1 2 \
                    --training_rounds_v2 5 \
                    --steps 10 \
                    --log_path logs/logs_cifar10

Trained timesteps are available here and are still being updated.

πŸ“Œ Key Arguments:

  • data_dir: The full path to the training data directory (unlike the root directory used during data generation).
  • log_path: The root directory for saving logs and models. The script will create a subdirectory within this path using the naming format: ${solver_name}-N${steps}-b${bound}-${loss_type}-lr2${lr2}rv1${rv1}-rv2${rv2}, for example, uni_pc-N10-b0.03072-LPIPS-lr20.01rv12-rv25

FID Evaluation

⚠️ Different FID Scores

It is important to note that FID (FrΓ©chet Inception Distance) scores can vary significantly depending on the processing pipeline used. To ensure transparency and reproducibility, our framework provides a script compute_fid.py that supports FID evaluation for both EDM and Latent-Diffusion.

πŸ“Œ How FID Evaluation Works

The compute_fid.py script is a streamlined version of gen_data.py with a few differences:

The --save_dir, --save_pt, and --save_png arguments are ignored because the generated data is directly processed for FID calculation without being saved.

The data is automatically forwarded to the FID computation module to extract features.

Optionally, you can pass your own timesteps via --custom_ts_1 and --custom_ts_2. If custom_ts_2 is not specified, it will be set the same as custom_ts_1

πŸ“Œ Example: Computing FID for Stable Diffusion

CUDA_VISIBLE_DEVICES=0 python3 compute_fid.py \
                    --all_config configs/stable_diff_v1-5.yml \
                    --total_samples 100 \
                    --sampling_batch_size 2 \
                    --steps 6 \
                    --solver_name uni_pc \
                    --skip_type time_uniform \
                    --low_gpu \
                    --num_prompts 5 --prompt_path captions_val2014.json

CUDA_VISIBLE_DEVICES=0 python3 compute_fid.py \
                    --all_config configs/stable_diff_v1-5.yml \
                    --total_samples 100 \
                    --sampling_batch_size 2 \
                    --steps 4 \
                    --solver_name ipndm \
                    --skip_type custom \
                    --custom_ts_1  [1.0000e+00,7.6668e-01,4.8113e-01,1.8417e-01,1.0000e-03] \
                    --custom_ts_2  [1.0000e+00,7.6706e-01,4.8103e-01,1.8396e-01,1.0000e-03] \
                    --low_gpu \
                    --num_prompts 5 --prompt_path captions_val2014.json

Citation

@inproceedings{tong2024learning,
  title     = {Learning to Discretize Denoising Diffusion ODEs},
  author    = {Tong, Vinh and Hoang, Trung-Dung and Liu, Anji and Van den Broeck, Guy and Niepert, Mathias},
  booktitle = {Proceedings of the 13th International Conference on Learning Representations},
  year      = {2025}
}
@article{tong2024learning,
  title={Learning to Discretize Denoising Diffusion ODEs},
  author={Tong, Vinh and Hoang, Trung-Dung and Liu, Anji and Broeck, Guy Van den and Niepert, Mathias},
  journal={arXiv preprint arXiv:2405.15506},
  year={2024}
}

License

MIT

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for vinhtong97/LD3

Finetuned
(1140)
this model

Datasets used to train vinhtong97/LD3