Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| class SwiGLU(nn.Module): | |
| def __init__(self, d_model: int, ffn_expansion_factor: int = 4): | |
| super().__init__() | |
| self.p_in = nn.Linear(d_model, (d_model * ffn_expansion_factor // 3) * 2) | |
| self.p_out = nn.Linear(d_model * ffn_expansion_factor // 3, d_model) | |
| def forward(self, x): | |
| gate, x = self.p_in(x).chunk(2, dim=-1) | |
| return self.p_out(nn.functional.silu(gate) * x) | |
| def modulate(x, shift, scale): | |
| return x * (1 + scale) + shift | |
| class GaussianFourierTimeEmbedding(nn.Module): | |
| def __init__(self, dim: int): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.randn(dim), requires_grad=False) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = x[:, None] * self.weight[None, :] * 2 * torch.pi | |
| x = torch.cat((torch.sin(x), torch.cos(x)), dim=1) | |
| return x | |
| class AdaLNFinalLayer(nn.Module): | |
| def __init__(self, hidden_dim, feature_dim): | |
| super().__init__() | |
| self.norm_final = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6) | |
| self.linear = nn.Linear(hidden_dim, feature_dim, bias=True) | |
| self.adaLN_modulation = nn.Sequential( | |
| nn.SiLU(), nn.Linear(hidden_dim, 2 * hidden_dim, bias=True) | |
| ) | |
| def forward(self, x, c): | |
| shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) | |
| x = modulate(self.norm_final(x), shift, scale) | |
| x = self.linear(x) | |
| return x | |
| class AdaLNMLP(nn.Module): | |
| def __init__(self, hidden_dim): | |
| super().__init__() | |
| self.hidden_dim = hidden_dim | |
| self.in_ln = nn.LayerNorm(hidden_dim, eps=1e-6, elementwise_affine=False) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(hidden_dim, hidden_dim, bias=True), | |
| nn.SiLU(), | |
| nn.Linear(hidden_dim, hidden_dim, bias=True), | |
| ) | |
| self.adaLN_modulation = nn.Sequential( | |
| nn.SiLU(), nn.Linear(hidden_dim, 3 * hidden_dim, bias=True) | |
| ) | |
| def forward(self, x, y): | |
| shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1) | |
| h = modulate(self.in_ln(x), shift_mlp, scale_mlp) | |
| h = self.mlp(h) | |
| return x + gate_mlp * h | |