Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,267 Bytes
56cfa73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import torch
import torch.nn.functional as F
from torch import nn
class SwiGLU(nn.Module):
def __init__(self, d_model: int, ffn_expansion_factor: int = 4):
super().__init__()
self.p_in = nn.Linear(d_model, (d_model * ffn_expansion_factor // 3) * 2)
self.p_out = nn.Linear(d_model * ffn_expansion_factor // 3, d_model)
def forward(self, x):
gate, x = self.p_in(x).chunk(2, dim=-1)
return self.p_out(nn.functional.silu(gate) * x)
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class GaussianFourierTimeEmbedding(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.weight = nn.Parameter(torch.randn(dim), requires_grad=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x[:, None] * self.weight[None, :] * 2 * torch.pi
x = torch.cat((torch.sin(x), torch.cos(x)), dim=1)
return x
class AdaLNFinalLayer(nn.Module):
def __init__(self, hidden_dim, feature_dim):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_dim, feature_dim, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(hidden_dim, 2 * hidden_dim, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class AdaLNMLP(nn.Module):
def __init__(self, hidden_dim):
super().__init__()
self.hidden_dim = hidden_dim
self.in_ln = nn.LayerNorm(hidden_dim, eps=1e-6, elementwise_affine=False)
self.mlp = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim, bias=True),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim, bias=True),
)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(hidden_dim, 3 * hidden_dim, bias=True)
)
def forward(self, x, y):
shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
h = self.mlp(h)
return x + gate_mlp * h
|