sereich's picture
Initial commit of Radio Upscaling UI (minus models)
f113387
raw
history blame
5.95 kB
import torch.nn as nn
import math
from src.models.utils import capture_init, weights_init
from src.models.modules import WNConv1d, WNConvTranspose1d
from torchaudio.functional import resample
from torch.nn import functional as F
class ResnetBlock(nn.Module):
def __init__(self, dim, dilation=1):
super().__init__()
self.block = nn.Sequential(
nn.LeakyReLU(0.2),
nn.ReflectionPad1d(dilation),
WNConv1d(dim, dim, kernel_size=3, dilation=dilation),
nn.LeakyReLU(0.2),
WNConv1d(dim, dim, kernel_size=1),
)
self.shortcut = WNConv1d(dim, dim, kernel_size=1)
def forward(self, x):
return self.shortcut(x) + self.block(x)
class Seanet(nn.Module):
@capture_init
def __init__(self,
latent_space_size=128,
ngf=32, n_residual_layers=3,
resample=1,
normalize=True,
floor=1e-3,
ratios=[8, 8, 2, 2],
in_channels=1,
out_channels=1,
lr_sr=16000,
hr_sr=16000,
upsample=True):
super().__init__()
self.resample = resample
self.normalize = normalize
self.floor = floor
self.lr_sr = lr_sr
self.hr_sr = hr_sr
self.scale_factor = int(self.hr_sr / self.lr_sr)
self.upsample = upsample
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
self.ratios = ratios
mult = int(2 ** len(ratios))
decoder_wrapper_conv_layer = [
nn.LeakyReLU(0.2),
nn.ReflectionPad1d(3),
WNConv1d(latent_space_size, mult * ngf, kernel_size=7, padding=0),
]
encoder_wrapper_conv_layer = [
nn.LeakyReLU(0.2),
nn.ReflectionPad1d(3),
WNConv1d(mult * ngf, latent_space_size, kernel_size=7, padding=0)
]
self.encoder.insert(0, nn.Sequential(*encoder_wrapper_conv_layer))
self.decoder.append(nn.Sequential(*decoder_wrapper_conv_layer))
for i, r in enumerate(ratios):
encoder_block = [
nn.LeakyReLU(0.2),
WNConv1d(mult * ngf // 2,
mult * ngf,
kernel_size=r * 2,
stride=r,
padding=r // 2 + r % 2,
),
]
decoder_block = [
nn.LeakyReLU(0.2),
WNConvTranspose1d(
mult * ngf,
mult * ngf // 2,
kernel_size=r * 2,
stride=r,
padding=r // 2 + r % 2,
output_padding=r % 2,
),
]
for j in range(n_residual_layers - 1, -1, -1):
encoder_block = [ResnetBlock(mult * ngf // 2, dilation=3 ** j)] + encoder_block
for j in range(n_residual_layers):
decoder_block += [ResnetBlock(mult * ngf // 2, dilation=3 ** j)]
mult //= 2
self.encoder.insert(0, nn.Sequential(*encoder_block))
self.decoder.append(nn.Sequential(*decoder_block))
encoder_wrapper_conv_layer = [
nn.ReflectionPad1d(3),
WNConv1d(in_channels, ngf, kernel_size=7, padding=0),
nn.Tanh(),
]
self.encoder.insert(0, nn.Sequential(*encoder_wrapper_conv_layer))
decoder_wrapper_conv_layer = [
nn.LeakyReLU(0.2),
nn.ReflectionPad1d(3),
WNConv1d(ngf, out_channels, kernel_size=7, padding=0),
nn.Tanh(),
]
self.decoder.append(nn.Sequential(*decoder_wrapper_conv_layer))
self.apply(weights_init)
def estimate_output_length(self, length):
"""
Return the nearest valid length to use with the model so that
there is no time steps left over in a convolutions, e.g. for all
layers, size of the input - kernel_size % stride = 0.
If the mixture has a valid length, the estimated sources
will have exactly the same length.
"""
depth = len(self.ratios)
for idx in range(depth - 1, -1, -1):
stride = self.ratios[idx]
kernel_size = 2 * stride
padding = stride // 2 + stride % 2
length = math.ceil((length - kernel_size + 2 * padding) / stride) + 1
length = max(length, 1)
for idx in range(depth):
stride = self.ratios[idx]
kernel_size = 2 * stride
padding = stride // 2 + stride % 2
output_padding = stride % 2
length = (length - 1) * stride + kernel_size - 2 * padding + output_padding
return int(length)
def pad_to_valid_length(self, signal):
valid_length = self.estimate_output_length(signal.shape[-1])
padding_len = valid_length - signal.shape[-1]
signal = F.pad(signal, (0, padding_len))
return signal, padding_len
def forward(self, signal):
target_len = signal.shape[-1]
if self.upsample:
target_len *= self.scale_factor
if self.normalize:
mono = signal.mean(dim=1, keepdim=True)
std = mono.std(dim=-1, keepdim=True)
signal = signal / (self.floor + std)
else:
std = 1
x = signal
if self.upsample:
x = resample(x, self.lr_sr, self.hr_sr)
x, padding_len = self.pad_to_valid_length(x)
skips = []
for i, encode in enumerate(self.encoder):
skips.append(x)
x = encode(x)
for j, decode in enumerate(self.decoder):
x = decode(x)
skip = skips.pop(-1)
x = x + skip
if target_len < x.shape[-1]:
x = x[..., :target_len]
return std * x