About finetuning current SDXL weights by the EQ-SDXL-VAE
You said in intro: "You can try to use this VAE to finetune your sdxl model and expect a better final result, but it may require lot of time to achieve it...". I am still very interesting in utilizing existed model weights. So my question is lot of how? I have ~500k samples and how many iterations are required to align the UNet of SDXL with new latent space?
lot of training time.
ALTHOUGH some reported result is "few k step with a small lora works well"
Your setup is definitely ok
I just thought dataset like LAION-400M needed. Finally it turns out in scale of kilo samples said to be working.
I just thought dataset like LAION-400M needed. Finally it turns out in scale of kilo samples said to be working.
My thought is like danbooru (8M) or CC 12M
and yes, I'm also surprising that few k or just few dozen k is enough
I spent a night to have a quick try by finetuning a lora about 48k iterations and get very poor result and I suspect that there is something wrong in finetuning process. Do I need modify my training script in aspect of VAE? Because I notice there are some parameters not used by oringinal VAE such as:
"shift_factor": 0.8640247167934477,
In my training script VAE encoding part goes like this:
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae.config.scaling_factor
model_input = model_input.to(weight_dtype)
should I change the code into:
model_input = model_input * vae.config.scaling_factor + vae.config.shift_factor
?
By the way, I use StableDiffusionXLPipeline from diffusers for inference.
I will strongly recommend you to:
Encode:
latent = vae.encode(pixel).latent_dist.sample()
std_latent = (latent - torch.tensor(vae.config.latents_mean)[None, :, None, None]) / torch.tensor(vae.config.latents_std)[None, :, None, None]
model_input = std_latent.to(weight_dtype)
Decode:
latent = model_output * torch.tensor(vae.config.latents_std)[None, :, None, None] + torch.tensor(vae.config.latents_mean)[None, :, None, None]
pixel = vae.decode(latent).sample * 0.5 + 0.5
To utilize this in SDXL pipeline you may need some modification of the source code, if you don't want to do it, just finetune with "scale" + "shift" only which follow the pipeline impl
Your "should I change the code into"... is correct if you only want to modify the trainer code
I checked sdxl pipeline of diffusers v0.32.2 https://github.com/huggingface/diffusers/blob/560fb5f4d65b8593c13e4be50a59b1fd9c2d9992/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L1268-L1281
and I found out the pipeline of this version has taken mean and std into consideration:
# unscale/denormalize the latents
# denormalize with the mean and std if available and not None
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
if has_latents_mean and has_latents_std:
latents_mean = (
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
)
latents_std = (
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
)
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
else:
latents = latents / self.vae.config.scaling_factor
But the calculation still has obvious difference from the code you provided. Such as the code from pipeline still unscales the latents no matter latent mean and std used. So I trained with model_input = model_input * vae.config.scaling_factor
and inferenced with latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
, no wonder finally led poor result.
According my understanding the difference between model_input = model_input * vae.config.scaling_factor
and latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
is that the former normalization applied same mean and std ( although represented as scaling and shift ) on all channels of latent but latter one applied channel wise mean and std which is possibly better for normalization. Is that the purpose like I thought?