import torch.nn as nn import math from src.models.utils import capture_init, weights_init from src.models.modules import WNConv1d, WNConvTranspose1d from torchaudio.functional import resample from torch.nn import functional as F class ResnetBlock(nn.Module): def __init__(self, dim, dilation=1): super().__init__() self.block = nn.Sequential( nn.LeakyReLU(0.2), nn.ReflectionPad1d(dilation), WNConv1d(dim, dim, kernel_size=3, dilation=dilation), nn.LeakyReLU(0.2), WNConv1d(dim, dim, kernel_size=1), ) self.shortcut = WNConv1d(dim, dim, kernel_size=1) def forward(self, x): return self.shortcut(x) + self.block(x) class Seanet(nn.Module): @capture_init def __init__(self, latent_space_size=128, ngf=32, n_residual_layers=3, resample=1, normalize=True, floor=1e-3, ratios=[8, 8, 2, 2], in_channels=1, out_channels=1, lr_sr=16000, hr_sr=16000, upsample=True): super().__init__() self.resample = resample self.normalize = normalize self.floor = floor self.lr_sr = lr_sr self.hr_sr = hr_sr self.scale_factor = int(self.hr_sr / self.lr_sr) self.upsample = upsample self.encoder = nn.ModuleList() self.decoder = nn.ModuleList() self.ratios = ratios mult = int(2 ** len(ratios)) decoder_wrapper_conv_layer = [ nn.LeakyReLU(0.2), nn.ReflectionPad1d(3), WNConv1d(latent_space_size, mult * ngf, kernel_size=7, padding=0), ] encoder_wrapper_conv_layer = [ nn.LeakyReLU(0.2), nn.ReflectionPad1d(3), WNConv1d(mult * ngf, latent_space_size, kernel_size=7, padding=0) ] self.encoder.insert(0, nn.Sequential(*encoder_wrapper_conv_layer)) self.decoder.append(nn.Sequential(*decoder_wrapper_conv_layer)) for i, r in enumerate(ratios): encoder_block = [ nn.LeakyReLU(0.2), WNConv1d(mult * ngf // 2, mult * ngf, kernel_size=r * 2, stride=r, padding=r // 2 + r % 2, ), ] decoder_block = [ nn.LeakyReLU(0.2), WNConvTranspose1d( mult * ngf, mult * ngf // 2, kernel_size=r * 2, stride=r, padding=r // 2 + r % 2, output_padding=r % 2, ), ] for j in range(n_residual_layers - 1, -1, -1): encoder_block = [ResnetBlock(mult * ngf // 2, dilation=3 ** j)] + encoder_block for j in range(n_residual_layers): decoder_block += [ResnetBlock(mult * ngf // 2, dilation=3 ** j)] mult //= 2 self.encoder.insert(0, nn.Sequential(*encoder_block)) self.decoder.append(nn.Sequential(*decoder_block)) encoder_wrapper_conv_layer = [ nn.ReflectionPad1d(3), WNConv1d(in_channels, ngf, kernel_size=7, padding=0), nn.Tanh(), ] self.encoder.insert(0, nn.Sequential(*encoder_wrapper_conv_layer)) decoder_wrapper_conv_layer = [ nn.LeakyReLU(0.2), nn.ReflectionPad1d(3), WNConv1d(ngf, out_channels, kernel_size=7, padding=0), nn.Tanh(), ] self.decoder.append(nn.Sequential(*decoder_wrapper_conv_layer)) self.apply(weights_init) def estimate_output_length(self, length): """ Return the nearest valid length to use with the model so that there is no time steps left over in a convolutions, e.g. for all layers, size of the input - kernel_size % stride = 0. If the mixture has a valid length, the estimated sources will have exactly the same length. """ depth = len(self.ratios) for idx in range(depth - 1, -1, -1): stride = self.ratios[idx] kernel_size = 2 * stride padding = stride // 2 + stride % 2 length = math.ceil((length - kernel_size + 2 * padding) / stride) + 1 length = max(length, 1) for idx in range(depth): stride = self.ratios[idx] kernel_size = 2 * stride padding = stride // 2 + stride % 2 output_padding = stride % 2 length = (length - 1) * stride + kernel_size - 2 * padding + output_padding return int(length) def pad_to_valid_length(self, signal): valid_length = self.estimate_output_length(signal.shape[-1]) padding_len = valid_length - signal.shape[-1] signal = F.pad(signal, (0, padding_len)) return signal, padding_len def forward(self, signal): target_len = signal.shape[-1] if self.upsample: target_len *= self.scale_factor if self.normalize: mono = signal.mean(dim=1, keepdim=True) std = mono.std(dim=-1, keepdim=True) signal = signal / (self.floor + std) else: std = 1 x = signal if self.upsample: x = resample(x, self.lr_sr, self.hr_sr) x, padding_len = self.pad_to_valid_length(x) skips = [] for i, encode in enumerate(self.encoder): skips.append(x) x = encode(x) for j, decode in enumerate(self.decoder): x = decode(x) skip = skips.pop(-1) x = x + skip if target_len < x.shape[-1]: x = x[..., :target_len] return std * x