DEIS
Fast Sampling of Diffusion Models with Exponential Integrator.
Overview
Original paper can be found here. The original implementation can be found here.
DEISMultistepScheduler
class diffusers.DEISMultistepScheduler
< source >( num_train_timesteps: int = 1000beta_start: float = 0.0001beta_end: float = 0.02beta_schedule: str = 'linear'trained_betas: typing.Optional[numpy.ndarray] = Nonesolver_order: int = 2prediction_type: str = 'epsilon'thresholding: bool = Falsedynamic_thresholding_ratio: float = 0.995sample_max_value: float = 1.0algorithm_type: str = 'deis'solver_type: str = 'logrho'lower_order_final: bool = True )
Parameters
- num_train_timesteps (
int
) β number of diffusion steps used to train the model. - beta_start (
float
) β the startingbeta
value of inference. - beta_end (
float
) β the finalbeta
value. - beta_schedule (
str
) β the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose fromlinear
,scaled_linear
, orsquaredcos_cap_v2
. - trained_betas (
np.ndarray
, optional) β option to pass an array of betas directly to the constructor to bypassbeta_start
,beta_end
etc. - solver_order (
int
, default2
) β the order of DEIS; can be1
or2
or3
. We recommend to usesolver_order=2
for guided sampling, andsolver_order=3
for unconditional sampling. - prediction_type (
str
, defaultepsilon
) β indicates whether the model predicts the noise (epsilon), or the data /x0
. One ofepsilon
,sample
, orv-prediction
. - thresholding (
bool
, defaultFalse
) β whether to use the βdynamic thresholdingβ method (introduced by Imagen, https://arxiv.org/abs/2205.11487). Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). - dynamic_thresholding_ratio (
float
, default0.995
) β the ratio for the dynamic thresholding method. Default is0.995
, the same as Imagen (https://arxiv.org/abs/2205.11487). - sample_max_value (
float
, default1.0
) β the threshold value for dynamic thresholding. Valid only whenthresholding=True
- algorithm_type (
str
, defaultdeis
) β the algorithm type for the solver. current we support multistep deis, we will add other variants of DEIS in the future - lower_order_final (
bool
, defaultTrue
) β whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically find this trick can stabilize the sampling of DEIS for steps < 15, especially for steps <= 10.
DEIS (https://arxiv.org/abs/2204.13902) is a fast high order solver for diffusion ODEs. We slightly modify the polynomial fitting formula in log-rho space instead of the original linear t space in DEIS paper. The modification enjoys closed-form coefficients for exponential multistep update instead of replying on the numerical solver. More variants of DEIS can be found in https://github.com/qsh-zh/deis.
Currently, we support the log-rho multistep DEIS. We recommend to use solver_order=2 / 3
while solver_order=1
reduces to DDIM.
We also support the βdynamic thresholdingβ method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space
diffusion models, you can set thresholding=True
to use the dynamic thresholding.
~ConfigMixin takes care of storing all config attributes that are passed in the schedulerβs __init__
function, such as num_train_timesteps
. They can be accessed via scheduler.config.num_train_timesteps
.
SchedulerMixin provides general loading and saving functionality via the SchedulerMixin.save_pretrained() and
from_pretrained() functions.
convert_model_output
< source >( model_output: FloatTensortimestep: intsample: FloatTensor ) β torch.FloatTensor
Parameters
- model_output (
torch.FloatTensor
) β direct output from learned diffusion model. - timestep (
int
) β current discrete timestep in the diffusion chain. - sample (
torch.FloatTensor
) β current instance of sample being created by diffusion process.
Returns
torch.FloatTensor
the converted model output.
Convert the model output to the corresponding type that the algorithm DEIS needs.
deis_first_order_update
< source >( model_output: FloatTensortimestep: intprev_timestep: intsample: FloatTensor ) β torch.FloatTensor
Parameters
- model_output (
torch.FloatTensor
) β direct output from learned diffusion model. - timestep (
int
) β current discrete timestep in the diffusion chain. - prev_timestep (
int
) β previous discrete timestep in the diffusion chain. - sample (
torch.FloatTensor
) β current instance of sample being created by diffusion process.
Returns
torch.FloatTensor
the sample tensor at the previous timestep.
One step for the first-order DEIS (equivalent to DDIM).
multistep_deis_second_order_update
< source >( model_output_list: typing.List[torch.FloatTensor]timestep_list: typing.List[int]prev_timestep: intsample: FloatTensor ) β torch.FloatTensor
Parameters
- model_output_list (
List[torch.FloatTensor]
) β direct outputs from learned diffusion model at current and latter timesteps. - timestep (
int
) β current and latter discrete timestep in the diffusion chain. - prev_timestep (
int
) β previous discrete timestep in the diffusion chain. - sample (
torch.FloatTensor
) β current instance of sample being created by diffusion process.
Returns
torch.FloatTensor
the sample tensor at the previous timestep.
One step for the second-order multistep DEIS.
multistep_deis_third_order_update
< source >( model_output_list: typing.List[torch.FloatTensor]timestep_list: typing.List[int]prev_timestep: intsample: FloatTensor ) β torch.FloatTensor
Parameters
- model_output_list (
List[torch.FloatTensor]
) β direct outputs from learned diffusion model at current and latter timesteps. - timestep (
int
) β current and latter discrete timestep in the diffusion chain. - prev_timestep (
int
) β previous discrete timestep in the diffusion chain. - sample (
torch.FloatTensor
) β current instance of sample being created by diffusion process.
Returns
torch.FloatTensor
the sample tensor at the previous timestep.
One step for the third-order multistep DEIS.
scale_model_input
< source >( sample: FloatTensor*args**kwargs ) β torch.FloatTensor
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep.
set_timesteps
< source >( num_inference_steps: intdevice: typing.Union[str, torch.device] = None )
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
step
< source >( model_output: FloatTensortimestep: intsample: FloatTensorreturn_dict: bool = True ) β ~scheduling_utils.SchedulerOutput
or tuple
Parameters
- model_output (
torch.FloatTensor
) β direct output from learned diffusion model. - timestep (
int
) β current discrete timestep in the diffusion chain. - sample (
torch.FloatTensor
) β current instance of sample being created by diffusion process. - return_dict (
bool
) β option for returning tuple rather than SchedulerOutput class
Returns
~scheduling_utils.SchedulerOutput
or tuple
~scheduling_utils.SchedulerOutput
if return_dict
is
True, otherwise a tuple
. When returning a tuple, the first element is the sample tensor.
Step function propagating the sample with the multistep DEIS.