from typing import NamedTuple import torch import torch.nn.functional as F def get_mean_shifted_latents( latents: torch.Tensor, shift: float = 0.11, delta_shift: float = 0.1, channels: list[int] = [0, 1, 1, 0], # list of {-1, 0, 1} ) -> torch.Tensor: shifted_latents = latents.clone() print("channels", channels) for idx, sign in enumerate(channels): if sign == 0: # skip continue latent_channel = shifted_latents[:, idx, :, :] positive_ratio = (latent_channel > 0).float().mean() target_ratio = positive_ratio + shift * sign # gradually shift latent_channel while True: latent_channel += delta_shift * sign new_positive_ratio = (latent_channel > 0).float().mean() if new_positive_ratio >= target_ratio: break # replace the channel in the original latents shifted_latents[:, idx, :, :] = latent_channel return shifted_latents def get_2d_gaussian( latent_height: int, latent_width: int, std_dev: float, device: torch.device, center_x: float = 0.0, center_y: float = 0.0, factor: int = 8, # idk why ): y = torch.linspace(-1, 1, steps=latent_height // factor, device=device) x = torch.linspace(-1, 1, steps=latent_width // factor, device=device) y_grid, x_grid = torch.meshgrid(y, x, indexing="ij") x_grid = x_grid - center_x y_grid = y_grid - center_y gauss = torch.exp(-((x_grid**2 + y_grid**2) / (2 * std_dev**2))) gauss = gauss[None, None, :, :] # add batch and channel dimensions return gauss def apply_tkg_noise( latents: torch.Tensor, shift: float = 0.11, delta_shift: float = 0.1, std_dev: float = 0.5, factor: int = 8, channels: list[int] = [0, 1, 1, 0], ): batch_size, num_channels, latent_height, latent_width = latents.shape shifted_latents = get_mean_shifted_latents( latents, shift=shift, delta_shift=delta_shift, channels=channels, ) gauss_mask = get_2d_gaussian( latent_height=latent_height, latent_width=latent_width, std_dev=std_dev, center_x=0.0, center_y=0.0, factor=factor, device=latents.device, ) gauss_mask = F.interpolate( gauss_mask, size=(latent_height, latent_width), mode="bilinear", align_corners=False, ) gauss_mask = gauss_mask.expand(batch_size, num_channels, -1, -1) noised_latents = shifted_latents * (1 - gauss_mask) + latents * gauss_mask return noised_latents class ColorSet(NamedTuple): name: str channels: list[int] # ref: Figure 28. Additional Result in various color Background with SD COLOR_SETS: list[ColorSet] = [ ColorSet("green", [0, 1, 1, 0]), ColorSet("cyan", [0, 1, 0, 0]), ColorSet("magenta", [0, -1, -1, -1]), ColorSet("purple", [0, 0, -1, -1]), ColorSet("black", [-1, 0, 0, 1]), ColorSet("orange", [-1, -1, 1, 0]), ColorSet("white", [0, 0, 0, -1]), ColorSet("yellow", [0, -1, 1, -1]), ] COLOR_SET_MAP: dict[str, ColorSet] = {c.name: c for c in COLOR_SETS}