import torch from torch import nn from torch.nn.utils.parametrizations import weight_norm import math from src.models.snake import Snake from src.models.utils import unfold import typing as tp def WNConv1d(*args, **kwargs): return weight_norm(nn.Conv1d(*args, **kwargs)) def WNConvTranspose1d(*args, **kwargs): return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) class BLSTM(nn.Module): """ BiLSTM with same hidden units as input dim. If `max_steps` is not None, input will be splitting in overlapping chunks and the LSTM applied separately on each chunk. """ def __init__(self, dim, layers=1, max_steps=None, skip=False): super().__init__() assert max_steps is None or max_steps % 4 == 0 self.max_steps = max_steps self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim) self.linear = nn.Linear(2 * dim, dim) self.skip = skip def forward(self, x): B, C, T = x.shape y = x framed = False if self.max_steps is not None and T > self.max_steps: width = self.max_steps stride = width // 2 frames = unfold(x, width, stride) nframes = frames.shape[2] framed = True x = frames.permute(0, 2, 1, 3).reshape(-1, C, width) x = x.permute(2, 0, 1) x = self.lstm(x)[0] x = self.linear(x) x = x.permute(1, 2, 0) if framed: out = [] frames = x.reshape(B, -1, C, width) limit = stride // 2 for k in range(nframes): if k == 0: out.append(frames[:, k, :, :-limit]) elif k == nframes - 1: out.append(frames[:, k, :, limit:]) else: out.append(frames[:, k, :, limit:-limit]) out = torch.cat(out, -1) out = out[..., :T] x = out if self.skip: x = x + y return x class LocalState(nn.Module): """Local state allows to have attention based only on data (no positional embedding), but while setting a constraint on the time window (e.g. decaying penalty term). Also a failed experiments with trying to provide some frequency based attention. """ def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4): super().__init__() assert channels % heads == 0, (channels, heads) self.heads = heads self.nfreqs = nfreqs self.ndecay = ndecay self.content = nn.Conv1d(channels, channels, 1) self.query = nn.Conv1d(channels, channels, 1) self.key = nn.Conv1d(channels, channels, 1) if nfreqs: self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1) if ndecay: self.query_decay = nn.Conv1d(channels, heads * ndecay, 1) # Initialize decay close to zero (there is a sigmoid), for maximum initial window. self.query_decay.weight.data *= 0.01 assert self.query_decay.bias is not None # stupid type checker self.query_decay.bias.data[:] = -2 # self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1) self.proj = nn.Conv1d(channels, channels, 1) def forward(self, x): B, C, T = x.shape heads = self.heads indexes = torch.arange(T, device=x.device, dtype=x.dtype) # left index are keys, right index are queries delta = indexes[:, None] - indexes[None, :] queries = self.query(x).view(B, heads, -1, T) keys = self.key(x).view(B, heads, -1, T) # t are keys, s are queries dots = torch.einsum("bhct,bhcs->bhts", keys, queries) dots /= keys.shape[2] ** 0.5 if self.nfreqs: periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype) freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1)) freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs ** 0.5 tmp = torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q) dots += tmp if self.ndecay: decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype) decay_q = self.query_decay(x).view(B, heads, -1, T) decay_q = torch.sigmoid(decay_q) / 2 decay_kernel = - decays.view(-1, 1, 1) * delta.abs() / self.ndecay ** 0.5 dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q) # Kill self reference. dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100) weights = torch.softmax(dots, dim=2) content = self.content(x).view(B, heads, -1, T) result = torch.einsum("bhts,bhct->bhcs", weights, content) result = result.reshape(B, -1, T) return x + self.proj(result) class LayerScale(nn.Module): """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf). This rescales diagonaly residual outputs close to 0 initially, then learnt. """ def __init__(self, channels: int, init: float = 0): super().__init__() self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True)) self.scale.data[:] = init def forward(self, x): return self.scale[:, None] * x class DConv(nn.Module): """ New residual branches in each encoder layer. This alternates dilated convolutions, potentially with LSTMs and attention. Also before entering each residual branch, dimension is projected on a smaller subspace, e.g. of dim `channels // compress`. """ def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4, norm=True, time_attn=False, heads=4, ndecay=4, lstm=False, act_func='gelu', freq_dim=None, reshape=False, kernel=3, dilate=True): """ Args: channels: input/output channels for residual branch. compress: amount of channel compression inside the branch. depth: number of layers in the residual branch. Each layer has its own projection, and potentially LSTM and attention. init: initial scale for LayerNorm. norm: use GroupNorm. time_attn: use LocalAttention. heads: number of heads for the LocalAttention. ndecay: number of decay controls in the LocalAttention. lstm: use LSTM. gelu: Use GELU activation. kernel: kernel size for the (dilated) convolutions. dilate: if true, use dilation, increasing with the depth. """ super().__init__() assert kernel % 2 == 1 self.channels = channels self.compress = compress self.depth = abs(depth) dilate = depth > 0 self.time_attn = time_attn self.lstm = lstm self.reshape = reshape self.act_func = act_func self.freq_dim = freq_dim norm_fn: tp.Callable[[int], nn.Module] norm_fn = lambda d: nn.Identity() # noqa if norm: norm_fn = lambda d: nn.GroupNorm(1, d) # noqa self.hidden = int(channels / compress) act: tp.Type[nn.Module] if act_func == 'gelu': act = nn.GELU elif act_func == 'snake': act = Snake else: act = nn.ReLU self.layers = nn.ModuleList([]) for d in range(self.depth): layer = nn.ModuleDict() dilation = 2 ** d if dilate else 1 padding = dilation * (kernel // 2) conv1 = nn.ModuleList([nn.Conv1d(channels, self.hidden, kernel, dilation=dilation, padding=padding), norm_fn(self.hidden)]) act_layer = act(freq_dim) if act_func == 'snake' else act() conv2 = nn.ModuleList([nn.Conv1d(self.hidden, 2 * channels, 1), norm_fn(2 * channels), nn.GLU(1), LayerScale(channels, init)]) layer.update({'conv1': nn.Sequential(*conv1), 'act': act_layer, 'conv2': nn.Sequential(*conv2)}) if lstm: layer.update({'lstm': BLSTM(self.hidden, layers=2, max_steps=200, skip=True)}) if time_attn: layer.update({'time_attn': LocalState(self.hidden, heads=heads, ndecay=ndecay)}) self.layers.append(layer) def forward(self, x): if self.reshape: B, C, Fr, T = x.shape x = x.permute(0, 2, 1, 3).reshape(-1, C, T) for layer in self.layers: skip = x x = layer['conv1'](x) if self.act_func == 'snake' and self.reshape: x = x.view(B, Fr, self.hidden, T).permute(0, 2, 3, 1) x = layer['act'](x) if self.act_func == 'snake' and self.reshape: x = x.permute(0, 3, 1, 2).reshape(-1, self.hidden, T) if self.lstm: x = layer['lstm'](x) if self.time_attn: x = layer['time_attn'](x) x = layer['conv2'](x) x = skip + x if self.reshape: x = x.view(B, Fr, C, T).permute(0, 2, 1, 3) return x class ScaledEmbedding(nn.Module): """ Boost learning rate for embeddings (with `scale`). Also, can make embeddings continuous with `smooth`. """ def __init__(self, num_embeddings: int, embedding_dim: int, scale: float = 10., smooth=False): super().__init__() self.embedding = nn.Embedding(num_embeddings, embedding_dim) if smooth: weight = torch.cumsum(self.embedding.weight.data, dim=0) # when summing gaussian, overscale raises as sqrt(n), so we nornalize by that. weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None] self.embedding.weight.data[:] = weight self.embedding.weight.data /= scale self.scale = scale @property def weight(self): return self.embedding.weight * self.scale def forward(self, x): out = self.embedding(x) * self.scale return out class FTB(nn.Module): def __init__(self, input_dim=257, in_channel=9, r_channel=5): super(FTB, self).__init__() self.input_dim = input_dim self.in_channel = in_channel self.conv1 = nn.Sequential( nn.Conv2d(in_channel, r_channel, kernel_size=[1, 1]), nn.BatchNorm2d(r_channel), nn.ReLU() ) self.conv1d = nn.Sequential( nn.Conv1d(r_channel * input_dim, in_channel, kernel_size=9, padding=4), nn.BatchNorm1d(in_channel), nn.ReLU() ) self.freq_fc = nn.Linear(input_dim, input_dim, bias=False) self.conv2 = nn.Sequential( nn.Conv2d(in_channel * 2, in_channel, kernel_size=[1, 1]), nn.BatchNorm2d(in_channel), nn.ReLU() ) def forward(self, inputs): ''' inputs should be [Batch, Ca, Dim, Time] ''' # T-F attention conv1_out = self.conv1(inputs) B, C, D, T = conv1_out.size() reshape1_out = torch.reshape(conv1_out, [B, C * D, T]) conv1d_out = self.conv1d(reshape1_out) conv1d_out = torch.reshape(conv1d_out, [B, self.in_channel, 1, T]) # now is also [B,C,D,T] att_out = conv1d_out * inputs # tranpose to [B,C,T,D] att_out = torch.transpose(att_out, 2, 3) freqfc_out = self.freq_fc(att_out) att_out = torch.transpose(freqfc_out, 2, 3) cat_out = torch.cat([att_out, inputs], 1) outputs = self.conv2(cat_out) return outputs