Mehdi Lakbar
Initial demo of Lina-speech (pardi-speech)
56cfa73
raw
history blame
2.27 kB
import torch
import torch.nn.functional as F
from torch import nn
class SwiGLU(nn.Module):
def __init__(self, d_model: int, ffn_expansion_factor: int = 4):
super().__init__()
self.p_in = nn.Linear(d_model, (d_model * ffn_expansion_factor // 3) * 2)
self.p_out = nn.Linear(d_model * ffn_expansion_factor // 3, d_model)
def forward(self, x):
gate, x = self.p_in(x).chunk(2, dim=-1)
return self.p_out(nn.functional.silu(gate) * x)
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class GaussianFourierTimeEmbedding(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.weight = nn.Parameter(torch.randn(dim), requires_grad=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x[:, None] * self.weight[None, :] * 2 * torch.pi
x = torch.cat((torch.sin(x), torch.cos(x)), dim=1)
return x
class AdaLNFinalLayer(nn.Module):
def __init__(self, hidden_dim, feature_dim):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_dim, feature_dim, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(hidden_dim, 2 * hidden_dim, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class AdaLNMLP(nn.Module):
def __init__(self, hidden_dim):
super().__init__()
self.hidden_dim = hidden_dim
self.in_ln = nn.LayerNorm(hidden_dim, eps=1e-6, elementwise_affine=False)
self.mlp = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim, bias=True),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim, bias=True),
)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(hidden_dim, 3 * hidden_dim, bias=True)
)
def forward(self, x, y):
shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
h = self.mlp(h)
return x + gate_mlp * h