Explaining the SDXL latent space
TL;DR
or check out the interactive demonstration
Table of Contents
A Short background story
The 4 channels of the SDXL latents
The 8-bit pixel space has 3 channels
The SDXL latent representation of an image has 4 channels
Direct conversion of SDXL latents to RGB with a linear approximation
A probable reason why the SDXL color range is biased towards yellow
What needs correcting?
Let's take an example output from SDXL
A complete demonstration
Increasing color range / removing color bias
Long prompts at high guidance scales becoming possible
Introductory note
I've gotten some strange questions after this article was scraped by other sites and reworded. If you're reading this anywhere other than Hugging Face, here is the original article: https://huggingface.co/blog/TimothyAlexisVass/explaining-the-sdxl-latent-space
A short background story
Special thanks to: Ollin Boer Bohan Haoming, Cristina Segalin and Birchlabs for helping with information, discussion and knowledge!
I was creating correction filters for the SDXL inference process to an UI I'm creating for diffusion models.
After having many years of experience with image correction, I wanted the fundamental capability to improve the actual output from SDXL. There were many techniques which I wanted available in the UX, which I set out to fix myself. I noticed that SDXL output is almost always either noisy in regular patterns or overly smooth. The color space always needed white balancing, with a biased and restricted color range, simply because of how SD models work.
Making corrections in a post process after the image is generated and converted to 8-bit RGB made very little sense, if it was possible to improve the information and color range before the actual output.
The most important thing to know in order to create filters and correction tools is to understand the data you are working with.
This led me to an experimental exploration of the SDXL latents with the intention of understanding them.
The tensor, which the diffusion models based on the SDXL architecture work with, looks like this:
[batch_size, 4 channels, height (y), width (x)]
My first question was simply "What exactly are these 4 channels?". To which most answers I received were along the lines of "It's not something that a human can understand."
But it is most definitely understandable. It's even very easy to understand and useful to know.
The 4 channels of the SDXL latents
For a 1024×1024px image generated by SDXL, the latents tensor is 128×128px, where every pixel in the latent space represents 64 (8×8) pixels in the pixel space. If we generate and decode the latents into a standard 8-bit jpg image, then...
The 8-bit pixel space has 3 channels
Red (R), Green (G) and Blue (B), each with 256 possible values ranging between 0-255. So, to store the full information of 64 pixels, we need to be able to store 64×256 = 16,384 values, per channel, in every latent pixel.
The SDXL latent representation of an image has 4 channels
Click the heading for an interactive demo!
0: Luminance
1: Cyan/Red => equivalent to rgb(0, 255, 255)/rgb(255, 0, 0)
2: Lime/Medium Purple => equivalent to rgb(127, 255, 0)/rgb(127, 0, 255)
3: Pattern/structure.
If each value can range between -4 and 4 at the point of decoding, then in a 16-bit floating point format with half precision, each latent pixel can contain 16,384 distinct values for each of the 4 channels.
Direct conversion of SDXL latents to RGB with a linear approximation
With this understanding, we can create an approximation function which directly converts the latents to RGB:
def latents_to_rgb(latents):
weights = (
(60, -60, 25, -70),
(60, -5, 15, -50),
(60, 10, -5, -35)
)
weights_tensor = torch.t(torch.tensor(weights, dtype=latents.dtype).to(latents.device))
biases_tensor = torch.tensor((150, 140, 130), dtype=latents.dtype).to(latents.device)
rgb_tensor = torch.einsum("...lxy,lr -> ...rxy", latents, weights_tensor) + biases_tensor.unsqueeze(-1).unsqueeze(-1)
image_array = rgb_tensor.clamp(0, 255)[0].byte().cpu().numpy()
image_array = image_array.transpose(1, 2, 0) # Change the order of dimensions
return Image.fromarray(image_array)
Here we have the latents_to_rgb result and a regular decoded output, resized for comparison:
A probable reason why the SDXL color range is biased towards yellow
Relatively few things in nature are blue, or white. These colors are most prominent in the sky, during enjoyable conditions. So, the model, knowing reality through images, thinks in luminance (channel 0) cyan/red (channel 1) and lime/medium purple (channel 2), where Red and Green are primary and blue is secondary. This is why very often, SDXL generations are biased towards yellow (red + green).
During inference, the values in the tensor will begin at min < -30
and max > 30
and the min/max boundary at time of decoding is around -4
to 4
. At higher guidance_scale
the values will have a higher difference between min
and max
.
One key in understanding the boundary is to look at what happens in the decoding process:
decoded = vae.decode(latents / vae.scaling_factor).sample # (SDXL vae.scaling_factor = 0.13025)
decoded = decoded.div(2).add(0.5).clamp(0, 1) # The dynamics outside of 0 to 1 at this point will be lost
If the values at this point are outside of the range 0 to 1, some information will be lost in the clamp. So if we can make corrections during denoising to serve the VAE what it expects, we may get better results.
What needs correcting?
How do you sharpen a blurry image, white balance, improve detail, increase contrast or increase the color range? The best way is to begin with a sharp image, which is correctly white balanced with great contrast, crisp details and a high range.
It's far easier to blur a sharp image, shift the color balance, reduce contrast, get nonsensical details and limit the color range than to improve it.
SDXL has a very prominent tendency to color bias and put values outside of the actual boundaries (left image). Which is easily solved by centering the values and getting them within the boundaries (right image):
def center_tensor(input_tensor, per_channel_shift=1, full_tensor_shift=1, channels=[0, 1, 2, 3]):
for channel in channels:
input_tensor[0, channel] -= input_tensor[0, channel].mean() * per_channel_shift
return input_tensor - input_tensor.mean() * full_tensor_shift
Let's take an example output from SDXL
seed: 77777777
guidance_scale: 20 # A high guidance scale can be fixed too
steps with base: 23
steps with refiner: 10
prompt: Cinematic.Beautiful smile action woman in detailed white mecha gundam armor with red details,green details,blue details,colorful,star wars universe,lush garden,flowers,volumetric lighting,perfect eyes,perfect teeth,blue sky,bright,intricate details,extreme detail of environment,infinite focus,well lit,interesting clothes,radial gradient fade,directional particle lighting,wow
negative_prompt: helmet, bokeh, painting, artwork, blocky, blur, ugly, old, boring, photoshopped, tired, wrinkles, scar, gray hair, big forehead, crosseyed, dumb, stupid, cockeyed, disfigured, crooked, blurry, unrealistic, grayscale, bad anatomy, unnatural irises, no pupils, blurry eyes, dark eyes, extra limbs, deformed, disfigured eyes, out of frame, no irises, assymetrical face, broken fingers, extra fingers, disfigured hands
Notice that I've purposely chosen a high guidance scale.
How can we fix this image? It's half painting, half photograph. The colors range is biased towards yellow. To the right is a fixed generation with the exact same settings.
But also with a sensible guidance_scale
set to 7.5, we can still conclude that the fixed output is better, without nonsensical details and correct white balance.
There are many things we can do in the latent space to generally improve a generation and there are some very simple things which we can do to target specific errors in a generation:
Outlier removal
This will control the amount of nonsensical details, by pruning values that are the farthest from the mean of the distribution. It also helps in generating at higher guidance_scale.
# Shrinking towards the mean (will also remove outliers)
def soft_clamp_tensor(input_tensor, threshold=3.5, boundary=4):
if max(abs(input_tensor.max()), abs(input_tensor.min())) < 4:
return input_tensor
channel_dim = 1
max_vals = input_tensor.max(channel_dim, keepdim=True)[0]
max_replace = ((input_tensor - threshold) / (max_vals - threshold)) * (boundary - threshold) + threshold
over_mask = (input_tensor > threshold)
min_vals = input_tensor.min(channel_dim, keepdim=True)[0]
min_replace = ((input_tensor + threshold) / (min_vals + threshold)) * (-boundary + threshold) - threshold
under_mask = (input_tensor < -threshold)
return torch.where(over_mask, max_replace, torch.where(under_mask, min_replace, input_tensor))
Color balancing and increased range
I have two main methods of achieving this. The first one is to shrink towards the mean while normalizing the values (Which will also remove outliers) and the second is to fix when the values get biased towards some color. This also helps in generating at higher guidance_scale.
# Center tensor (balance colors)
def center_tensor(input_tensor, channel_shift=1, full_shift=1, channels=[0, 1, 2, 3]):
for channel in channels:
input_tensor[0, channel] -= input_tensor[0, channel].mean() * channel_shift
return input_tensor - input_tensor.mean() * full_shift
Tensor maximizing
This is basically done by multiplying the tensors by a very small amount like 1e-5
for a few steps and to make sure that the final tensor is using the full possible range ( closer to -4/4) before converting to RGB. Remember, in the pixel space, it's easier to reduce contrast, saturation and sharpness with intact dynamics than to increase it.
# Maximize/normalize tensor
def maximize_tensor(input_tensor, boundary=4, channels=[0, 1, 2]):
min_val = input_tensor.min()
max_val = input_tensor.max()
normalization_factor = boundary / max(abs(min_val), abs(max_val))
input_tensor[0, channels] *= normalization_factor
return input_tensor
Callback implementation example
def callback(pipe, step_index, timestep, cbk):
if timestep > 950:
threshold = max(cbk["latents"].max(), abs(cbk["latents"].min())) * 0.998
cbk["latents"] = soft_clamp_tensor(cbk["latents"], threshold*0.998, threshold)
if timestep > 700:
cbk["latents"] = center_tensor(cbk["latents"], 0.8, 0.8)
if timestep > 1 and timestep < 100:
cbk["latents"] = center_tensor(cbk["latents"], 0.6, 1.0)
cbk["latents"] = maximize_tensor(cbk["latents"])
return cbk
image = base(
prompt,
guidance_scale = guidance_scale,
callback_on_step_end=callback,
callback_on_step_end_inputs=["latents"]
).images[0]
This simple implementation of the three methods are used in the last set of images, with the women in the garden.
A complete demonstration
Click the heading or this link for an interactive demo!
This demonstration uses a more advanced implementation of the techniques by detecting outliers using Z-score, by shifting towards mean dynamically and by applying strength to each technique.
Original SDXL (too yellow) and slight modification (white balanced)
Medium modification and hard modification (both with all 3 techniques applied)
Increasing color range / removing color bias
For the below, SDXL has limited the color range to red and green in the regular output. Because there is nothing in the prompt suggesting that there is such a thing as blue. This is a rather good generation, but the color range has become restricted.
If you give someone a palette of black, red, green and yellow and then tell them to paint a clear blue sky, the natural response is to ask you to supply blue and white.
To include blue in the generation, we can simply realign the color space when it gets restricted and SDXL will appropriately include the full color spectrum in the generation.
Long prompts at high guidance scales becoming possible
Here is a typical scenario, where the increased color range makes the whole prompt possible.
This example apply the simple, hard modification shown earlier, to illustrate the difference more clearly.
prompt: Photograph of woman in red dress in a luxury garden surrounded with blue, yellow, purple and flowers in many colors, high class, award-winning photography, Portra 400, full format. blue sky, intricate details even to the smallest particle, extreme detail of the environment, sharp portrait, well lit, interesting outfit, beautiful shadows, bright, photoquality, ultra realistic, masterpiece
Here are some more comparisons on the same concept
Keep in mind that these all just use the same static modifications.