""" This code is based on Facebook's HDemucs code: https://github.com/facebookresearch/demucs """ import numpy as np import torch from torch import nn from torch.nn import functional as F from src.models.utils import capture_init from src.models.spec import spectro, ispectro from src.models.modules import DConv, ScaledEmbedding, FTB import logging logger = logging.getLogger(__name__) def rescale_conv(conv, reference): std = conv.weight.std().detach() scale = (std / reference) ** 0.5 conv.weight.data /= scale if conv.bias is not None: conv.bias.data /= scale def rescale_module(module, reference): for sub in module.modules(): if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)): rescale_conv(sub, reference) class HEncLayer(nn.Module): def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False, freq=True, dconv=True, is_first=False, freq_attn=False, freq_dim=None, norm=True, context=0, dconv_kw={}, pad=True, rewrite=True): """Encoder layer. This used both by the time and the frequency branch. Args: chin: number of input channels. chout: number of output channels. norm_groups: number of groups for group norm. empty: used to make a layer with just the first conv. this is used before merging the time and freq. branches. freq: this is acting on frequencies. dconv: insert DConv residual branches. norm: use GroupNorm. context: context size for the 1x1 conv. dconv_kw: list of kwargs for the DConv class. pad: pad the input. Padding is done so that the output size is always the input size / stride. rewrite: add 1x1 conv at the end of the layer. """ super().__init__() norm_fn = lambda d: nn.Identity() # noqa if norm: norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa if stride == 1 and kernel_size % 2 == 0 and kernel_size > 1: kernel_size -= 1 if pad: pad = (kernel_size - stride) // 2 else: pad = 0 klass = nn.Conv2d self.chin = chin self.chout = chout self.freq = freq self.kernel_size = kernel_size self.stride = stride self.empty = empty self.freq_attn = freq_attn self.freq_dim = freq_dim self.norm = norm self.pad = pad if freq: kernel_size = [kernel_size, 1] stride = [stride, 1] if pad != 0: pad = [pad, 0] # klass = nn.Conv2d else: kernel_size = [1, kernel_size] stride = [1, stride] if pad != 0: pad = [0, pad] self.is_first = is_first if is_first: self.pre_conv = nn.Conv2d(chin, chout, [1, 1]) chin = chout if self.freq_attn: self.freq_attn_block = FTB(input_dim=freq_dim, in_channel=chin) self.conv = klass(chin, chout, kernel_size, stride, pad) if self.empty: return self.norm1 = norm_fn(chout) self.rewrite = None if rewrite: self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context) self.norm2 = norm_fn(2 * chout) self.dconv = None if dconv: self.dconv = DConv(chout, **dconv_kw) def forward(self, x, inject=None): """ `inject` is used to inject the result from the time branch into the frequency branch, when both have the same stride. """ if not self.freq: le = x.shape[-1] if not le % self.stride == 0: x = F.pad(x, (0, self.stride - (le % self.stride))) if self.is_first: x = self.pre_conv(x) if self.freq_attn: x = self.freq_attn_block(x) x = self.conv(x) x = F.gelu(self.norm1(x)) if self.dconv: x = self.dconv(x) if self.rewrite: x = self.norm2(self.rewrite(x)) x = F.glu(x, dim=1) return x class HDecLayer(nn.Module): def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False, freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True, context_freq=True, rewrite=True): """ Same as HEncLayer but for decoder. See `HEncLayer` for documentation. """ super().__init__() norm_fn = lambda d: nn.Identity() # noqa if norm: norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa if stride == 1 and kernel_size % 2 == 0 and kernel_size > 1: kernel_size -= 1 if pad: pad = (kernel_size - stride) // 2 else: pad = 0 self.pad = pad self.last = last self.freq = freq self.chin = chin self.empty = empty self.stride = stride self.kernel_size = kernel_size self.norm = norm self.context_freq = context_freq klass = nn.Conv2d klass_tr = nn.ConvTranspose2d if freq: kernel_size = [kernel_size, 1] stride = [stride, 1] else: kernel_size = [1, kernel_size] stride = [1, stride] self.conv_tr = klass_tr(chin, chout, kernel_size, stride) self.norm2 = norm_fn(chout) if self.empty: return self.rewrite = None if rewrite: if context_freq: self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context) else: self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1, [0, context]) self.norm1 = norm_fn(2 * chin) self.dconv = None if dconv: self.dconv = DConv(chin, **dconv_kw) def forward(self, x, skip, length): if self.freq and x.dim() == 3: B, C, T = x.shape x = x.view(B, self.chin, -1, T) if not self.empty: x = torch.cat([x, skip], dim=1) if self.rewrite: y = F.glu(self.norm1(self.rewrite(x)), dim=1) else: y = x if self.dconv: y = self.dconv(y) else: y = x assert skip is None z = self.norm2(self.conv_tr(y)) if self.freq: if self.pad: z = z[..., self.pad:-self.pad, :] else: z = z[..., self.pad:self.pad + length] assert z.shape[-1] == length, (z.shape[-1], length) if not self.last: z = F.gelu(z) return z class Aero(nn.Module): """ Deep model for Audio Super Resolution. """ @capture_init def __init__(self, # Channels in_channels=1, out_channels=1, audio_channels=2, channels=48, growth=2, # STFT nfft=512, hop_length=64, end_iters=0, cac=True, # Main structure rewrite=True, hybrid=False, hybrid_old=False, # Frequency branch freq_emb=0.2, emb_scale=10, emb_smooth=True, # Convolutions kernel_size=8, strides=[4, 4, 2, 2], context=1, context_enc=0, freq_ends=4, enc_freq_attn=4, # Normalization norm_starts=2, norm_groups=4, # DConv residual branch dconv_mode=1, dconv_depth=2, dconv_comp=4, dconv_time_attn=2, dconv_lstm=2, dconv_init=1e-3, # Weight init rescale=0.1, # Metadata lr_sr=4000, hr_sr=16000, spec_upsample=True, act_func='snake', debug=False): """ Args: sources (list[str]): list of source names. audio_channels (int): input/output audio channels. channels (int): initial number of hidden channels. growth: increase the number of hidden channels by this factor at each layer. nfft: number of fft bins. Note that changing this require careful computation of various shape parameters and will not work out of the box for hybrid models. end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`. cac: uses complex as channels, i.e. complex numbers are 2 channels each in input and output. no further processing is done before ISTFT. depth (int): number of layers in the encoder and in the decoder. rewrite (bool): add 1x1 convolution to each layer. hybrid (bool): make a hybrid time/frequency domain, otherwise frequency only. hybrid_old: some models trained for MDX had a padding bug. This replicates this bug to avoid retraining them. freq_emb: add frequency embedding after the first frequency layer if > 0, the actual value controls the weight of the embedding. emb_scale: equivalent to scaling the embedding learning rate emb_smooth: initialize the embedding with a smooth one (with respect to frequencies). kernel_size: kernel_size for encoder and decoder layers. stride: stride for encoder and decoder layers. context: context for 1x1 conv in the decoder. context_enc: context for 1x1 conv in the encoder. norm_starts: layer at which group norm starts being used. decoder layers are numbered in reverse order. norm_groups: number of groups for group norm. dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both. dconv_depth: depth of residual DConv branch. dconv_comp: compression of DConv branch. dconv_freq_attn: adds freq attention layers in DConv branch starting at this layer. dconv_time_attn: adds time attention layers in DConv branch starting at this layer. dconv_lstm: adds a LSTM layer in DConv branch starting at this layer. dconv_init: initial scale for the DConv branch LayerScale. rescale: weight recaling trick lr_sr: source low-resolution sample-rate hr_sr: target high-resolution sample-rate spec_upsample: if true, upsamples in the spectral domain, otherwise performs sinc-interpolation beforehand act_func: 'snake'/'relu' debug: if true, prints out input dimensions throughout model layers. """ super().__init__() self.cac = cac self.in_channels = in_channels self.out_channels = out_channels self.audio_channels = audio_channels self.kernel_size = kernel_size self.context = context self.strides = strides self.depth = len(strides) self.channels = channels self.lr_sr = lr_sr self.hr_sr = hr_sr self.spec_upsample = spec_upsample self.scale = hr_sr / lr_sr if self.spec_upsample else 1 self.nfft = nfft self.hop_length = int(hop_length // self.scale) # this is for the input signal self.win_length = int(self.nfft // self.scale) # this is for the input signal self.end_iters = end_iters self.freq_emb = None self.hybrid = hybrid self.hybrid_old = hybrid_old self.debug = debug self.encoder = nn.ModuleList() self.decoder = nn.ModuleList() chin_z = self.in_channels if self.cac: chin_z *= 2 chout_z = channels freqs = nfft // 2 for index in range(self.depth): freq_attn = index >= enc_freq_attn lstm = index >= dconv_lstm time_attn = index >= dconv_time_attn norm = index >= norm_starts freq = index <= freq_ends stri = strides[index] ker = kernel_size pad = True if freq and freqs < kernel_size: ker = freqs kw = { 'kernel_size': ker, 'stride': stri, 'freq': freq, 'pad': pad, 'norm': norm, 'rewrite': rewrite, 'norm_groups': norm_groups, 'dconv_kw': { 'lstm': lstm, 'time_attn': time_attn, 'depth': dconv_depth, 'compress': dconv_comp, 'init': dconv_init, 'act_func': act_func, 'reshape': True, 'freq_dim': freqs // strides[index] if freq else freqs } } kw_dec = dict(kw) enc = HEncLayer(chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, is_first=index == 0, freq_attn=freq_attn, freq_dim=freqs, **kw) self.encoder.append(enc) if index == 0: chin = self.out_channels chin_z = chin if self.cac: chin_z *= 2 dec = HDecLayer(2 * chout_z, chin_z, dconv=dconv_mode & 2, last=index == 0, context=context, **kw_dec) self.decoder.insert(0, dec) chin_z = chout_z chout_z = int(growth * chout_z) if freq: freqs //= strides[index] if index == 0 and freq_emb: self.freq_emb = ScaledEmbedding( freqs, chin_z, smooth=emb_smooth, scale=emb_scale) self.freq_emb_scale = freq_emb if rescale: rescale_module(self, reference=rescale) def _spec(self, x, scale=False): if np.mod(x.shape[-1], self.hop_length): x = F.pad(x, (0, self.hop_length - np.mod(x.shape[-1], self.hop_length))) hl = self.hop_length nfft = self.nfft win_length = self.win_length if scale: hl = int(hl * self.scale) win_length = int(win_length * self.scale) z = spectro(x, nfft, hl, win_length=win_length)[..., :-1, :] return z def _ispec(self, z): hl = int(self.hop_length * self.scale) win_length = int(self.win_length * self.scale) z = F.pad(z, (0, 0, 0, 1)) x = ispectro(z, hl, win_length=win_length) return x def _move_complex_to_channels_dim(self, z): B, C, Fr, T = z.shape m = torch.view_as_real(z).permute(0, 1, 4, 2, 3) m = m.reshape(B, C * 2, Fr, T) return m def _convert_to_complex(self, x): """ :param x: signal of shape [Batch, Channels, 2, Freq, TimeFrames] :return: complex signal of shape [Batch, Channels, Freq, TimeFrames] """ out = x.permute(0, 1, 3, 4, 2) out = torch.view_as_complex(out.contiguous()) return out def forward(self, mix, return_spec=False, return_lr_spec=False): x = mix length = x.shape[-1] if self.debug: logger.info(f'hdemucs in shape: {x.shape}') z = self._spec(x) x = self._move_complex_to_channels_dim(z) if self.debug: logger.info(f'x spec shape: {x.shape}') B, C, Fq, T = x.shape # unlike previous Demucs, we always normalize because it is easier. mean = x.mean(dim=(1, 2, 3), keepdim=True) std = x.std(dim=(1, 2, 3), keepdim=True) x = (x - mean) / (1e-5 + std) # okay, this is a giant mess I know... saved = [] # skip connections, freq. lengths = [] # saved lengths to properly remove padding, freq branch. for idx, encode in enumerate(self.encoder): lengths.append(x.shape[-1]) inject = None x = encode(x, inject) if self.debug: logger.info(f'encoder {idx} out shape: {x.shape}') if idx == 0 and self.freq_emb is not None: # add frequency embedding to allow for non equivariant convolutions # over the frequency axis. frs = torch.arange(x.shape[-2], device=x.device) emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x) x = x + self.freq_emb_scale * emb saved.append(x) x = torch.zeros_like(x) # initialize everything to zero (signal will go through u-net skips). for idx, decode in enumerate(self.decoder): skip = saved.pop(-1) x = decode(x, skip, lengths.pop(-1)) if self.debug: logger.info(f'decoder {idx} out shape: {x.shape}') # Let's make sure we used all stored skip connections. assert len(saved) == 0 x = x.view(B, self.out_channels, -1, Fq, T) x = x * std[:, None] + mean[:, None] if self.debug: logger.info(f'post view shape: {x.shape}') x_spec_complex = self._convert_to_complex(x) if self.debug: logger.info(f'x_spec_complex shape: {x_spec_complex.shape}') x = self._ispec(x_spec_complex) if self.debug: logger.info(f'hdemucs out shape: {x.shape}') x = x[..., :int(length * self.scale)] if self.debug: logger.info(f'hdemucs out - trimmed shape: {x.shape}') if return_spec: if return_lr_spec: return x, x_spec_complex, z else: return x, x_spec_complex return x