import torch import torch.nn.functional as F from torch import nn class SwiGLU(nn.Module): def __init__(self, d_model: int, ffn_expansion_factor: int = 4): super().__init__() self.p_in = nn.Linear(d_model, (d_model * ffn_expansion_factor // 3) * 2) self.p_out = nn.Linear(d_model * ffn_expansion_factor // 3, d_model) def forward(self, x): gate, x = self.p_in(x).chunk(2, dim=-1) return self.p_out(nn.functional.silu(gate) * x) def modulate(x, shift, scale): return x * (1 + scale) + shift class GaussianFourierTimeEmbedding(nn.Module): def __init__(self, dim: int): super().__init__() self.weight = nn.Parameter(torch.randn(dim), requires_grad=False) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x[:, None] * self.weight[None, :] * 2 * torch.pi x = torch.cat((torch.sin(x), torch.cos(x)), dim=1) return x class AdaLNFinalLayer(nn.Module): def __init__(self, hidden_dim, feature_dim): super().__init__() self.norm_final = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6) self.linear = nn.Linear(hidden_dim, feature_dim, bias=True) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_dim, 2 * hidden_dim, bias=True) ) def forward(self, x, c): shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) x = modulate(self.norm_final(x), shift, scale) x = self.linear(x) return x class AdaLNMLP(nn.Module): def __init__(self, hidden_dim): super().__init__() self.hidden_dim = hidden_dim self.in_ln = nn.LayerNorm(hidden_dim, eps=1e-6, elementwise_affine=False) self.mlp = nn.Sequential( nn.Linear(hidden_dim, hidden_dim, bias=True), nn.SiLU(), nn.Linear(hidden_dim, hidden_dim, bias=True), ) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_dim, 3 * hidden_dim, bias=True) ) def forward(self, x, y): shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1) h = modulate(self.in_ln(x), shift_mlp, scale_mlp) h = self.mlp(h) return x + gate_mlp * h