# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # -------------------------------------------------------- # References: # GLIDE: https://github.com/openai/glide-text2im # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py # # Modifications Copyright (c) Ensemble AI, 2025. # Description of modifications: Using NdLinear in the model to # make the model more compact yet with similar performance. import torch import torch.nn as nn import numpy as np import math from timm.models.vision_transformer import PatchEmbed, Attention, Mlp from mlp import NdMlp from ndlinear import NdLinear from transformers import PreTrainedModel, PretrainedConfig def modulate(x, shift, scale): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) class TimestepEmbedder(nn.Module): def __init__(self, hidden_size, frequency_embedding_size=256): super().__init__() self.mlp = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_size, bias=True), nn.SiLU(), nn.Linear(hidden_size, hidden_size, bias=True), ) self.frequency_embedding_size = frequency_embedding_size @staticmethod def timestep_embedding(t, dim, max_period=10000): half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=t.device) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding def forward(self, t): t_freq = self.timestep_embedding(t, self.frequency_embedding_size) t_emb = self.mlp(t_freq) return t_emb class NdTimestepEmbedder(nn.Module): def __init__(self, hidden_size, frequency_embedding_size=256, use_num_transforms=2, tse_scale_factor=1, knowledge_transfer=False, src_layers=None): super().__init__() self.activation = nn.SiLU() self.frequency_embedding_size = frequency_embedding_size self.use_num_transforms = use_num_transforms if knowledge_transfer and not src_layers: raise ValueError("Source layers must be provided for knowledge transfer.") if use_num_transforms == 2: self.ndlinear_1 = NdLinear((frequency_embedding_size // 16, 16), (int(hidden_size // tse_scale_factor // 2), 2)) self.ndlinear_2 = NdLinear((int(hidden_size // tse_scale_factor // 2), 2), (hidden_size, 1)) if use_num_transforms == 20: self.ndlinear_1 = NdLinear((frequency_embedding_size, 1), (int(hidden_size // tse_scale_factor), 1)) self.ndlinear_2 = NdLinear((int(hidden_size // tse_scale_factor), 1), (hidden_size, 1)) @staticmethod def timestep_embedding(t, dim, max_period=10000): half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=t.device) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding def forward(self, t): t_freq = self.timestep_embedding(t, self.frequency_embedding_size) if self.use_num_transforms == 2: t_freq = t_freq.reshape(*t_freq.shape, 1) elif self.use_num_transforms == 21: t_freq = t_freq.reshape(t_freq.shape[0], 16, 16) elif self.use_num_transforms == 3: t_freq = t_freq.reshape(t_freq.shape[0], t_freq.shape[1] // 16, 16, 1) elif self.use_num_transforms == 4: t_freq = t_freq.reshape(t_freq.shape[0], t_freq.shape[1] // 16, 4, 4, 1) t_emb = self.ndlinear_1(t_freq) t_emb = self.activation(t_emb) t_emb = self.ndlinear_2(t_emb) t_emb = t_emb.squeeze() return t_emb class LabelEmbedder(nn.Module): def __init__(self, num_classes, hidden_size, dropout_prob): super().__init__() use_cfg_embedding = dropout_prob > 0 self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) self.num_classes = num_classes self.dropout_prob = dropout_prob def token_drop(self, labels, force_drop_ids=None): if force_drop_ids is None: drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob else: drop_ids = force_drop_ids == 1 labels = torch.where(drop_ids, self.num_classes, labels) return labels def forward(self, labels, train, force_drop_ids=None): use_dropout = self.dropout_prob > 0 if (train and use_dropout) or (force_drop_ids is not None): labels = self.token_drop(labels, force_drop_ids) embeddings = self.embedding_table(labels) return embeddings class DiTBlock(nn.Module): def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, use_ndmlp=False, use_variant=4, use_ndadaln=False, **block_kwargs): super().__init__() self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) mlp_hidden_dim = int(hidden_size * mlp_ratio) approx_gelu = lambda: nn.GELU(approximate="tanh") if use_ndmlp: self.mlp = NdMlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0, use_variant=use_variant) else: self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) def forward(self, x, c): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) modulated_x = modulate(self.norm2(x), shift_mlp, scale_mlp) mlp_output = self.mlp(modulated_x) gated_mlp_output = gate_mlp.unsqueeze(1) * mlp_output x = x + gated_mlp_output return x class FinalLayer(nn.Module): def __init__(self, hidden_size, patch_size, out_channels, use_ndadaln=False): super().__init__() self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) self.use_ndadaln = use_ndadaln if self.use_ndadaln: self.adaLN_modulation = nn.Sequential( nn.SiLU(), NdLinear((hidden_size, 1), (2 * hidden_size, 1)) ) else: self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) ) def forward(self, x, c): shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) x = modulate(self.norm_final(x), shift, scale) x = self.linear(x) return x class DiTConfig(PretrainedConfig): model_type = "ndlinear_dit" def __init__(self, input_size=32, patch_size=2, in_channels=4, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, num_classes=1000, learn_sigma=True, use_ndmlp=False, use_ndtse=False, use_variant=4, tse_scale_factor=2, use_num_transforms=2, **kwargs): super().__init__(**kwargs) self.input_size = input_size self.patch_size = patch_size self.in_channels = in_channels self.out_channels = in_channels * 2 if learn_sigma else in_channels self.hidden_size = hidden_size self.depth = depth self.num_heads = num_heads self.mlp_ratio = mlp_ratio self.class_dropout_prob = class_dropout_prob self.num_classes = num_classes self.learn_sigma = learn_sigma self.use_ndmlp = use_ndmlp self.use_ndtse = use_ndtse self.use_variant = use_variant self.tse_scale_factor = tse_scale_factor self.use_num_transforms = use_num_transforms class DiT(PreTrainedModel): config_class = DiTConfig def __init__(self, config): super().__init__(config) self.input_size = config.input_size self.patch_size = config.patch_size self.in_channels = config.in_channels self.hidden_size = config.hidden_size self.depth = config.depth self.num_heads = config.num_heads self.mlp_ratio = config.mlp_ratio self.class_dropout_prob = config.class_dropout_prob self.num_classes = config.num_classes self.learn_sigma = config.learn_sigma self.use_ndmlp = config.use_ndmlp self.use_ndtse = config.use_ndtse self.use_variant = config.use_variant self.tse_scale_factor = config.tse_scale_factor self.use_num_transforms = config.use_num_transforms self.out_channels = config.out_channels self.ndadaln = getattr(config, "ndadaln", False) self.x_embedder = PatchEmbed(self.input_size, self.patch_size, self.in_channels, self.hidden_size, bias=True) if self.use_ndtse: self.t_embedder = NdTimestepEmbedder( hidden_size=self.hidden_size, frequency_embedding_size=256, use_num_transforms=self.use_num_transforms, tse_scale_factor=1, knowledge_transfer=False, src_layers=None ) else: self.t_embedder = TimestepEmbedder(self.hidden_size) self.y_embedder = LabelEmbedder(self.num_classes, self.hidden_size, self.class_dropout_prob) num_patches = self.x_embedder.num_patches self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.hidden_size), requires_grad=False) self.blocks = nn.ModuleList([ DiTBlock(self.hidden_size, self.num_heads, mlp_ratio=self.mlp_ratio, use_ndmlp=self.use_ndmlp, use_variant=self.use_variant) for _ in range(self.depth) ]) if self.use_ndmlp: approx_gelu = lambda: nn.GELU(approximate="tanh") for idx, layer in enumerate(self.blocks): if idx % 2 == 0: layer.mlp = NdMlp( in_features=self.hidden_size, hidden_features=self.hidden_size * 4, act_layer=approx_gelu, drop=0, use_variant=self.use_variant ) self.final_layer = FinalLayer(self.hidden_size, self.patch_size, self.out_channels, use_ndadaln=self.ndadaln) self.initialize_weights() def initialize_weights(self): def _basic_init(module): if isinstance(module, nn.Linear): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5)) self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) w = self.x_embedder.proj.weight.data nn.init.xavier_uniform_(w.view([w.shape[0], -1])) nn.init.constant_(self.x_embedder.proj.bias, 0) nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) if not self.use_ndtse: nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) nn.init.constant_(self.final_layer.linear.weight, 0) nn.init.constant_(self.final_layer.linear.bias, 0) def unpatchify(self, x): c = self.out_channels p = self.x_embedder.patch_size[0] h = w = int(x.shape[1] ** 0.5) assert h * w == x.shape[1] x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) x = torch.einsum('nhwpqc->nchpwq', x) imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) return imgs def ckpt_wrapper(self, module): def ckpt_forward(*inputs): outputs = module(*inputs) return outputs return ckpt_forward def forward(self, x, t, y): x = self.x_embedder(x) + self.pos_embed t = self.t_embedder(t) y = self.y_embedder(y, self.training) c = t + y for block in self.blocks: x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, c) x = self.final_layer(x, c) x = self.unpatchify(x) return x def forward_with_cfg(self, x, t, y, cfg_scale): half = x[: len(x) // 2] combined = torch.cat([half, half], dim=0) model_out = self.forward(combined, t, y) eps, rest = model_out[:, :3], model_out[:, 3:] cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) eps = torch.cat([half_eps, half_eps], dim=0) return torch.cat([eps, rest], dim=1) def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): grid_h = np.arange(grid_size, dtype=np.float32) grid_w = np.arange(grid_size, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size, grid_size]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token and extra_tokens > 0: pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): assert embed_dim % 2 == 0 emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) emb = np.concatenate([emb_h, emb_w], axis=1) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float64) omega /= embed_dim / 2. omega = 1. / 10000 ** omega pos = pos.reshape(-1) out = np.einsum('m,d->md', pos, omega) emb_sin = np.sin(out) emb_cos = np.cos(out) emb = np.concatenate([emb_sin, emb_cos], axis=1) return emb def DiT_XL_2(**kwargs): config = DiTConfig(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) return DiT(config) def DiT_XL_4(**kwargs): config = DiTConfig(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs) return DiT(config) def DiT_XL_8(**kwargs): config = DiTConfig(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs) return DiT(config) def DiT_L_2(**kwargs): config = DiTConfig(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs) return DiT(config) def DiT_L_4(**kwargs): config = DiTConfig(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs) return DiT(config) def DiT_L_8(**kwargs): config = DiTConfig(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs) return DiT(config) def DiT_B_2(**kwargs): config = DiTConfig(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs) return DiT(config) def DiT_B_4(**kwargs): config = DiTConfig(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs) return DiT(config) def DiT_B_8(**kwargs): config = DiTConfig(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs) return DiT(config) def DiT_S_2(**kwargs): config = DiTConfig(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs) return DiT(config) def DiT_S_4(**kwargs): config = DiTConfig(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs) return DiT(config) def DiT_S_8(**kwargs): config = DiTConfig(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs) return DiT(config) def DiT_MS_2(**kwargs): config = DiTConfig(depth=6, hidden_size=384, patch_size=2, num_heads=6, **kwargs) return DiT(config) def DiT_MS_4(**kwargs): config = DiTConfig(depth=6, hidden_size=384, patch_size=4, num_heads=6, **kwargs) return DiT(config) def DiT_MS_8(**kwargs): config = DiTConfig(depth=6, hidden_size=384, patch_size=8, num_heads=6, **kwargs) return DiT(config) def DiT_XS_2(**kwargs): config = DiTConfig(depth=1, hidden_size=384, patch_size=2, num_heads=6, **kwargs) return DiT(config) def DiT_XS_4(**kwargs): config = DiTConfig(depth=1, hidden_size=384, patch_size=4, num_heads=6, **kwargs) return DiT(config) def DiT_XS_8(**kwargs): config = DiTConfig(depth=1, hidden_size=384, patch_size=8, num_heads=6, **kwargs) return DiT(config) DiT_models = { 'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8, 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8, 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8, 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8, 'DiT-XS/2': DiT_XS_2, 'DiT-XS/4': DiT_XS_4, 'DiT-XS/8': DiT_XS_8, 'DiT-MS/2': DiT_MS_2, 'DiT-MS/4': DiT_MS_4, 'DiT-MS/8': DiT_MS_8 }