sereich commited on
Commit
f113387
·
1 Parent(s): 50ec1db

Initial commit of Radio Upscaling UI (minus models)

Browse files
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from torchaudio.functional import resample
5
+
6
+ from processAudio import upscaleAudio
7
+
8
+ class Object(object):
9
+ pass
10
+
11
+ def processAudio(model: gr.Dropdown, audioData: gr.Audio):
12
+ if len(audioData[1].shape) == 1: #Convert mono to stereo
13
+ lrAudio = torch.tensor(np.array([
14
+ audioData[1].copy().astype(np.float32)/32768,
15
+ audioData[1].copy().astype(np.float32)/32768
16
+ ]))
17
+ elif audioData[1].shape[1] > 2:
18
+ raise gr.Error("Audio with more than 2 channels is not supported.")
19
+ else: #re-order channel data from [samples, 2] to [2, samples]
20
+ lrAudio = torch.tensor(audioData[1].copy().astype(np.float32)/32768).transpose(0,1)
21
+ if audioData[0] != 44100:
22
+ lrAudio = resample(lrAudio, audioData[0], 44100)
23
+ hrAudio=upscaleAudio(lrAudio, "models/" + model)
24
+ hrAudio=hrAudio / max(hrAudio.abs().max().item(), 1)
25
+ outAudio=(hrAudio*32767).numpy().astype(np.int16).transpose(1,0)
26
+ return tuple([44100, outAudio])
27
+
28
+ with gr.Blocks(theme=gr.themes.Default().set(body_background_fill="#CCEEFF")) as layout:
29
+ with gr.Row():
30
+ gr.Markdown("<h2>Broadcast Audio Upscaler</h2>")
31
+ with gr.Row():
32
+ with open("html/directions.html", "r") as directionsHtml:
33
+ gr.Markdown(directionsHtml.read())
34
+ with gr.Row():
35
+ modelSelect = gr.Dropdown(
36
+ [
37
+ ["FM Radio Super Resolution","FM_Radio_SR.th"],
38
+ ["AM Radio Super Resolution (Beta)","AM_Radio_SR.th"]
39
+ ],
40
+ label="Select Model:",
41
+ value="FM_Radio_SR.th",
42
+ )
43
+ with gr.Row():
44
+ with gr.Column():
45
+ audioFileSelect = gr.Audio(label="Audio File (Mono or Stereo):",sources="upload")
46
+ with gr.Column():
47
+ audioOutput = gr.Audio(show_download_button=True, label="Restored Audio:", sources=[])
48
+ with gr.Row():
49
+ submit = gr.Button("Process Audio", variant="primary").click(fn=processAudio, inputs=[modelSelect, audioFileSelect], outputs=audioOutput)
50
+ with gr.Row():
51
+ with gr.Accordion("More Information:", open=False):
52
+ with open("html/information.html", "r") as informationHtml:
53
+ gr.Markdown(informationHtml.read())
54
+ layout.launch()
conf/experiment/aero_441-441_512_256.yaml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package experiment
2
+ name: aero-nfft=${experiment.nfft}-hl=${experiment.hop_length}
3
+
4
+ # Dataset related
5
+ lr_sr: 44100 # low resolution sample rate, added to support BWE. Should be included in training cfg
6
+ hr_sr: 44100 # high resolution sample rate. Should be included in training cfg
7
+ segment: 2
8
+ stride: 4 # in seconds, how much to stride between training examples
9
+ pad: true # if training sample is too short, pad it
10
+ upsample: false
11
+ batch_size: 6
12
+ nfft: 512
13
+ hop_length: 256
14
+
15
+ # models related
16
+ model: aero
17
+ aero: # see aero.py for a detailed description
18
+ in_channels: 2
19
+ out_channels: 2
20
+ # Channels
21
+ channels: 64 #Small: 48 Large: 64
22
+ growth: 2
23
+ # STFT
24
+ nfft: 512
25
+ hop_length: 256
26
+ end_iters: 0
27
+ cac: true
28
+ # Main structure
29
+ rewrite: true
30
+ hybrid: false
31
+ hybrid_old: false
32
+ # Frequency Branch
33
+ freq_emb: 0.2
34
+ emb_scale: 10
35
+ emb_smooth: true
36
+ # Convolutions
37
+ kernel_size: 8
38
+ strides: [ 4,4,2,2 ]
39
+ context: 1
40
+ context_enc: 0
41
+ freq_ends: 4
42
+ enc_freq_attn: 0
43
+ # normalization
44
+ norm_starts: 2
45
+ norm_groups: 4
46
+ # DConv residual branch
47
+ dconv_mode: 1
48
+ dconv_depth: 2
49
+ dconv_comp: 4
50
+ dconv_time_attn: 2
51
+ dconv_lstm: 2
52
+ dconv_init: 0.001
53
+ # Weight init
54
+ rescale: 0.1
55
+ lr_sr: 44100
56
+ hr_sr: 44100
57
+ spec_upsample: true
58
+ act_func: snake
59
+ debug: false
60
+
61
+ adversarial: True
62
+ features_loss_lambda: 100
63
+ only_features_loss: False
64
+ only_adversarial_loss: False
65
+ discriminator_models: [ msd, mpd ] #msd_melgan/msd_hifi/mpd/hifi
66
+ melgan_discriminator:
67
+ n_layers: 5
68
+ num_D: 3
69
+ downsampling_factor: 1
70
+ ndf: 16
71
+ channels: 2
72
+ msd:
73
+ hidden: 32 #Small: 16 Large: 32
74
+ channels: 2
75
+ mpd:
76
+ hidden: 32 #Small: 16 Large: 32
77
+ channels: 2
conf/experiment/seanet_441-441.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package experiment
2
+ name: seanet-nfft=${experiment.nfft}-hl=${experiment.hop_length}
3
+
4
+ # Dataset related
5
+ lr_sr: 44100 # low resolution sample rate, added to support BWE. Should be included in training cfg
6
+ hr_sr: 44100 # high resolution sample rate. Should be included in training cfg
7
+ segment: 2
8
+ stride: 2 # in seconds, how much to stride between training examples
9
+ pad: true # if training sample is too short, pad it
10
+ upsample: false
11
+ batch_size: 5
12
+ nfft: 1024
13
+ hop_length: 512
14
+
15
+ # models related
16
+ model: seanet
17
+ seanet:
18
+ latent_space_size: 256
19
+ ngf: 32
20
+ n_residual_layers: 4
21
+ resample: 1
22
+ normalize: False
23
+ floor: 0.0005
24
+ ratios: [ 8,8,2,2 ]
25
+ lr_sr: 44100
26
+ hr_sr: 44100
27
+ in_channels: 2
28
+ out_channels: 2
29
+
30
+ adversarial: True
31
+ features_loss_lambda: 100
32
+ discriminator_models: [ msd, mpd ] #msd_melgan/msd_hifi/mpd/hifi
33
+ only_features_loss: False
34
+ only_adversarial_loss: False
35
+ msd:
36
+ hidden: 16
37
+ channels: 2
38
+ mpd:
39
+ hidden: 16
40
+ channels: 2
41
+ melgan_discriminator:
42
+ n_layers: 5
43
+ num_D: 3
44
+ downsampling_factor: 1
45
+ ndf: 16
46
+ channels: 2
html/directions.html ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <h3>Directions:</h3>
2
+ <ol>
3
+ <li>Select a model from the model drop-down.</li>
4
+ <li>Select an audio file.</li>
5
+ <li>After the file loads, press "Process Audio" to process the file.</li>
6
+ </ol>
html/information.html ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p>This tool is based on the "aero" audio upscaling project (<a href="https://github.com/slp-rl/aero">original version</a>, <a href="https://github.com/pokepress/aero">modified version used here</a>). It takes digitized radio recordings and attempts to reduce noise/interference while recreating lost frequencies.</p>
2
+
3
+ <B>The FM model should be able to:</B>
4
+ <ul>
5
+ <li>Reduce stereo hiss</li>
6
+ <li>Reduce certain kinds of static/interference</li>
7
+ <li>Reduce 19khz stereo pilot</li>
8
+ <li>Reduce overly warm bass frequencies (boost them 3 or so dB if you miss them)</li>
9
+ <li>Restore lost high & low frequencies (generally only up to 16khz since most training data is MP3-based)</li>
10
+ </ul>
11
+
12
+ <B>The AM model should be able to:</B>
13
+ <ul>
14
+ <li>Reduce certain kinds of static/interference</li>
15
+ <li>Restore lost high & low frequencies (capacity tends to diminish after 10khz)</li>
16
+ </ul>
processAudio.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+
4
+ import time
5
+
6
+ import torch
7
+ import logging
8
+ from pathlib import Path
9
+
10
+ import torchaudio
11
+ from torchaudio.functional import resample
12
+ from gradio import Progress
13
+
14
+ from src.models import modelFactory
15
+ from src.utils import bold
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ SEGMENT_DURATION_SEC = 5
20
+ SEGMENT_OVERLAP_SAMPLES = 2048
21
+ SERIALIZE_KEY_STATE = 'state'
22
+
23
+ def _load_model(checkpoint_file="models/FM_Radio_SR.th",model_name="aero"):
24
+ checkpoint_file = Path(checkpoint_file)
25
+ model = modelFactory.get_model(model_name)['generator']
26
+ package = torch.load(checkpoint_file, 'cpu')
27
+ if 'state' in package.keys(): #raw model file
28
+ logger.info(bold(f'Loading model {model_name} from file.'))
29
+ model.load_state_dict(package[SERIALIZE_KEY_STATE])
30
+
31
+ return model
32
+
33
+ def crossfade_and_blend(out_clip, in_clip):
34
+ fade_out = torchaudio.transforms.Fade(0,SEGMENT_OVERLAP_SAMPLES)
35
+ fade_in = torchaudio.transforms.Fade(SEGMENT_OVERLAP_SAMPLES, 0)
36
+ return fade_out(out_clip) + fade_in(in_clip)
37
+
38
+ def upscaleAudio(lr_sig, checkpoint_file: str, sr=44100, hr_sr=44100, model_name="aero", progress=Progress()):
39
+
40
+ model = _load_model(checkpoint_file,model_name)
41
+ device = torch.device('cpu')
42
+ #model.cuda()
43
+
44
+ logger.info(f'lr wav shape: {lr_sig.shape}')
45
+
46
+ segment_duration_samples = sr * SEGMENT_DURATION_SEC
47
+ n_chunks = math.ceil(lr_sig.shape[-1] / segment_duration_samples)
48
+ logger.info(f'number of chunks: {n_chunks}')
49
+
50
+ lr_chunks = []
51
+ for i in range(n_chunks):
52
+ start = i * segment_duration_samples
53
+ end = min((i + 1) * segment_duration_samples, lr_sig.shape[-1])
54
+ lr_chunks.append(lr_sig[:, start:end])
55
+
56
+ pr_chunks = []
57
+
58
+ lr_segment_overlap_samples = int((sr/hr_sr) * SEGMENT_OVERLAP_SAMPLES)
59
+
60
+ model.eval()
61
+ pred_start = time.time()
62
+ with torch.no_grad():
63
+ previous_chunk = None
64
+ progress(0)
65
+ for i, lr_chunk in enumerate(lr_chunks):
66
+ progress(i/n_chunks)
67
+ pr_chunk = None
68
+ if previous_chunk is not None:
69
+ combined_chunk = torch.cat((previous_chunk[...,-lr_segment_overlap_samples:], lr_chunk), 1)
70
+ pr_combined_chunk = model(combined_chunk.unsqueeze(0).to(device)).squeeze(0)
71
+ pr_chunk = pr_combined_chunk[...,SEGMENT_OVERLAP_SAMPLES:]
72
+ pr_chunks[-1][...,-SEGMENT_OVERLAP_SAMPLES:] = crossfade_and_blend(pr_chunks[-1][...,-SEGMENT_OVERLAP_SAMPLES:], pr_combined_chunk.cpu()[...,:SEGMENT_OVERLAP_SAMPLES] )
73
+ else:
74
+ pr_chunk = model(lr_chunk.unsqueeze(0).to(device)).squeeze(0)
75
+ logger.info(f'lr chunk {i} shape: {lr_chunk.shape}')
76
+ logger.info(f'pr chunk {i} shape: {pr_chunk.shape}')
77
+ pr_chunks.append(pr_chunk.cpu())
78
+ previous_chunk = lr_chunk
79
+
80
+ pred_duration = time.time() - pred_start
81
+ logger.info(f'prediction duration: {pred_duration}')
82
+
83
+ pr = torch.concat(pr_chunks, dim=-1)
84
+
85
+ logger.info(f'pr wav shape: {pr.shape}')
86
+
87
+ return pr
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch==2.2.1
2
+ torchvision=0.17.2
3
+ torchaudio==2.2.1
4
+ opencv-python
src/models/aero.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This code is based on Facebook's HDemucs code: https://github.com/facebookresearch/demucs
3
+ """
4
+ import numpy as np
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+ from src.models.utils import capture_init
10
+ from src.models.spec import spectro, ispectro
11
+ from src.models.modules import DConv, ScaledEmbedding, FTB
12
+
13
+ import logging
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def rescale_conv(conv, reference):
18
+ std = conv.weight.std().detach()
19
+ scale = (std / reference) ** 0.5
20
+ conv.weight.data /= scale
21
+ if conv.bias is not None:
22
+ conv.bias.data /= scale
23
+
24
+
25
+ def rescale_module(module, reference):
26
+ for sub in module.modules():
27
+ if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
28
+ rescale_conv(sub, reference)
29
+
30
+
31
+ class HEncLayer(nn.Module):
32
+ def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False,
33
+ freq=True, dconv=True, is_first=False, freq_attn=False, freq_dim=None, norm=True, context=0,
34
+ dconv_kw={}, pad=True,
35
+ rewrite=True):
36
+ """Encoder layer. This used both by the time and the frequency branch.
37
+
38
+ Args:
39
+ chin: number of input channels.
40
+ chout: number of output channels.
41
+ norm_groups: number of groups for group norm.
42
+ empty: used to make a layer with just the first conv. this is used
43
+ before merging the time and freq. branches.
44
+ freq: this is acting on frequencies.
45
+ dconv: insert DConv residual branches.
46
+ norm: use GroupNorm.
47
+ context: context size for the 1x1 conv.
48
+ dconv_kw: list of kwargs for the DConv class.
49
+ pad: pad the input. Padding is done so that the output size is
50
+ always the input size / stride.
51
+ rewrite: add 1x1 conv at the end of the layer.
52
+ """
53
+ super().__init__()
54
+ norm_fn = lambda d: nn.Identity() # noqa
55
+ if norm:
56
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
57
+ if stride == 1 and kernel_size % 2 == 0 and kernel_size > 1:
58
+ kernel_size -= 1
59
+ if pad:
60
+ pad = (kernel_size - stride) // 2
61
+ else:
62
+ pad = 0
63
+ klass = nn.Conv2d
64
+ self.chin = chin
65
+ self.chout = chout
66
+ self.freq = freq
67
+ self.kernel_size = kernel_size
68
+ self.stride = stride
69
+ self.empty = empty
70
+ self.freq_attn = freq_attn
71
+ self.freq_dim = freq_dim
72
+ self.norm = norm
73
+ self.pad = pad
74
+ if freq:
75
+ kernel_size = [kernel_size, 1]
76
+ stride = [stride, 1]
77
+ if pad != 0:
78
+ pad = [pad, 0]
79
+ # klass = nn.Conv2d
80
+ else:
81
+ kernel_size = [1, kernel_size]
82
+ stride = [1, stride]
83
+ if pad != 0:
84
+ pad = [0, pad]
85
+
86
+ self.is_first = is_first
87
+
88
+ if is_first:
89
+ self.pre_conv = nn.Conv2d(chin, chout, [1, 1])
90
+ chin = chout
91
+
92
+ if self.freq_attn:
93
+ self.freq_attn_block = FTB(input_dim=freq_dim, in_channel=chin)
94
+
95
+ self.conv = klass(chin, chout, kernel_size, stride, pad)
96
+ if self.empty:
97
+ return
98
+ self.norm1 = norm_fn(chout)
99
+ self.rewrite = None
100
+ if rewrite:
101
+ self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
102
+ self.norm2 = norm_fn(2 * chout)
103
+
104
+ self.dconv = None
105
+ if dconv:
106
+ self.dconv = DConv(chout, **dconv_kw)
107
+
108
+ def forward(self, x, inject=None):
109
+ """
110
+ `inject` is used to inject the result from the time branch into the frequency branch,
111
+ when both have the same stride.
112
+ """
113
+
114
+ if not self.freq:
115
+ le = x.shape[-1]
116
+ if not le % self.stride == 0:
117
+ x = F.pad(x, (0, self.stride - (le % self.stride)))
118
+
119
+ if self.is_first:
120
+ x = self.pre_conv(x)
121
+
122
+ if self.freq_attn:
123
+ x = self.freq_attn_block(x)
124
+
125
+ x = self.conv(x)
126
+
127
+ x = F.gelu(self.norm1(x))
128
+ if self.dconv:
129
+ x = self.dconv(x)
130
+
131
+ if self.rewrite:
132
+ x = self.norm2(self.rewrite(x))
133
+ x = F.glu(x, dim=1)
134
+
135
+ return x
136
+
137
+
138
+ class HDecLayer(nn.Module):
139
+ def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False,
140
+ freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True,
141
+ context_freq=True, rewrite=True):
142
+ """
143
+ Same as HEncLayer but for decoder. See `HEncLayer` for documentation.
144
+ """
145
+ super().__init__()
146
+ norm_fn = lambda d: nn.Identity() # noqa
147
+ if norm:
148
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
149
+ if stride == 1 and kernel_size % 2 == 0 and kernel_size > 1:
150
+ kernel_size -= 1
151
+ if pad:
152
+ pad = (kernel_size - stride) // 2
153
+ else:
154
+ pad = 0
155
+ self.pad = pad
156
+ self.last = last
157
+ self.freq = freq
158
+ self.chin = chin
159
+ self.empty = empty
160
+ self.stride = stride
161
+ self.kernel_size = kernel_size
162
+ self.norm = norm
163
+ self.context_freq = context_freq
164
+ klass = nn.Conv2d
165
+ klass_tr = nn.ConvTranspose2d
166
+ if freq:
167
+ kernel_size = [kernel_size, 1]
168
+ stride = [stride, 1]
169
+ else:
170
+ kernel_size = [1, kernel_size]
171
+ stride = [1, stride]
172
+ self.conv_tr = klass_tr(chin, chout, kernel_size, stride)
173
+ self.norm2 = norm_fn(chout)
174
+ if self.empty:
175
+ return
176
+ self.rewrite = None
177
+ if rewrite:
178
+ if context_freq:
179
+ self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
180
+ else:
181
+ self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1,
182
+ [0, context])
183
+ self.norm1 = norm_fn(2 * chin)
184
+
185
+ self.dconv = None
186
+ if dconv:
187
+ self.dconv = DConv(chin, **dconv_kw)
188
+
189
+ def forward(self, x, skip, length):
190
+ if self.freq and x.dim() == 3:
191
+ B, C, T = x.shape
192
+ x = x.view(B, self.chin, -1, T)
193
+
194
+ if not self.empty:
195
+ x = torch.cat([x, skip], dim=1)
196
+
197
+ if self.rewrite:
198
+ y = F.glu(self.norm1(self.rewrite(x)), dim=1)
199
+ else:
200
+ y = x
201
+ if self.dconv:
202
+ y = self.dconv(y)
203
+ else:
204
+ y = x
205
+ assert skip is None
206
+ z = self.norm2(self.conv_tr(y))
207
+ if self.freq:
208
+ if self.pad:
209
+ z = z[..., self.pad:-self.pad, :]
210
+ else:
211
+ z = z[..., self.pad:self.pad + length]
212
+ assert z.shape[-1] == length, (z.shape[-1], length)
213
+ if not self.last:
214
+ z = F.gelu(z)
215
+ return z
216
+
217
+
218
+ class Aero(nn.Module):
219
+ """
220
+ Deep model for Audio Super Resolution.
221
+ """
222
+
223
+ @capture_init
224
+ def __init__(self,
225
+ # Channels
226
+ in_channels=1,
227
+ out_channels=1,
228
+ audio_channels=2,
229
+ channels=48,
230
+ growth=2,
231
+ # STFT
232
+ nfft=512,
233
+ hop_length=64,
234
+ end_iters=0,
235
+ cac=True,
236
+ # Main structure
237
+ rewrite=True,
238
+ hybrid=False,
239
+ hybrid_old=False,
240
+ # Frequency branch
241
+ freq_emb=0.2,
242
+ emb_scale=10,
243
+ emb_smooth=True,
244
+ # Convolutions
245
+ kernel_size=8,
246
+ strides=[4, 4, 2, 2],
247
+ context=1,
248
+ context_enc=0,
249
+ freq_ends=4,
250
+ enc_freq_attn=4,
251
+ # Normalization
252
+ norm_starts=2,
253
+ norm_groups=4,
254
+ # DConv residual branch
255
+ dconv_mode=1,
256
+ dconv_depth=2,
257
+ dconv_comp=4,
258
+ dconv_time_attn=2,
259
+ dconv_lstm=2,
260
+ dconv_init=1e-3,
261
+ # Weight init
262
+ rescale=0.1,
263
+ # Metadata
264
+ lr_sr=4000,
265
+ hr_sr=16000,
266
+ spec_upsample=True,
267
+ act_func='snake',
268
+ debug=False):
269
+ """
270
+ Args:
271
+ sources (list[str]): list of source names.
272
+ audio_channels (int): input/output audio channels.
273
+ channels (int): initial number of hidden channels.
274
+ growth: increase the number of hidden channels by this factor at each layer.
275
+ nfft: number of fft bins. Note that changing this require careful computation of
276
+ various shape parameters and will not work out of the box for hybrid models.
277
+ end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
278
+ cac: uses complex as channels, i.e. complex numbers are 2 channels each
279
+ in input and output. no further processing is done before ISTFT.
280
+ depth (int): number of layers in the encoder and in the decoder.
281
+ rewrite (bool): add 1x1 convolution to each layer.
282
+ hybrid (bool): make a hybrid time/frequency domain, otherwise frequency only.
283
+ hybrid_old: some models trained for MDX had a padding bug. This replicates
284
+ this bug to avoid retraining them.
285
+ freq_emb: add frequency embedding after the first frequency layer if > 0,
286
+ the actual value controls the weight of the embedding.
287
+ emb_scale: equivalent to scaling the embedding learning rate
288
+ emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
289
+ kernel_size: kernel_size for encoder and decoder layers.
290
+ stride: stride for encoder and decoder layers.
291
+ context: context for 1x1 conv in the decoder.
292
+ context_enc: context for 1x1 conv in the encoder.
293
+ norm_starts: layer at which group norm starts being used.
294
+ decoder layers are numbered in reverse order.
295
+ norm_groups: number of groups for group norm.
296
+ dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
297
+ dconv_depth: depth of residual DConv branch.
298
+ dconv_comp: compression of DConv branch.
299
+ dconv_freq_attn: adds freq attention layers in DConv branch starting at this layer.
300
+ dconv_time_attn: adds time attention layers in DConv branch starting at this layer.
301
+ dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
302
+ dconv_init: initial scale for the DConv branch LayerScale.
303
+ rescale: weight recaling trick
304
+ lr_sr: source low-resolution sample-rate
305
+ hr_sr: target high-resolution sample-rate
306
+ spec_upsample: if true, upsamples in the spectral domain, otherwise performs sinc-interpolation beforehand
307
+ act_func: 'snake'/'relu'
308
+ debug: if true, prints out input dimensions throughout model layers.
309
+ """
310
+ super().__init__()
311
+ self.cac = cac
312
+ self.in_channels = in_channels
313
+ self.out_channels = out_channels
314
+ self.audio_channels = audio_channels
315
+ self.kernel_size = kernel_size
316
+ self.context = context
317
+ self.strides = strides
318
+ self.depth = len(strides)
319
+ self.channels = channels
320
+ self.lr_sr = lr_sr
321
+ self.hr_sr = hr_sr
322
+ self.spec_upsample = spec_upsample
323
+
324
+ self.scale = hr_sr / lr_sr if self.spec_upsample else 1
325
+
326
+ self.nfft = nfft
327
+ self.hop_length = int(hop_length // self.scale) # this is for the input signal
328
+ self.win_length = int(self.nfft // self.scale) # this is for the input signal
329
+ self.end_iters = end_iters
330
+ self.freq_emb = None
331
+ self.hybrid = hybrid
332
+ self.hybrid_old = hybrid_old
333
+ self.debug = debug
334
+
335
+ self.encoder = nn.ModuleList()
336
+ self.decoder = nn.ModuleList()
337
+
338
+ chin_z = self.in_channels
339
+ if self.cac:
340
+ chin_z *= 2
341
+ chout_z = channels
342
+ freqs = nfft // 2
343
+
344
+ for index in range(self.depth):
345
+ freq_attn = index >= enc_freq_attn
346
+ lstm = index >= dconv_lstm
347
+ time_attn = index >= dconv_time_attn
348
+ norm = index >= norm_starts
349
+ freq = index <= freq_ends
350
+ stri = strides[index]
351
+ ker = kernel_size
352
+
353
+ pad = True
354
+ if freq and freqs < kernel_size:
355
+ ker = freqs
356
+
357
+ kw = {
358
+ 'kernel_size': ker,
359
+ 'stride': stri,
360
+ 'freq': freq,
361
+ 'pad': pad,
362
+ 'norm': norm,
363
+ 'rewrite': rewrite,
364
+ 'norm_groups': norm_groups,
365
+ 'dconv_kw': {
366
+ 'lstm': lstm,
367
+ 'time_attn': time_attn,
368
+ 'depth': dconv_depth,
369
+ 'compress': dconv_comp,
370
+ 'init': dconv_init,
371
+ 'act_func': act_func,
372
+ 'reshape': True,
373
+ 'freq_dim': freqs // strides[index] if freq else freqs
374
+ }
375
+ }
376
+
377
+ kw_dec = dict(kw)
378
+
379
+ enc = HEncLayer(chin_z, chout_z,
380
+ dconv=dconv_mode & 1, context=context_enc,
381
+ is_first=index == 0, freq_attn=freq_attn, freq_dim=freqs,
382
+ **kw)
383
+
384
+ self.encoder.append(enc)
385
+ if index == 0:
386
+ chin = self.out_channels
387
+ chin_z = chin
388
+ if self.cac:
389
+ chin_z *= 2
390
+ dec = HDecLayer(2 * chout_z, chin_z, dconv=dconv_mode & 2,
391
+ last=index == 0, context=context, **kw_dec)
392
+
393
+ self.decoder.insert(0, dec)
394
+
395
+ chin_z = chout_z
396
+ chout_z = int(growth * chout_z)
397
+
398
+ if freq:
399
+ freqs //= strides[index]
400
+
401
+ if index == 0 and freq_emb:
402
+ self.freq_emb = ScaledEmbedding(
403
+ freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
404
+ self.freq_emb_scale = freq_emb
405
+
406
+ if rescale:
407
+ rescale_module(self, reference=rescale)
408
+
409
+ def _spec(self, x, scale=False):
410
+ if np.mod(x.shape[-1], self.hop_length):
411
+ x = F.pad(x, (0, self.hop_length - np.mod(x.shape[-1], self.hop_length)))
412
+ hl = self.hop_length
413
+ nfft = self.nfft
414
+ win_length = self.win_length
415
+
416
+ if scale:
417
+ hl = int(hl * self.scale)
418
+ win_length = int(win_length * self.scale)
419
+
420
+ z = spectro(x, nfft, hl, win_length=win_length)[..., :-1, :]
421
+ return z
422
+
423
+ def _ispec(self, z):
424
+ hl = int(self.hop_length * self.scale)
425
+ win_length = int(self.win_length * self.scale)
426
+ z = F.pad(z, (0, 0, 0, 1))
427
+ x = ispectro(z, hl, win_length=win_length)
428
+ return x
429
+
430
+ def _move_complex_to_channels_dim(self, z):
431
+ B, C, Fr, T = z.shape
432
+ m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
433
+ m = m.reshape(B, C * 2, Fr, T)
434
+ return m
435
+
436
+ def _convert_to_complex(self, x):
437
+ """
438
+
439
+ :param x: signal of shape [Batch, Channels, 2, Freq, TimeFrames]
440
+ :return: complex signal of shape [Batch, Channels, Freq, TimeFrames]
441
+ """
442
+ out = x.permute(0, 1, 3, 4, 2)
443
+ out = torch.view_as_complex(out.contiguous())
444
+ return out
445
+
446
+ def forward(self, mix, return_spec=False, return_lr_spec=False):
447
+ x = mix
448
+ length = x.shape[-1]
449
+
450
+ if self.debug:
451
+ logger.info(f'hdemucs in shape: {x.shape}')
452
+
453
+ z = self._spec(x)
454
+ x = self._move_complex_to_channels_dim(z)
455
+
456
+ if self.debug:
457
+ logger.info(f'x spec shape: {x.shape}')
458
+
459
+ B, C, Fq, T = x.shape
460
+
461
+ # unlike previous Demucs, we always normalize because it is easier.
462
+ mean = x.mean(dim=(1, 2, 3), keepdim=True)
463
+ std = x.std(dim=(1, 2, 3), keepdim=True)
464
+ x = (x - mean) / (1e-5 + std)
465
+
466
+ # okay, this is a giant mess I know...
467
+ saved = [] # skip connections, freq.
468
+ lengths = [] # saved lengths to properly remove padding, freq branch.
469
+ for idx, encode in enumerate(self.encoder):
470
+ lengths.append(x.shape[-1])
471
+ inject = None
472
+ x = encode(x, inject)
473
+ if self.debug:
474
+ logger.info(f'encoder {idx} out shape: {x.shape}')
475
+ if idx == 0 and self.freq_emb is not None:
476
+ # add frequency embedding to allow for non equivariant convolutions
477
+ # over the frequency axis.
478
+ frs = torch.arange(x.shape[-2], device=x.device)
479
+ emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
480
+ x = x + self.freq_emb_scale * emb
481
+
482
+ saved.append(x)
483
+
484
+ x = torch.zeros_like(x)
485
+ # initialize everything to zero (signal will go through u-net skips).
486
+
487
+ for idx, decode in enumerate(self.decoder):
488
+ skip = saved.pop(-1)
489
+ x = decode(x, skip, lengths.pop(-1))
490
+
491
+ if self.debug:
492
+ logger.info(f'decoder {idx} out shape: {x.shape}')
493
+
494
+ # Let's make sure we used all stored skip connections.
495
+ assert len(saved) == 0
496
+
497
+ x = x.view(B, self.out_channels, -1, Fq, T)
498
+ x = x * std[:, None] + mean[:, None]
499
+
500
+ if self.debug:
501
+ logger.info(f'post view shape: {x.shape}')
502
+
503
+ x_spec_complex = self._convert_to_complex(x)
504
+
505
+ if self.debug:
506
+ logger.info(f'x_spec_complex shape: {x_spec_complex.shape}')
507
+
508
+ x = self._ispec(x_spec_complex)
509
+
510
+ if self.debug:
511
+ logger.info(f'hdemucs out shape: {x.shape}')
512
+
513
+ x = x[..., :int(length * self.scale)]
514
+
515
+ if self.debug:
516
+ logger.info(f'hdemucs out - trimmed shape: {x.shape}')
517
+
518
+ if return_spec:
519
+ if return_lr_spec:
520
+ return x, x_spec_complex, z
521
+ else:
522
+ return x, x_spec_complex
523
+
524
+ return x
src/models/modelFactory.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.models.aero import Aero
2
+ from src.models.seanet import Seanet
3
+ from yaml import safe_load
4
+
5
+ def get_model(model_name="aero"):
6
+ if model_name == 'aero':
7
+ with open("conf/experiment/aero_441-441_512_256.yaml") as f:
8
+ generator = Aero(**safe_load(f)["aero"])
9
+ elif model_name == 'seanet':
10
+ with open("conf/experiment/seanet_441-441.yaml") as f:
11
+ generator = Seanet(**safe_load(f)["seanet"])
12
+
13
+ models = {'generator': generator}
14
+
15
+ return models
src/models/modules.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn.utils.parametrizations import weight_norm
4
+
5
+ import math
6
+
7
+ from src.models.snake import Snake
8
+ from src.models.utils import unfold
9
+
10
+ import typing as tp
11
+
12
+ def WNConv1d(*args, **kwargs):
13
+ return weight_norm(nn.Conv1d(*args, **kwargs))
14
+
15
+
16
+ def WNConvTranspose1d(*args, **kwargs):
17
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
18
+
19
+ class BLSTM(nn.Module):
20
+ """
21
+ BiLSTM with same hidden units as input dim.
22
+ If `max_steps` is not None, input will be splitting in overlapping
23
+ chunks and the LSTM applied separately on each chunk.
24
+ """
25
+
26
+ def __init__(self, dim, layers=1, max_steps=None, skip=False):
27
+ super().__init__()
28
+ assert max_steps is None or max_steps % 4 == 0
29
+ self.max_steps = max_steps
30
+ self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
31
+ self.linear = nn.Linear(2 * dim, dim)
32
+ self.skip = skip
33
+
34
+ def forward(self, x):
35
+ B, C, T = x.shape
36
+ y = x
37
+ framed = False
38
+ if self.max_steps is not None and T > self.max_steps:
39
+ width = self.max_steps
40
+ stride = width // 2
41
+ frames = unfold(x, width, stride)
42
+ nframes = frames.shape[2]
43
+ framed = True
44
+ x = frames.permute(0, 2, 1, 3).reshape(-1, C, width)
45
+
46
+ x = x.permute(2, 0, 1)
47
+
48
+ x = self.lstm(x)[0]
49
+ x = self.linear(x)
50
+ x = x.permute(1, 2, 0)
51
+ if framed:
52
+ out = []
53
+ frames = x.reshape(B, -1, C, width)
54
+ limit = stride // 2
55
+ for k in range(nframes):
56
+ if k == 0:
57
+ out.append(frames[:, k, :, :-limit])
58
+ elif k == nframes - 1:
59
+ out.append(frames[:, k, :, limit:])
60
+ else:
61
+ out.append(frames[:, k, :, limit:-limit])
62
+ out = torch.cat(out, -1)
63
+ out = out[..., :T]
64
+ x = out
65
+ if self.skip:
66
+ x = x + y
67
+ return x
68
+
69
+
70
+ class LocalState(nn.Module):
71
+ """Local state allows to have attention based only on data (no positional embedding),
72
+ but while setting a constraint on the time window (e.g. decaying penalty term).
73
+ Also a failed experiments with trying to provide some frequency based attention.
74
+ """
75
+
76
+ def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4):
77
+ super().__init__()
78
+ assert channels % heads == 0, (channels, heads)
79
+ self.heads = heads
80
+ self.nfreqs = nfreqs
81
+ self.ndecay = ndecay
82
+ self.content = nn.Conv1d(channels, channels, 1)
83
+ self.query = nn.Conv1d(channels, channels, 1)
84
+ self.key = nn.Conv1d(channels, channels, 1)
85
+ if nfreqs:
86
+ self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1)
87
+ if ndecay:
88
+ self.query_decay = nn.Conv1d(channels, heads * ndecay, 1)
89
+ # Initialize decay close to zero (there is a sigmoid), for maximum initial window.
90
+ self.query_decay.weight.data *= 0.01
91
+ assert self.query_decay.bias is not None # stupid type checker
92
+ self.query_decay.bias.data[:] = -2
93
+ # self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1)
94
+ self.proj = nn.Conv1d(channels, channels, 1)
95
+
96
+ def forward(self, x):
97
+ B, C, T = x.shape
98
+ heads = self.heads
99
+ indexes = torch.arange(T, device=x.device, dtype=x.dtype)
100
+ # left index are keys, right index are queries
101
+ delta = indexes[:, None] - indexes[None, :]
102
+
103
+ queries = self.query(x).view(B, heads, -1, T)
104
+ keys = self.key(x).view(B, heads, -1, T)
105
+ # t are keys, s are queries
106
+ dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
107
+ dots /= keys.shape[2] ** 0.5
108
+ if self.nfreqs:
109
+ periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype)
110
+ freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1))
111
+ freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs ** 0.5
112
+ tmp = torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q)
113
+ dots += tmp
114
+ if self.ndecay:
115
+ decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype)
116
+ decay_q = self.query_decay(x).view(B, heads, -1, T)
117
+ decay_q = torch.sigmoid(decay_q) / 2
118
+ decay_kernel = - decays.view(-1, 1, 1) * delta.abs() / self.ndecay ** 0.5
119
+ dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q)
120
+
121
+ # Kill self reference.
122
+ dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100)
123
+ weights = torch.softmax(dots, dim=2)
124
+
125
+ content = self.content(x).view(B, heads, -1, T)
126
+ result = torch.einsum("bhts,bhct->bhcs", weights, content)
127
+
128
+ result = result.reshape(B, -1, T)
129
+ return x + self.proj(result)
130
+
131
+
132
+ class LayerScale(nn.Module):
133
+ """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
134
+ This rescales diagonaly residual outputs close to 0 initially, then learnt.
135
+ """
136
+
137
+ def __init__(self, channels: int, init: float = 0):
138
+ super().__init__()
139
+ self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
140
+ self.scale.data[:] = init
141
+
142
+ def forward(self, x):
143
+ return self.scale[:, None] * x
144
+
145
+
146
+ class DConv(nn.Module):
147
+ """
148
+ New residual branches in each encoder layer.
149
+ This alternates dilated convolutions, potentially with LSTMs and attention.
150
+ Also before entering each residual branch, dimension is projected on a smaller subspace,
151
+ e.g. of dim `channels // compress`.
152
+ """
153
+
154
+ def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4,
155
+ norm=True, time_attn=False, heads=4, ndecay=4, lstm=False,
156
+ act_func='gelu', freq_dim=None, reshape=False,
157
+ kernel=3, dilate=True):
158
+ """
159
+ Args:
160
+ channels: input/output channels for residual branch.
161
+ compress: amount of channel compression inside the branch.
162
+ depth: number of layers in the residual branch. Each layer has its own
163
+ projection, and potentially LSTM and attention.
164
+ init: initial scale for LayerNorm.
165
+ norm: use GroupNorm.
166
+ time_attn: use LocalAttention.
167
+ heads: number of heads for the LocalAttention.
168
+ ndecay: number of decay controls in the LocalAttention.
169
+ lstm: use LSTM.
170
+ gelu: Use GELU activation.
171
+ kernel: kernel size for the (dilated) convolutions.
172
+ dilate: if true, use dilation, increasing with the depth.
173
+ """
174
+
175
+ super().__init__()
176
+ assert kernel % 2 == 1
177
+ self.channels = channels
178
+ self.compress = compress
179
+ self.depth = abs(depth)
180
+ dilate = depth > 0
181
+
182
+ self.time_attn = time_attn
183
+ self.lstm = lstm
184
+ self.reshape = reshape
185
+ self.act_func = act_func
186
+ self.freq_dim = freq_dim
187
+
188
+ norm_fn: tp.Callable[[int], nn.Module]
189
+ norm_fn = lambda d: nn.Identity() # noqa
190
+ if norm:
191
+ norm_fn = lambda d: nn.GroupNorm(1, d) # noqa
192
+
193
+ self.hidden = int(channels / compress)
194
+
195
+ act: tp.Type[nn.Module]
196
+ if act_func == 'gelu':
197
+ act = nn.GELU
198
+ elif act_func == 'snake':
199
+ act = Snake
200
+ else:
201
+ act = nn.ReLU
202
+
203
+ self.layers = nn.ModuleList([])
204
+ for d in range(self.depth):
205
+ layer = nn.ModuleDict()
206
+ dilation = 2 ** d if dilate else 1
207
+ padding = dilation * (kernel // 2)
208
+ conv1 = nn.ModuleList([nn.Conv1d(channels, self.hidden, kernel, dilation=dilation, padding=padding),
209
+ norm_fn(self.hidden)])
210
+ act_layer = act(freq_dim) if act_func == 'snake' else act()
211
+ conv2 = nn.ModuleList([nn.Conv1d(self.hidden, 2 * channels, 1),
212
+ norm_fn(2 * channels), nn.GLU(1),
213
+ LayerScale(channels, init)])
214
+
215
+ layer.update({'conv1': nn.Sequential(*conv1), 'act': act_layer, 'conv2': nn.Sequential(*conv2)})
216
+ if lstm:
217
+ layer.update({'lstm': BLSTM(self.hidden, layers=2, max_steps=200, skip=True)})
218
+ if time_attn:
219
+ layer.update({'time_attn': LocalState(self.hidden, heads=heads, ndecay=ndecay)})
220
+
221
+ self.layers.append(layer)
222
+
223
+ def forward(self, x):
224
+
225
+ if self.reshape:
226
+ B, C, Fr, T = x.shape
227
+ x = x.permute(0, 2, 1, 3).reshape(-1, C, T)
228
+
229
+ for layer in self.layers:
230
+ skip = x
231
+
232
+ x = layer['conv1'](x)
233
+
234
+ if self.act_func == 'snake' and self.reshape:
235
+ x = x.view(B, Fr, self.hidden, T).permute(0, 2, 3, 1)
236
+ x = layer['act'](x)
237
+ if self.act_func == 'snake' and self.reshape:
238
+ x = x.permute(0, 3, 1, 2).reshape(-1, self.hidden, T)
239
+
240
+ if self.lstm:
241
+ x = layer['lstm'](x)
242
+ if self.time_attn:
243
+ x = layer['time_attn'](x)
244
+
245
+ x = layer['conv2'](x)
246
+ x = skip + x
247
+
248
+ if self.reshape:
249
+ x = x.view(B, Fr, C, T).permute(0, 2, 1, 3)
250
+
251
+ return x
252
+
253
+
254
+ class ScaledEmbedding(nn.Module):
255
+ """
256
+ Boost learning rate for embeddings (with `scale`).
257
+ Also, can make embeddings continuous with `smooth`.
258
+ """
259
+
260
+ def __init__(self, num_embeddings: int, embedding_dim: int,
261
+ scale: float = 10., smooth=False):
262
+ super().__init__()
263
+ self.embedding = nn.Embedding(num_embeddings, embedding_dim)
264
+ if smooth:
265
+ weight = torch.cumsum(self.embedding.weight.data, dim=0)
266
+ # when summing gaussian, overscale raises as sqrt(n), so we nornalize by that.
267
+ weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None]
268
+ self.embedding.weight.data[:] = weight
269
+ self.embedding.weight.data /= scale
270
+ self.scale = scale
271
+
272
+ @property
273
+ def weight(self):
274
+ return self.embedding.weight * self.scale
275
+
276
+ def forward(self, x):
277
+ out = self.embedding(x) * self.scale
278
+ return out
279
+
280
+
281
+ class FTB(nn.Module):
282
+
283
+ def __init__(self, input_dim=257, in_channel=9, r_channel=5):
284
+ super(FTB, self).__init__()
285
+ self.input_dim = input_dim
286
+ self.in_channel = in_channel
287
+ self.conv1 = nn.Sequential(
288
+ nn.Conv2d(in_channel, r_channel, kernel_size=[1, 1]),
289
+ nn.BatchNorm2d(r_channel),
290
+ nn.ReLU()
291
+ )
292
+
293
+ self.conv1d = nn.Sequential(
294
+ nn.Conv1d(r_channel * input_dim, in_channel, kernel_size=9, padding=4),
295
+ nn.BatchNorm1d(in_channel),
296
+ nn.ReLU()
297
+ )
298
+ self.freq_fc = nn.Linear(input_dim, input_dim, bias=False)
299
+
300
+ self.conv2 = nn.Sequential(
301
+ nn.Conv2d(in_channel * 2, in_channel, kernel_size=[1, 1]),
302
+ nn.BatchNorm2d(in_channel),
303
+ nn.ReLU()
304
+ )
305
+
306
+ def forward(self, inputs):
307
+ '''
308
+ inputs should be [Batch, Ca, Dim, Time]
309
+ '''
310
+ # T-F attention
311
+ conv1_out = self.conv1(inputs)
312
+ B, C, D, T = conv1_out.size()
313
+ reshape1_out = torch.reshape(conv1_out, [B, C * D, T])
314
+ conv1d_out = self.conv1d(reshape1_out)
315
+ conv1d_out = torch.reshape(conv1d_out, [B, self.in_channel, 1, T])
316
+
317
+ # now is also [B,C,D,T]
318
+ att_out = conv1d_out * inputs
319
+
320
+ # tranpose to [B,C,T,D]
321
+ att_out = torch.transpose(att_out, 2, 3)
322
+ freqfc_out = self.freq_fc(att_out)
323
+ att_out = torch.transpose(freqfc_out, 2, 3)
324
+
325
+ cat_out = torch.cat([att_out, inputs], 1)
326
+ outputs = self.conv2(cat_out)
327
+ return outputs
src/models/seanet.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import math
3
+ from src.models.utils import capture_init, weights_init
4
+ from src.models.modules import WNConv1d, WNConvTranspose1d
5
+ from torchaudio.functional import resample
6
+ from torch.nn import functional as F
7
+
8
+ class ResnetBlock(nn.Module):
9
+ def __init__(self, dim, dilation=1):
10
+ super().__init__()
11
+ self.block = nn.Sequential(
12
+ nn.LeakyReLU(0.2),
13
+ nn.ReflectionPad1d(dilation),
14
+ WNConv1d(dim, dim, kernel_size=3, dilation=dilation),
15
+ nn.LeakyReLU(0.2),
16
+ WNConv1d(dim, dim, kernel_size=1),
17
+ )
18
+ self.shortcut = WNConv1d(dim, dim, kernel_size=1)
19
+
20
+ def forward(self, x):
21
+ return self.shortcut(x) + self.block(x)
22
+
23
+
24
+ class Seanet(nn.Module):
25
+
26
+ @capture_init
27
+ def __init__(self,
28
+ latent_space_size=128,
29
+ ngf=32, n_residual_layers=3,
30
+ resample=1,
31
+ normalize=True,
32
+ floor=1e-3,
33
+ ratios=[8, 8, 2, 2],
34
+ in_channels=1,
35
+ out_channels=1,
36
+ lr_sr=16000,
37
+ hr_sr=16000,
38
+ upsample=True):
39
+ super().__init__()
40
+
41
+ self.resample = resample
42
+ self.normalize = normalize
43
+ self.floor = floor
44
+ self.lr_sr = lr_sr
45
+ self.hr_sr = hr_sr
46
+ self.scale_factor = int(self.hr_sr / self.lr_sr)
47
+ self.upsample = upsample
48
+
49
+ self.encoder = nn.ModuleList()
50
+ self.decoder = nn.ModuleList()
51
+
52
+ self.ratios = ratios
53
+ mult = int(2 ** len(ratios))
54
+
55
+ decoder_wrapper_conv_layer = [
56
+ nn.LeakyReLU(0.2),
57
+ nn.ReflectionPad1d(3),
58
+ WNConv1d(latent_space_size, mult * ngf, kernel_size=7, padding=0),
59
+ ]
60
+
61
+ encoder_wrapper_conv_layer = [
62
+ nn.LeakyReLU(0.2),
63
+ nn.ReflectionPad1d(3),
64
+ WNConv1d(mult * ngf, latent_space_size, kernel_size=7, padding=0)
65
+ ]
66
+
67
+ self.encoder.insert(0, nn.Sequential(*encoder_wrapper_conv_layer))
68
+ self.decoder.append(nn.Sequential(*decoder_wrapper_conv_layer))
69
+
70
+ for i, r in enumerate(ratios):
71
+ encoder_block = [
72
+ nn.LeakyReLU(0.2),
73
+ WNConv1d(mult * ngf // 2,
74
+ mult * ngf,
75
+ kernel_size=r * 2,
76
+ stride=r,
77
+ padding=r // 2 + r % 2,
78
+ ),
79
+ ]
80
+
81
+ decoder_block = [
82
+ nn.LeakyReLU(0.2),
83
+ WNConvTranspose1d(
84
+ mult * ngf,
85
+ mult * ngf // 2,
86
+ kernel_size=r * 2,
87
+ stride=r,
88
+ padding=r // 2 + r % 2,
89
+ output_padding=r % 2,
90
+ ),
91
+ ]
92
+
93
+ for j in range(n_residual_layers - 1, -1, -1):
94
+ encoder_block = [ResnetBlock(mult * ngf // 2, dilation=3 ** j)] + encoder_block
95
+
96
+ for j in range(n_residual_layers):
97
+ decoder_block += [ResnetBlock(mult * ngf // 2, dilation=3 ** j)]
98
+
99
+ mult //= 2
100
+
101
+ self.encoder.insert(0, nn.Sequential(*encoder_block))
102
+ self.decoder.append(nn.Sequential(*decoder_block))
103
+
104
+ encoder_wrapper_conv_layer = [
105
+ nn.ReflectionPad1d(3),
106
+ WNConv1d(in_channels, ngf, kernel_size=7, padding=0),
107
+ nn.Tanh(),
108
+ ]
109
+ self.encoder.insert(0, nn.Sequential(*encoder_wrapper_conv_layer))
110
+
111
+ decoder_wrapper_conv_layer = [
112
+ nn.LeakyReLU(0.2),
113
+ nn.ReflectionPad1d(3),
114
+ WNConv1d(ngf, out_channels, kernel_size=7, padding=0),
115
+ nn.Tanh(),
116
+ ]
117
+ self.decoder.append(nn.Sequential(*decoder_wrapper_conv_layer))
118
+
119
+ self.apply(weights_init)
120
+
121
+ def estimate_output_length(self, length):
122
+ """
123
+ Return the nearest valid length to use with the model so that
124
+ there is no time steps left over in a convolutions, e.g. for all
125
+ layers, size of the input - kernel_size % stride = 0.
126
+
127
+ If the mixture has a valid length, the estimated sources
128
+ will have exactly the same length.
129
+ """
130
+ depth = len(self.ratios)
131
+ for idx in range(depth - 1, -1, -1):
132
+ stride = self.ratios[idx]
133
+ kernel_size = 2 * stride
134
+ padding = stride // 2 + stride % 2
135
+ length = math.ceil((length - kernel_size + 2 * padding) / stride) + 1
136
+ length = max(length, 1)
137
+ for idx in range(depth):
138
+ stride = self.ratios[idx]
139
+ kernel_size = 2 * stride
140
+ padding = stride // 2 + stride % 2
141
+ output_padding = stride % 2
142
+ length = (length - 1) * stride + kernel_size - 2 * padding + output_padding
143
+ return int(length)
144
+
145
+ def pad_to_valid_length(self, signal):
146
+ valid_length = self.estimate_output_length(signal.shape[-1])
147
+ padding_len = valid_length - signal.shape[-1]
148
+ signal = F.pad(signal, (0, padding_len))
149
+ return signal, padding_len
150
+
151
+ def forward(self, signal):
152
+
153
+ target_len = signal.shape[-1]
154
+ if self.upsample:
155
+ target_len *= self.scale_factor
156
+ if self.normalize:
157
+ mono = signal.mean(dim=1, keepdim=True)
158
+ std = mono.std(dim=-1, keepdim=True)
159
+ signal = signal / (self.floor + std)
160
+ else:
161
+ std = 1
162
+ x = signal
163
+ if self.upsample:
164
+ x = resample(x, self.lr_sr, self.hr_sr)
165
+
166
+ x, padding_len = self.pad_to_valid_length(x)
167
+ skips = []
168
+ for i, encode in enumerate(self.encoder):
169
+ skips.append(x)
170
+ x = encode(x)
171
+ for j, decode in enumerate(self.decoder):
172
+ x = decode(x)
173
+ skip = skips.pop(-1)
174
+ x = x + skip
175
+ if target_len < x.shape[-1]:
176
+ x = x[..., :target_len]
177
+ return std * x
src/models/snake.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, sin, pow
3
+ from torch.nn import Parameter
4
+ from torch.distributions.exponential import Exponential
5
+
6
+ class Snake(nn.Module):
7
+ '''
8
+ Implementation of the serpentine-like sine-based periodic activation function:
9
+ .. math::
10
+ Snake_a := x + \frac{1}{a} sin^2(ax) = x - \frac{1}{2a}cos{2ax} + \frac{1}{2a}
11
+ This activation function is able to better extrapolate to previously unseen data,
12
+ especially in the case of learning periodic functions
13
+
14
+ Shape:
15
+ - Input: (N, *) where * means, any number of additional
16
+ dimensions
17
+ - Output: (N, *), same shape as the input
18
+
19
+ Parameters:
20
+ - a - trainable parameter
21
+
22
+ References:
23
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
24
+ https://arxiv.org/abs/2006.08195
25
+
26
+ Examples:
27
+ >>> a1 = snake(256)
28
+ >>> x = torch.randn(256)
29
+ >>> x = a1(x)
30
+ '''
31
+
32
+ def __init__(self, in_features, a=None, trainable=True):
33
+ '''
34
+ Initialization.
35
+ Args:
36
+ in_features: shape of the input
37
+ a: trainable parameter
38
+ trainable: sets `a` as a trainable parameter
39
+
40
+ `a` is initialized to 1 by default, higher values = higher-frequency,
41
+ 5-50 is a good starting point if you already think your data is periodic,
42
+ consider starting lower e.g. 0.5 if you think not, but don't worry,
43
+ `a` will be trained along with the rest of your model
44
+ '''
45
+ super(Snake, self).__init__()
46
+ self.in_features = in_features if isinstance(in_features, list) else [in_features]
47
+
48
+ # Initialize `a`
49
+ if a is not None:
50
+ self.a = Parameter(torch.ones(self.in_features) * a) # create a tensor out of alpha
51
+ else:
52
+ m = Exponential(torch.tensor([0.1]))
53
+ self.a = Parameter((m.rsample(self.in_features)).squeeze()) # random init = mix of frequencies
54
+
55
+ self.a.requiresGrad = trainable # set the training of `a` to true
56
+
57
+ def extra_repr(self) -> str:
58
+ return 'in_features={}'.format(self.in_features)
59
+
60
+ def forward(self, x):
61
+ '''
62
+ Forward pass of the function.
63
+ Applies the function to the input elementwise.
64
+ Snake ∶= x + 1/a* sin^2 (xa)
65
+ '''
66
+ return x + (1.0 / self.a) * pow(sin(x * self.a), 2)
src/models/spec.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This code is based on Facebook's HDemucs code: https://github.com/facebookresearch/demucs
3
+ """
4
+ """Conveniance wrapper to perform STFT and iSTFT"""
5
+
6
+ import torch as th
7
+
8
+
9
+ def spectro(x, n_fft=512, hop_length=None, pad=0, win_length=None):
10
+ *other, length = x.shape
11
+ x = x.reshape(-1, length)
12
+ z = th.stft(x,
13
+ n_fft * (1 + pad),
14
+ hop_length or n_fft // 4,
15
+ window=th.hann_window(win_length).to(x),
16
+ win_length=win_length or n_fft,
17
+ normalized=True,
18
+ center=True,
19
+ return_complex=True,
20
+ pad_mode='reflect')
21
+ _, freqs, frame = z.shape
22
+ return z.view(*other, freqs, frame)
23
+
24
+
25
+ def ispectro(z, hop_length=None, length=None, pad=0, win_length=None):
26
+ *other, freqs, frames = z.shape
27
+ n_fft = 2 * freqs - 2
28
+ z = z.view(-1, freqs, frames)
29
+ win_length = win_length or n_fft // (1 + pad)
30
+ x = th.istft(z,
31
+ n_fft,
32
+ hop_length or n_fft // 2,
33
+ window=th.hann_window(win_length).to(z.real),
34
+ win_length=win_length,
35
+ normalized=True,
36
+ length=length,
37
+ center=True)
38
+ _, length = x.shape
39
+ return x.view(*other, length)
src/models/utils.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import math
3
+
4
+ from torch.nn import functional as F
5
+ import torchaudio
6
+
7
+
8
+ def capture_init(init):
9
+ """capture_init.
10
+
11
+ Decorate `__init__` with this, and you can then
12
+ recover the *args and **kwargs passed to it in `self._init_args_kwargs`
13
+ """
14
+
15
+ @functools.wraps(init)
16
+ def __init__(self, *args, **kwargs):
17
+ self._init_args_kwargs = (args, kwargs)
18
+ init(self, *args, **kwargs)
19
+
20
+ return __init__
21
+
22
+
23
+ def unfold(a, kernel_size, stride):
24
+ """Given input of size [*OT, T], output Tensor of size [*OT, F, K]
25
+ with K the kernel size, by extracting frames with the given stride.
26
+ This will pad the input so that `F = ceil(T / K)`.
27
+ see https://github.com/pytorch/pytorch/issues/60466
28
+ """
29
+ *shape, length = a.shape
30
+ n_frames = math.ceil(length / stride)
31
+ tgt_length = (n_frames - 1) * stride + kernel_size
32
+ a = F.pad(a, (0, tgt_length - length))
33
+ strides = list(a.stride())
34
+ assert strides[-1] == 1, 'data should be contiguous'
35
+ strides = strides[:-1] + [stride, 1]
36
+ return a.as_strided([*shape, n_frames, kernel_size], strides)
37
+
38
+
39
+ def weights_init(m):
40
+ classname = m.__class__.__name__
41
+ if classname.find("Conv") != -1:
42
+ m.weight.data.normal_(0.0, 0.02)
43
+ elif classname.find("BatchNorm2d") != -1:
44
+ m.weight.data.normal_(1.0, 0.02)
45
+ m.bias.data.fill_(0)
46
+
47
+ def write(wav, filename, sr):
48
+ # Normalize audio if it prevents clipping
49
+ wav = wav / max(wav.abs().max().item(), 1)
50
+ torchaudio.save(filename, wav.cpu(), sr)
src/utils.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import math
3
+
4
+ from torch.nn import functional as F
5
+
6
+ from contextlib import contextmanager
7
+
8
+
9
+ def capture_init(init):
10
+ """capture_init.
11
+
12
+ Decorate `__init__` with this, and you can then
13
+ recover the *args and **kwargs passed to it in `self._init_args_kwargs`
14
+ """
15
+
16
+ @functools.wraps(init)
17
+ def __init__(self, *args, **kwargs):
18
+ self._init_args_kwargs = (args, kwargs)
19
+ init(self, *args, **kwargs)
20
+
21
+ return __init__
22
+
23
+
24
+ def unfold(a, kernel_size, stride):
25
+ """Given input of size [*OT, T], output Tensor of size [*OT, F, K]
26
+ with K the kernel size, by extracting frames with the given stride.
27
+ This will pad the input so that `F = ceil(T / K)`.
28
+ see https://github.com/pytorch/pytorch/issues/60466
29
+ """
30
+ *shape, length = a.shape
31
+ n_frames = math.ceil(length / stride)
32
+ tgt_length = (n_frames - 1) * stride + kernel_size
33
+ a = F.pad(a, (0, tgt_length - length))
34
+ strides = list(a.stride())
35
+ assert strides[-1] == 1, 'data should be contiguous'
36
+ strides = strides[:-1] + [stride, 1]
37
+ return a.as_strided([*shape, n_frames, kernel_size], strides)
38
+
39
+ def colorize(text, color):
40
+ """
41
+ Display text with some ANSI color in the terminal.
42
+ """
43
+ code = f"\033[{color}m"
44
+ restore = "\033[0m"
45
+ return "".join([code, text, restore])
46
+
47
+
48
+ def bold(text):
49
+ """
50
+ Display text in bold in the terminal.
51
+ """
52
+ return colorize(text, "1")