Spaces:
Running
Running
| import torch | |
| from torch import nn, einsum | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from einops.layers.torch import Rearrange | |
| # source: https://github.com/lucidrains/conformer/blob/master/conformer/conformer.py | |
| # helper functions | |
| def exists(val): | |
| return val is not None | |
| def default(val, d): | |
| return val if exists(val) else d | |
| def calc_same_padding(kernel_size): | |
| pad = kernel_size // 2 | |
| return (pad, pad - (kernel_size + 1) % 2) | |
| class Swish(nn.Module): | |
| def forward(self, x): | |
| return x * x.sigmoid() | |
| class GLU(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.dim = dim | |
| def forward(self, x): | |
| out, gate = x.chunk(2, dim=self.dim) | |
| return out * gate.sigmoid() | |
| class DepthWiseConv1d(nn.Module): | |
| def __init__(self, chan_in, chan_out, kernel_size, padding): | |
| super().__init__() | |
| self.padding = padding | |
| self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups = chan_in) | |
| def forward(self, x): | |
| x = F.pad(x, self.padding) | |
| return self.conv(x) | |
| # attention, feedforward, and conv module | |
| class Scale(nn.Module): | |
| def __init__(self, scale, fn): | |
| super().__init__() | |
| self.fn = fn | |
| self.scale = scale | |
| def forward(self, x, **kwargs): | |
| return self.fn(x, **kwargs) * self.scale | |
| class PreNorm(nn.Module): | |
| def __init__(self, dim, fn): | |
| super().__init__() | |
| self.fn = fn | |
| self.norm = nn.LayerNorm(dim) | |
| def forward(self, x, **kwargs): | |
| x = self.norm(x) | |
| return self.fn(x, **kwargs) | |
| class Attention(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| heads = 8, | |
| dim_head = 64, | |
| dropout = 0., | |
| max_pos_emb = 512 | |
| ): | |
| super().__init__() | |
| inner_dim = dim_head * heads | |
| self.heads= heads | |
| self.scale = dim_head ** -0.5 | |
| self.to_q = nn.Linear(dim, inner_dim, bias = False) | |
| self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) | |
| self.to_out = nn.Linear(inner_dim, dim) | |
| self.max_pos_emb = max_pos_emb | |
| self.rel_pos_emb = nn.Embedding(2 * max_pos_emb + 1, dim_head) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x, context = None, mask = None, context_mask = None): | |
| n, device, h, max_pos_emb, has_context = x.shape[-2], x.device, self.heads, self.max_pos_emb, exists(context) | |
| context = default(context, x) | |
| q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) | |
| q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) | |
| dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale | |
| # shaw's relative positional embedding | |
| seq = torch.arange(n, device = device) | |
| dist = rearrange(seq, 'i -> i ()') - rearrange(seq, 'j -> () j') | |
| dist = dist.clamp(-max_pos_emb, max_pos_emb) + max_pos_emb | |
| rel_pos_emb = self.rel_pos_emb(dist).to(q) | |
| pos_attn = einsum('b h n d, n r d -> b h n r', q, rel_pos_emb) * self.scale | |
| dots = dots + pos_attn | |
| if exists(mask) or exists(context_mask): | |
| mask = default(mask, lambda: torch.ones(*x.shape[:2], device = device)) | |
| context_mask = default(context_mask, mask) if not has_context else default(context_mask, lambda: torch.ones(*context.shape[:2], device = device)) | |
| mask_value = -torch.finfo(dots.dtype).max | |
| mask = rearrange(mask, 'b i -> b () i ()') * rearrange(context_mask, 'b j -> b () () j') | |
| dots.masked_fill_(~mask, mask_value) | |
| attn = dots.softmax(dim = -1) | |
| out = einsum('b h i j, b h j d -> b h i d', attn, v) | |
| out = rearrange(out, 'b h n d -> b n (h d)') | |
| out = self.to_out(out) | |
| return self.dropout(out) | |
| class FeedForward(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| mult = 4, | |
| dropout = 0. | |
| ): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(dim, dim * mult), | |
| Swish(), | |
| nn.Dropout(dropout), | |
| nn.Linear(dim * mult, dim), | |
| nn.Dropout(dropout) | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| class ConformerConvModule(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| causal = False, | |
| expansion_factor = 2, | |
| kernel_size = 31, | |
| dropout = 0.): | |
| super().__init__() | |
| inner_dim = dim * expansion_factor | |
| padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0) | |
| self.net = nn.Sequential( | |
| nn.LayerNorm(dim), | |
| Rearrange('b n c -> b c n'), | |
| nn.Conv1d(dim, inner_dim * 2, 1), | |
| GLU(dim=1), | |
| DepthWiseConv1d(inner_dim, inner_dim, kernel_size = kernel_size, padding = padding), | |
| nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(), | |
| Swish(), | |
| nn.Conv1d(inner_dim, dim, 1), | |
| Rearrange('b c n -> b n c'), | |
| nn.Dropout(dropout) | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| # Conformer Block | |
| class ConformerBlock(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| dim, | |
| dim_head = 64, | |
| heads = 8, | |
| ff_mult = 4, | |
| conv_expansion_factor = 2, | |
| conv_kernel_size = 31, | |
| attn_dropout = 0., | |
| ff_dropout = 0., | |
| conv_dropout = 0. | |
| ): | |
| super().__init__() | |
| self.ff1 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) | |
| self.attn = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout) | |
| self.conv = ConformerConvModule(dim = dim, causal = False, expansion_factor = conv_expansion_factor, kernel_size = conv_kernel_size, dropout = conv_dropout) | |
| self.ff2 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) | |
| self.attn = PreNorm(dim, self.attn) | |
| self.ff1 = Scale(0.5, PreNorm(dim, self.ff1)) | |
| self.ff2 = Scale(0.5, PreNorm(dim, self.ff2)) | |
| self.post_norm = nn.LayerNorm(dim) | |
| def forward(self, x, mask = None): | |
| x = self.ff1(x) + x | |
| x = self.attn(x, mask = mask) + x | |
| x = self.conv(x) + x | |
| x = self.ff2(x) + x | |
| x = self.post_norm(x) | |
| return x | |