Initial commit of Radio Upscaling UI (minus models)
Browse files- app.py +54 -0
- conf/experiment/aero_441-441_512_256.yaml +77 -0
- conf/experiment/seanet_441-441.yaml +46 -0
- html/directions.html +6 -0
- html/information.html +16 -0
- processAudio.py +87 -0
- requirements.txt +4 -0
- src/models/aero.py +524 -0
- src/models/modelFactory.py +15 -0
- src/models/modules.py +327 -0
- src/models/seanet.py +177 -0
- src/models/snake.py +66 -0
- src/models/spec.py +39 -0
- src/models/utils.py +50 -0
- src/utils.py +52 -0
app.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from torchaudio.functional import resample
|
| 5 |
+
|
| 6 |
+
from processAudio import upscaleAudio
|
| 7 |
+
|
| 8 |
+
class Object(object):
|
| 9 |
+
pass
|
| 10 |
+
|
| 11 |
+
def processAudio(model: gr.Dropdown, audioData: gr.Audio):
|
| 12 |
+
if len(audioData[1].shape) == 1: #Convert mono to stereo
|
| 13 |
+
lrAudio = torch.tensor(np.array([
|
| 14 |
+
audioData[1].copy().astype(np.float32)/32768,
|
| 15 |
+
audioData[1].copy().astype(np.float32)/32768
|
| 16 |
+
]))
|
| 17 |
+
elif audioData[1].shape[1] > 2:
|
| 18 |
+
raise gr.Error("Audio with more than 2 channels is not supported.")
|
| 19 |
+
else: #re-order channel data from [samples, 2] to [2, samples]
|
| 20 |
+
lrAudio = torch.tensor(audioData[1].copy().astype(np.float32)/32768).transpose(0,1)
|
| 21 |
+
if audioData[0] != 44100:
|
| 22 |
+
lrAudio = resample(lrAudio, audioData[0], 44100)
|
| 23 |
+
hrAudio=upscaleAudio(lrAudio, "models/" + model)
|
| 24 |
+
hrAudio=hrAudio / max(hrAudio.abs().max().item(), 1)
|
| 25 |
+
outAudio=(hrAudio*32767).numpy().astype(np.int16).transpose(1,0)
|
| 26 |
+
return tuple([44100, outAudio])
|
| 27 |
+
|
| 28 |
+
with gr.Blocks(theme=gr.themes.Default().set(body_background_fill="#CCEEFF")) as layout:
|
| 29 |
+
with gr.Row():
|
| 30 |
+
gr.Markdown("<h2>Broadcast Audio Upscaler</h2>")
|
| 31 |
+
with gr.Row():
|
| 32 |
+
with open("html/directions.html", "r") as directionsHtml:
|
| 33 |
+
gr.Markdown(directionsHtml.read())
|
| 34 |
+
with gr.Row():
|
| 35 |
+
modelSelect = gr.Dropdown(
|
| 36 |
+
[
|
| 37 |
+
["FM Radio Super Resolution","FM_Radio_SR.th"],
|
| 38 |
+
["AM Radio Super Resolution (Beta)","AM_Radio_SR.th"]
|
| 39 |
+
],
|
| 40 |
+
label="Select Model:",
|
| 41 |
+
value="FM_Radio_SR.th",
|
| 42 |
+
)
|
| 43 |
+
with gr.Row():
|
| 44 |
+
with gr.Column():
|
| 45 |
+
audioFileSelect = gr.Audio(label="Audio File (Mono or Stereo):",sources="upload")
|
| 46 |
+
with gr.Column():
|
| 47 |
+
audioOutput = gr.Audio(show_download_button=True, label="Restored Audio:", sources=[])
|
| 48 |
+
with gr.Row():
|
| 49 |
+
submit = gr.Button("Process Audio", variant="primary").click(fn=processAudio, inputs=[modelSelect, audioFileSelect], outputs=audioOutput)
|
| 50 |
+
with gr.Row():
|
| 51 |
+
with gr.Accordion("More Information:", open=False):
|
| 52 |
+
with open("html/information.html", "r") as informationHtml:
|
| 53 |
+
gr.Markdown(informationHtml.read())
|
| 54 |
+
layout.launch()
|
conf/experiment/aero_441-441_512_256.yaml
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package experiment
|
| 2 |
+
name: aero-nfft=${experiment.nfft}-hl=${experiment.hop_length}
|
| 3 |
+
|
| 4 |
+
# Dataset related
|
| 5 |
+
lr_sr: 44100 # low resolution sample rate, added to support BWE. Should be included in training cfg
|
| 6 |
+
hr_sr: 44100 # high resolution sample rate. Should be included in training cfg
|
| 7 |
+
segment: 2
|
| 8 |
+
stride: 4 # in seconds, how much to stride between training examples
|
| 9 |
+
pad: true # if training sample is too short, pad it
|
| 10 |
+
upsample: false
|
| 11 |
+
batch_size: 6
|
| 12 |
+
nfft: 512
|
| 13 |
+
hop_length: 256
|
| 14 |
+
|
| 15 |
+
# models related
|
| 16 |
+
model: aero
|
| 17 |
+
aero: # see aero.py for a detailed description
|
| 18 |
+
in_channels: 2
|
| 19 |
+
out_channels: 2
|
| 20 |
+
# Channels
|
| 21 |
+
channels: 64 #Small: 48 Large: 64
|
| 22 |
+
growth: 2
|
| 23 |
+
# STFT
|
| 24 |
+
nfft: 512
|
| 25 |
+
hop_length: 256
|
| 26 |
+
end_iters: 0
|
| 27 |
+
cac: true
|
| 28 |
+
# Main structure
|
| 29 |
+
rewrite: true
|
| 30 |
+
hybrid: false
|
| 31 |
+
hybrid_old: false
|
| 32 |
+
# Frequency Branch
|
| 33 |
+
freq_emb: 0.2
|
| 34 |
+
emb_scale: 10
|
| 35 |
+
emb_smooth: true
|
| 36 |
+
# Convolutions
|
| 37 |
+
kernel_size: 8
|
| 38 |
+
strides: [ 4,4,2,2 ]
|
| 39 |
+
context: 1
|
| 40 |
+
context_enc: 0
|
| 41 |
+
freq_ends: 4
|
| 42 |
+
enc_freq_attn: 0
|
| 43 |
+
# normalization
|
| 44 |
+
norm_starts: 2
|
| 45 |
+
norm_groups: 4
|
| 46 |
+
# DConv residual branch
|
| 47 |
+
dconv_mode: 1
|
| 48 |
+
dconv_depth: 2
|
| 49 |
+
dconv_comp: 4
|
| 50 |
+
dconv_time_attn: 2
|
| 51 |
+
dconv_lstm: 2
|
| 52 |
+
dconv_init: 0.001
|
| 53 |
+
# Weight init
|
| 54 |
+
rescale: 0.1
|
| 55 |
+
lr_sr: 44100
|
| 56 |
+
hr_sr: 44100
|
| 57 |
+
spec_upsample: true
|
| 58 |
+
act_func: snake
|
| 59 |
+
debug: false
|
| 60 |
+
|
| 61 |
+
adversarial: True
|
| 62 |
+
features_loss_lambda: 100
|
| 63 |
+
only_features_loss: False
|
| 64 |
+
only_adversarial_loss: False
|
| 65 |
+
discriminator_models: [ msd, mpd ] #msd_melgan/msd_hifi/mpd/hifi
|
| 66 |
+
melgan_discriminator:
|
| 67 |
+
n_layers: 5
|
| 68 |
+
num_D: 3
|
| 69 |
+
downsampling_factor: 1
|
| 70 |
+
ndf: 16
|
| 71 |
+
channels: 2
|
| 72 |
+
msd:
|
| 73 |
+
hidden: 32 #Small: 16 Large: 32
|
| 74 |
+
channels: 2
|
| 75 |
+
mpd:
|
| 76 |
+
hidden: 32 #Small: 16 Large: 32
|
| 77 |
+
channels: 2
|
conf/experiment/seanet_441-441.yaml
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package experiment
|
| 2 |
+
name: seanet-nfft=${experiment.nfft}-hl=${experiment.hop_length}
|
| 3 |
+
|
| 4 |
+
# Dataset related
|
| 5 |
+
lr_sr: 44100 # low resolution sample rate, added to support BWE. Should be included in training cfg
|
| 6 |
+
hr_sr: 44100 # high resolution sample rate. Should be included in training cfg
|
| 7 |
+
segment: 2
|
| 8 |
+
stride: 2 # in seconds, how much to stride between training examples
|
| 9 |
+
pad: true # if training sample is too short, pad it
|
| 10 |
+
upsample: false
|
| 11 |
+
batch_size: 5
|
| 12 |
+
nfft: 1024
|
| 13 |
+
hop_length: 512
|
| 14 |
+
|
| 15 |
+
# models related
|
| 16 |
+
model: seanet
|
| 17 |
+
seanet:
|
| 18 |
+
latent_space_size: 256
|
| 19 |
+
ngf: 32
|
| 20 |
+
n_residual_layers: 4
|
| 21 |
+
resample: 1
|
| 22 |
+
normalize: False
|
| 23 |
+
floor: 0.0005
|
| 24 |
+
ratios: [ 8,8,2,2 ]
|
| 25 |
+
lr_sr: 44100
|
| 26 |
+
hr_sr: 44100
|
| 27 |
+
in_channels: 2
|
| 28 |
+
out_channels: 2
|
| 29 |
+
|
| 30 |
+
adversarial: True
|
| 31 |
+
features_loss_lambda: 100
|
| 32 |
+
discriminator_models: [ msd, mpd ] #msd_melgan/msd_hifi/mpd/hifi
|
| 33 |
+
only_features_loss: False
|
| 34 |
+
only_adversarial_loss: False
|
| 35 |
+
msd:
|
| 36 |
+
hidden: 16
|
| 37 |
+
channels: 2
|
| 38 |
+
mpd:
|
| 39 |
+
hidden: 16
|
| 40 |
+
channels: 2
|
| 41 |
+
melgan_discriminator:
|
| 42 |
+
n_layers: 5
|
| 43 |
+
num_D: 3
|
| 44 |
+
downsampling_factor: 1
|
| 45 |
+
ndf: 16
|
| 46 |
+
channels: 2
|
html/directions.html
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<h3>Directions:</h3>
|
| 2 |
+
<ol>
|
| 3 |
+
<li>Select a model from the model drop-down.</li>
|
| 4 |
+
<li>Select an audio file.</li>
|
| 5 |
+
<li>After the file loads, press "Process Audio" to process the file.</li>
|
| 6 |
+
</ol>
|
html/information.html
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<p>This tool is based on the "aero" audio upscaling project (<a href="https://github.com/slp-rl/aero">original version</a>, <a href="https://github.com/pokepress/aero">modified version used here</a>). It takes digitized radio recordings and attempts to reduce noise/interference while recreating lost frequencies.</p>
|
| 2 |
+
|
| 3 |
+
<B>The FM model should be able to:</B>
|
| 4 |
+
<ul>
|
| 5 |
+
<li>Reduce stereo hiss</li>
|
| 6 |
+
<li>Reduce certain kinds of static/interference</li>
|
| 7 |
+
<li>Reduce 19khz stereo pilot</li>
|
| 8 |
+
<li>Reduce overly warm bass frequencies (boost them 3 or so dB if you miss them)</li>
|
| 9 |
+
<li>Restore lost high & low frequencies (generally only up to 16khz since most training data is MP3-based)</li>
|
| 10 |
+
</ul>
|
| 11 |
+
|
| 12 |
+
<B>The AM model should be able to:</B>
|
| 13 |
+
<ul>
|
| 14 |
+
<li>Reduce certain kinds of static/interference</li>
|
| 15 |
+
<li>Restore lost high & low frequencies (capacity tends to diminish after 10khz)</li>
|
| 16 |
+
</ul>
|
processAudio.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import logging
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import torchaudio
|
| 11 |
+
from torchaudio.functional import resample
|
| 12 |
+
from gradio import Progress
|
| 13 |
+
|
| 14 |
+
from src.models import modelFactory
|
| 15 |
+
from src.utils import bold
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
SEGMENT_DURATION_SEC = 5
|
| 20 |
+
SEGMENT_OVERLAP_SAMPLES = 2048
|
| 21 |
+
SERIALIZE_KEY_STATE = 'state'
|
| 22 |
+
|
| 23 |
+
def _load_model(checkpoint_file="models/FM_Radio_SR.th",model_name="aero"):
|
| 24 |
+
checkpoint_file = Path(checkpoint_file)
|
| 25 |
+
model = modelFactory.get_model(model_name)['generator']
|
| 26 |
+
package = torch.load(checkpoint_file, 'cpu')
|
| 27 |
+
if 'state' in package.keys(): #raw model file
|
| 28 |
+
logger.info(bold(f'Loading model {model_name} from file.'))
|
| 29 |
+
model.load_state_dict(package[SERIALIZE_KEY_STATE])
|
| 30 |
+
|
| 31 |
+
return model
|
| 32 |
+
|
| 33 |
+
def crossfade_and_blend(out_clip, in_clip):
|
| 34 |
+
fade_out = torchaudio.transforms.Fade(0,SEGMENT_OVERLAP_SAMPLES)
|
| 35 |
+
fade_in = torchaudio.transforms.Fade(SEGMENT_OVERLAP_SAMPLES, 0)
|
| 36 |
+
return fade_out(out_clip) + fade_in(in_clip)
|
| 37 |
+
|
| 38 |
+
def upscaleAudio(lr_sig, checkpoint_file: str, sr=44100, hr_sr=44100, model_name="aero", progress=Progress()):
|
| 39 |
+
|
| 40 |
+
model = _load_model(checkpoint_file,model_name)
|
| 41 |
+
device = torch.device('cpu')
|
| 42 |
+
#model.cuda()
|
| 43 |
+
|
| 44 |
+
logger.info(f'lr wav shape: {lr_sig.shape}')
|
| 45 |
+
|
| 46 |
+
segment_duration_samples = sr * SEGMENT_DURATION_SEC
|
| 47 |
+
n_chunks = math.ceil(lr_sig.shape[-1] / segment_duration_samples)
|
| 48 |
+
logger.info(f'number of chunks: {n_chunks}')
|
| 49 |
+
|
| 50 |
+
lr_chunks = []
|
| 51 |
+
for i in range(n_chunks):
|
| 52 |
+
start = i * segment_duration_samples
|
| 53 |
+
end = min((i + 1) * segment_duration_samples, lr_sig.shape[-1])
|
| 54 |
+
lr_chunks.append(lr_sig[:, start:end])
|
| 55 |
+
|
| 56 |
+
pr_chunks = []
|
| 57 |
+
|
| 58 |
+
lr_segment_overlap_samples = int((sr/hr_sr) * SEGMENT_OVERLAP_SAMPLES)
|
| 59 |
+
|
| 60 |
+
model.eval()
|
| 61 |
+
pred_start = time.time()
|
| 62 |
+
with torch.no_grad():
|
| 63 |
+
previous_chunk = None
|
| 64 |
+
progress(0)
|
| 65 |
+
for i, lr_chunk in enumerate(lr_chunks):
|
| 66 |
+
progress(i/n_chunks)
|
| 67 |
+
pr_chunk = None
|
| 68 |
+
if previous_chunk is not None:
|
| 69 |
+
combined_chunk = torch.cat((previous_chunk[...,-lr_segment_overlap_samples:], lr_chunk), 1)
|
| 70 |
+
pr_combined_chunk = model(combined_chunk.unsqueeze(0).to(device)).squeeze(0)
|
| 71 |
+
pr_chunk = pr_combined_chunk[...,SEGMENT_OVERLAP_SAMPLES:]
|
| 72 |
+
pr_chunks[-1][...,-SEGMENT_OVERLAP_SAMPLES:] = crossfade_and_blend(pr_chunks[-1][...,-SEGMENT_OVERLAP_SAMPLES:], pr_combined_chunk.cpu()[...,:SEGMENT_OVERLAP_SAMPLES] )
|
| 73 |
+
else:
|
| 74 |
+
pr_chunk = model(lr_chunk.unsqueeze(0).to(device)).squeeze(0)
|
| 75 |
+
logger.info(f'lr chunk {i} shape: {lr_chunk.shape}')
|
| 76 |
+
logger.info(f'pr chunk {i} shape: {pr_chunk.shape}')
|
| 77 |
+
pr_chunks.append(pr_chunk.cpu())
|
| 78 |
+
previous_chunk = lr_chunk
|
| 79 |
+
|
| 80 |
+
pred_duration = time.time() - pred_start
|
| 81 |
+
logger.info(f'prediction duration: {pred_duration}')
|
| 82 |
+
|
| 83 |
+
pr = torch.concat(pr_chunks, dim=-1)
|
| 84 |
+
|
| 85 |
+
logger.info(f'pr wav shape: {pr.shape}')
|
| 86 |
+
|
| 87 |
+
return pr
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.2.1
|
| 2 |
+
torchvision=0.17.2
|
| 3 |
+
torchaudio==2.2.1
|
| 4 |
+
opencv-python
|
src/models/aero.py
ADDED
|
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This code is based on Facebook's HDemucs code: https://github.com/facebookresearch/demucs
|
| 3 |
+
"""
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
|
| 9 |
+
from src.models.utils import capture_init
|
| 10 |
+
from src.models.spec import spectro, ispectro
|
| 11 |
+
from src.models.modules import DConv, ScaledEmbedding, FTB
|
| 12 |
+
|
| 13 |
+
import logging
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def rescale_conv(conv, reference):
|
| 18 |
+
std = conv.weight.std().detach()
|
| 19 |
+
scale = (std / reference) ** 0.5
|
| 20 |
+
conv.weight.data /= scale
|
| 21 |
+
if conv.bias is not None:
|
| 22 |
+
conv.bias.data /= scale
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def rescale_module(module, reference):
|
| 26 |
+
for sub in module.modules():
|
| 27 |
+
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
|
| 28 |
+
rescale_conv(sub, reference)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class HEncLayer(nn.Module):
|
| 32 |
+
def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False,
|
| 33 |
+
freq=True, dconv=True, is_first=False, freq_attn=False, freq_dim=None, norm=True, context=0,
|
| 34 |
+
dconv_kw={}, pad=True,
|
| 35 |
+
rewrite=True):
|
| 36 |
+
"""Encoder layer. This used both by the time and the frequency branch.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
chin: number of input channels.
|
| 40 |
+
chout: number of output channels.
|
| 41 |
+
norm_groups: number of groups for group norm.
|
| 42 |
+
empty: used to make a layer with just the first conv. this is used
|
| 43 |
+
before merging the time and freq. branches.
|
| 44 |
+
freq: this is acting on frequencies.
|
| 45 |
+
dconv: insert DConv residual branches.
|
| 46 |
+
norm: use GroupNorm.
|
| 47 |
+
context: context size for the 1x1 conv.
|
| 48 |
+
dconv_kw: list of kwargs for the DConv class.
|
| 49 |
+
pad: pad the input. Padding is done so that the output size is
|
| 50 |
+
always the input size / stride.
|
| 51 |
+
rewrite: add 1x1 conv at the end of the layer.
|
| 52 |
+
"""
|
| 53 |
+
super().__init__()
|
| 54 |
+
norm_fn = lambda d: nn.Identity() # noqa
|
| 55 |
+
if norm:
|
| 56 |
+
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
|
| 57 |
+
if stride == 1 and kernel_size % 2 == 0 and kernel_size > 1:
|
| 58 |
+
kernel_size -= 1
|
| 59 |
+
if pad:
|
| 60 |
+
pad = (kernel_size - stride) // 2
|
| 61 |
+
else:
|
| 62 |
+
pad = 0
|
| 63 |
+
klass = nn.Conv2d
|
| 64 |
+
self.chin = chin
|
| 65 |
+
self.chout = chout
|
| 66 |
+
self.freq = freq
|
| 67 |
+
self.kernel_size = kernel_size
|
| 68 |
+
self.stride = stride
|
| 69 |
+
self.empty = empty
|
| 70 |
+
self.freq_attn = freq_attn
|
| 71 |
+
self.freq_dim = freq_dim
|
| 72 |
+
self.norm = norm
|
| 73 |
+
self.pad = pad
|
| 74 |
+
if freq:
|
| 75 |
+
kernel_size = [kernel_size, 1]
|
| 76 |
+
stride = [stride, 1]
|
| 77 |
+
if pad != 0:
|
| 78 |
+
pad = [pad, 0]
|
| 79 |
+
# klass = nn.Conv2d
|
| 80 |
+
else:
|
| 81 |
+
kernel_size = [1, kernel_size]
|
| 82 |
+
stride = [1, stride]
|
| 83 |
+
if pad != 0:
|
| 84 |
+
pad = [0, pad]
|
| 85 |
+
|
| 86 |
+
self.is_first = is_first
|
| 87 |
+
|
| 88 |
+
if is_first:
|
| 89 |
+
self.pre_conv = nn.Conv2d(chin, chout, [1, 1])
|
| 90 |
+
chin = chout
|
| 91 |
+
|
| 92 |
+
if self.freq_attn:
|
| 93 |
+
self.freq_attn_block = FTB(input_dim=freq_dim, in_channel=chin)
|
| 94 |
+
|
| 95 |
+
self.conv = klass(chin, chout, kernel_size, stride, pad)
|
| 96 |
+
if self.empty:
|
| 97 |
+
return
|
| 98 |
+
self.norm1 = norm_fn(chout)
|
| 99 |
+
self.rewrite = None
|
| 100 |
+
if rewrite:
|
| 101 |
+
self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
|
| 102 |
+
self.norm2 = norm_fn(2 * chout)
|
| 103 |
+
|
| 104 |
+
self.dconv = None
|
| 105 |
+
if dconv:
|
| 106 |
+
self.dconv = DConv(chout, **dconv_kw)
|
| 107 |
+
|
| 108 |
+
def forward(self, x, inject=None):
|
| 109 |
+
"""
|
| 110 |
+
`inject` is used to inject the result from the time branch into the frequency branch,
|
| 111 |
+
when both have the same stride.
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
if not self.freq:
|
| 115 |
+
le = x.shape[-1]
|
| 116 |
+
if not le % self.stride == 0:
|
| 117 |
+
x = F.pad(x, (0, self.stride - (le % self.stride)))
|
| 118 |
+
|
| 119 |
+
if self.is_first:
|
| 120 |
+
x = self.pre_conv(x)
|
| 121 |
+
|
| 122 |
+
if self.freq_attn:
|
| 123 |
+
x = self.freq_attn_block(x)
|
| 124 |
+
|
| 125 |
+
x = self.conv(x)
|
| 126 |
+
|
| 127 |
+
x = F.gelu(self.norm1(x))
|
| 128 |
+
if self.dconv:
|
| 129 |
+
x = self.dconv(x)
|
| 130 |
+
|
| 131 |
+
if self.rewrite:
|
| 132 |
+
x = self.norm2(self.rewrite(x))
|
| 133 |
+
x = F.glu(x, dim=1)
|
| 134 |
+
|
| 135 |
+
return x
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class HDecLayer(nn.Module):
|
| 139 |
+
def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False,
|
| 140 |
+
freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True,
|
| 141 |
+
context_freq=True, rewrite=True):
|
| 142 |
+
"""
|
| 143 |
+
Same as HEncLayer but for decoder. See `HEncLayer` for documentation.
|
| 144 |
+
"""
|
| 145 |
+
super().__init__()
|
| 146 |
+
norm_fn = lambda d: nn.Identity() # noqa
|
| 147 |
+
if norm:
|
| 148 |
+
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
|
| 149 |
+
if stride == 1 and kernel_size % 2 == 0 and kernel_size > 1:
|
| 150 |
+
kernel_size -= 1
|
| 151 |
+
if pad:
|
| 152 |
+
pad = (kernel_size - stride) // 2
|
| 153 |
+
else:
|
| 154 |
+
pad = 0
|
| 155 |
+
self.pad = pad
|
| 156 |
+
self.last = last
|
| 157 |
+
self.freq = freq
|
| 158 |
+
self.chin = chin
|
| 159 |
+
self.empty = empty
|
| 160 |
+
self.stride = stride
|
| 161 |
+
self.kernel_size = kernel_size
|
| 162 |
+
self.norm = norm
|
| 163 |
+
self.context_freq = context_freq
|
| 164 |
+
klass = nn.Conv2d
|
| 165 |
+
klass_tr = nn.ConvTranspose2d
|
| 166 |
+
if freq:
|
| 167 |
+
kernel_size = [kernel_size, 1]
|
| 168 |
+
stride = [stride, 1]
|
| 169 |
+
else:
|
| 170 |
+
kernel_size = [1, kernel_size]
|
| 171 |
+
stride = [1, stride]
|
| 172 |
+
self.conv_tr = klass_tr(chin, chout, kernel_size, stride)
|
| 173 |
+
self.norm2 = norm_fn(chout)
|
| 174 |
+
if self.empty:
|
| 175 |
+
return
|
| 176 |
+
self.rewrite = None
|
| 177 |
+
if rewrite:
|
| 178 |
+
if context_freq:
|
| 179 |
+
self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
|
| 180 |
+
else:
|
| 181 |
+
self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1,
|
| 182 |
+
[0, context])
|
| 183 |
+
self.norm1 = norm_fn(2 * chin)
|
| 184 |
+
|
| 185 |
+
self.dconv = None
|
| 186 |
+
if dconv:
|
| 187 |
+
self.dconv = DConv(chin, **dconv_kw)
|
| 188 |
+
|
| 189 |
+
def forward(self, x, skip, length):
|
| 190 |
+
if self.freq and x.dim() == 3:
|
| 191 |
+
B, C, T = x.shape
|
| 192 |
+
x = x.view(B, self.chin, -1, T)
|
| 193 |
+
|
| 194 |
+
if not self.empty:
|
| 195 |
+
x = torch.cat([x, skip], dim=1)
|
| 196 |
+
|
| 197 |
+
if self.rewrite:
|
| 198 |
+
y = F.glu(self.norm1(self.rewrite(x)), dim=1)
|
| 199 |
+
else:
|
| 200 |
+
y = x
|
| 201 |
+
if self.dconv:
|
| 202 |
+
y = self.dconv(y)
|
| 203 |
+
else:
|
| 204 |
+
y = x
|
| 205 |
+
assert skip is None
|
| 206 |
+
z = self.norm2(self.conv_tr(y))
|
| 207 |
+
if self.freq:
|
| 208 |
+
if self.pad:
|
| 209 |
+
z = z[..., self.pad:-self.pad, :]
|
| 210 |
+
else:
|
| 211 |
+
z = z[..., self.pad:self.pad + length]
|
| 212 |
+
assert z.shape[-1] == length, (z.shape[-1], length)
|
| 213 |
+
if not self.last:
|
| 214 |
+
z = F.gelu(z)
|
| 215 |
+
return z
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class Aero(nn.Module):
|
| 219 |
+
"""
|
| 220 |
+
Deep model for Audio Super Resolution.
|
| 221 |
+
"""
|
| 222 |
+
|
| 223 |
+
@capture_init
|
| 224 |
+
def __init__(self,
|
| 225 |
+
# Channels
|
| 226 |
+
in_channels=1,
|
| 227 |
+
out_channels=1,
|
| 228 |
+
audio_channels=2,
|
| 229 |
+
channels=48,
|
| 230 |
+
growth=2,
|
| 231 |
+
# STFT
|
| 232 |
+
nfft=512,
|
| 233 |
+
hop_length=64,
|
| 234 |
+
end_iters=0,
|
| 235 |
+
cac=True,
|
| 236 |
+
# Main structure
|
| 237 |
+
rewrite=True,
|
| 238 |
+
hybrid=False,
|
| 239 |
+
hybrid_old=False,
|
| 240 |
+
# Frequency branch
|
| 241 |
+
freq_emb=0.2,
|
| 242 |
+
emb_scale=10,
|
| 243 |
+
emb_smooth=True,
|
| 244 |
+
# Convolutions
|
| 245 |
+
kernel_size=8,
|
| 246 |
+
strides=[4, 4, 2, 2],
|
| 247 |
+
context=1,
|
| 248 |
+
context_enc=0,
|
| 249 |
+
freq_ends=4,
|
| 250 |
+
enc_freq_attn=4,
|
| 251 |
+
# Normalization
|
| 252 |
+
norm_starts=2,
|
| 253 |
+
norm_groups=4,
|
| 254 |
+
# DConv residual branch
|
| 255 |
+
dconv_mode=1,
|
| 256 |
+
dconv_depth=2,
|
| 257 |
+
dconv_comp=4,
|
| 258 |
+
dconv_time_attn=2,
|
| 259 |
+
dconv_lstm=2,
|
| 260 |
+
dconv_init=1e-3,
|
| 261 |
+
# Weight init
|
| 262 |
+
rescale=0.1,
|
| 263 |
+
# Metadata
|
| 264 |
+
lr_sr=4000,
|
| 265 |
+
hr_sr=16000,
|
| 266 |
+
spec_upsample=True,
|
| 267 |
+
act_func='snake',
|
| 268 |
+
debug=False):
|
| 269 |
+
"""
|
| 270 |
+
Args:
|
| 271 |
+
sources (list[str]): list of source names.
|
| 272 |
+
audio_channels (int): input/output audio channels.
|
| 273 |
+
channels (int): initial number of hidden channels.
|
| 274 |
+
growth: increase the number of hidden channels by this factor at each layer.
|
| 275 |
+
nfft: number of fft bins. Note that changing this require careful computation of
|
| 276 |
+
various shape parameters and will not work out of the box for hybrid models.
|
| 277 |
+
end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
|
| 278 |
+
cac: uses complex as channels, i.e. complex numbers are 2 channels each
|
| 279 |
+
in input and output. no further processing is done before ISTFT.
|
| 280 |
+
depth (int): number of layers in the encoder and in the decoder.
|
| 281 |
+
rewrite (bool): add 1x1 convolution to each layer.
|
| 282 |
+
hybrid (bool): make a hybrid time/frequency domain, otherwise frequency only.
|
| 283 |
+
hybrid_old: some models trained for MDX had a padding bug. This replicates
|
| 284 |
+
this bug to avoid retraining them.
|
| 285 |
+
freq_emb: add frequency embedding after the first frequency layer if > 0,
|
| 286 |
+
the actual value controls the weight of the embedding.
|
| 287 |
+
emb_scale: equivalent to scaling the embedding learning rate
|
| 288 |
+
emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
|
| 289 |
+
kernel_size: kernel_size for encoder and decoder layers.
|
| 290 |
+
stride: stride for encoder and decoder layers.
|
| 291 |
+
context: context for 1x1 conv in the decoder.
|
| 292 |
+
context_enc: context for 1x1 conv in the encoder.
|
| 293 |
+
norm_starts: layer at which group norm starts being used.
|
| 294 |
+
decoder layers are numbered in reverse order.
|
| 295 |
+
norm_groups: number of groups for group norm.
|
| 296 |
+
dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
|
| 297 |
+
dconv_depth: depth of residual DConv branch.
|
| 298 |
+
dconv_comp: compression of DConv branch.
|
| 299 |
+
dconv_freq_attn: adds freq attention layers in DConv branch starting at this layer.
|
| 300 |
+
dconv_time_attn: adds time attention layers in DConv branch starting at this layer.
|
| 301 |
+
dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
|
| 302 |
+
dconv_init: initial scale for the DConv branch LayerScale.
|
| 303 |
+
rescale: weight recaling trick
|
| 304 |
+
lr_sr: source low-resolution sample-rate
|
| 305 |
+
hr_sr: target high-resolution sample-rate
|
| 306 |
+
spec_upsample: if true, upsamples in the spectral domain, otherwise performs sinc-interpolation beforehand
|
| 307 |
+
act_func: 'snake'/'relu'
|
| 308 |
+
debug: if true, prints out input dimensions throughout model layers.
|
| 309 |
+
"""
|
| 310 |
+
super().__init__()
|
| 311 |
+
self.cac = cac
|
| 312 |
+
self.in_channels = in_channels
|
| 313 |
+
self.out_channels = out_channels
|
| 314 |
+
self.audio_channels = audio_channels
|
| 315 |
+
self.kernel_size = kernel_size
|
| 316 |
+
self.context = context
|
| 317 |
+
self.strides = strides
|
| 318 |
+
self.depth = len(strides)
|
| 319 |
+
self.channels = channels
|
| 320 |
+
self.lr_sr = lr_sr
|
| 321 |
+
self.hr_sr = hr_sr
|
| 322 |
+
self.spec_upsample = spec_upsample
|
| 323 |
+
|
| 324 |
+
self.scale = hr_sr / lr_sr if self.spec_upsample else 1
|
| 325 |
+
|
| 326 |
+
self.nfft = nfft
|
| 327 |
+
self.hop_length = int(hop_length // self.scale) # this is for the input signal
|
| 328 |
+
self.win_length = int(self.nfft // self.scale) # this is for the input signal
|
| 329 |
+
self.end_iters = end_iters
|
| 330 |
+
self.freq_emb = None
|
| 331 |
+
self.hybrid = hybrid
|
| 332 |
+
self.hybrid_old = hybrid_old
|
| 333 |
+
self.debug = debug
|
| 334 |
+
|
| 335 |
+
self.encoder = nn.ModuleList()
|
| 336 |
+
self.decoder = nn.ModuleList()
|
| 337 |
+
|
| 338 |
+
chin_z = self.in_channels
|
| 339 |
+
if self.cac:
|
| 340 |
+
chin_z *= 2
|
| 341 |
+
chout_z = channels
|
| 342 |
+
freqs = nfft // 2
|
| 343 |
+
|
| 344 |
+
for index in range(self.depth):
|
| 345 |
+
freq_attn = index >= enc_freq_attn
|
| 346 |
+
lstm = index >= dconv_lstm
|
| 347 |
+
time_attn = index >= dconv_time_attn
|
| 348 |
+
norm = index >= norm_starts
|
| 349 |
+
freq = index <= freq_ends
|
| 350 |
+
stri = strides[index]
|
| 351 |
+
ker = kernel_size
|
| 352 |
+
|
| 353 |
+
pad = True
|
| 354 |
+
if freq and freqs < kernel_size:
|
| 355 |
+
ker = freqs
|
| 356 |
+
|
| 357 |
+
kw = {
|
| 358 |
+
'kernel_size': ker,
|
| 359 |
+
'stride': stri,
|
| 360 |
+
'freq': freq,
|
| 361 |
+
'pad': pad,
|
| 362 |
+
'norm': norm,
|
| 363 |
+
'rewrite': rewrite,
|
| 364 |
+
'norm_groups': norm_groups,
|
| 365 |
+
'dconv_kw': {
|
| 366 |
+
'lstm': lstm,
|
| 367 |
+
'time_attn': time_attn,
|
| 368 |
+
'depth': dconv_depth,
|
| 369 |
+
'compress': dconv_comp,
|
| 370 |
+
'init': dconv_init,
|
| 371 |
+
'act_func': act_func,
|
| 372 |
+
'reshape': True,
|
| 373 |
+
'freq_dim': freqs // strides[index] if freq else freqs
|
| 374 |
+
}
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
kw_dec = dict(kw)
|
| 378 |
+
|
| 379 |
+
enc = HEncLayer(chin_z, chout_z,
|
| 380 |
+
dconv=dconv_mode & 1, context=context_enc,
|
| 381 |
+
is_first=index == 0, freq_attn=freq_attn, freq_dim=freqs,
|
| 382 |
+
**kw)
|
| 383 |
+
|
| 384 |
+
self.encoder.append(enc)
|
| 385 |
+
if index == 0:
|
| 386 |
+
chin = self.out_channels
|
| 387 |
+
chin_z = chin
|
| 388 |
+
if self.cac:
|
| 389 |
+
chin_z *= 2
|
| 390 |
+
dec = HDecLayer(2 * chout_z, chin_z, dconv=dconv_mode & 2,
|
| 391 |
+
last=index == 0, context=context, **kw_dec)
|
| 392 |
+
|
| 393 |
+
self.decoder.insert(0, dec)
|
| 394 |
+
|
| 395 |
+
chin_z = chout_z
|
| 396 |
+
chout_z = int(growth * chout_z)
|
| 397 |
+
|
| 398 |
+
if freq:
|
| 399 |
+
freqs //= strides[index]
|
| 400 |
+
|
| 401 |
+
if index == 0 and freq_emb:
|
| 402 |
+
self.freq_emb = ScaledEmbedding(
|
| 403 |
+
freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
|
| 404 |
+
self.freq_emb_scale = freq_emb
|
| 405 |
+
|
| 406 |
+
if rescale:
|
| 407 |
+
rescale_module(self, reference=rescale)
|
| 408 |
+
|
| 409 |
+
def _spec(self, x, scale=False):
|
| 410 |
+
if np.mod(x.shape[-1], self.hop_length):
|
| 411 |
+
x = F.pad(x, (0, self.hop_length - np.mod(x.shape[-1], self.hop_length)))
|
| 412 |
+
hl = self.hop_length
|
| 413 |
+
nfft = self.nfft
|
| 414 |
+
win_length = self.win_length
|
| 415 |
+
|
| 416 |
+
if scale:
|
| 417 |
+
hl = int(hl * self.scale)
|
| 418 |
+
win_length = int(win_length * self.scale)
|
| 419 |
+
|
| 420 |
+
z = spectro(x, nfft, hl, win_length=win_length)[..., :-1, :]
|
| 421 |
+
return z
|
| 422 |
+
|
| 423 |
+
def _ispec(self, z):
|
| 424 |
+
hl = int(self.hop_length * self.scale)
|
| 425 |
+
win_length = int(self.win_length * self.scale)
|
| 426 |
+
z = F.pad(z, (0, 0, 0, 1))
|
| 427 |
+
x = ispectro(z, hl, win_length=win_length)
|
| 428 |
+
return x
|
| 429 |
+
|
| 430 |
+
def _move_complex_to_channels_dim(self, z):
|
| 431 |
+
B, C, Fr, T = z.shape
|
| 432 |
+
m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
|
| 433 |
+
m = m.reshape(B, C * 2, Fr, T)
|
| 434 |
+
return m
|
| 435 |
+
|
| 436 |
+
def _convert_to_complex(self, x):
|
| 437 |
+
"""
|
| 438 |
+
|
| 439 |
+
:param x: signal of shape [Batch, Channels, 2, Freq, TimeFrames]
|
| 440 |
+
:return: complex signal of shape [Batch, Channels, Freq, TimeFrames]
|
| 441 |
+
"""
|
| 442 |
+
out = x.permute(0, 1, 3, 4, 2)
|
| 443 |
+
out = torch.view_as_complex(out.contiguous())
|
| 444 |
+
return out
|
| 445 |
+
|
| 446 |
+
def forward(self, mix, return_spec=False, return_lr_spec=False):
|
| 447 |
+
x = mix
|
| 448 |
+
length = x.shape[-1]
|
| 449 |
+
|
| 450 |
+
if self.debug:
|
| 451 |
+
logger.info(f'hdemucs in shape: {x.shape}')
|
| 452 |
+
|
| 453 |
+
z = self._spec(x)
|
| 454 |
+
x = self._move_complex_to_channels_dim(z)
|
| 455 |
+
|
| 456 |
+
if self.debug:
|
| 457 |
+
logger.info(f'x spec shape: {x.shape}')
|
| 458 |
+
|
| 459 |
+
B, C, Fq, T = x.shape
|
| 460 |
+
|
| 461 |
+
# unlike previous Demucs, we always normalize because it is easier.
|
| 462 |
+
mean = x.mean(dim=(1, 2, 3), keepdim=True)
|
| 463 |
+
std = x.std(dim=(1, 2, 3), keepdim=True)
|
| 464 |
+
x = (x - mean) / (1e-5 + std)
|
| 465 |
+
|
| 466 |
+
# okay, this is a giant mess I know...
|
| 467 |
+
saved = [] # skip connections, freq.
|
| 468 |
+
lengths = [] # saved lengths to properly remove padding, freq branch.
|
| 469 |
+
for idx, encode in enumerate(self.encoder):
|
| 470 |
+
lengths.append(x.shape[-1])
|
| 471 |
+
inject = None
|
| 472 |
+
x = encode(x, inject)
|
| 473 |
+
if self.debug:
|
| 474 |
+
logger.info(f'encoder {idx} out shape: {x.shape}')
|
| 475 |
+
if idx == 0 and self.freq_emb is not None:
|
| 476 |
+
# add frequency embedding to allow for non equivariant convolutions
|
| 477 |
+
# over the frequency axis.
|
| 478 |
+
frs = torch.arange(x.shape[-2], device=x.device)
|
| 479 |
+
emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
|
| 480 |
+
x = x + self.freq_emb_scale * emb
|
| 481 |
+
|
| 482 |
+
saved.append(x)
|
| 483 |
+
|
| 484 |
+
x = torch.zeros_like(x)
|
| 485 |
+
# initialize everything to zero (signal will go through u-net skips).
|
| 486 |
+
|
| 487 |
+
for idx, decode in enumerate(self.decoder):
|
| 488 |
+
skip = saved.pop(-1)
|
| 489 |
+
x = decode(x, skip, lengths.pop(-1))
|
| 490 |
+
|
| 491 |
+
if self.debug:
|
| 492 |
+
logger.info(f'decoder {idx} out shape: {x.shape}')
|
| 493 |
+
|
| 494 |
+
# Let's make sure we used all stored skip connections.
|
| 495 |
+
assert len(saved) == 0
|
| 496 |
+
|
| 497 |
+
x = x.view(B, self.out_channels, -1, Fq, T)
|
| 498 |
+
x = x * std[:, None] + mean[:, None]
|
| 499 |
+
|
| 500 |
+
if self.debug:
|
| 501 |
+
logger.info(f'post view shape: {x.shape}')
|
| 502 |
+
|
| 503 |
+
x_spec_complex = self._convert_to_complex(x)
|
| 504 |
+
|
| 505 |
+
if self.debug:
|
| 506 |
+
logger.info(f'x_spec_complex shape: {x_spec_complex.shape}')
|
| 507 |
+
|
| 508 |
+
x = self._ispec(x_spec_complex)
|
| 509 |
+
|
| 510 |
+
if self.debug:
|
| 511 |
+
logger.info(f'hdemucs out shape: {x.shape}')
|
| 512 |
+
|
| 513 |
+
x = x[..., :int(length * self.scale)]
|
| 514 |
+
|
| 515 |
+
if self.debug:
|
| 516 |
+
logger.info(f'hdemucs out - trimmed shape: {x.shape}')
|
| 517 |
+
|
| 518 |
+
if return_spec:
|
| 519 |
+
if return_lr_spec:
|
| 520 |
+
return x, x_spec_complex, z
|
| 521 |
+
else:
|
| 522 |
+
return x, x_spec_complex
|
| 523 |
+
|
| 524 |
+
return x
|
src/models/modelFactory.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.models.aero import Aero
|
| 2 |
+
from src.models.seanet import Seanet
|
| 3 |
+
from yaml import safe_load
|
| 4 |
+
|
| 5 |
+
def get_model(model_name="aero"):
|
| 6 |
+
if model_name == 'aero':
|
| 7 |
+
with open("conf/experiment/aero_441-441_512_256.yaml") as f:
|
| 8 |
+
generator = Aero(**safe_load(f)["aero"])
|
| 9 |
+
elif model_name == 'seanet':
|
| 10 |
+
with open("conf/experiment/seanet_441-441.yaml") as f:
|
| 11 |
+
generator = Seanet(**safe_load(f)["seanet"])
|
| 12 |
+
|
| 13 |
+
models = {'generator': generator}
|
| 14 |
+
|
| 15 |
+
return models
|
src/models/modules.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn.utils.parametrizations import weight_norm
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
from src.models.snake import Snake
|
| 8 |
+
from src.models.utils import unfold
|
| 9 |
+
|
| 10 |
+
import typing as tp
|
| 11 |
+
|
| 12 |
+
def WNConv1d(*args, **kwargs):
|
| 13 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def WNConvTranspose1d(*args, **kwargs):
|
| 17 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
| 18 |
+
|
| 19 |
+
class BLSTM(nn.Module):
|
| 20 |
+
"""
|
| 21 |
+
BiLSTM with same hidden units as input dim.
|
| 22 |
+
If `max_steps` is not None, input will be splitting in overlapping
|
| 23 |
+
chunks and the LSTM applied separately on each chunk.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, dim, layers=1, max_steps=None, skip=False):
|
| 27 |
+
super().__init__()
|
| 28 |
+
assert max_steps is None or max_steps % 4 == 0
|
| 29 |
+
self.max_steps = max_steps
|
| 30 |
+
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
|
| 31 |
+
self.linear = nn.Linear(2 * dim, dim)
|
| 32 |
+
self.skip = skip
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
B, C, T = x.shape
|
| 36 |
+
y = x
|
| 37 |
+
framed = False
|
| 38 |
+
if self.max_steps is not None and T > self.max_steps:
|
| 39 |
+
width = self.max_steps
|
| 40 |
+
stride = width // 2
|
| 41 |
+
frames = unfold(x, width, stride)
|
| 42 |
+
nframes = frames.shape[2]
|
| 43 |
+
framed = True
|
| 44 |
+
x = frames.permute(0, 2, 1, 3).reshape(-1, C, width)
|
| 45 |
+
|
| 46 |
+
x = x.permute(2, 0, 1)
|
| 47 |
+
|
| 48 |
+
x = self.lstm(x)[0]
|
| 49 |
+
x = self.linear(x)
|
| 50 |
+
x = x.permute(1, 2, 0)
|
| 51 |
+
if framed:
|
| 52 |
+
out = []
|
| 53 |
+
frames = x.reshape(B, -1, C, width)
|
| 54 |
+
limit = stride // 2
|
| 55 |
+
for k in range(nframes):
|
| 56 |
+
if k == 0:
|
| 57 |
+
out.append(frames[:, k, :, :-limit])
|
| 58 |
+
elif k == nframes - 1:
|
| 59 |
+
out.append(frames[:, k, :, limit:])
|
| 60 |
+
else:
|
| 61 |
+
out.append(frames[:, k, :, limit:-limit])
|
| 62 |
+
out = torch.cat(out, -1)
|
| 63 |
+
out = out[..., :T]
|
| 64 |
+
x = out
|
| 65 |
+
if self.skip:
|
| 66 |
+
x = x + y
|
| 67 |
+
return x
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class LocalState(nn.Module):
|
| 71 |
+
"""Local state allows to have attention based only on data (no positional embedding),
|
| 72 |
+
but while setting a constraint on the time window (e.g. decaying penalty term).
|
| 73 |
+
Also a failed experiments with trying to provide some frequency based attention.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4):
|
| 77 |
+
super().__init__()
|
| 78 |
+
assert channels % heads == 0, (channels, heads)
|
| 79 |
+
self.heads = heads
|
| 80 |
+
self.nfreqs = nfreqs
|
| 81 |
+
self.ndecay = ndecay
|
| 82 |
+
self.content = nn.Conv1d(channels, channels, 1)
|
| 83 |
+
self.query = nn.Conv1d(channels, channels, 1)
|
| 84 |
+
self.key = nn.Conv1d(channels, channels, 1)
|
| 85 |
+
if nfreqs:
|
| 86 |
+
self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1)
|
| 87 |
+
if ndecay:
|
| 88 |
+
self.query_decay = nn.Conv1d(channels, heads * ndecay, 1)
|
| 89 |
+
# Initialize decay close to zero (there is a sigmoid), for maximum initial window.
|
| 90 |
+
self.query_decay.weight.data *= 0.01
|
| 91 |
+
assert self.query_decay.bias is not None # stupid type checker
|
| 92 |
+
self.query_decay.bias.data[:] = -2
|
| 93 |
+
# self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1)
|
| 94 |
+
self.proj = nn.Conv1d(channels, channels, 1)
|
| 95 |
+
|
| 96 |
+
def forward(self, x):
|
| 97 |
+
B, C, T = x.shape
|
| 98 |
+
heads = self.heads
|
| 99 |
+
indexes = torch.arange(T, device=x.device, dtype=x.dtype)
|
| 100 |
+
# left index are keys, right index are queries
|
| 101 |
+
delta = indexes[:, None] - indexes[None, :]
|
| 102 |
+
|
| 103 |
+
queries = self.query(x).view(B, heads, -1, T)
|
| 104 |
+
keys = self.key(x).view(B, heads, -1, T)
|
| 105 |
+
# t are keys, s are queries
|
| 106 |
+
dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
|
| 107 |
+
dots /= keys.shape[2] ** 0.5
|
| 108 |
+
if self.nfreqs:
|
| 109 |
+
periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype)
|
| 110 |
+
freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1))
|
| 111 |
+
freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs ** 0.5
|
| 112 |
+
tmp = torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q)
|
| 113 |
+
dots += tmp
|
| 114 |
+
if self.ndecay:
|
| 115 |
+
decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype)
|
| 116 |
+
decay_q = self.query_decay(x).view(B, heads, -1, T)
|
| 117 |
+
decay_q = torch.sigmoid(decay_q) / 2
|
| 118 |
+
decay_kernel = - decays.view(-1, 1, 1) * delta.abs() / self.ndecay ** 0.5
|
| 119 |
+
dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q)
|
| 120 |
+
|
| 121 |
+
# Kill self reference.
|
| 122 |
+
dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100)
|
| 123 |
+
weights = torch.softmax(dots, dim=2)
|
| 124 |
+
|
| 125 |
+
content = self.content(x).view(B, heads, -1, T)
|
| 126 |
+
result = torch.einsum("bhts,bhct->bhcs", weights, content)
|
| 127 |
+
|
| 128 |
+
result = result.reshape(B, -1, T)
|
| 129 |
+
return x + self.proj(result)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class LayerScale(nn.Module):
|
| 133 |
+
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
|
| 134 |
+
This rescales diagonaly residual outputs close to 0 initially, then learnt.
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
def __init__(self, channels: int, init: float = 0):
|
| 138 |
+
super().__init__()
|
| 139 |
+
self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
|
| 140 |
+
self.scale.data[:] = init
|
| 141 |
+
|
| 142 |
+
def forward(self, x):
|
| 143 |
+
return self.scale[:, None] * x
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class DConv(nn.Module):
|
| 147 |
+
"""
|
| 148 |
+
New residual branches in each encoder layer.
|
| 149 |
+
This alternates dilated convolutions, potentially with LSTMs and attention.
|
| 150 |
+
Also before entering each residual branch, dimension is projected on a smaller subspace,
|
| 151 |
+
e.g. of dim `channels // compress`.
|
| 152 |
+
"""
|
| 153 |
+
|
| 154 |
+
def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4,
|
| 155 |
+
norm=True, time_attn=False, heads=4, ndecay=4, lstm=False,
|
| 156 |
+
act_func='gelu', freq_dim=None, reshape=False,
|
| 157 |
+
kernel=3, dilate=True):
|
| 158 |
+
"""
|
| 159 |
+
Args:
|
| 160 |
+
channels: input/output channels for residual branch.
|
| 161 |
+
compress: amount of channel compression inside the branch.
|
| 162 |
+
depth: number of layers in the residual branch. Each layer has its own
|
| 163 |
+
projection, and potentially LSTM and attention.
|
| 164 |
+
init: initial scale for LayerNorm.
|
| 165 |
+
norm: use GroupNorm.
|
| 166 |
+
time_attn: use LocalAttention.
|
| 167 |
+
heads: number of heads for the LocalAttention.
|
| 168 |
+
ndecay: number of decay controls in the LocalAttention.
|
| 169 |
+
lstm: use LSTM.
|
| 170 |
+
gelu: Use GELU activation.
|
| 171 |
+
kernel: kernel size for the (dilated) convolutions.
|
| 172 |
+
dilate: if true, use dilation, increasing with the depth.
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
super().__init__()
|
| 176 |
+
assert kernel % 2 == 1
|
| 177 |
+
self.channels = channels
|
| 178 |
+
self.compress = compress
|
| 179 |
+
self.depth = abs(depth)
|
| 180 |
+
dilate = depth > 0
|
| 181 |
+
|
| 182 |
+
self.time_attn = time_attn
|
| 183 |
+
self.lstm = lstm
|
| 184 |
+
self.reshape = reshape
|
| 185 |
+
self.act_func = act_func
|
| 186 |
+
self.freq_dim = freq_dim
|
| 187 |
+
|
| 188 |
+
norm_fn: tp.Callable[[int], nn.Module]
|
| 189 |
+
norm_fn = lambda d: nn.Identity() # noqa
|
| 190 |
+
if norm:
|
| 191 |
+
norm_fn = lambda d: nn.GroupNorm(1, d) # noqa
|
| 192 |
+
|
| 193 |
+
self.hidden = int(channels / compress)
|
| 194 |
+
|
| 195 |
+
act: tp.Type[nn.Module]
|
| 196 |
+
if act_func == 'gelu':
|
| 197 |
+
act = nn.GELU
|
| 198 |
+
elif act_func == 'snake':
|
| 199 |
+
act = Snake
|
| 200 |
+
else:
|
| 201 |
+
act = nn.ReLU
|
| 202 |
+
|
| 203 |
+
self.layers = nn.ModuleList([])
|
| 204 |
+
for d in range(self.depth):
|
| 205 |
+
layer = nn.ModuleDict()
|
| 206 |
+
dilation = 2 ** d if dilate else 1
|
| 207 |
+
padding = dilation * (kernel // 2)
|
| 208 |
+
conv1 = nn.ModuleList([nn.Conv1d(channels, self.hidden, kernel, dilation=dilation, padding=padding),
|
| 209 |
+
norm_fn(self.hidden)])
|
| 210 |
+
act_layer = act(freq_dim) if act_func == 'snake' else act()
|
| 211 |
+
conv2 = nn.ModuleList([nn.Conv1d(self.hidden, 2 * channels, 1),
|
| 212 |
+
norm_fn(2 * channels), nn.GLU(1),
|
| 213 |
+
LayerScale(channels, init)])
|
| 214 |
+
|
| 215 |
+
layer.update({'conv1': nn.Sequential(*conv1), 'act': act_layer, 'conv2': nn.Sequential(*conv2)})
|
| 216 |
+
if lstm:
|
| 217 |
+
layer.update({'lstm': BLSTM(self.hidden, layers=2, max_steps=200, skip=True)})
|
| 218 |
+
if time_attn:
|
| 219 |
+
layer.update({'time_attn': LocalState(self.hidden, heads=heads, ndecay=ndecay)})
|
| 220 |
+
|
| 221 |
+
self.layers.append(layer)
|
| 222 |
+
|
| 223 |
+
def forward(self, x):
|
| 224 |
+
|
| 225 |
+
if self.reshape:
|
| 226 |
+
B, C, Fr, T = x.shape
|
| 227 |
+
x = x.permute(0, 2, 1, 3).reshape(-1, C, T)
|
| 228 |
+
|
| 229 |
+
for layer in self.layers:
|
| 230 |
+
skip = x
|
| 231 |
+
|
| 232 |
+
x = layer['conv1'](x)
|
| 233 |
+
|
| 234 |
+
if self.act_func == 'snake' and self.reshape:
|
| 235 |
+
x = x.view(B, Fr, self.hidden, T).permute(0, 2, 3, 1)
|
| 236 |
+
x = layer['act'](x)
|
| 237 |
+
if self.act_func == 'snake' and self.reshape:
|
| 238 |
+
x = x.permute(0, 3, 1, 2).reshape(-1, self.hidden, T)
|
| 239 |
+
|
| 240 |
+
if self.lstm:
|
| 241 |
+
x = layer['lstm'](x)
|
| 242 |
+
if self.time_attn:
|
| 243 |
+
x = layer['time_attn'](x)
|
| 244 |
+
|
| 245 |
+
x = layer['conv2'](x)
|
| 246 |
+
x = skip + x
|
| 247 |
+
|
| 248 |
+
if self.reshape:
|
| 249 |
+
x = x.view(B, Fr, C, T).permute(0, 2, 1, 3)
|
| 250 |
+
|
| 251 |
+
return x
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class ScaledEmbedding(nn.Module):
|
| 255 |
+
"""
|
| 256 |
+
Boost learning rate for embeddings (with `scale`).
|
| 257 |
+
Also, can make embeddings continuous with `smooth`.
|
| 258 |
+
"""
|
| 259 |
+
|
| 260 |
+
def __init__(self, num_embeddings: int, embedding_dim: int,
|
| 261 |
+
scale: float = 10., smooth=False):
|
| 262 |
+
super().__init__()
|
| 263 |
+
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
|
| 264 |
+
if smooth:
|
| 265 |
+
weight = torch.cumsum(self.embedding.weight.data, dim=0)
|
| 266 |
+
# when summing gaussian, overscale raises as sqrt(n), so we nornalize by that.
|
| 267 |
+
weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None]
|
| 268 |
+
self.embedding.weight.data[:] = weight
|
| 269 |
+
self.embedding.weight.data /= scale
|
| 270 |
+
self.scale = scale
|
| 271 |
+
|
| 272 |
+
@property
|
| 273 |
+
def weight(self):
|
| 274 |
+
return self.embedding.weight * self.scale
|
| 275 |
+
|
| 276 |
+
def forward(self, x):
|
| 277 |
+
out = self.embedding(x) * self.scale
|
| 278 |
+
return out
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class FTB(nn.Module):
|
| 282 |
+
|
| 283 |
+
def __init__(self, input_dim=257, in_channel=9, r_channel=5):
|
| 284 |
+
super(FTB, self).__init__()
|
| 285 |
+
self.input_dim = input_dim
|
| 286 |
+
self.in_channel = in_channel
|
| 287 |
+
self.conv1 = nn.Sequential(
|
| 288 |
+
nn.Conv2d(in_channel, r_channel, kernel_size=[1, 1]),
|
| 289 |
+
nn.BatchNorm2d(r_channel),
|
| 290 |
+
nn.ReLU()
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
self.conv1d = nn.Sequential(
|
| 294 |
+
nn.Conv1d(r_channel * input_dim, in_channel, kernel_size=9, padding=4),
|
| 295 |
+
nn.BatchNorm1d(in_channel),
|
| 296 |
+
nn.ReLU()
|
| 297 |
+
)
|
| 298 |
+
self.freq_fc = nn.Linear(input_dim, input_dim, bias=False)
|
| 299 |
+
|
| 300 |
+
self.conv2 = nn.Sequential(
|
| 301 |
+
nn.Conv2d(in_channel * 2, in_channel, kernel_size=[1, 1]),
|
| 302 |
+
nn.BatchNorm2d(in_channel),
|
| 303 |
+
nn.ReLU()
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
def forward(self, inputs):
|
| 307 |
+
'''
|
| 308 |
+
inputs should be [Batch, Ca, Dim, Time]
|
| 309 |
+
'''
|
| 310 |
+
# T-F attention
|
| 311 |
+
conv1_out = self.conv1(inputs)
|
| 312 |
+
B, C, D, T = conv1_out.size()
|
| 313 |
+
reshape1_out = torch.reshape(conv1_out, [B, C * D, T])
|
| 314 |
+
conv1d_out = self.conv1d(reshape1_out)
|
| 315 |
+
conv1d_out = torch.reshape(conv1d_out, [B, self.in_channel, 1, T])
|
| 316 |
+
|
| 317 |
+
# now is also [B,C,D,T]
|
| 318 |
+
att_out = conv1d_out * inputs
|
| 319 |
+
|
| 320 |
+
# tranpose to [B,C,T,D]
|
| 321 |
+
att_out = torch.transpose(att_out, 2, 3)
|
| 322 |
+
freqfc_out = self.freq_fc(att_out)
|
| 323 |
+
att_out = torch.transpose(freqfc_out, 2, 3)
|
| 324 |
+
|
| 325 |
+
cat_out = torch.cat([att_out, inputs], 1)
|
| 326 |
+
outputs = self.conv2(cat_out)
|
| 327 |
+
return outputs
|
src/models/seanet.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import math
|
| 3 |
+
from src.models.utils import capture_init, weights_init
|
| 4 |
+
from src.models.modules import WNConv1d, WNConvTranspose1d
|
| 5 |
+
from torchaudio.functional import resample
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
|
| 8 |
+
class ResnetBlock(nn.Module):
|
| 9 |
+
def __init__(self, dim, dilation=1):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.block = nn.Sequential(
|
| 12 |
+
nn.LeakyReLU(0.2),
|
| 13 |
+
nn.ReflectionPad1d(dilation),
|
| 14 |
+
WNConv1d(dim, dim, kernel_size=3, dilation=dilation),
|
| 15 |
+
nn.LeakyReLU(0.2),
|
| 16 |
+
WNConv1d(dim, dim, kernel_size=1),
|
| 17 |
+
)
|
| 18 |
+
self.shortcut = WNConv1d(dim, dim, kernel_size=1)
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
return self.shortcut(x) + self.block(x)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Seanet(nn.Module):
|
| 25 |
+
|
| 26 |
+
@capture_init
|
| 27 |
+
def __init__(self,
|
| 28 |
+
latent_space_size=128,
|
| 29 |
+
ngf=32, n_residual_layers=3,
|
| 30 |
+
resample=1,
|
| 31 |
+
normalize=True,
|
| 32 |
+
floor=1e-3,
|
| 33 |
+
ratios=[8, 8, 2, 2],
|
| 34 |
+
in_channels=1,
|
| 35 |
+
out_channels=1,
|
| 36 |
+
lr_sr=16000,
|
| 37 |
+
hr_sr=16000,
|
| 38 |
+
upsample=True):
|
| 39 |
+
super().__init__()
|
| 40 |
+
|
| 41 |
+
self.resample = resample
|
| 42 |
+
self.normalize = normalize
|
| 43 |
+
self.floor = floor
|
| 44 |
+
self.lr_sr = lr_sr
|
| 45 |
+
self.hr_sr = hr_sr
|
| 46 |
+
self.scale_factor = int(self.hr_sr / self.lr_sr)
|
| 47 |
+
self.upsample = upsample
|
| 48 |
+
|
| 49 |
+
self.encoder = nn.ModuleList()
|
| 50 |
+
self.decoder = nn.ModuleList()
|
| 51 |
+
|
| 52 |
+
self.ratios = ratios
|
| 53 |
+
mult = int(2 ** len(ratios))
|
| 54 |
+
|
| 55 |
+
decoder_wrapper_conv_layer = [
|
| 56 |
+
nn.LeakyReLU(0.2),
|
| 57 |
+
nn.ReflectionPad1d(3),
|
| 58 |
+
WNConv1d(latent_space_size, mult * ngf, kernel_size=7, padding=0),
|
| 59 |
+
]
|
| 60 |
+
|
| 61 |
+
encoder_wrapper_conv_layer = [
|
| 62 |
+
nn.LeakyReLU(0.2),
|
| 63 |
+
nn.ReflectionPad1d(3),
|
| 64 |
+
WNConv1d(mult * ngf, latent_space_size, kernel_size=7, padding=0)
|
| 65 |
+
]
|
| 66 |
+
|
| 67 |
+
self.encoder.insert(0, nn.Sequential(*encoder_wrapper_conv_layer))
|
| 68 |
+
self.decoder.append(nn.Sequential(*decoder_wrapper_conv_layer))
|
| 69 |
+
|
| 70 |
+
for i, r in enumerate(ratios):
|
| 71 |
+
encoder_block = [
|
| 72 |
+
nn.LeakyReLU(0.2),
|
| 73 |
+
WNConv1d(mult * ngf // 2,
|
| 74 |
+
mult * ngf,
|
| 75 |
+
kernel_size=r * 2,
|
| 76 |
+
stride=r,
|
| 77 |
+
padding=r // 2 + r % 2,
|
| 78 |
+
),
|
| 79 |
+
]
|
| 80 |
+
|
| 81 |
+
decoder_block = [
|
| 82 |
+
nn.LeakyReLU(0.2),
|
| 83 |
+
WNConvTranspose1d(
|
| 84 |
+
mult * ngf,
|
| 85 |
+
mult * ngf // 2,
|
| 86 |
+
kernel_size=r * 2,
|
| 87 |
+
stride=r,
|
| 88 |
+
padding=r // 2 + r % 2,
|
| 89 |
+
output_padding=r % 2,
|
| 90 |
+
),
|
| 91 |
+
]
|
| 92 |
+
|
| 93 |
+
for j in range(n_residual_layers - 1, -1, -1):
|
| 94 |
+
encoder_block = [ResnetBlock(mult * ngf // 2, dilation=3 ** j)] + encoder_block
|
| 95 |
+
|
| 96 |
+
for j in range(n_residual_layers):
|
| 97 |
+
decoder_block += [ResnetBlock(mult * ngf // 2, dilation=3 ** j)]
|
| 98 |
+
|
| 99 |
+
mult //= 2
|
| 100 |
+
|
| 101 |
+
self.encoder.insert(0, nn.Sequential(*encoder_block))
|
| 102 |
+
self.decoder.append(nn.Sequential(*decoder_block))
|
| 103 |
+
|
| 104 |
+
encoder_wrapper_conv_layer = [
|
| 105 |
+
nn.ReflectionPad1d(3),
|
| 106 |
+
WNConv1d(in_channels, ngf, kernel_size=7, padding=0),
|
| 107 |
+
nn.Tanh(),
|
| 108 |
+
]
|
| 109 |
+
self.encoder.insert(0, nn.Sequential(*encoder_wrapper_conv_layer))
|
| 110 |
+
|
| 111 |
+
decoder_wrapper_conv_layer = [
|
| 112 |
+
nn.LeakyReLU(0.2),
|
| 113 |
+
nn.ReflectionPad1d(3),
|
| 114 |
+
WNConv1d(ngf, out_channels, kernel_size=7, padding=0),
|
| 115 |
+
nn.Tanh(),
|
| 116 |
+
]
|
| 117 |
+
self.decoder.append(nn.Sequential(*decoder_wrapper_conv_layer))
|
| 118 |
+
|
| 119 |
+
self.apply(weights_init)
|
| 120 |
+
|
| 121 |
+
def estimate_output_length(self, length):
|
| 122 |
+
"""
|
| 123 |
+
Return the nearest valid length to use with the model so that
|
| 124 |
+
there is no time steps left over in a convolutions, e.g. for all
|
| 125 |
+
layers, size of the input - kernel_size % stride = 0.
|
| 126 |
+
|
| 127 |
+
If the mixture has a valid length, the estimated sources
|
| 128 |
+
will have exactly the same length.
|
| 129 |
+
"""
|
| 130 |
+
depth = len(self.ratios)
|
| 131 |
+
for idx in range(depth - 1, -1, -1):
|
| 132 |
+
stride = self.ratios[idx]
|
| 133 |
+
kernel_size = 2 * stride
|
| 134 |
+
padding = stride // 2 + stride % 2
|
| 135 |
+
length = math.ceil((length - kernel_size + 2 * padding) / stride) + 1
|
| 136 |
+
length = max(length, 1)
|
| 137 |
+
for idx in range(depth):
|
| 138 |
+
stride = self.ratios[idx]
|
| 139 |
+
kernel_size = 2 * stride
|
| 140 |
+
padding = stride // 2 + stride % 2
|
| 141 |
+
output_padding = stride % 2
|
| 142 |
+
length = (length - 1) * stride + kernel_size - 2 * padding + output_padding
|
| 143 |
+
return int(length)
|
| 144 |
+
|
| 145 |
+
def pad_to_valid_length(self, signal):
|
| 146 |
+
valid_length = self.estimate_output_length(signal.shape[-1])
|
| 147 |
+
padding_len = valid_length - signal.shape[-1]
|
| 148 |
+
signal = F.pad(signal, (0, padding_len))
|
| 149 |
+
return signal, padding_len
|
| 150 |
+
|
| 151 |
+
def forward(self, signal):
|
| 152 |
+
|
| 153 |
+
target_len = signal.shape[-1]
|
| 154 |
+
if self.upsample:
|
| 155 |
+
target_len *= self.scale_factor
|
| 156 |
+
if self.normalize:
|
| 157 |
+
mono = signal.mean(dim=1, keepdim=True)
|
| 158 |
+
std = mono.std(dim=-1, keepdim=True)
|
| 159 |
+
signal = signal / (self.floor + std)
|
| 160 |
+
else:
|
| 161 |
+
std = 1
|
| 162 |
+
x = signal
|
| 163 |
+
if self.upsample:
|
| 164 |
+
x = resample(x, self.lr_sr, self.hr_sr)
|
| 165 |
+
|
| 166 |
+
x, padding_len = self.pad_to_valid_length(x)
|
| 167 |
+
skips = []
|
| 168 |
+
for i, encode in enumerate(self.encoder):
|
| 169 |
+
skips.append(x)
|
| 170 |
+
x = encode(x)
|
| 171 |
+
for j, decode in enumerate(self.decoder):
|
| 172 |
+
x = decode(x)
|
| 173 |
+
skip = skips.pop(-1)
|
| 174 |
+
x = x + skip
|
| 175 |
+
if target_len < x.shape[-1]:
|
| 176 |
+
x = x[..., :target_len]
|
| 177 |
+
return std * x
|
src/models/snake.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn, sin, pow
|
| 3 |
+
from torch.nn import Parameter
|
| 4 |
+
from torch.distributions.exponential import Exponential
|
| 5 |
+
|
| 6 |
+
class Snake(nn.Module):
|
| 7 |
+
'''
|
| 8 |
+
Implementation of the serpentine-like sine-based periodic activation function:
|
| 9 |
+
.. math::
|
| 10 |
+
Snake_a := x + \frac{1}{a} sin^2(ax) = x - \frac{1}{2a}cos{2ax} + \frac{1}{2a}
|
| 11 |
+
This activation function is able to better extrapolate to previously unseen data,
|
| 12 |
+
especially in the case of learning periodic functions
|
| 13 |
+
|
| 14 |
+
Shape:
|
| 15 |
+
- Input: (N, *) where * means, any number of additional
|
| 16 |
+
dimensions
|
| 17 |
+
- Output: (N, *), same shape as the input
|
| 18 |
+
|
| 19 |
+
Parameters:
|
| 20 |
+
- a - trainable parameter
|
| 21 |
+
|
| 22 |
+
References:
|
| 23 |
+
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
| 24 |
+
https://arxiv.org/abs/2006.08195
|
| 25 |
+
|
| 26 |
+
Examples:
|
| 27 |
+
>>> a1 = snake(256)
|
| 28 |
+
>>> x = torch.randn(256)
|
| 29 |
+
>>> x = a1(x)
|
| 30 |
+
'''
|
| 31 |
+
|
| 32 |
+
def __init__(self, in_features, a=None, trainable=True):
|
| 33 |
+
'''
|
| 34 |
+
Initialization.
|
| 35 |
+
Args:
|
| 36 |
+
in_features: shape of the input
|
| 37 |
+
a: trainable parameter
|
| 38 |
+
trainable: sets `a` as a trainable parameter
|
| 39 |
+
|
| 40 |
+
`a` is initialized to 1 by default, higher values = higher-frequency,
|
| 41 |
+
5-50 is a good starting point if you already think your data is periodic,
|
| 42 |
+
consider starting lower e.g. 0.5 if you think not, but don't worry,
|
| 43 |
+
`a` will be trained along with the rest of your model
|
| 44 |
+
'''
|
| 45 |
+
super(Snake, self).__init__()
|
| 46 |
+
self.in_features = in_features if isinstance(in_features, list) else [in_features]
|
| 47 |
+
|
| 48 |
+
# Initialize `a`
|
| 49 |
+
if a is not None:
|
| 50 |
+
self.a = Parameter(torch.ones(self.in_features) * a) # create a tensor out of alpha
|
| 51 |
+
else:
|
| 52 |
+
m = Exponential(torch.tensor([0.1]))
|
| 53 |
+
self.a = Parameter((m.rsample(self.in_features)).squeeze()) # random init = mix of frequencies
|
| 54 |
+
|
| 55 |
+
self.a.requiresGrad = trainable # set the training of `a` to true
|
| 56 |
+
|
| 57 |
+
def extra_repr(self) -> str:
|
| 58 |
+
return 'in_features={}'.format(self.in_features)
|
| 59 |
+
|
| 60 |
+
def forward(self, x):
|
| 61 |
+
'''
|
| 62 |
+
Forward pass of the function.
|
| 63 |
+
Applies the function to the input elementwise.
|
| 64 |
+
Snake ∶= x + 1/a* sin^2 (xa)
|
| 65 |
+
'''
|
| 66 |
+
return x + (1.0 / self.a) * pow(sin(x * self.a), 2)
|
src/models/spec.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This code is based on Facebook's HDemucs code: https://github.com/facebookresearch/demucs
|
| 3 |
+
"""
|
| 4 |
+
"""Conveniance wrapper to perform STFT and iSTFT"""
|
| 5 |
+
|
| 6 |
+
import torch as th
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def spectro(x, n_fft=512, hop_length=None, pad=0, win_length=None):
|
| 10 |
+
*other, length = x.shape
|
| 11 |
+
x = x.reshape(-1, length)
|
| 12 |
+
z = th.stft(x,
|
| 13 |
+
n_fft * (1 + pad),
|
| 14 |
+
hop_length or n_fft // 4,
|
| 15 |
+
window=th.hann_window(win_length).to(x),
|
| 16 |
+
win_length=win_length or n_fft,
|
| 17 |
+
normalized=True,
|
| 18 |
+
center=True,
|
| 19 |
+
return_complex=True,
|
| 20 |
+
pad_mode='reflect')
|
| 21 |
+
_, freqs, frame = z.shape
|
| 22 |
+
return z.view(*other, freqs, frame)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def ispectro(z, hop_length=None, length=None, pad=0, win_length=None):
|
| 26 |
+
*other, freqs, frames = z.shape
|
| 27 |
+
n_fft = 2 * freqs - 2
|
| 28 |
+
z = z.view(-1, freqs, frames)
|
| 29 |
+
win_length = win_length or n_fft // (1 + pad)
|
| 30 |
+
x = th.istft(z,
|
| 31 |
+
n_fft,
|
| 32 |
+
hop_length or n_fft // 2,
|
| 33 |
+
window=th.hann_window(win_length).to(z.real),
|
| 34 |
+
win_length=win_length,
|
| 35 |
+
normalized=True,
|
| 36 |
+
length=length,
|
| 37 |
+
center=True)
|
| 38 |
+
_, length = x.shape
|
| 39 |
+
return x.view(*other, length)
|
src/models/utils.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
import torchaudio
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def capture_init(init):
|
| 9 |
+
"""capture_init.
|
| 10 |
+
|
| 11 |
+
Decorate `__init__` with this, and you can then
|
| 12 |
+
recover the *args and **kwargs passed to it in `self._init_args_kwargs`
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
@functools.wraps(init)
|
| 16 |
+
def __init__(self, *args, **kwargs):
|
| 17 |
+
self._init_args_kwargs = (args, kwargs)
|
| 18 |
+
init(self, *args, **kwargs)
|
| 19 |
+
|
| 20 |
+
return __init__
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def unfold(a, kernel_size, stride):
|
| 24 |
+
"""Given input of size [*OT, T], output Tensor of size [*OT, F, K]
|
| 25 |
+
with K the kernel size, by extracting frames with the given stride.
|
| 26 |
+
This will pad the input so that `F = ceil(T / K)`.
|
| 27 |
+
see https://github.com/pytorch/pytorch/issues/60466
|
| 28 |
+
"""
|
| 29 |
+
*shape, length = a.shape
|
| 30 |
+
n_frames = math.ceil(length / stride)
|
| 31 |
+
tgt_length = (n_frames - 1) * stride + kernel_size
|
| 32 |
+
a = F.pad(a, (0, tgt_length - length))
|
| 33 |
+
strides = list(a.stride())
|
| 34 |
+
assert strides[-1] == 1, 'data should be contiguous'
|
| 35 |
+
strides = strides[:-1] + [stride, 1]
|
| 36 |
+
return a.as_strided([*shape, n_frames, kernel_size], strides)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def weights_init(m):
|
| 40 |
+
classname = m.__class__.__name__
|
| 41 |
+
if classname.find("Conv") != -1:
|
| 42 |
+
m.weight.data.normal_(0.0, 0.02)
|
| 43 |
+
elif classname.find("BatchNorm2d") != -1:
|
| 44 |
+
m.weight.data.normal_(1.0, 0.02)
|
| 45 |
+
m.bias.data.fill_(0)
|
| 46 |
+
|
| 47 |
+
def write(wav, filename, sr):
|
| 48 |
+
# Normalize audio if it prevents clipping
|
| 49 |
+
wav = wav / max(wav.abs().max().item(), 1)
|
| 50 |
+
torchaudio.save(filename, wav.cpu(), sr)
|
src/utils.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
|
| 6 |
+
from contextlib import contextmanager
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def capture_init(init):
|
| 10 |
+
"""capture_init.
|
| 11 |
+
|
| 12 |
+
Decorate `__init__` with this, and you can then
|
| 13 |
+
recover the *args and **kwargs passed to it in `self._init_args_kwargs`
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
@functools.wraps(init)
|
| 17 |
+
def __init__(self, *args, **kwargs):
|
| 18 |
+
self._init_args_kwargs = (args, kwargs)
|
| 19 |
+
init(self, *args, **kwargs)
|
| 20 |
+
|
| 21 |
+
return __init__
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def unfold(a, kernel_size, stride):
|
| 25 |
+
"""Given input of size [*OT, T], output Tensor of size [*OT, F, K]
|
| 26 |
+
with K the kernel size, by extracting frames with the given stride.
|
| 27 |
+
This will pad the input so that `F = ceil(T / K)`.
|
| 28 |
+
see https://github.com/pytorch/pytorch/issues/60466
|
| 29 |
+
"""
|
| 30 |
+
*shape, length = a.shape
|
| 31 |
+
n_frames = math.ceil(length / stride)
|
| 32 |
+
tgt_length = (n_frames - 1) * stride + kernel_size
|
| 33 |
+
a = F.pad(a, (0, tgt_length - length))
|
| 34 |
+
strides = list(a.stride())
|
| 35 |
+
assert strides[-1] == 1, 'data should be contiguous'
|
| 36 |
+
strides = strides[:-1] + [stride, 1]
|
| 37 |
+
return a.as_strided([*shape, n_frames, kernel_size], strides)
|
| 38 |
+
|
| 39 |
+
def colorize(text, color):
|
| 40 |
+
"""
|
| 41 |
+
Display text with some ANSI color in the terminal.
|
| 42 |
+
"""
|
| 43 |
+
code = f"\033[{color}m"
|
| 44 |
+
restore = "\033[0m"
|
| 45 |
+
return "".join([code, text, restore])
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def bold(text):
|
| 49 |
+
"""
|
| 50 |
+
Display text in bold in the terminal.
|
| 51 |
+
"""
|
| 52 |
+
return colorize(text, "1")
|