Initial commit of Radio Upscaling UI (minus models)
Browse files- app.py +54 -0
- conf/experiment/aero_441-441_512_256.yaml +77 -0
- conf/experiment/seanet_441-441.yaml +46 -0
- html/directions.html +6 -0
- html/information.html +16 -0
- processAudio.py +87 -0
- requirements.txt +4 -0
- src/models/aero.py +524 -0
- src/models/modelFactory.py +15 -0
- src/models/modules.py +327 -0
- src/models/seanet.py +177 -0
- src/models/snake.py +66 -0
- src/models/spec.py +39 -0
- src/models/utils.py +50 -0
- src/utils.py +52 -0
app.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from torchaudio.functional import resample
|
5 |
+
|
6 |
+
from processAudio import upscaleAudio
|
7 |
+
|
8 |
+
class Object(object):
|
9 |
+
pass
|
10 |
+
|
11 |
+
def processAudio(model: gr.Dropdown, audioData: gr.Audio):
|
12 |
+
if len(audioData[1].shape) == 1: #Convert mono to stereo
|
13 |
+
lrAudio = torch.tensor(np.array([
|
14 |
+
audioData[1].copy().astype(np.float32)/32768,
|
15 |
+
audioData[1].copy().astype(np.float32)/32768
|
16 |
+
]))
|
17 |
+
elif audioData[1].shape[1] > 2:
|
18 |
+
raise gr.Error("Audio with more than 2 channels is not supported.")
|
19 |
+
else: #re-order channel data from [samples, 2] to [2, samples]
|
20 |
+
lrAudio = torch.tensor(audioData[1].copy().astype(np.float32)/32768).transpose(0,1)
|
21 |
+
if audioData[0] != 44100:
|
22 |
+
lrAudio = resample(lrAudio, audioData[0], 44100)
|
23 |
+
hrAudio=upscaleAudio(lrAudio, "models/" + model)
|
24 |
+
hrAudio=hrAudio / max(hrAudio.abs().max().item(), 1)
|
25 |
+
outAudio=(hrAudio*32767).numpy().astype(np.int16).transpose(1,0)
|
26 |
+
return tuple([44100, outAudio])
|
27 |
+
|
28 |
+
with gr.Blocks(theme=gr.themes.Default().set(body_background_fill="#CCEEFF")) as layout:
|
29 |
+
with gr.Row():
|
30 |
+
gr.Markdown("<h2>Broadcast Audio Upscaler</h2>")
|
31 |
+
with gr.Row():
|
32 |
+
with open("html/directions.html", "r") as directionsHtml:
|
33 |
+
gr.Markdown(directionsHtml.read())
|
34 |
+
with gr.Row():
|
35 |
+
modelSelect = gr.Dropdown(
|
36 |
+
[
|
37 |
+
["FM Radio Super Resolution","FM_Radio_SR.th"],
|
38 |
+
["AM Radio Super Resolution (Beta)","AM_Radio_SR.th"]
|
39 |
+
],
|
40 |
+
label="Select Model:",
|
41 |
+
value="FM_Radio_SR.th",
|
42 |
+
)
|
43 |
+
with gr.Row():
|
44 |
+
with gr.Column():
|
45 |
+
audioFileSelect = gr.Audio(label="Audio File (Mono or Stereo):",sources="upload")
|
46 |
+
with gr.Column():
|
47 |
+
audioOutput = gr.Audio(show_download_button=True, label="Restored Audio:", sources=[])
|
48 |
+
with gr.Row():
|
49 |
+
submit = gr.Button("Process Audio", variant="primary").click(fn=processAudio, inputs=[modelSelect, audioFileSelect], outputs=audioOutput)
|
50 |
+
with gr.Row():
|
51 |
+
with gr.Accordion("More Information:", open=False):
|
52 |
+
with open("html/information.html", "r") as informationHtml:
|
53 |
+
gr.Markdown(informationHtml.read())
|
54 |
+
layout.launch()
|
conf/experiment/aero_441-441_512_256.yaml
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package experiment
|
2 |
+
name: aero-nfft=${experiment.nfft}-hl=${experiment.hop_length}
|
3 |
+
|
4 |
+
# Dataset related
|
5 |
+
lr_sr: 44100 # low resolution sample rate, added to support BWE. Should be included in training cfg
|
6 |
+
hr_sr: 44100 # high resolution sample rate. Should be included in training cfg
|
7 |
+
segment: 2
|
8 |
+
stride: 4 # in seconds, how much to stride between training examples
|
9 |
+
pad: true # if training sample is too short, pad it
|
10 |
+
upsample: false
|
11 |
+
batch_size: 6
|
12 |
+
nfft: 512
|
13 |
+
hop_length: 256
|
14 |
+
|
15 |
+
# models related
|
16 |
+
model: aero
|
17 |
+
aero: # see aero.py for a detailed description
|
18 |
+
in_channels: 2
|
19 |
+
out_channels: 2
|
20 |
+
# Channels
|
21 |
+
channels: 64 #Small: 48 Large: 64
|
22 |
+
growth: 2
|
23 |
+
# STFT
|
24 |
+
nfft: 512
|
25 |
+
hop_length: 256
|
26 |
+
end_iters: 0
|
27 |
+
cac: true
|
28 |
+
# Main structure
|
29 |
+
rewrite: true
|
30 |
+
hybrid: false
|
31 |
+
hybrid_old: false
|
32 |
+
# Frequency Branch
|
33 |
+
freq_emb: 0.2
|
34 |
+
emb_scale: 10
|
35 |
+
emb_smooth: true
|
36 |
+
# Convolutions
|
37 |
+
kernel_size: 8
|
38 |
+
strides: [ 4,4,2,2 ]
|
39 |
+
context: 1
|
40 |
+
context_enc: 0
|
41 |
+
freq_ends: 4
|
42 |
+
enc_freq_attn: 0
|
43 |
+
# normalization
|
44 |
+
norm_starts: 2
|
45 |
+
norm_groups: 4
|
46 |
+
# DConv residual branch
|
47 |
+
dconv_mode: 1
|
48 |
+
dconv_depth: 2
|
49 |
+
dconv_comp: 4
|
50 |
+
dconv_time_attn: 2
|
51 |
+
dconv_lstm: 2
|
52 |
+
dconv_init: 0.001
|
53 |
+
# Weight init
|
54 |
+
rescale: 0.1
|
55 |
+
lr_sr: 44100
|
56 |
+
hr_sr: 44100
|
57 |
+
spec_upsample: true
|
58 |
+
act_func: snake
|
59 |
+
debug: false
|
60 |
+
|
61 |
+
adversarial: True
|
62 |
+
features_loss_lambda: 100
|
63 |
+
only_features_loss: False
|
64 |
+
only_adversarial_loss: False
|
65 |
+
discriminator_models: [ msd, mpd ] #msd_melgan/msd_hifi/mpd/hifi
|
66 |
+
melgan_discriminator:
|
67 |
+
n_layers: 5
|
68 |
+
num_D: 3
|
69 |
+
downsampling_factor: 1
|
70 |
+
ndf: 16
|
71 |
+
channels: 2
|
72 |
+
msd:
|
73 |
+
hidden: 32 #Small: 16 Large: 32
|
74 |
+
channels: 2
|
75 |
+
mpd:
|
76 |
+
hidden: 32 #Small: 16 Large: 32
|
77 |
+
channels: 2
|
conf/experiment/seanet_441-441.yaml
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package experiment
|
2 |
+
name: seanet-nfft=${experiment.nfft}-hl=${experiment.hop_length}
|
3 |
+
|
4 |
+
# Dataset related
|
5 |
+
lr_sr: 44100 # low resolution sample rate, added to support BWE. Should be included in training cfg
|
6 |
+
hr_sr: 44100 # high resolution sample rate. Should be included in training cfg
|
7 |
+
segment: 2
|
8 |
+
stride: 2 # in seconds, how much to stride between training examples
|
9 |
+
pad: true # if training sample is too short, pad it
|
10 |
+
upsample: false
|
11 |
+
batch_size: 5
|
12 |
+
nfft: 1024
|
13 |
+
hop_length: 512
|
14 |
+
|
15 |
+
# models related
|
16 |
+
model: seanet
|
17 |
+
seanet:
|
18 |
+
latent_space_size: 256
|
19 |
+
ngf: 32
|
20 |
+
n_residual_layers: 4
|
21 |
+
resample: 1
|
22 |
+
normalize: False
|
23 |
+
floor: 0.0005
|
24 |
+
ratios: [ 8,8,2,2 ]
|
25 |
+
lr_sr: 44100
|
26 |
+
hr_sr: 44100
|
27 |
+
in_channels: 2
|
28 |
+
out_channels: 2
|
29 |
+
|
30 |
+
adversarial: True
|
31 |
+
features_loss_lambda: 100
|
32 |
+
discriminator_models: [ msd, mpd ] #msd_melgan/msd_hifi/mpd/hifi
|
33 |
+
only_features_loss: False
|
34 |
+
only_adversarial_loss: False
|
35 |
+
msd:
|
36 |
+
hidden: 16
|
37 |
+
channels: 2
|
38 |
+
mpd:
|
39 |
+
hidden: 16
|
40 |
+
channels: 2
|
41 |
+
melgan_discriminator:
|
42 |
+
n_layers: 5
|
43 |
+
num_D: 3
|
44 |
+
downsampling_factor: 1
|
45 |
+
ndf: 16
|
46 |
+
channels: 2
|
html/directions.html
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<h3>Directions:</h3>
|
2 |
+
<ol>
|
3 |
+
<li>Select a model from the model drop-down.</li>
|
4 |
+
<li>Select an audio file.</li>
|
5 |
+
<li>After the file loads, press "Process Audio" to process the file.</li>
|
6 |
+
</ol>
|
html/information.html
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<p>This tool is based on the "aero" audio upscaling project (<a href="https://github.com/slp-rl/aero">original version</a>, <a href="https://github.com/pokepress/aero">modified version used here</a>). It takes digitized radio recordings and attempts to reduce noise/interference while recreating lost frequencies.</p>
|
2 |
+
|
3 |
+
<B>The FM model should be able to:</B>
|
4 |
+
<ul>
|
5 |
+
<li>Reduce stereo hiss</li>
|
6 |
+
<li>Reduce certain kinds of static/interference</li>
|
7 |
+
<li>Reduce 19khz stereo pilot</li>
|
8 |
+
<li>Reduce overly warm bass frequencies (boost them 3 or so dB if you miss them)</li>
|
9 |
+
<li>Restore lost high & low frequencies (generally only up to 16khz since most training data is MP3-based)</li>
|
10 |
+
</ul>
|
11 |
+
|
12 |
+
<B>The AM model should be able to:</B>
|
13 |
+
<ul>
|
14 |
+
<li>Reduce certain kinds of static/interference</li>
|
15 |
+
<li>Restore lost high & low frequencies (capacity tends to diminish after 10khz)</li>
|
16 |
+
</ul>
|
processAudio.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
|
4 |
+
import time
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import logging
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
import torchaudio
|
11 |
+
from torchaudio.functional import resample
|
12 |
+
from gradio import Progress
|
13 |
+
|
14 |
+
from src.models import modelFactory
|
15 |
+
from src.utils import bold
|
16 |
+
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
SEGMENT_DURATION_SEC = 5
|
20 |
+
SEGMENT_OVERLAP_SAMPLES = 2048
|
21 |
+
SERIALIZE_KEY_STATE = 'state'
|
22 |
+
|
23 |
+
def _load_model(checkpoint_file="models/FM_Radio_SR.th",model_name="aero"):
|
24 |
+
checkpoint_file = Path(checkpoint_file)
|
25 |
+
model = modelFactory.get_model(model_name)['generator']
|
26 |
+
package = torch.load(checkpoint_file, 'cpu')
|
27 |
+
if 'state' in package.keys(): #raw model file
|
28 |
+
logger.info(bold(f'Loading model {model_name} from file.'))
|
29 |
+
model.load_state_dict(package[SERIALIZE_KEY_STATE])
|
30 |
+
|
31 |
+
return model
|
32 |
+
|
33 |
+
def crossfade_and_blend(out_clip, in_clip):
|
34 |
+
fade_out = torchaudio.transforms.Fade(0,SEGMENT_OVERLAP_SAMPLES)
|
35 |
+
fade_in = torchaudio.transforms.Fade(SEGMENT_OVERLAP_SAMPLES, 0)
|
36 |
+
return fade_out(out_clip) + fade_in(in_clip)
|
37 |
+
|
38 |
+
def upscaleAudio(lr_sig, checkpoint_file: str, sr=44100, hr_sr=44100, model_name="aero", progress=Progress()):
|
39 |
+
|
40 |
+
model = _load_model(checkpoint_file,model_name)
|
41 |
+
device = torch.device('cpu')
|
42 |
+
#model.cuda()
|
43 |
+
|
44 |
+
logger.info(f'lr wav shape: {lr_sig.shape}')
|
45 |
+
|
46 |
+
segment_duration_samples = sr * SEGMENT_DURATION_SEC
|
47 |
+
n_chunks = math.ceil(lr_sig.shape[-1] / segment_duration_samples)
|
48 |
+
logger.info(f'number of chunks: {n_chunks}')
|
49 |
+
|
50 |
+
lr_chunks = []
|
51 |
+
for i in range(n_chunks):
|
52 |
+
start = i * segment_duration_samples
|
53 |
+
end = min((i + 1) * segment_duration_samples, lr_sig.shape[-1])
|
54 |
+
lr_chunks.append(lr_sig[:, start:end])
|
55 |
+
|
56 |
+
pr_chunks = []
|
57 |
+
|
58 |
+
lr_segment_overlap_samples = int((sr/hr_sr) * SEGMENT_OVERLAP_SAMPLES)
|
59 |
+
|
60 |
+
model.eval()
|
61 |
+
pred_start = time.time()
|
62 |
+
with torch.no_grad():
|
63 |
+
previous_chunk = None
|
64 |
+
progress(0)
|
65 |
+
for i, lr_chunk in enumerate(lr_chunks):
|
66 |
+
progress(i/n_chunks)
|
67 |
+
pr_chunk = None
|
68 |
+
if previous_chunk is not None:
|
69 |
+
combined_chunk = torch.cat((previous_chunk[...,-lr_segment_overlap_samples:], lr_chunk), 1)
|
70 |
+
pr_combined_chunk = model(combined_chunk.unsqueeze(0).to(device)).squeeze(0)
|
71 |
+
pr_chunk = pr_combined_chunk[...,SEGMENT_OVERLAP_SAMPLES:]
|
72 |
+
pr_chunks[-1][...,-SEGMENT_OVERLAP_SAMPLES:] = crossfade_and_blend(pr_chunks[-1][...,-SEGMENT_OVERLAP_SAMPLES:], pr_combined_chunk.cpu()[...,:SEGMENT_OVERLAP_SAMPLES] )
|
73 |
+
else:
|
74 |
+
pr_chunk = model(lr_chunk.unsqueeze(0).to(device)).squeeze(0)
|
75 |
+
logger.info(f'lr chunk {i} shape: {lr_chunk.shape}')
|
76 |
+
logger.info(f'pr chunk {i} shape: {pr_chunk.shape}')
|
77 |
+
pr_chunks.append(pr_chunk.cpu())
|
78 |
+
previous_chunk = lr_chunk
|
79 |
+
|
80 |
+
pred_duration = time.time() - pred_start
|
81 |
+
logger.info(f'prediction duration: {pred_duration}')
|
82 |
+
|
83 |
+
pr = torch.concat(pr_chunks, dim=-1)
|
84 |
+
|
85 |
+
logger.info(f'pr wav shape: {pr.shape}')
|
86 |
+
|
87 |
+
return pr
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.2.1
|
2 |
+
torchvision=0.17.2
|
3 |
+
torchaudio==2.2.1
|
4 |
+
opencv-python
|
src/models/aero.py
ADDED
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This code is based on Facebook's HDemucs code: https://github.com/facebookresearch/demucs
|
3 |
+
"""
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
from torch.nn import functional as F
|
8 |
+
|
9 |
+
from src.models.utils import capture_init
|
10 |
+
from src.models.spec import spectro, ispectro
|
11 |
+
from src.models.modules import DConv, ScaledEmbedding, FTB
|
12 |
+
|
13 |
+
import logging
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
def rescale_conv(conv, reference):
|
18 |
+
std = conv.weight.std().detach()
|
19 |
+
scale = (std / reference) ** 0.5
|
20 |
+
conv.weight.data /= scale
|
21 |
+
if conv.bias is not None:
|
22 |
+
conv.bias.data /= scale
|
23 |
+
|
24 |
+
|
25 |
+
def rescale_module(module, reference):
|
26 |
+
for sub in module.modules():
|
27 |
+
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
|
28 |
+
rescale_conv(sub, reference)
|
29 |
+
|
30 |
+
|
31 |
+
class HEncLayer(nn.Module):
|
32 |
+
def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False,
|
33 |
+
freq=True, dconv=True, is_first=False, freq_attn=False, freq_dim=None, norm=True, context=0,
|
34 |
+
dconv_kw={}, pad=True,
|
35 |
+
rewrite=True):
|
36 |
+
"""Encoder layer. This used both by the time and the frequency branch.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
chin: number of input channels.
|
40 |
+
chout: number of output channels.
|
41 |
+
norm_groups: number of groups for group norm.
|
42 |
+
empty: used to make a layer with just the first conv. this is used
|
43 |
+
before merging the time and freq. branches.
|
44 |
+
freq: this is acting on frequencies.
|
45 |
+
dconv: insert DConv residual branches.
|
46 |
+
norm: use GroupNorm.
|
47 |
+
context: context size for the 1x1 conv.
|
48 |
+
dconv_kw: list of kwargs for the DConv class.
|
49 |
+
pad: pad the input. Padding is done so that the output size is
|
50 |
+
always the input size / stride.
|
51 |
+
rewrite: add 1x1 conv at the end of the layer.
|
52 |
+
"""
|
53 |
+
super().__init__()
|
54 |
+
norm_fn = lambda d: nn.Identity() # noqa
|
55 |
+
if norm:
|
56 |
+
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
|
57 |
+
if stride == 1 and kernel_size % 2 == 0 and kernel_size > 1:
|
58 |
+
kernel_size -= 1
|
59 |
+
if pad:
|
60 |
+
pad = (kernel_size - stride) // 2
|
61 |
+
else:
|
62 |
+
pad = 0
|
63 |
+
klass = nn.Conv2d
|
64 |
+
self.chin = chin
|
65 |
+
self.chout = chout
|
66 |
+
self.freq = freq
|
67 |
+
self.kernel_size = kernel_size
|
68 |
+
self.stride = stride
|
69 |
+
self.empty = empty
|
70 |
+
self.freq_attn = freq_attn
|
71 |
+
self.freq_dim = freq_dim
|
72 |
+
self.norm = norm
|
73 |
+
self.pad = pad
|
74 |
+
if freq:
|
75 |
+
kernel_size = [kernel_size, 1]
|
76 |
+
stride = [stride, 1]
|
77 |
+
if pad != 0:
|
78 |
+
pad = [pad, 0]
|
79 |
+
# klass = nn.Conv2d
|
80 |
+
else:
|
81 |
+
kernel_size = [1, kernel_size]
|
82 |
+
stride = [1, stride]
|
83 |
+
if pad != 0:
|
84 |
+
pad = [0, pad]
|
85 |
+
|
86 |
+
self.is_first = is_first
|
87 |
+
|
88 |
+
if is_first:
|
89 |
+
self.pre_conv = nn.Conv2d(chin, chout, [1, 1])
|
90 |
+
chin = chout
|
91 |
+
|
92 |
+
if self.freq_attn:
|
93 |
+
self.freq_attn_block = FTB(input_dim=freq_dim, in_channel=chin)
|
94 |
+
|
95 |
+
self.conv = klass(chin, chout, kernel_size, stride, pad)
|
96 |
+
if self.empty:
|
97 |
+
return
|
98 |
+
self.norm1 = norm_fn(chout)
|
99 |
+
self.rewrite = None
|
100 |
+
if rewrite:
|
101 |
+
self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
|
102 |
+
self.norm2 = norm_fn(2 * chout)
|
103 |
+
|
104 |
+
self.dconv = None
|
105 |
+
if dconv:
|
106 |
+
self.dconv = DConv(chout, **dconv_kw)
|
107 |
+
|
108 |
+
def forward(self, x, inject=None):
|
109 |
+
"""
|
110 |
+
`inject` is used to inject the result from the time branch into the frequency branch,
|
111 |
+
when both have the same stride.
|
112 |
+
"""
|
113 |
+
|
114 |
+
if not self.freq:
|
115 |
+
le = x.shape[-1]
|
116 |
+
if not le % self.stride == 0:
|
117 |
+
x = F.pad(x, (0, self.stride - (le % self.stride)))
|
118 |
+
|
119 |
+
if self.is_first:
|
120 |
+
x = self.pre_conv(x)
|
121 |
+
|
122 |
+
if self.freq_attn:
|
123 |
+
x = self.freq_attn_block(x)
|
124 |
+
|
125 |
+
x = self.conv(x)
|
126 |
+
|
127 |
+
x = F.gelu(self.norm1(x))
|
128 |
+
if self.dconv:
|
129 |
+
x = self.dconv(x)
|
130 |
+
|
131 |
+
if self.rewrite:
|
132 |
+
x = self.norm2(self.rewrite(x))
|
133 |
+
x = F.glu(x, dim=1)
|
134 |
+
|
135 |
+
return x
|
136 |
+
|
137 |
+
|
138 |
+
class HDecLayer(nn.Module):
|
139 |
+
def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False,
|
140 |
+
freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True,
|
141 |
+
context_freq=True, rewrite=True):
|
142 |
+
"""
|
143 |
+
Same as HEncLayer but for decoder. See `HEncLayer` for documentation.
|
144 |
+
"""
|
145 |
+
super().__init__()
|
146 |
+
norm_fn = lambda d: nn.Identity() # noqa
|
147 |
+
if norm:
|
148 |
+
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
|
149 |
+
if stride == 1 and kernel_size % 2 == 0 and kernel_size > 1:
|
150 |
+
kernel_size -= 1
|
151 |
+
if pad:
|
152 |
+
pad = (kernel_size - stride) // 2
|
153 |
+
else:
|
154 |
+
pad = 0
|
155 |
+
self.pad = pad
|
156 |
+
self.last = last
|
157 |
+
self.freq = freq
|
158 |
+
self.chin = chin
|
159 |
+
self.empty = empty
|
160 |
+
self.stride = stride
|
161 |
+
self.kernel_size = kernel_size
|
162 |
+
self.norm = norm
|
163 |
+
self.context_freq = context_freq
|
164 |
+
klass = nn.Conv2d
|
165 |
+
klass_tr = nn.ConvTranspose2d
|
166 |
+
if freq:
|
167 |
+
kernel_size = [kernel_size, 1]
|
168 |
+
stride = [stride, 1]
|
169 |
+
else:
|
170 |
+
kernel_size = [1, kernel_size]
|
171 |
+
stride = [1, stride]
|
172 |
+
self.conv_tr = klass_tr(chin, chout, kernel_size, stride)
|
173 |
+
self.norm2 = norm_fn(chout)
|
174 |
+
if self.empty:
|
175 |
+
return
|
176 |
+
self.rewrite = None
|
177 |
+
if rewrite:
|
178 |
+
if context_freq:
|
179 |
+
self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
|
180 |
+
else:
|
181 |
+
self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1,
|
182 |
+
[0, context])
|
183 |
+
self.norm1 = norm_fn(2 * chin)
|
184 |
+
|
185 |
+
self.dconv = None
|
186 |
+
if dconv:
|
187 |
+
self.dconv = DConv(chin, **dconv_kw)
|
188 |
+
|
189 |
+
def forward(self, x, skip, length):
|
190 |
+
if self.freq and x.dim() == 3:
|
191 |
+
B, C, T = x.shape
|
192 |
+
x = x.view(B, self.chin, -1, T)
|
193 |
+
|
194 |
+
if not self.empty:
|
195 |
+
x = torch.cat([x, skip], dim=1)
|
196 |
+
|
197 |
+
if self.rewrite:
|
198 |
+
y = F.glu(self.norm1(self.rewrite(x)), dim=1)
|
199 |
+
else:
|
200 |
+
y = x
|
201 |
+
if self.dconv:
|
202 |
+
y = self.dconv(y)
|
203 |
+
else:
|
204 |
+
y = x
|
205 |
+
assert skip is None
|
206 |
+
z = self.norm2(self.conv_tr(y))
|
207 |
+
if self.freq:
|
208 |
+
if self.pad:
|
209 |
+
z = z[..., self.pad:-self.pad, :]
|
210 |
+
else:
|
211 |
+
z = z[..., self.pad:self.pad + length]
|
212 |
+
assert z.shape[-1] == length, (z.shape[-1], length)
|
213 |
+
if not self.last:
|
214 |
+
z = F.gelu(z)
|
215 |
+
return z
|
216 |
+
|
217 |
+
|
218 |
+
class Aero(nn.Module):
|
219 |
+
"""
|
220 |
+
Deep model for Audio Super Resolution.
|
221 |
+
"""
|
222 |
+
|
223 |
+
@capture_init
|
224 |
+
def __init__(self,
|
225 |
+
# Channels
|
226 |
+
in_channels=1,
|
227 |
+
out_channels=1,
|
228 |
+
audio_channels=2,
|
229 |
+
channels=48,
|
230 |
+
growth=2,
|
231 |
+
# STFT
|
232 |
+
nfft=512,
|
233 |
+
hop_length=64,
|
234 |
+
end_iters=0,
|
235 |
+
cac=True,
|
236 |
+
# Main structure
|
237 |
+
rewrite=True,
|
238 |
+
hybrid=False,
|
239 |
+
hybrid_old=False,
|
240 |
+
# Frequency branch
|
241 |
+
freq_emb=0.2,
|
242 |
+
emb_scale=10,
|
243 |
+
emb_smooth=True,
|
244 |
+
# Convolutions
|
245 |
+
kernel_size=8,
|
246 |
+
strides=[4, 4, 2, 2],
|
247 |
+
context=1,
|
248 |
+
context_enc=0,
|
249 |
+
freq_ends=4,
|
250 |
+
enc_freq_attn=4,
|
251 |
+
# Normalization
|
252 |
+
norm_starts=2,
|
253 |
+
norm_groups=4,
|
254 |
+
# DConv residual branch
|
255 |
+
dconv_mode=1,
|
256 |
+
dconv_depth=2,
|
257 |
+
dconv_comp=4,
|
258 |
+
dconv_time_attn=2,
|
259 |
+
dconv_lstm=2,
|
260 |
+
dconv_init=1e-3,
|
261 |
+
# Weight init
|
262 |
+
rescale=0.1,
|
263 |
+
# Metadata
|
264 |
+
lr_sr=4000,
|
265 |
+
hr_sr=16000,
|
266 |
+
spec_upsample=True,
|
267 |
+
act_func='snake',
|
268 |
+
debug=False):
|
269 |
+
"""
|
270 |
+
Args:
|
271 |
+
sources (list[str]): list of source names.
|
272 |
+
audio_channels (int): input/output audio channels.
|
273 |
+
channels (int): initial number of hidden channels.
|
274 |
+
growth: increase the number of hidden channels by this factor at each layer.
|
275 |
+
nfft: number of fft bins. Note that changing this require careful computation of
|
276 |
+
various shape parameters and will not work out of the box for hybrid models.
|
277 |
+
end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
|
278 |
+
cac: uses complex as channels, i.e. complex numbers are 2 channels each
|
279 |
+
in input and output. no further processing is done before ISTFT.
|
280 |
+
depth (int): number of layers in the encoder and in the decoder.
|
281 |
+
rewrite (bool): add 1x1 convolution to each layer.
|
282 |
+
hybrid (bool): make a hybrid time/frequency domain, otherwise frequency only.
|
283 |
+
hybrid_old: some models trained for MDX had a padding bug. This replicates
|
284 |
+
this bug to avoid retraining them.
|
285 |
+
freq_emb: add frequency embedding after the first frequency layer if > 0,
|
286 |
+
the actual value controls the weight of the embedding.
|
287 |
+
emb_scale: equivalent to scaling the embedding learning rate
|
288 |
+
emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
|
289 |
+
kernel_size: kernel_size for encoder and decoder layers.
|
290 |
+
stride: stride for encoder and decoder layers.
|
291 |
+
context: context for 1x1 conv in the decoder.
|
292 |
+
context_enc: context for 1x1 conv in the encoder.
|
293 |
+
norm_starts: layer at which group norm starts being used.
|
294 |
+
decoder layers are numbered in reverse order.
|
295 |
+
norm_groups: number of groups for group norm.
|
296 |
+
dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
|
297 |
+
dconv_depth: depth of residual DConv branch.
|
298 |
+
dconv_comp: compression of DConv branch.
|
299 |
+
dconv_freq_attn: adds freq attention layers in DConv branch starting at this layer.
|
300 |
+
dconv_time_attn: adds time attention layers in DConv branch starting at this layer.
|
301 |
+
dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
|
302 |
+
dconv_init: initial scale for the DConv branch LayerScale.
|
303 |
+
rescale: weight recaling trick
|
304 |
+
lr_sr: source low-resolution sample-rate
|
305 |
+
hr_sr: target high-resolution sample-rate
|
306 |
+
spec_upsample: if true, upsamples in the spectral domain, otherwise performs sinc-interpolation beforehand
|
307 |
+
act_func: 'snake'/'relu'
|
308 |
+
debug: if true, prints out input dimensions throughout model layers.
|
309 |
+
"""
|
310 |
+
super().__init__()
|
311 |
+
self.cac = cac
|
312 |
+
self.in_channels = in_channels
|
313 |
+
self.out_channels = out_channels
|
314 |
+
self.audio_channels = audio_channels
|
315 |
+
self.kernel_size = kernel_size
|
316 |
+
self.context = context
|
317 |
+
self.strides = strides
|
318 |
+
self.depth = len(strides)
|
319 |
+
self.channels = channels
|
320 |
+
self.lr_sr = lr_sr
|
321 |
+
self.hr_sr = hr_sr
|
322 |
+
self.spec_upsample = spec_upsample
|
323 |
+
|
324 |
+
self.scale = hr_sr / lr_sr if self.spec_upsample else 1
|
325 |
+
|
326 |
+
self.nfft = nfft
|
327 |
+
self.hop_length = int(hop_length // self.scale) # this is for the input signal
|
328 |
+
self.win_length = int(self.nfft // self.scale) # this is for the input signal
|
329 |
+
self.end_iters = end_iters
|
330 |
+
self.freq_emb = None
|
331 |
+
self.hybrid = hybrid
|
332 |
+
self.hybrid_old = hybrid_old
|
333 |
+
self.debug = debug
|
334 |
+
|
335 |
+
self.encoder = nn.ModuleList()
|
336 |
+
self.decoder = nn.ModuleList()
|
337 |
+
|
338 |
+
chin_z = self.in_channels
|
339 |
+
if self.cac:
|
340 |
+
chin_z *= 2
|
341 |
+
chout_z = channels
|
342 |
+
freqs = nfft // 2
|
343 |
+
|
344 |
+
for index in range(self.depth):
|
345 |
+
freq_attn = index >= enc_freq_attn
|
346 |
+
lstm = index >= dconv_lstm
|
347 |
+
time_attn = index >= dconv_time_attn
|
348 |
+
norm = index >= norm_starts
|
349 |
+
freq = index <= freq_ends
|
350 |
+
stri = strides[index]
|
351 |
+
ker = kernel_size
|
352 |
+
|
353 |
+
pad = True
|
354 |
+
if freq and freqs < kernel_size:
|
355 |
+
ker = freqs
|
356 |
+
|
357 |
+
kw = {
|
358 |
+
'kernel_size': ker,
|
359 |
+
'stride': stri,
|
360 |
+
'freq': freq,
|
361 |
+
'pad': pad,
|
362 |
+
'norm': norm,
|
363 |
+
'rewrite': rewrite,
|
364 |
+
'norm_groups': norm_groups,
|
365 |
+
'dconv_kw': {
|
366 |
+
'lstm': lstm,
|
367 |
+
'time_attn': time_attn,
|
368 |
+
'depth': dconv_depth,
|
369 |
+
'compress': dconv_comp,
|
370 |
+
'init': dconv_init,
|
371 |
+
'act_func': act_func,
|
372 |
+
'reshape': True,
|
373 |
+
'freq_dim': freqs // strides[index] if freq else freqs
|
374 |
+
}
|
375 |
+
}
|
376 |
+
|
377 |
+
kw_dec = dict(kw)
|
378 |
+
|
379 |
+
enc = HEncLayer(chin_z, chout_z,
|
380 |
+
dconv=dconv_mode & 1, context=context_enc,
|
381 |
+
is_first=index == 0, freq_attn=freq_attn, freq_dim=freqs,
|
382 |
+
**kw)
|
383 |
+
|
384 |
+
self.encoder.append(enc)
|
385 |
+
if index == 0:
|
386 |
+
chin = self.out_channels
|
387 |
+
chin_z = chin
|
388 |
+
if self.cac:
|
389 |
+
chin_z *= 2
|
390 |
+
dec = HDecLayer(2 * chout_z, chin_z, dconv=dconv_mode & 2,
|
391 |
+
last=index == 0, context=context, **kw_dec)
|
392 |
+
|
393 |
+
self.decoder.insert(0, dec)
|
394 |
+
|
395 |
+
chin_z = chout_z
|
396 |
+
chout_z = int(growth * chout_z)
|
397 |
+
|
398 |
+
if freq:
|
399 |
+
freqs //= strides[index]
|
400 |
+
|
401 |
+
if index == 0 and freq_emb:
|
402 |
+
self.freq_emb = ScaledEmbedding(
|
403 |
+
freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
|
404 |
+
self.freq_emb_scale = freq_emb
|
405 |
+
|
406 |
+
if rescale:
|
407 |
+
rescale_module(self, reference=rescale)
|
408 |
+
|
409 |
+
def _spec(self, x, scale=False):
|
410 |
+
if np.mod(x.shape[-1], self.hop_length):
|
411 |
+
x = F.pad(x, (0, self.hop_length - np.mod(x.shape[-1], self.hop_length)))
|
412 |
+
hl = self.hop_length
|
413 |
+
nfft = self.nfft
|
414 |
+
win_length = self.win_length
|
415 |
+
|
416 |
+
if scale:
|
417 |
+
hl = int(hl * self.scale)
|
418 |
+
win_length = int(win_length * self.scale)
|
419 |
+
|
420 |
+
z = spectro(x, nfft, hl, win_length=win_length)[..., :-1, :]
|
421 |
+
return z
|
422 |
+
|
423 |
+
def _ispec(self, z):
|
424 |
+
hl = int(self.hop_length * self.scale)
|
425 |
+
win_length = int(self.win_length * self.scale)
|
426 |
+
z = F.pad(z, (0, 0, 0, 1))
|
427 |
+
x = ispectro(z, hl, win_length=win_length)
|
428 |
+
return x
|
429 |
+
|
430 |
+
def _move_complex_to_channels_dim(self, z):
|
431 |
+
B, C, Fr, T = z.shape
|
432 |
+
m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
|
433 |
+
m = m.reshape(B, C * 2, Fr, T)
|
434 |
+
return m
|
435 |
+
|
436 |
+
def _convert_to_complex(self, x):
|
437 |
+
"""
|
438 |
+
|
439 |
+
:param x: signal of shape [Batch, Channels, 2, Freq, TimeFrames]
|
440 |
+
:return: complex signal of shape [Batch, Channels, Freq, TimeFrames]
|
441 |
+
"""
|
442 |
+
out = x.permute(0, 1, 3, 4, 2)
|
443 |
+
out = torch.view_as_complex(out.contiguous())
|
444 |
+
return out
|
445 |
+
|
446 |
+
def forward(self, mix, return_spec=False, return_lr_spec=False):
|
447 |
+
x = mix
|
448 |
+
length = x.shape[-1]
|
449 |
+
|
450 |
+
if self.debug:
|
451 |
+
logger.info(f'hdemucs in shape: {x.shape}')
|
452 |
+
|
453 |
+
z = self._spec(x)
|
454 |
+
x = self._move_complex_to_channels_dim(z)
|
455 |
+
|
456 |
+
if self.debug:
|
457 |
+
logger.info(f'x spec shape: {x.shape}')
|
458 |
+
|
459 |
+
B, C, Fq, T = x.shape
|
460 |
+
|
461 |
+
# unlike previous Demucs, we always normalize because it is easier.
|
462 |
+
mean = x.mean(dim=(1, 2, 3), keepdim=True)
|
463 |
+
std = x.std(dim=(1, 2, 3), keepdim=True)
|
464 |
+
x = (x - mean) / (1e-5 + std)
|
465 |
+
|
466 |
+
# okay, this is a giant mess I know...
|
467 |
+
saved = [] # skip connections, freq.
|
468 |
+
lengths = [] # saved lengths to properly remove padding, freq branch.
|
469 |
+
for idx, encode in enumerate(self.encoder):
|
470 |
+
lengths.append(x.shape[-1])
|
471 |
+
inject = None
|
472 |
+
x = encode(x, inject)
|
473 |
+
if self.debug:
|
474 |
+
logger.info(f'encoder {idx} out shape: {x.shape}')
|
475 |
+
if idx == 0 and self.freq_emb is not None:
|
476 |
+
# add frequency embedding to allow for non equivariant convolutions
|
477 |
+
# over the frequency axis.
|
478 |
+
frs = torch.arange(x.shape[-2], device=x.device)
|
479 |
+
emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
|
480 |
+
x = x + self.freq_emb_scale * emb
|
481 |
+
|
482 |
+
saved.append(x)
|
483 |
+
|
484 |
+
x = torch.zeros_like(x)
|
485 |
+
# initialize everything to zero (signal will go through u-net skips).
|
486 |
+
|
487 |
+
for idx, decode in enumerate(self.decoder):
|
488 |
+
skip = saved.pop(-1)
|
489 |
+
x = decode(x, skip, lengths.pop(-1))
|
490 |
+
|
491 |
+
if self.debug:
|
492 |
+
logger.info(f'decoder {idx} out shape: {x.shape}')
|
493 |
+
|
494 |
+
# Let's make sure we used all stored skip connections.
|
495 |
+
assert len(saved) == 0
|
496 |
+
|
497 |
+
x = x.view(B, self.out_channels, -1, Fq, T)
|
498 |
+
x = x * std[:, None] + mean[:, None]
|
499 |
+
|
500 |
+
if self.debug:
|
501 |
+
logger.info(f'post view shape: {x.shape}')
|
502 |
+
|
503 |
+
x_spec_complex = self._convert_to_complex(x)
|
504 |
+
|
505 |
+
if self.debug:
|
506 |
+
logger.info(f'x_spec_complex shape: {x_spec_complex.shape}')
|
507 |
+
|
508 |
+
x = self._ispec(x_spec_complex)
|
509 |
+
|
510 |
+
if self.debug:
|
511 |
+
logger.info(f'hdemucs out shape: {x.shape}')
|
512 |
+
|
513 |
+
x = x[..., :int(length * self.scale)]
|
514 |
+
|
515 |
+
if self.debug:
|
516 |
+
logger.info(f'hdemucs out - trimmed shape: {x.shape}')
|
517 |
+
|
518 |
+
if return_spec:
|
519 |
+
if return_lr_spec:
|
520 |
+
return x, x_spec_complex, z
|
521 |
+
else:
|
522 |
+
return x, x_spec_complex
|
523 |
+
|
524 |
+
return x
|
src/models/modelFactory.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.models.aero import Aero
|
2 |
+
from src.models.seanet import Seanet
|
3 |
+
from yaml import safe_load
|
4 |
+
|
5 |
+
def get_model(model_name="aero"):
|
6 |
+
if model_name == 'aero':
|
7 |
+
with open("conf/experiment/aero_441-441_512_256.yaml") as f:
|
8 |
+
generator = Aero(**safe_load(f)["aero"])
|
9 |
+
elif model_name == 'seanet':
|
10 |
+
with open("conf/experiment/seanet_441-441.yaml") as f:
|
11 |
+
generator = Seanet(**safe_load(f)["seanet"])
|
12 |
+
|
13 |
+
models = {'generator': generator}
|
14 |
+
|
15 |
+
return models
|
src/models/modules.py
ADDED
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn.utils.parametrizations import weight_norm
|
4 |
+
|
5 |
+
import math
|
6 |
+
|
7 |
+
from src.models.snake import Snake
|
8 |
+
from src.models.utils import unfold
|
9 |
+
|
10 |
+
import typing as tp
|
11 |
+
|
12 |
+
def WNConv1d(*args, **kwargs):
|
13 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
14 |
+
|
15 |
+
|
16 |
+
def WNConvTranspose1d(*args, **kwargs):
|
17 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
18 |
+
|
19 |
+
class BLSTM(nn.Module):
|
20 |
+
"""
|
21 |
+
BiLSTM with same hidden units as input dim.
|
22 |
+
If `max_steps` is not None, input will be splitting in overlapping
|
23 |
+
chunks and the LSTM applied separately on each chunk.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, dim, layers=1, max_steps=None, skip=False):
|
27 |
+
super().__init__()
|
28 |
+
assert max_steps is None or max_steps % 4 == 0
|
29 |
+
self.max_steps = max_steps
|
30 |
+
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
|
31 |
+
self.linear = nn.Linear(2 * dim, dim)
|
32 |
+
self.skip = skip
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
B, C, T = x.shape
|
36 |
+
y = x
|
37 |
+
framed = False
|
38 |
+
if self.max_steps is not None and T > self.max_steps:
|
39 |
+
width = self.max_steps
|
40 |
+
stride = width // 2
|
41 |
+
frames = unfold(x, width, stride)
|
42 |
+
nframes = frames.shape[2]
|
43 |
+
framed = True
|
44 |
+
x = frames.permute(0, 2, 1, 3).reshape(-1, C, width)
|
45 |
+
|
46 |
+
x = x.permute(2, 0, 1)
|
47 |
+
|
48 |
+
x = self.lstm(x)[0]
|
49 |
+
x = self.linear(x)
|
50 |
+
x = x.permute(1, 2, 0)
|
51 |
+
if framed:
|
52 |
+
out = []
|
53 |
+
frames = x.reshape(B, -1, C, width)
|
54 |
+
limit = stride // 2
|
55 |
+
for k in range(nframes):
|
56 |
+
if k == 0:
|
57 |
+
out.append(frames[:, k, :, :-limit])
|
58 |
+
elif k == nframes - 1:
|
59 |
+
out.append(frames[:, k, :, limit:])
|
60 |
+
else:
|
61 |
+
out.append(frames[:, k, :, limit:-limit])
|
62 |
+
out = torch.cat(out, -1)
|
63 |
+
out = out[..., :T]
|
64 |
+
x = out
|
65 |
+
if self.skip:
|
66 |
+
x = x + y
|
67 |
+
return x
|
68 |
+
|
69 |
+
|
70 |
+
class LocalState(nn.Module):
|
71 |
+
"""Local state allows to have attention based only on data (no positional embedding),
|
72 |
+
but while setting a constraint on the time window (e.g. decaying penalty term).
|
73 |
+
Also a failed experiments with trying to provide some frequency based attention.
|
74 |
+
"""
|
75 |
+
|
76 |
+
def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4):
|
77 |
+
super().__init__()
|
78 |
+
assert channels % heads == 0, (channels, heads)
|
79 |
+
self.heads = heads
|
80 |
+
self.nfreqs = nfreqs
|
81 |
+
self.ndecay = ndecay
|
82 |
+
self.content = nn.Conv1d(channels, channels, 1)
|
83 |
+
self.query = nn.Conv1d(channels, channels, 1)
|
84 |
+
self.key = nn.Conv1d(channels, channels, 1)
|
85 |
+
if nfreqs:
|
86 |
+
self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1)
|
87 |
+
if ndecay:
|
88 |
+
self.query_decay = nn.Conv1d(channels, heads * ndecay, 1)
|
89 |
+
# Initialize decay close to zero (there is a sigmoid), for maximum initial window.
|
90 |
+
self.query_decay.weight.data *= 0.01
|
91 |
+
assert self.query_decay.bias is not None # stupid type checker
|
92 |
+
self.query_decay.bias.data[:] = -2
|
93 |
+
# self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1)
|
94 |
+
self.proj = nn.Conv1d(channels, channels, 1)
|
95 |
+
|
96 |
+
def forward(self, x):
|
97 |
+
B, C, T = x.shape
|
98 |
+
heads = self.heads
|
99 |
+
indexes = torch.arange(T, device=x.device, dtype=x.dtype)
|
100 |
+
# left index are keys, right index are queries
|
101 |
+
delta = indexes[:, None] - indexes[None, :]
|
102 |
+
|
103 |
+
queries = self.query(x).view(B, heads, -1, T)
|
104 |
+
keys = self.key(x).view(B, heads, -1, T)
|
105 |
+
# t are keys, s are queries
|
106 |
+
dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
|
107 |
+
dots /= keys.shape[2] ** 0.5
|
108 |
+
if self.nfreqs:
|
109 |
+
periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype)
|
110 |
+
freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1))
|
111 |
+
freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs ** 0.5
|
112 |
+
tmp = torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q)
|
113 |
+
dots += tmp
|
114 |
+
if self.ndecay:
|
115 |
+
decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype)
|
116 |
+
decay_q = self.query_decay(x).view(B, heads, -1, T)
|
117 |
+
decay_q = torch.sigmoid(decay_q) / 2
|
118 |
+
decay_kernel = - decays.view(-1, 1, 1) * delta.abs() / self.ndecay ** 0.5
|
119 |
+
dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q)
|
120 |
+
|
121 |
+
# Kill self reference.
|
122 |
+
dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100)
|
123 |
+
weights = torch.softmax(dots, dim=2)
|
124 |
+
|
125 |
+
content = self.content(x).view(B, heads, -1, T)
|
126 |
+
result = torch.einsum("bhts,bhct->bhcs", weights, content)
|
127 |
+
|
128 |
+
result = result.reshape(B, -1, T)
|
129 |
+
return x + self.proj(result)
|
130 |
+
|
131 |
+
|
132 |
+
class LayerScale(nn.Module):
|
133 |
+
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
|
134 |
+
This rescales diagonaly residual outputs close to 0 initially, then learnt.
|
135 |
+
"""
|
136 |
+
|
137 |
+
def __init__(self, channels: int, init: float = 0):
|
138 |
+
super().__init__()
|
139 |
+
self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
|
140 |
+
self.scale.data[:] = init
|
141 |
+
|
142 |
+
def forward(self, x):
|
143 |
+
return self.scale[:, None] * x
|
144 |
+
|
145 |
+
|
146 |
+
class DConv(nn.Module):
|
147 |
+
"""
|
148 |
+
New residual branches in each encoder layer.
|
149 |
+
This alternates dilated convolutions, potentially with LSTMs and attention.
|
150 |
+
Also before entering each residual branch, dimension is projected on a smaller subspace,
|
151 |
+
e.g. of dim `channels // compress`.
|
152 |
+
"""
|
153 |
+
|
154 |
+
def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4,
|
155 |
+
norm=True, time_attn=False, heads=4, ndecay=4, lstm=False,
|
156 |
+
act_func='gelu', freq_dim=None, reshape=False,
|
157 |
+
kernel=3, dilate=True):
|
158 |
+
"""
|
159 |
+
Args:
|
160 |
+
channels: input/output channels for residual branch.
|
161 |
+
compress: amount of channel compression inside the branch.
|
162 |
+
depth: number of layers in the residual branch. Each layer has its own
|
163 |
+
projection, and potentially LSTM and attention.
|
164 |
+
init: initial scale for LayerNorm.
|
165 |
+
norm: use GroupNorm.
|
166 |
+
time_attn: use LocalAttention.
|
167 |
+
heads: number of heads for the LocalAttention.
|
168 |
+
ndecay: number of decay controls in the LocalAttention.
|
169 |
+
lstm: use LSTM.
|
170 |
+
gelu: Use GELU activation.
|
171 |
+
kernel: kernel size for the (dilated) convolutions.
|
172 |
+
dilate: if true, use dilation, increasing with the depth.
|
173 |
+
"""
|
174 |
+
|
175 |
+
super().__init__()
|
176 |
+
assert kernel % 2 == 1
|
177 |
+
self.channels = channels
|
178 |
+
self.compress = compress
|
179 |
+
self.depth = abs(depth)
|
180 |
+
dilate = depth > 0
|
181 |
+
|
182 |
+
self.time_attn = time_attn
|
183 |
+
self.lstm = lstm
|
184 |
+
self.reshape = reshape
|
185 |
+
self.act_func = act_func
|
186 |
+
self.freq_dim = freq_dim
|
187 |
+
|
188 |
+
norm_fn: tp.Callable[[int], nn.Module]
|
189 |
+
norm_fn = lambda d: nn.Identity() # noqa
|
190 |
+
if norm:
|
191 |
+
norm_fn = lambda d: nn.GroupNorm(1, d) # noqa
|
192 |
+
|
193 |
+
self.hidden = int(channels / compress)
|
194 |
+
|
195 |
+
act: tp.Type[nn.Module]
|
196 |
+
if act_func == 'gelu':
|
197 |
+
act = nn.GELU
|
198 |
+
elif act_func == 'snake':
|
199 |
+
act = Snake
|
200 |
+
else:
|
201 |
+
act = nn.ReLU
|
202 |
+
|
203 |
+
self.layers = nn.ModuleList([])
|
204 |
+
for d in range(self.depth):
|
205 |
+
layer = nn.ModuleDict()
|
206 |
+
dilation = 2 ** d if dilate else 1
|
207 |
+
padding = dilation * (kernel // 2)
|
208 |
+
conv1 = nn.ModuleList([nn.Conv1d(channels, self.hidden, kernel, dilation=dilation, padding=padding),
|
209 |
+
norm_fn(self.hidden)])
|
210 |
+
act_layer = act(freq_dim) if act_func == 'snake' else act()
|
211 |
+
conv2 = nn.ModuleList([nn.Conv1d(self.hidden, 2 * channels, 1),
|
212 |
+
norm_fn(2 * channels), nn.GLU(1),
|
213 |
+
LayerScale(channels, init)])
|
214 |
+
|
215 |
+
layer.update({'conv1': nn.Sequential(*conv1), 'act': act_layer, 'conv2': nn.Sequential(*conv2)})
|
216 |
+
if lstm:
|
217 |
+
layer.update({'lstm': BLSTM(self.hidden, layers=2, max_steps=200, skip=True)})
|
218 |
+
if time_attn:
|
219 |
+
layer.update({'time_attn': LocalState(self.hidden, heads=heads, ndecay=ndecay)})
|
220 |
+
|
221 |
+
self.layers.append(layer)
|
222 |
+
|
223 |
+
def forward(self, x):
|
224 |
+
|
225 |
+
if self.reshape:
|
226 |
+
B, C, Fr, T = x.shape
|
227 |
+
x = x.permute(0, 2, 1, 3).reshape(-1, C, T)
|
228 |
+
|
229 |
+
for layer in self.layers:
|
230 |
+
skip = x
|
231 |
+
|
232 |
+
x = layer['conv1'](x)
|
233 |
+
|
234 |
+
if self.act_func == 'snake' and self.reshape:
|
235 |
+
x = x.view(B, Fr, self.hidden, T).permute(0, 2, 3, 1)
|
236 |
+
x = layer['act'](x)
|
237 |
+
if self.act_func == 'snake' and self.reshape:
|
238 |
+
x = x.permute(0, 3, 1, 2).reshape(-1, self.hidden, T)
|
239 |
+
|
240 |
+
if self.lstm:
|
241 |
+
x = layer['lstm'](x)
|
242 |
+
if self.time_attn:
|
243 |
+
x = layer['time_attn'](x)
|
244 |
+
|
245 |
+
x = layer['conv2'](x)
|
246 |
+
x = skip + x
|
247 |
+
|
248 |
+
if self.reshape:
|
249 |
+
x = x.view(B, Fr, C, T).permute(0, 2, 1, 3)
|
250 |
+
|
251 |
+
return x
|
252 |
+
|
253 |
+
|
254 |
+
class ScaledEmbedding(nn.Module):
|
255 |
+
"""
|
256 |
+
Boost learning rate for embeddings (with `scale`).
|
257 |
+
Also, can make embeddings continuous with `smooth`.
|
258 |
+
"""
|
259 |
+
|
260 |
+
def __init__(self, num_embeddings: int, embedding_dim: int,
|
261 |
+
scale: float = 10., smooth=False):
|
262 |
+
super().__init__()
|
263 |
+
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
|
264 |
+
if smooth:
|
265 |
+
weight = torch.cumsum(self.embedding.weight.data, dim=0)
|
266 |
+
# when summing gaussian, overscale raises as sqrt(n), so we nornalize by that.
|
267 |
+
weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None]
|
268 |
+
self.embedding.weight.data[:] = weight
|
269 |
+
self.embedding.weight.data /= scale
|
270 |
+
self.scale = scale
|
271 |
+
|
272 |
+
@property
|
273 |
+
def weight(self):
|
274 |
+
return self.embedding.weight * self.scale
|
275 |
+
|
276 |
+
def forward(self, x):
|
277 |
+
out = self.embedding(x) * self.scale
|
278 |
+
return out
|
279 |
+
|
280 |
+
|
281 |
+
class FTB(nn.Module):
|
282 |
+
|
283 |
+
def __init__(self, input_dim=257, in_channel=9, r_channel=5):
|
284 |
+
super(FTB, self).__init__()
|
285 |
+
self.input_dim = input_dim
|
286 |
+
self.in_channel = in_channel
|
287 |
+
self.conv1 = nn.Sequential(
|
288 |
+
nn.Conv2d(in_channel, r_channel, kernel_size=[1, 1]),
|
289 |
+
nn.BatchNorm2d(r_channel),
|
290 |
+
nn.ReLU()
|
291 |
+
)
|
292 |
+
|
293 |
+
self.conv1d = nn.Sequential(
|
294 |
+
nn.Conv1d(r_channel * input_dim, in_channel, kernel_size=9, padding=4),
|
295 |
+
nn.BatchNorm1d(in_channel),
|
296 |
+
nn.ReLU()
|
297 |
+
)
|
298 |
+
self.freq_fc = nn.Linear(input_dim, input_dim, bias=False)
|
299 |
+
|
300 |
+
self.conv2 = nn.Sequential(
|
301 |
+
nn.Conv2d(in_channel * 2, in_channel, kernel_size=[1, 1]),
|
302 |
+
nn.BatchNorm2d(in_channel),
|
303 |
+
nn.ReLU()
|
304 |
+
)
|
305 |
+
|
306 |
+
def forward(self, inputs):
|
307 |
+
'''
|
308 |
+
inputs should be [Batch, Ca, Dim, Time]
|
309 |
+
'''
|
310 |
+
# T-F attention
|
311 |
+
conv1_out = self.conv1(inputs)
|
312 |
+
B, C, D, T = conv1_out.size()
|
313 |
+
reshape1_out = torch.reshape(conv1_out, [B, C * D, T])
|
314 |
+
conv1d_out = self.conv1d(reshape1_out)
|
315 |
+
conv1d_out = torch.reshape(conv1d_out, [B, self.in_channel, 1, T])
|
316 |
+
|
317 |
+
# now is also [B,C,D,T]
|
318 |
+
att_out = conv1d_out * inputs
|
319 |
+
|
320 |
+
# tranpose to [B,C,T,D]
|
321 |
+
att_out = torch.transpose(att_out, 2, 3)
|
322 |
+
freqfc_out = self.freq_fc(att_out)
|
323 |
+
att_out = torch.transpose(freqfc_out, 2, 3)
|
324 |
+
|
325 |
+
cat_out = torch.cat([att_out, inputs], 1)
|
326 |
+
outputs = self.conv2(cat_out)
|
327 |
+
return outputs
|
src/models/seanet.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import math
|
3 |
+
from src.models.utils import capture_init, weights_init
|
4 |
+
from src.models.modules import WNConv1d, WNConvTranspose1d
|
5 |
+
from torchaudio.functional import resample
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
class ResnetBlock(nn.Module):
|
9 |
+
def __init__(self, dim, dilation=1):
|
10 |
+
super().__init__()
|
11 |
+
self.block = nn.Sequential(
|
12 |
+
nn.LeakyReLU(0.2),
|
13 |
+
nn.ReflectionPad1d(dilation),
|
14 |
+
WNConv1d(dim, dim, kernel_size=3, dilation=dilation),
|
15 |
+
nn.LeakyReLU(0.2),
|
16 |
+
WNConv1d(dim, dim, kernel_size=1),
|
17 |
+
)
|
18 |
+
self.shortcut = WNConv1d(dim, dim, kernel_size=1)
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
return self.shortcut(x) + self.block(x)
|
22 |
+
|
23 |
+
|
24 |
+
class Seanet(nn.Module):
|
25 |
+
|
26 |
+
@capture_init
|
27 |
+
def __init__(self,
|
28 |
+
latent_space_size=128,
|
29 |
+
ngf=32, n_residual_layers=3,
|
30 |
+
resample=1,
|
31 |
+
normalize=True,
|
32 |
+
floor=1e-3,
|
33 |
+
ratios=[8, 8, 2, 2],
|
34 |
+
in_channels=1,
|
35 |
+
out_channels=1,
|
36 |
+
lr_sr=16000,
|
37 |
+
hr_sr=16000,
|
38 |
+
upsample=True):
|
39 |
+
super().__init__()
|
40 |
+
|
41 |
+
self.resample = resample
|
42 |
+
self.normalize = normalize
|
43 |
+
self.floor = floor
|
44 |
+
self.lr_sr = lr_sr
|
45 |
+
self.hr_sr = hr_sr
|
46 |
+
self.scale_factor = int(self.hr_sr / self.lr_sr)
|
47 |
+
self.upsample = upsample
|
48 |
+
|
49 |
+
self.encoder = nn.ModuleList()
|
50 |
+
self.decoder = nn.ModuleList()
|
51 |
+
|
52 |
+
self.ratios = ratios
|
53 |
+
mult = int(2 ** len(ratios))
|
54 |
+
|
55 |
+
decoder_wrapper_conv_layer = [
|
56 |
+
nn.LeakyReLU(0.2),
|
57 |
+
nn.ReflectionPad1d(3),
|
58 |
+
WNConv1d(latent_space_size, mult * ngf, kernel_size=7, padding=0),
|
59 |
+
]
|
60 |
+
|
61 |
+
encoder_wrapper_conv_layer = [
|
62 |
+
nn.LeakyReLU(0.2),
|
63 |
+
nn.ReflectionPad1d(3),
|
64 |
+
WNConv1d(mult * ngf, latent_space_size, kernel_size=7, padding=0)
|
65 |
+
]
|
66 |
+
|
67 |
+
self.encoder.insert(0, nn.Sequential(*encoder_wrapper_conv_layer))
|
68 |
+
self.decoder.append(nn.Sequential(*decoder_wrapper_conv_layer))
|
69 |
+
|
70 |
+
for i, r in enumerate(ratios):
|
71 |
+
encoder_block = [
|
72 |
+
nn.LeakyReLU(0.2),
|
73 |
+
WNConv1d(mult * ngf // 2,
|
74 |
+
mult * ngf,
|
75 |
+
kernel_size=r * 2,
|
76 |
+
stride=r,
|
77 |
+
padding=r // 2 + r % 2,
|
78 |
+
),
|
79 |
+
]
|
80 |
+
|
81 |
+
decoder_block = [
|
82 |
+
nn.LeakyReLU(0.2),
|
83 |
+
WNConvTranspose1d(
|
84 |
+
mult * ngf,
|
85 |
+
mult * ngf // 2,
|
86 |
+
kernel_size=r * 2,
|
87 |
+
stride=r,
|
88 |
+
padding=r // 2 + r % 2,
|
89 |
+
output_padding=r % 2,
|
90 |
+
),
|
91 |
+
]
|
92 |
+
|
93 |
+
for j in range(n_residual_layers - 1, -1, -1):
|
94 |
+
encoder_block = [ResnetBlock(mult * ngf // 2, dilation=3 ** j)] + encoder_block
|
95 |
+
|
96 |
+
for j in range(n_residual_layers):
|
97 |
+
decoder_block += [ResnetBlock(mult * ngf // 2, dilation=3 ** j)]
|
98 |
+
|
99 |
+
mult //= 2
|
100 |
+
|
101 |
+
self.encoder.insert(0, nn.Sequential(*encoder_block))
|
102 |
+
self.decoder.append(nn.Sequential(*decoder_block))
|
103 |
+
|
104 |
+
encoder_wrapper_conv_layer = [
|
105 |
+
nn.ReflectionPad1d(3),
|
106 |
+
WNConv1d(in_channels, ngf, kernel_size=7, padding=0),
|
107 |
+
nn.Tanh(),
|
108 |
+
]
|
109 |
+
self.encoder.insert(0, nn.Sequential(*encoder_wrapper_conv_layer))
|
110 |
+
|
111 |
+
decoder_wrapper_conv_layer = [
|
112 |
+
nn.LeakyReLU(0.2),
|
113 |
+
nn.ReflectionPad1d(3),
|
114 |
+
WNConv1d(ngf, out_channels, kernel_size=7, padding=0),
|
115 |
+
nn.Tanh(),
|
116 |
+
]
|
117 |
+
self.decoder.append(nn.Sequential(*decoder_wrapper_conv_layer))
|
118 |
+
|
119 |
+
self.apply(weights_init)
|
120 |
+
|
121 |
+
def estimate_output_length(self, length):
|
122 |
+
"""
|
123 |
+
Return the nearest valid length to use with the model so that
|
124 |
+
there is no time steps left over in a convolutions, e.g. for all
|
125 |
+
layers, size of the input - kernel_size % stride = 0.
|
126 |
+
|
127 |
+
If the mixture has a valid length, the estimated sources
|
128 |
+
will have exactly the same length.
|
129 |
+
"""
|
130 |
+
depth = len(self.ratios)
|
131 |
+
for idx in range(depth - 1, -1, -1):
|
132 |
+
stride = self.ratios[idx]
|
133 |
+
kernel_size = 2 * stride
|
134 |
+
padding = stride // 2 + stride % 2
|
135 |
+
length = math.ceil((length - kernel_size + 2 * padding) / stride) + 1
|
136 |
+
length = max(length, 1)
|
137 |
+
for idx in range(depth):
|
138 |
+
stride = self.ratios[idx]
|
139 |
+
kernel_size = 2 * stride
|
140 |
+
padding = stride // 2 + stride % 2
|
141 |
+
output_padding = stride % 2
|
142 |
+
length = (length - 1) * stride + kernel_size - 2 * padding + output_padding
|
143 |
+
return int(length)
|
144 |
+
|
145 |
+
def pad_to_valid_length(self, signal):
|
146 |
+
valid_length = self.estimate_output_length(signal.shape[-1])
|
147 |
+
padding_len = valid_length - signal.shape[-1]
|
148 |
+
signal = F.pad(signal, (0, padding_len))
|
149 |
+
return signal, padding_len
|
150 |
+
|
151 |
+
def forward(self, signal):
|
152 |
+
|
153 |
+
target_len = signal.shape[-1]
|
154 |
+
if self.upsample:
|
155 |
+
target_len *= self.scale_factor
|
156 |
+
if self.normalize:
|
157 |
+
mono = signal.mean(dim=1, keepdim=True)
|
158 |
+
std = mono.std(dim=-1, keepdim=True)
|
159 |
+
signal = signal / (self.floor + std)
|
160 |
+
else:
|
161 |
+
std = 1
|
162 |
+
x = signal
|
163 |
+
if self.upsample:
|
164 |
+
x = resample(x, self.lr_sr, self.hr_sr)
|
165 |
+
|
166 |
+
x, padding_len = self.pad_to_valid_length(x)
|
167 |
+
skips = []
|
168 |
+
for i, encode in enumerate(self.encoder):
|
169 |
+
skips.append(x)
|
170 |
+
x = encode(x)
|
171 |
+
for j, decode in enumerate(self.decoder):
|
172 |
+
x = decode(x)
|
173 |
+
skip = skips.pop(-1)
|
174 |
+
x = x + skip
|
175 |
+
if target_len < x.shape[-1]:
|
176 |
+
x = x[..., :target_len]
|
177 |
+
return std * x
|
src/models/snake.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn, sin, pow
|
3 |
+
from torch.nn import Parameter
|
4 |
+
from torch.distributions.exponential import Exponential
|
5 |
+
|
6 |
+
class Snake(nn.Module):
|
7 |
+
'''
|
8 |
+
Implementation of the serpentine-like sine-based periodic activation function:
|
9 |
+
.. math::
|
10 |
+
Snake_a := x + \frac{1}{a} sin^2(ax) = x - \frac{1}{2a}cos{2ax} + \frac{1}{2a}
|
11 |
+
This activation function is able to better extrapolate to previously unseen data,
|
12 |
+
especially in the case of learning periodic functions
|
13 |
+
|
14 |
+
Shape:
|
15 |
+
- Input: (N, *) where * means, any number of additional
|
16 |
+
dimensions
|
17 |
+
- Output: (N, *), same shape as the input
|
18 |
+
|
19 |
+
Parameters:
|
20 |
+
- a - trainable parameter
|
21 |
+
|
22 |
+
References:
|
23 |
+
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
24 |
+
https://arxiv.org/abs/2006.08195
|
25 |
+
|
26 |
+
Examples:
|
27 |
+
>>> a1 = snake(256)
|
28 |
+
>>> x = torch.randn(256)
|
29 |
+
>>> x = a1(x)
|
30 |
+
'''
|
31 |
+
|
32 |
+
def __init__(self, in_features, a=None, trainable=True):
|
33 |
+
'''
|
34 |
+
Initialization.
|
35 |
+
Args:
|
36 |
+
in_features: shape of the input
|
37 |
+
a: trainable parameter
|
38 |
+
trainable: sets `a` as a trainable parameter
|
39 |
+
|
40 |
+
`a` is initialized to 1 by default, higher values = higher-frequency,
|
41 |
+
5-50 is a good starting point if you already think your data is periodic,
|
42 |
+
consider starting lower e.g. 0.5 if you think not, but don't worry,
|
43 |
+
`a` will be trained along with the rest of your model
|
44 |
+
'''
|
45 |
+
super(Snake, self).__init__()
|
46 |
+
self.in_features = in_features if isinstance(in_features, list) else [in_features]
|
47 |
+
|
48 |
+
# Initialize `a`
|
49 |
+
if a is not None:
|
50 |
+
self.a = Parameter(torch.ones(self.in_features) * a) # create a tensor out of alpha
|
51 |
+
else:
|
52 |
+
m = Exponential(torch.tensor([0.1]))
|
53 |
+
self.a = Parameter((m.rsample(self.in_features)).squeeze()) # random init = mix of frequencies
|
54 |
+
|
55 |
+
self.a.requiresGrad = trainable # set the training of `a` to true
|
56 |
+
|
57 |
+
def extra_repr(self) -> str:
|
58 |
+
return 'in_features={}'.format(self.in_features)
|
59 |
+
|
60 |
+
def forward(self, x):
|
61 |
+
'''
|
62 |
+
Forward pass of the function.
|
63 |
+
Applies the function to the input elementwise.
|
64 |
+
Snake ∶= x + 1/a* sin^2 (xa)
|
65 |
+
'''
|
66 |
+
return x + (1.0 / self.a) * pow(sin(x * self.a), 2)
|
src/models/spec.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This code is based on Facebook's HDemucs code: https://github.com/facebookresearch/demucs
|
3 |
+
"""
|
4 |
+
"""Conveniance wrapper to perform STFT and iSTFT"""
|
5 |
+
|
6 |
+
import torch as th
|
7 |
+
|
8 |
+
|
9 |
+
def spectro(x, n_fft=512, hop_length=None, pad=0, win_length=None):
|
10 |
+
*other, length = x.shape
|
11 |
+
x = x.reshape(-1, length)
|
12 |
+
z = th.stft(x,
|
13 |
+
n_fft * (1 + pad),
|
14 |
+
hop_length or n_fft // 4,
|
15 |
+
window=th.hann_window(win_length).to(x),
|
16 |
+
win_length=win_length or n_fft,
|
17 |
+
normalized=True,
|
18 |
+
center=True,
|
19 |
+
return_complex=True,
|
20 |
+
pad_mode='reflect')
|
21 |
+
_, freqs, frame = z.shape
|
22 |
+
return z.view(*other, freqs, frame)
|
23 |
+
|
24 |
+
|
25 |
+
def ispectro(z, hop_length=None, length=None, pad=0, win_length=None):
|
26 |
+
*other, freqs, frames = z.shape
|
27 |
+
n_fft = 2 * freqs - 2
|
28 |
+
z = z.view(-1, freqs, frames)
|
29 |
+
win_length = win_length or n_fft // (1 + pad)
|
30 |
+
x = th.istft(z,
|
31 |
+
n_fft,
|
32 |
+
hop_length or n_fft // 2,
|
33 |
+
window=th.hann_window(win_length).to(z.real),
|
34 |
+
win_length=win_length,
|
35 |
+
normalized=True,
|
36 |
+
length=length,
|
37 |
+
center=True)
|
38 |
+
_, length = x.shape
|
39 |
+
return x.view(*other, length)
|
src/models/utils.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import math
|
3 |
+
|
4 |
+
from torch.nn import functional as F
|
5 |
+
import torchaudio
|
6 |
+
|
7 |
+
|
8 |
+
def capture_init(init):
|
9 |
+
"""capture_init.
|
10 |
+
|
11 |
+
Decorate `__init__` with this, and you can then
|
12 |
+
recover the *args and **kwargs passed to it in `self._init_args_kwargs`
|
13 |
+
"""
|
14 |
+
|
15 |
+
@functools.wraps(init)
|
16 |
+
def __init__(self, *args, **kwargs):
|
17 |
+
self._init_args_kwargs = (args, kwargs)
|
18 |
+
init(self, *args, **kwargs)
|
19 |
+
|
20 |
+
return __init__
|
21 |
+
|
22 |
+
|
23 |
+
def unfold(a, kernel_size, stride):
|
24 |
+
"""Given input of size [*OT, T], output Tensor of size [*OT, F, K]
|
25 |
+
with K the kernel size, by extracting frames with the given stride.
|
26 |
+
This will pad the input so that `F = ceil(T / K)`.
|
27 |
+
see https://github.com/pytorch/pytorch/issues/60466
|
28 |
+
"""
|
29 |
+
*shape, length = a.shape
|
30 |
+
n_frames = math.ceil(length / stride)
|
31 |
+
tgt_length = (n_frames - 1) * stride + kernel_size
|
32 |
+
a = F.pad(a, (0, tgt_length - length))
|
33 |
+
strides = list(a.stride())
|
34 |
+
assert strides[-1] == 1, 'data should be contiguous'
|
35 |
+
strides = strides[:-1] + [stride, 1]
|
36 |
+
return a.as_strided([*shape, n_frames, kernel_size], strides)
|
37 |
+
|
38 |
+
|
39 |
+
def weights_init(m):
|
40 |
+
classname = m.__class__.__name__
|
41 |
+
if classname.find("Conv") != -1:
|
42 |
+
m.weight.data.normal_(0.0, 0.02)
|
43 |
+
elif classname.find("BatchNorm2d") != -1:
|
44 |
+
m.weight.data.normal_(1.0, 0.02)
|
45 |
+
m.bias.data.fill_(0)
|
46 |
+
|
47 |
+
def write(wav, filename, sr):
|
48 |
+
# Normalize audio if it prevents clipping
|
49 |
+
wav = wav / max(wav.abs().max().item(), 1)
|
50 |
+
torchaudio.save(filename, wav.cpu(), sr)
|
src/utils.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import math
|
3 |
+
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
from contextlib import contextmanager
|
7 |
+
|
8 |
+
|
9 |
+
def capture_init(init):
|
10 |
+
"""capture_init.
|
11 |
+
|
12 |
+
Decorate `__init__` with this, and you can then
|
13 |
+
recover the *args and **kwargs passed to it in `self._init_args_kwargs`
|
14 |
+
"""
|
15 |
+
|
16 |
+
@functools.wraps(init)
|
17 |
+
def __init__(self, *args, **kwargs):
|
18 |
+
self._init_args_kwargs = (args, kwargs)
|
19 |
+
init(self, *args, **kwargs)
|
20 |
+
|
21 |
+
return __init__
|
22 |
+
|
23 |
+
|
24 |
+
def unfold(a, kernel_size, stride):
|
25 |
+
"""Given input of size [*OT, T], output Tensor of size [*OT, F, K]
|
26 |
+
with K the kernel size, by extracting frames with the given stride.
|
27 |
+
This will pad the input so that `F = ceil(T / K)`.
|
28 |
+
see https://github.com/pytorch/pytorch/issues/60466
|
29 |
+
"""
|
30 |
+
*shape, length = a.shape
|
31 |
+
n_frames = math.ceil(length / stride)
|
32 |
+
tgt_length = (n_frames - 1) * stride + kernel_size
|
33 |
+
a = F.pad(a, (0, tgt_length - length))
|
34 |
+
strides = list(a.stride())
|
35 |
+
assert strides[-1] == 1, 'data should be contiguous'
|
36 |
+
strides = strides[:-1] + [stride, 1]
|
37 |
+
return a.as_strided([*shape, n_frames, kernel_size], strides)
|
38 |
+
|
39 |
+
def colorize(text, color):
|
40 |
+
"""
|
41 |
+
Display text with some ANSI color in the terminal.
|
42 |
+
"""
|
43 |
+
code = f"\033[{color}m"
|
44 |
+
restore = "\033[0m"
|
45 |
+
return "".join([code, text, restore])
|
46 |
+
|
47 |
+
|
48 |
+
def bold(text):
|
49 |
+
"""
|
50 |
+
Display text in bold in the terminal.
|
51 |
+
"""
|
52 |
+
return colorize(text, "1")
|