jlmarrugom commited on
Commit
5ec3488
·
1 Parent(s): 446d1e4

Upload 36 files

Browse files
main.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from voicefixer.base import VoiceFixer
2
+ import streamlit as st
3
+ from audio_recorder_streamlit import audio_recorder
4
+ from io import BytesIO
5
+ import soundfile as sf
6
+
7
+ st.set_page_config(page_title="VoiceFixer app", page_icon=":notes:")
8
+ st.title("Voice Fixer App :notes:")
9
+ st.write(
10
+ """
11
+ This app is a mix of [VoiceFixer Model](https://github.com/haoheliu/voicefixer), and a custom
12
+ Streamlit component that [records audio](https://github.com/Joooohan/audio-recorder-streamlit) Online.
13
+ Currently the app shows great results when removing background noises, but
14
+ speech improvements aren't as obvious.
15
+ """)
16
+ #Config files are on voicefixer/base and voicefixer/vocoder/config import
17
+ # They were uploaded on hugging face
18
+ voicefixer = VoiceFixer()
19
+ audio_bytes = audio_recorder(
20
+ pause_threshold= 1.5
21
+ )
22
+ try:
23
+ data, samplerate = sf.read(BytesIO(audio_bytes))
24
+ print(samplerate)
25
+ sf.write("original.wav",data,samplerate)
26
+ st.audio(audio_bytes, format = "audio/wav")
27
+ if data.shape[0]>=10000:
28
+ voicefixer.restore(input="original.wav", # low quality .wav/.flac file
29
+ output="enhanced_output.wav",
30
+ cuda=False, # GPU acceleration
31
+ mode=0)
32
+ st.write("The Audio without background noises and a little enhancement :ocean:")
33
+ st.audio("enhanced_output.wav")
34
+
35
+ else: st.warning("Recorded Audio is too short, try again :relieved:")#wink
36
+ except:
37
+ st.info("Try to record some audio :relieved:")
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ffmpeg
2
+ libsndfile1
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ audio_recorder_streamlit>=0.0.7
2
+ soundfile>=0.9.0
3
+ huggingface-hub>=0.11.1
4
+ librosa>=0.8.1,<0.9.0
5
+ torch>=1.7.0
6
+ matplotlib
7
+ progressbar
8
+ torchlibrosa==0.0.7
9
+ GitPython
10
+ streamlit>=1.12.
11
+ pyyaml
voicefixer/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+ """
4
+ @File : __init__.py.py
5
+ @Contact : [email protected]
6
+ @License : (C)Copyright 2020-2100
7
+
8
+ @Modify Time @Author @Version @Desciption
9
+ ------------ ------- -------- -----------
10
+ 9/14/21 12:31 AM Haohe Liu 1.0 None
11
+ """
12
+
13
+ from voicefixer.vocoder.base import Vocoder
14
+ from voicefixer.base import VoiceFixer
voicefixer/__main__.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ from genericpath import exists
3
+ import os.path
4
+ import argparse
5
+ from voicefixer import VoiceFixer
6
+ import torch
7
+ import os
8
+
9
+
10
+ def writefile(infile, outfile, mode, append_mode, cuda, verbose=False):
11
+ if append_mode is True:
12
+ outbasename, outext = os.path.splitext(os.path.basename(outfile))
13
+ outfile = os.path.join(
14
+ os.path.dirname(outfile), "{}-mode{}{}".format(outbasename, mode, outext)
15
+ )
16
+ if verbose:
17
+ print("Processing {}, mode={}".format(infile, mode))
18
+ voicefixer.restore(input=infile, output=outfile, cuda=cuda, mode=int(mode))
19
+
20
+ def check_arguments(args):
21
+ process_file, process_folder = len(args.infile) != 0, len(args.infolder) != 0
22
+ # assert len(args.infile) == 0 and len(args.outfile) == 0 or process_file, \
23
+ # "Error: You should give the input and output file path at the same time. The input and output file path we receive is %s and %s" % (args.infile, args.outfile)
24
+ # assert len(args.infolder) == 0 and len(args.outfolder) == 0 or process_folder, \
25
+ # "Error: You should give the input and output folder path at the same time. The input and output folder path we receive is %s and %s" % (args.infolder, args.outfolder)
26
+ assert (
27
+ process_file or process_folder
28
+ ), "Error: You need to specify a input file path (--infile) or a input folder path (--infolder) to proceed. For more information please run: voicefixer -h"
29
+
30
+ # if(args.cuda and not torch.cuda.is_available()):
31
+ # print("Warning: You set --cuda while no cuda device found on your machine. We will use CPU instead.")
32
+ if process_file:
33
+ assert os.path.exists(args.infile), (
34
+ "Error: The input file %s is not found." % args.infile
35
+ )
36
+ output_dirname = os.path.dirname(args.outfile)
37
+ if len(output_dirname) > 1:
38
+ os.makedirs(output_dirname, exist_ok=True)
39
+ if process_folder:
40
+ assert os.path.exists(args.infolder), (
41
+ "Error: The input folder %s is not found." % args.infile
42
+ )
43
+ output_dirname = args.outfolder
44
+ if len(output_dirname) > 1:
45
+ os.makedirs(args.outfolder, exist_ok=True)
46
+
47
+ return process_file, process_folder
48
+
49
+
50
+ if __name__ == "__main__":
51
+ parser = argparse.ArgumentParser(
52
+ description="VoiceFixer - restores degraded speech"
53
+ )
54
+ parser.add_argument(
55
+ "-i",
56
+ "--infile",
57
+ type=str,
58
+ default="",
59
+ help="An input file to be processed by VoiceFixer.",
60
+ )
61
+ parser.add_argument(
62
+ "-o",
63
+ "--outfile",
64
+ type=str,
65
+ default="outfile.wav",
66
+ help="An output file to store the result.",
67
+ )
68
+
69
+ parser.add_argument(
70
+ "-ifdr",
71
+ "--infolder",
72
+ type=str,
73
+ default="",
74
+ help="Input folder. Place all your wav file that need process in this folder.",
75
+ )
76
+ parser.add_argument(
77
+ "-ofdr",
78
+ "--outfolder",
79
+ type=str,
80
+ default="outfolder",
81
+ help="Output folder. The processed files will be stored in this folder.",
82
+ )
83
+
84
+ parser.add_argument(
85
+ "--mode", help="mode", choices=["0", "1", "2", "all"], default="0"
86
+ )
87
+ parser.add_argument('--disable-cuda', help='Set this flag if you do not want to use your gpu.', default=False, action="store_true")
88
+ parser.add_argument(
89
+ "--silent",
90
+ help="Set this flag if you do not want to see any message.",
91
+ default=False,
92
+ action="store_true",
93
+ )
94
+
95
+ args = parser.parse_args()
96
+
97
+ if torch.cuda.is_available() and not args.disable_cuda:
98
+ cuda = True
99
+ else:
100
+ cuda = False
101
+
102
+ process_file, process_folder = check_arguments(args)
103
+
104
+ if not args.silent:
105
+ print("Initializing VoiceFixer")
106
+ voicefixer = VoiceFixer()
107
+
108
+ if not args.silent:
109
+ print("Start processing the input file %s." % args.infile)
110
+
111
+ if process_file:
112
+ audioext = os.path.splitext(os.path.basename(args.infile))[-1]
113
+ if audioext != ".wav":
114
+ raise ValueError(
115
+ "Error: Error processing the input file. We only support the .wav format currently. Please convert your %s format to .wav. Thanks."
116
+ % audioext
117
+ )
118
+ if args.mode == "all":
119
+ for file_mode in range(3):
120
+ writefile(
121
+ args.infile,
122
+ args.outfile,
123
+ file_mode,
124
+ True,
125
+ cuda,
126
+ verbose=not args.silent,
127
+ )
128
+ else:
129
+ writefile(
130
+ args.infile,
131
+ args.outfile,
132
+ args.mode,
133
+ False,
134
+ cuda,
135
+ verbose=not args.silent,
136
+ )
137
+
138
+ if process_folder:
139
+ if not args.silent:
140
+ files = [
141
+ file
142
+ for file in os.listdir(args.infolder)
143
+ if (os.path.splitext(os.path.basename(file))[-1] == ".wav")
144
+ ]
145
+ print(
146
+ "Found %s .wav files in the input folder %s. Start processing."
147
+ % (len(files), args.infolder)
148
+ )
149
+ for file in os.listdir(args.infolder):
150
+ outbasename, outext = os.path.splitext(os.path.basename(file))
151
+ in_file = os.path.join(args.infolder, file)
152
+ out_file = os.path.join(args.outfolder, file)
153
+
154
+ if args.mode == "all":
155
+ for file_mode in range(3):
156
+ writefile(
157
+ in_file,
158
+ out_file,
159
+ file_mode,
160
+ True,
161
+ cuda,
162
+ verbose=not args.silent,
163
+ )
164
+ else:
165
+ writefile(
166
+ in_file, out_file, args.mode, False, cuda, verbose=not args.silent
167
+ )
168
+
169
+ if not args.silent:
170
+ print("Done")
voicefixer/base.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa.display
2
+ from voicefixer.tools.pytorch_util import *
3
+ from voicefixer.tools.wav import *
4
+ from voicefixer.restorer.model import VoiceFixer as voicefixer_fe
5
+ import os
6
+ from huggingface_hub import hf_hub_download
7
+
8
+ path_to_ckpt = hf_hub_download(repo_id="jlmarrugom/voice_fixer", filename="vf.ckpt")
9
+
10
+
11
+ EPS = 1e-8
12
+
13
+
14
+ class VoiceFixer(nn.Module):
15
+ def __init__(self):
16
+ super(VoiceFixer, self).__init__()
17
+ self._model = voicefixer_fe(channels=2, sample_rate=44100)
18
+ # print(os.path.join(os.path.expanduser('~'), ".cache/voicefixer/analysis_module/checkpoints/epoch=15_trimed_bn.ckpt"))
19
+ self.analysis_module_ckpt = path_to_ckpt #"models/vf.ckpt"
20
+ if(not os.path.exists(self.analysis_module_ckpt)):
21
+ raise RuntimeError("Error 0: The checkpoint for analysis module (vf.ckpt) is not found in ~/.cache/voicefixer/analysis_module/checkpoints. \
22
+ By default the checkpoint should be download automatically by this program. Something bad may happened.\
23
+ But don't worry! Alternatively you can download it directly from Zenodo: https://zenodo.org/record/5600188/files/vf.ckpt?download=1.")
24
+ self._model.load_state_dict(
25
+ torch.load(
26
+ self.analysis_module_ckpt
27
+ )
28
+ )
29
+ self._model.eval()
30
+
31
+ def _load_wav_energy(self, path, sample_rate, threshold=0.95):
32
+ wav_10k, _ = librosa.load(path, sr=sample_rate)
33
+ stft = np.log10(np.abs(librosa.stft(wav_10k)) + 1.0)
34
+ fbins = stft.shape[0]
35
+ e_stft = np.sum(stft, axis=1)
36
+ for i in range(e_stft.shape[0]):
37
+ e_stft[-i - 1] = np.sum(e_stft[: -i - 1])
38
+ total = e_stft[-1]
39
+ for i in range(e_stft.shape[0]):
40
+ if e_stft[i] < total * threshold:
41
+ continue
42
+ else:
43
+ break
44
+ return wav_10k, int((sample_rate // 2) * (i / fbins))
45
+
46
+ def _load_wav(self, path, sample_rate, threshold=0.95):
47
+ wav_10k, _ = librosa.load(path, sr=sample_rate)
48
+ return wav_10k
49
+
50
+ def _amp_to_original_f(self, mel_sp_est, mel_sp_target, cutoff=0.2):
51
+ freq_dim = mel_sp_target.size()[-1]
52
+ mel_sp_est_low, mel_sp_target_low = (
53
+ mel_sp_est[..., 5 : int(freq_dim * cutoff)],
54
+ mel_sp_target[..., 5 : int(freq_dim * cutoff)],
55
+ )
56
+ energy_est, energy_target = torch.mean(mel_sp_est_low, dim=(2, 3)), torch.mean(
57
+ mel_sp_target_low, dim=(2, 3)
58
+ )
59
+ amp_ratio = energy_target / energy_est
60
+ return mel_sp_est * amp_ratio[..., None, None], mel_sp_target
61
+
62
+ def _trim_center(self, est, ref):
63
+ diff = np.abs(est.shape[-1] - ref.shape[-1])
64
+ if est.shape[-1] == ref.shape[-1]:
65
+ return est, ref
66
+ elif est.shape[-1] > ref.shape[-1]:
67
+ min_len = min(est.shape[-1], ref.shape[-1])
68
+ est, ref = est[..., int(diff // 2) : -int(diff // 2)], ref
69
+ est, ref = est[..., :min_len], ref[..., :min_len]
70
+ return est, ref
71
+ else:
72
+ min_len = min(est.shape[-1], ref.shape[-1])
73
+ est, ref = est, ref[..., int(diff // 2) : -int(diff // 2)]
74
+ est, ref = est[..., :min_len], ref[..., :min_len]
75
+ return est, ref
76
+
77
+ def _pre(self, model, input, cuda):
78
+ input = input[None, None, ...]
79
+ input = torch.tensor(input)
80
+ input = try_tensor_cuda(input, cuda=cuda)
81
+ sp, _, _ = model.f_helper.wav_to_spectrogram_phase(input)
82
+ mel_orig = model.mel(sp.permute(0, 1, 3, 2)).permute(0, 1, 3, 2)
83
+ # return models.to_log(sp), models.to_log(mel_orig)
84
+ return sp, mel_orig
85
+
86
+ def remove_higher_frequency(self, wav, ratio=0.95):
87
+ stft = librosa.stft(wav)
88
+ real, img = np.real(stft), np.imag(stft)
89
+ mag = (real**2 + img**2) ** 0.5
90
+ cos, sin = real / (mag + EPS), img / (mag + EPS)
91
+ spec = np.abs(stft) # [1025,T]
92
+ feature = spec.copy()
93
+ feature = np.log10(feature + EPS)
94
+ feature[feature < 0] = 0
95
+ energy_level = np.sum(feature, axis=1)
96
+ threshold = np.sum(energy_level) * ratio
97
+ curent_level, i = energy_level[0], 0
98
+ while i < energy_level.shape[0] and curent_level < threshold:
99
+ curent_level += energy_level[i + 1, ...]
100
+ i += 1
101
+ spec[i:, ...] = np.zeros_like(spec[i:, ...])
102
+ stft = spec * cos + 1j * spec * sin
103
+ return librosa.istft(stft)
104
+
105
+ @torch.no_grad()
106
+ def restore_inmem(self, wav_10k, cuda=False, mode=0, your_vocoder_func=None):
107
+ check_cuda_availability(cuda=cuda)
108
+ self._model = try_tensor_cuda(self._model, cuda=cuda)
109
+ if mode == 0:
110
+ self._model.eval()
111
+ elif mode == 1:
112
+ self._model.eval()
113
+ elif mode == 2:
114
+ self._model.train() # More effective on seriously demaged speech
115
+ res = []
116
+ seg_length = 44100 * 30
117
+ break_point = seg_length
118
+ while break_point < wav_10k.shape[0] + seg_length:
119
+ segment = wav_10k[break_point - seg_length : break_point]
120
+ if mode == 1:
121
+ segment = self.remove_higher_frequency(segment)
122
+ sp, mel_noisy = self._pre(self._model, segment, cuda)
123
+ out_model = self._model(sp, mel_noisy)
124
+ denoised_mel = from_log(out_model["mel"])
125
+ if your_vocoder_func is None:
126
+ out = self._model.vocoder(denoised_mel, cuda=cuda)
127
+ else:
128
+ out = your_vocoder_func(denoised_mel)
129
+ # unify energy
130
+ if torch.max(torch.abs(out)) > 1.0:
131
+ out = out / torch.max(torch.abs(out))
132
+ print("Warning: Exceed energy limit,", input)
133
+ # frame alignment
134
+ out, _ = self._trim_center(out, segment)
135
+ res.append(out)
136
+ break_point += seg_length
137
+ out = torch.cat(res, -1)
138
+ return tensor2numpy(out.squeeze(0))
139
+
140
+ def restore(self, input, output, cuda=False, mode=0, your_vocoder_func=None):
141
+ wav_10k = self._load_wav(input, sample_rate=44100)
142
+ out_np_wav = self.restore_inmem(
143
+ wav_10k, cuda=cuda, mode=mode, your_vocoder_func=your_vocoder_func
144
+ )
145
+ save_wave(out_np_wav, fname=output, sample_rate=44100)
voicefixer/restorer/__init__.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+ """
4
+ @File : __init__.py.py
5
+ @Contact : [email protected]
6
+ @License : (C)Copyright 2020-2100
7
+
8
+ @Modify Time @Author @Version @Desciption
9
+ ------------ ------- -------- -----------
10
+ 9/14/21 12:31 AM Haohe Liu 1.0 None
11
+ """
12
+
13
+ import os
14
+ import torch
15
+ import urllib.request
16
+
17
+ meta = {
18
+ "voicefixer_fe": {
19
+ "path": os.path.join(
20
+ os.path.expanduser("~"),
21
+ ".cache/voicefixer/analysis_module/checkpoints/vf.ckpt",
22
+ ),
23
+ "url": "https://zenodo.org/record/5600188/files/vf.ckpt?download=1",
24
+ },
25
+ }
26
+
27
+ if not os.path.exists(meta["voicefixer_fe"]["path"]):
28
+ os.makedirs(os.path.dirname(meta["voicefixer_fe"]["path"]), exist_ok=True)
29
+ print("Downloading the main structure of voicefixer")
30
+
31
+ urllib.request.urlretrieve(
32
+ meta["voicefixer_fe"]["url"], meta["voicefixer_fe"]["path"]
33
+ )
34
+ print(
35
+ "Weights downloaded in: {} Size: {}".format(
36
+ meta["voicefixer_fe"]["path"],
37
+ os.path.getsize(meta["voicefixer_fe"]["path"]),
38
+ )
39
+ )
40
+
41
+ # cmd = "wget "+ meta["voicefixer_fe"]['url'] + " -O " + meta["voicefixer_fe"]['path']
42
+ # os.system(cmd)
43
+ # temp = torch.load(meta["voicefixer_fe"]['path'])
44
+ # torch.save(temp['state_dict'], os.path.join(os.path.expanduser('~'), ".cache/voicefixer/analysis_module/checkpoints/vf.ckpt"))
voicefixer/restorer/model.py ADDED
@@ -0,0 +1,680 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import pytorch_lightning as pl
2
+
3
+ import torch.utils
4
+ from voicefixer.tools.mel_scale import MelScale
5
+ import torch.utils.data
6
+ import matplotlib.pyplot as plt
7
+ import librosa.display
8
+ from voicefixer.vocoder.base import Vocoder
9
+ from voicefixer.tools.pytorch_util import *
10
+ from voicefixer.restorer.model_kqq_bn import UNetResComplex_100Mb
11
+ from voicefixer.tools.random_ import *
12
+ from voicefixer.tools.wav import *
13
+ from voicefixer.tools.modules.fDomainHelper import FDomainHelper
14
+
15
+ from voicefixer.tools.io import load_json, write_json
16
+ from matplotlib import cm
17
+
18
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
19
+ EPS = 1e-8
20
+
21
+
22
+ class BN_GRU(torch.nn.Module):
23
+ def __init__(
24
+ self,
25
+ input_dim,
26
+ hidden_dim,
27
+ layer=1,
28
+ bidirectional=False,
29
+ batchnorm=True,
30
+ dropout=0.0,
31
+ ):
32
+ super(BN_GRU, self).__init__()
33
+ self.batchnorm = batchnorm
34
+ if batchnorm:
35
+ self.bn = nn.BatchNorm2d(1)
36
+ self.gru = torch.nn.GRU(
37
+ input_size=input_dim,
38
+ hidden_size=hidden_dim,
39
+ num_layers=layer,
40
+ bidirectional=bidirectional,
41
+ dropout=dropout,
42
+ batch_first=True,
43
+ )
44
+ self.init_weights()
45
+
46
+ def init_weights(self):
47
+ for m in self.modules():
48
+ if type(m) in [nn.GRU, nn.LSTM, nn.RNN]:
49
+ for name, param in m.named_parameters():
50
+ if "weight_ih" in name:
51
+ torch.nn.init.xavier_uniform_(param.data)
52
+ elif "weight_hh" in name:
53
+ torch.nn.init.orthogonal_(param.data)
54
+ elif "bias" in name:
55
+ param.data.fill_(0)
56
+
57
+ def forward(self, inputs):
58
+ # (batch, 1, seq, feature)
59
+ if self.batchnorm:
60
+ inputs = self.bn(inputs)
61
+ out, _ = self.gru(inputs.squeeze(1))
62
+ return out.unsqueeze(1)
63
+
64
+
65
+ class Generator(nn.Module):
66
+ def __init__(self, n_mel, hidden, channels):
67
+ super(Generator, self).__init__()
68
+ # todo the currently running trail don't have dropout
69
+ self.denoiser = nn.Sequential(
70
+ nn.BatchNorm2d(1),
71
+ nn.Linear(n_mel, n_mel * 2),
72
+ nn.ReLU(inplace=True),
73
+ nn.BatchNorm2d(1),
74
+ nn.Linear(n_mel * 2, n_mel * 4),
75
+ nn.Dropout(0.5),
76
+ nn.ReLU(inplace=True),
77
+ BN_GRU(
78
+ input_dim=n_mel * 4,
79
+ hidden_dim=n_mel * 2,
80
+ bidirectional=True,
81
+ layer=2,
82
+ batchnorm=True,
83
+ ),
84
+ BN_GRU(
85
+ input_dim=n_mel * 4,
86
+ hidden_dim=n_mel * 2,
87
+ bidirectional=True,
88
+ layer=2,
89
+ batchnorm=True,
90
+ ),
91
+ nn.BatchNorm2d(1),
92
+ nn.ReLU(inplace=True),
93
+ nn.Linear(n_mel * 4, n_mel * 4),
94
+ nn.Dropout(0.5),
95
+ nn.BatchNorm2d(1),
96
+ nn.ReLU(inplace=True),
97
+ nn.Linear(n_mel * 4, n_mel),
98
+ nn.Sigmoid(),
99
+ )
100
+
101
+ self.unet = UNetResComplex_100Mb(channels=channels)
102
+
103
+ def forward(self, sp, mel_orig):
104
+ # Denoising
105
+ noisy = mel_orig.clone()
106
+ clean = self.denoiser(noisy) * noisy
107
+ x = to_log(clean.detach())
108
+ unet_in = torch.cat([to_log(mel_orig), x], dim=1)
109
+ # unet_in = lstm_out
110
+ unet_out = self.unet(unet_in)["mel"]
111
+ # masks
112
+ mel = unet_out + x
113
+ # todo mel and addition here are in log scales
114
+ return {
115
+ "mel": mel,
116
+ "lstm_out": unet_out,
117
+ "unet_out": unet_out,
118
+ "noisy": noisy,
119
+ "clean": clean,
120
+ }
121
+
122
+
123
+ class VoiceFixer(nn.Module):
124
+ def __init__(
125
+ self,
126
+ channels,
127
+ type_target="vocals",
128
+ nsrc=1,
129
+ loss="l1",
130
+ lr=0.002,
131
+ gamma=0.9,
132
+ batchsize=None,
133
+ frame_length=None,
134
+ sample_rate=None,
135
+ warm_up_steps=1000,
136
+ reduce_lr_steps=15000,
137
+ # datas
138
+ check_val_every_n_epoch=5,
139
+ ):
140
+ super(VoiceFixer, self).__init__()
141
+
142
+ if sample_rate == 44100:
143
+ window_size = 2048
144
+ hop_size = 441
145
+ n_mel = 128
146
+ elif sample_rate == 24000:
147
+ window_size = 768
148
+ hop_size = 240
149
+ n_mel = 80
150
+ elif sample_rate == 16000:
151
+ window_size = 512
152
+ hop_size = 160
153
+ n_mel = 80
154
+ else:
155
+ raise ValueError(
156
+ "Error: Sample rate " + str(sample_rate) + " not supported"
157
+ )
158
+
159
+ center = (True,)
160
+ pad_mode = "reflect"
161
+ window = "hann"
162
+ freeze_parameters = True
163
+
164
+ # self.save_hyperparameters()
165
+ self.nsrc = nsrc
166
+ self.type_target = type_target
167
+ self.channels = channels
168
+ self.lr = lr
169
+ self.generated = None
170
+ self.gamma = gamma
171
+ self.sample_rate = sample_rate
172
+ self.sample_rate = sample_rate
173
+ self.batchsize = batchsize
174
+ self.frame_length = frame_length
175
+ # self.hparams['channels'] = 2
176
+
177
+ # self.am = AudioMetrics()
178
+ # self.im = ImgMetrics()
179
+
180
+ self.vocoder = Vocoder(sample_rate=44100)
181
+
182
+ self.valid = None
183
+ self.fake = None
184
+
185
+ self.train_step = 0
186
+ self.val_step = 0
187
+ self.val_result_save_dir = None
188
+ self.val_result_save_dir_step = None
189
+ self.downsample_ratio = 2**6 # This number equals 2^{#encoder_blcoks}
190
+ self.check_val_every_n_epoch = check_val_every_n_epoch
191
+
192
+ self.f_helper = FDomainHelper(
193
+ window_size=window_size,
194
+ hop_size=hop_size,
195
+ center=center,
196
+ pad_mode=pad_mode,
197
+ window=window,
198
+ freeze_parameters=freeze_parameters,
199
+ )
200
+
201
+ hidden = window_size // 2 + 1
202
+
203
+ self.mel = MelScale(n_mels=n_mel, sample_rate=sample_rate, n_stft=hidden)
204
+
205
+ # masking
206
+ self.generator = Generator(n_mel, hidden, channels)
207
+
208
+ self.lr_lambda = lambda step: self.get_lr_lambda(
209
+ step,
210
+ gamma=self.gamma,
211
+ warm_up_steps=warm_up_steps,
212
+ reduce_lr_steps=reduce_lr_steps,
213
+ )
214
+
215
+ self.lr_lambda_2 = lambda step: self.get_lr_lambda(
216
+ step, gamma=self.gamma, warm_up_steps=10, reduce_lr_steps=reduce_lr_steps
217
+ )
218
+
219
+ self.mel_weight_44k_128 = (
220
+ torch.tensor(
221
+ [
222
+ 19.40951426,
223
+ 19.94047336,
224
+ 20.4859038,
225
+ 21.04629067,
226
+ 21.62194148,
227
+ 22.21335214,
228
+ 22.8210215,
229
+ 23.44529231,
230
+ 24.08660962,
231
+ 24.74541882,
232
+ 25.42234287,
233
+ 26.11770576,
234
+ 26.83212784,
235
+ 27.56615283,
236
+ 28.32007747,
237
+ 29.0947679,
238
+ 29.89060111,
239
+ 30.70832636,
240
+ 31.54828121,
241
+ 32.41121487,
242
+ 33.29780773,
243
+ 34.20865341,
244
+ 35.14437675,
245
+ 36.1056621,
246
+ 37.09332763,
247
+ 38.10795802,
248
+ 39.15039691,
249
+ 40.22119881,
250
+ 41.32154931,
251
+ 42.45172373,
252
+ 43.61293329,
253
+ 44.80609379,
254
+ 46.031602,
255
+ 47.29070223,
256
+ 48.58427549,
257
+ 49.91327905,
258
+ 51.27863232,
259
+ 52.68119708,
260
+ 54.1222372,
261
+ 55.60274206,
262
+ 57.12364703,
263
+ 58.68617876,
264
+ 60.29148652,
265
+ 61.94081306,
266
+ 63.63501986,
267
+ 65.37562658,
268
+ 67.16408954,
269
+ 69.00109084,
270
+ 70.88850318,
271
+ 72.82736101,
272
+ 74.81985537,
273
+ 76.86654792,
274
+ 78.96885475,
275
+ 81.12900906,
276
+ 83.34840929,
277
+ 85.62810662,
278
+ 87.97005418,
279
+ 90.37689804,
280
+ 92.84887686,
281
+ 95.38872881,
282
+ 97.99777002,
283
+ 100.67862715,
284
+ 103.43232942,
285
+ 106.26140638,
286
+ 109.16827015,
287
+ 112.15470471,
288
+ 115.22184756,
289
+ 118.37439245,
290
+ 121.6122689,
291
+ 124.93877158,
292
+ 128.35661454,
293
+ 131.86761321,
294
+ 135.47417938,
295
+ 139.18059494,
296
+ 142.98713744,
297
+ 146.89771854,
298
+ 150.91684347,
299
+ 155.0446638,
300
+ 159.28614648,
301
+ 163.64270198,
302
+ 168.12035831,
303
+ 172.71749158,
304
+ 177.44220154,
305
+ 182.29556933,
306
+ 187.28286676,
307
+ 192.40502126,
308
+ 197.6682721,
309
+ 203.07516896,
310
+ 208.63088733,
311
+ 214.33770931,
312
+ 220.19910108,
313
+ 226.22363072,
314
+ 232.41087124,
315
+ 238.76803591,
316
+ 245.30079083,
317
+ 252.01064464,
318
+ 258.90261676,
319
+ 265.98474,
320
+ 273.26010248,
321
+ 280.73496362,
322
+ 288.41440094,
323
+ 296.30489752,
324
+ 304.41180337,
325
+ 312.7377183,
326
+ 321.28877878,
327
+ 330.07870237,
328
+ 339.10812951,
329
+ 348.38276173,
330
+ 357.91393924,
331
+ 367.70513992,
332
+ 377.76413924,
333
+ 388.09467408,
334
+ 398.70920178,
335
+ 409.61813793,
336
+ 420.81980127,
337
+ 432.33215467,
338
+ 444.16083117,
339
+ 456.30919947,
340
+ 468.78589276,
341
+ 481.61325588,
342
+ 494.78824596,
343
+ 508.31969844,
344
+ 522.2238331,
345
+ 536.51163441,
346
+ 551.18859414,
347
+ 566.26142988,
348
+ 581.75006061,
349
+ 597.66210737,
350
+ ]
351
+ )
352
+ / 19.40951426
353
+ )
354
+ self.mel_weight_44k_128 = self.mel_weight_44k_128[None, None, None, ...]
355
+
356
+ self.g_loss_weight = 0.01
357
+ self.d_loss_weight = 1
358
+
359
+ def get_vocoder(self):
360
+ return self.vocoder
361
+
362
+ def get_f_helper(self):
363
+ return self.f_helper
364
+
365
+ def get_lr_lambda(self, step, gamma, warm_up_steps, reduce_lr_steps):
366
+ r"""Get lr_lambda for LambdaLR. E.g.,
367
+
368
+ .. code-block: python
369
+ lr_lambda = lambda step: get_lr_lambda(step, warm_up_steps=1000, reduce_lr_steps=10000)
370
+
371
+ from torch.optim.lr_scheduler import LambdaLR
372
+ LambdaLR(optimizer, lr_lambda)
373
+ """
374
+ if step <= warm_up_steps:
375
+ return step / warm_up_steps
376
+ else:
377
+ return gamma ** (step // reduce_lr_steps)
378
+
379
+ def init_weights(self, module: nn.Module):
380
+ for m in module.modules():
381
+ if type(m) in [nn.GRU, nn.LSTM, nn.RNN]:
382
+ for name, param in m.named_parameters():
383
+ if "weight_ih" in name:
384
+ torch.nn.init.xavier_uniform_(param.data)
385
+ elif "weight_hh" in name:
386
+ torch.nn.init.orthogonal_(param.data)
387
+ elif "bias" in name:
388
+ param.data.fill_(0)
389
+
390
+ def pre(self, input):
391
+ sp, _, _ = self.f_helper.wav_to_spectrogram_phase(input)
392
+ mel_orig = self.mel(sp.permute(0, 1, 3, 2)).permute(0, 1, 3, 2)
393
+ return sp, mel_orig
394
+
395
+ def forward(self, sp, mel_orig):
396
+ """
397
+ Args:
398
+ input: (batch_size, channels_num, segment_samples)
399
+
400
+ Outputs:
401
+ output_dict: {
402
+ 'wav': (batch_size, channels_num, segment_samples),
403
+ 'sp': (batch_size, channels_num, time_steps, freq_bins)}
404
+ """
405
+ return self.generator(sp, mel_orig)
406
+
407
+ def configure_optimizers(self):
408
+ optimizer_g = torch.optim.Adam(
409
+ [{"params": self.generator.parameters()}],
410
+ lr=self.lr,
411
+ amsgrad=True,
412
+ betas=(0.5, 0.999),
413
+ )
414
+ optimizer_d = torch.optim.Adam(
415
+ [{"params": self.discriminator.parameters()}],
416
+ lr=self.lr,
417
+ amsgrad=True,
418
+ betas=(0.5, 0.999),
419
+ )
420
+
421
+ scheduler_g = {
422
+ "scheduler": torch.optim.lr_scheduler.LambdaLR(optimizer_g, self.lr_lambda),
423
+ "interval": "step",
424
+ "frequency": 1,
425
+ }
426
+ scheduler_d = {
427
+ "scheduler": torch.optim.lr_scheduler.LambdaLR(optimizer_d, self.lr_lambda),
428
+ "interval": "step",
429
+ "frequency": 1,
430
+ }
431
+ return [optimizer_g, optimizer_d], [scheduler_g, scheduler_d]
432
+
433
+ def preprocess(self, batch, train=False, cutoff=None):
434
+ if train:
435
+ vocal = batch[self.type_target] # final target
436
+ noise = batch["noise_LR"] # augmented low resolution audio with noise
437
+ augLR = batch[
438
+ self.type_target + "_aug_LR"
439
+ ] # # augment low resolution audio
440
+ LR = batch[self.type_target + "_LR"]
441
+ # embed()
442
+ vocal, LR, augLR, noise = (
443
+ vocal.float().permute(0, 2, 1),
444
+ LR.float().permute(0, 2, 1),
445
+ augLR.float().permute(0, 2, 1),
446
+ noise.float().permute(0, 2, 1),
447
+ )
448
+ # LR, noise = self.add_random_noise(LR, noise)
449
+ snr, scale = [], []
450
+ for i in range(vocal.size()[0]):
451
+ (
452
+ vocal[i, ...],
453
+ LR[i, ...],
454
+ augLR[i, ...],
455
+ noise[i, ...],
456
+ _snr,
457
+ _scale,
458
+ ) = add_noise_and_scale_with_HQ_with_Aug(
459
+ vocal[i, ...],
460
+ LR[i, ...],
461
+ augLR[i, ...],
462
+ noise[i, ...],
463
+ snr_l=-5,
464
+ snr_h=45,
465
+ scale_lower=0.6,
466
+ scale_upper=1.0,
467
+ )
468
+ snr.append(_snr), scale.append(_scale)
469
+ # vocal, LR = self.amp_to_original_f(vocal, LR)
470
+ # noise = (noise * 0.0) + 1e-8 # todo
471
+ return vocal, augLR, LR, noise + augLR
472
+ else:
473
+ if cutoff is None:
474
+ LR_noisy = batch["noisy"]
475
+ LR = batch["vocals"]
476
+ vocals = batch["vocals"]
477
+ vocals, LR, LR_noisy = (
478
+ vocals.float().permute(0, 2, 1),
479
+ LR.float().permute(0, 2, 1),
480
+ LR_noisy.float().permute(0, 2, 1),
481
+ )
482
+ return vocals, LR, LR_noisy, batch["fname"][0]
483
+ else:
484
+ LR_noisy = batch["noisy" + "LR" + "_" + str(cutoff)]
485
+ LR = batch["vocals" + "LR" + "_" + str(cutoff)]
486
+ vocals = batch["vocals"]
487
+ vocals, LR, LR_noisy = (
488
+ vocals.float().permute(0, 2, 1),
489
+ LR.float().permute(0, 2, 1),
490
+ LR_noisy.float().permute(0, 2, 1),
491
+ )
492
+ return vocals, LR, LR_noisy, batch["fname"][0]
493
+
494
+ def training_step(self, batch, batch_nb, optimizer_idx):
495
+ # dict_keys(['vocals', 'vocals_aug', 'vocals_augLR', 'noise'])
496
+ config = load_json("temp_path.json")
497
+ if "g_loss_weight" not in config.keys():
498
+ config["g_loss_weight"] = self.g_loss_weight
499
+ config["d_loss_weight"] = self.d_loss_weight
500
+ write_json(config, "temp_path.json")
501
+ elif (
502
+ config["g_loss_weight"] != self.g_loss_weight
503
+ or config["d_loss_weight"] != self.d_loss_weight
504
+ ):
505
+ print(
506
+ "Update d_loss weight, from",
507
+ self.d_loss_weight,
508
+ "to",
509
+ config["d_loss_weight"],
510
+ )
511
+ print(
512
+ "Update g_loss weight, from",
513
+ self.g_loss_weight,
514
+ "to",
515
+ config["g_loss_weight"],
516
+ )
517
+ self.g_loss_weight = config["g_loss_weight"]
518
+ self.d_loss_weight = config["d_loss_weight"]
519
+
520
+ if optimizer_idx == 0:
521
+ self.vocal, self.augLR, _, self.LR_noisy = self.preprocess(
522
+ batch, train=True
523
+ )
524
+
525
+ for i in range(self.vocal.size()[0]):
526
+ save_wave(
527
+ tensor2numpy(self.vocal[i, ...]),
528
+ str(i) + "vocal" + ".wav",
529
+ sample_rate=44100,
530
+ )
531
+ save_wave(
532
+ tensor2numpy(self.LR_noisy[i, ...]),
533
+ str(i) + "LR_noisy" + ".wav",
534
+ sample_rate=44100,
535
+ )
536
+
537
+ # all_mel_e2e in non-log scale
538
+ _, self.mel_target = self.pre(self.vocal)
539
+ self.sp_LR_target, self.mel_LR_target = self.pre(self.augLR)
540
+ self.sp_LR_target_noisy, self.mel_LR_target_noisy = self.pre(self.LR_noisy)
541
+
542
+ if self.valid is None or self.valid.size()[0] != self.mel_target.size()[0]:
543
+ self.valid = torch.ones(
544
+ self.mel_target.size()[0], 1, self.mel_target.size()[2], 1
545
+ )
546
+ self.valid = self.valid.type_as(self.mel_target)
547
+ if self.fake is None or self.fake.size()[0] != self.mel_target.size()[0]:
548
+ self.fake = torch.zeros(
549
+ self.mel_target.size()[0], 1, self.mel_target.size()[2], 1
550
+ )
551
+ self.fake = self.fake.type_as(self.mel_target)
552
+
553
+ self.generated = self(self.sp_LR_target_noisy, self.mel_LR_target_noisy)
554
+
555
+ denoise_loss = self.l1loss(self.generated["clean"], self.mel_LR_target)
556
+ targ_loss = self.l1loss(self.generated["mel"], to_log(self.mel_target))
557
+
558
+ self.log(
559
+ "targ-l",
560
+ targ_loss,
561
+ on_step=True,
562
+ on_epoch=False,
563
+ logger=True,
564
+ sync_dist=True,
565
+ prog_bar=True,
566
+ )
567
+ self.log(
568
+ "noise-l",
569
+ denoise_loss,
570
+ on_step=True,
571
+ on_epoch=False,
572
+ logger=True,
573
+ sync_dist=True,
574
+ prog_bar=True,
575
+ )
576
+
577
+ loss = targ_loss + denoise_loss
578
+
579
+ if self.train_step >= 18000:
580
+ g_loss = self.bce_loss(
581
+ self.discriminator(self.generated["mel"]), self.valid
582
+ )
583
+ self.log(
584
+ "g_l",
585
+ g_loss,
586
+ on_step=True,
587
+ on_epoch=False,
588
+ logger=True,
589
+ sync_dist=True,
590
+ prog_bar=True,
591
+ )
592
+ # print("g_loss", g_loss)
593
+ all_loss = loss + self.g_loss_weight * g_loss
594
+ self.log(
595
+ "all_loss",
596
+ all_loss,
597
+ on_step=True,
598
+ on_epoch=True,
599
+ logger=True,
600
+ sync_dist=True,
601
+ )
602
+ else:
603
+ all_loss = loss
604
+ self.train_step += 0.5
605
+ return {"loss": all_loss}
606
+
607
+ elif optimizer_idx == 1:
608
+ if self.train_step >= 16000:
609
+ self.generated = self(self.sp_LR_target_noisy, self.mel_LR_target_noisy)
610
+ self.train_step += 0.5
611
+ real_loss = self.bce_loss(
612
+ self.discriminator(to_log(self.mel_target)), self.valid
613
+ )
614
+ self.log(
615
+ "r_l",
616
+ real_loss,
617
+ on_step=True,
618
+ on_epoch=False,
619
+ logger=True,
620
+ sync_dist=True,
621
+ prog_bar=True,
622
+ )
623
+ fake_loss = self.bce_loss(
624
+ self.discriminator(self.generated["mel"].detach()), self.fake
625
+ )
626
+ self.log(
627
+ "d_l",
628
+ fake_loss,
629
+ on_step=True,
630
+ on_epoch=False,
631
+ logger=True,
632
+ sync_dist=True,
633
+ prog_bar=True,
634
+ )
635
+ d_loss = self.d_loss_weight * (real_loss + fake_loss) / 2
636
+ self.log(
637
+ "discriminator_loss",
638
+ d_loss,
639
+ on_step=True,
640
+ on_epoch=True,
641
+ logger=True,
642
+ sync_dist=True,
643
+ )
644
+ return {"loss": d_loss}
645
+
646
+ def draw_and_save(
647
+ self, mel: torch.Tensor, path, clip_max=None, clip_min=None, needlog=True
648
+ ):
649
+ plt.figure(figsize=(15, 5))
650
+ if clip_min is None:
651
+ clip_max, clip_min = self.clip(mel)
652
+ mel = np.transpose(tensor2numpy(mel)[0, 0, ...], (1, 0))
653
+ # assert np.sum(mel < 0) == 0, str(np.sum(mel < 0)) + str(np.sum(mel < 0))
654
+
655
+ if needlog:
656
+ assert np.sum(mel < 0) == 0, str(np.sum(mel < 0)) + "-" + path
657
+ mel_log = np.log10(mel + EPS)
658
+ else:
659
+ mel_log = mel
660
+
661
+ # plt.imshow(mel)
662
+ librosa.display.specshow(
663
+ mel_log,
664
+ sr=44100,
665
+ x_axis="frames",
666
+ y_axis="mel",
667
+ cmap=cm.jet,
668
+ vmax=clip_max,
669
+ vmin=clip_min,
670
+ )
671
+ plt.colorbar()
672
+ plt.savefig(path)
673
+ plt.close()
674
+
675
+ def clip(self, *args):
676
+ val_max, val_min = [], []
677
+ for each in args:
678
+ val_max.append(torch.max(each))
679
+ val_min.append(torch.min(each))
680
+ return max(val_max), min(val_min)
voicefixer/restorer/model_kqq_bn.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from voicefixer.restorer.modules import *
2
+
3
+ from voicefixer.tools.pytorch_util import *
4
+
5
+
6
+ class UNetResComplex_100Mb(nn.Module):
7
+ def __init__(self, channels, nsrc=1):
8
+ super(UNetResComplex_100Mb, self).__init__()
9
+ activation = "relu"
10
+ momentum = 0.01
11
+
12
+ self.nsrc = nsrc
13
+ self.channels = channels
14
+ self.downsample_ratio = 2**6 # This number equals 2^{#encoder_blcoks}
15
+
16
+ self.encoder_block1 = EncoderBlockRes(
17
+ in_channels=channels * nsrc,
18
+ out_channels=32,
19
+ downsample=(2, 2),
20
+ activation=activation,
21
+ momentum=momentum,
22
+ )
23
+ self.encoder_block2 = EncoderBlockRes(
24
+ in_channels=32,
25
+ out_channels=64,
26
+ downsample=(2, 2),
27
+ activation=activation,
28
+ momentum=momentum,
29
+ )
30
+ self.encoder_block3 = EncoderBlockRes(
31
+ in_channels=64,
32
+ out_channels=128,
33
+ downsample=(2, 2),
34
+ activation=activation,
35
+ momentum=momentum,
36
+ )
37
+ self.encoder_block4 = EncoderBlockRes(
38
+ in_channels=128,
39
+ out_channels=256,
40
+ downsample=(2, 2),
41
+ activation=activation,
42
+ momentum=momentum,
43
+ )
44
+ self.encoder_block5 = EncoderBlockRes(
45
+ in_channels=256,
46
+ out_channels=384,
47
+ downsample=(2, 2),
48
+ activation=activation,
49
+ momentum=momentum,
50
+ )
51
+ self.encoder_block6 = EncoderBlockRes(
52
+ in_channels=384,
53
+ out_channels=384,
54
+ downsample=(2, 2),
55
+ activation=activation,
56
+ momentum=momentum,
57
+ )
58
+ self.conv_block7 = ConvBlockRes(
59
+ in_channels=384,
60
+ out_channels=384,
61
+ size=3,
62
+ activation=activation,
63
+ momentum=momentum,
64
+ )
65
+ self.decoder_block1 = DecoderBlockRes(
66
+ in_channels=384,
67
+ out_channels=384,
68
+ stride=(2, 2),
69
+ activation=activation,
70
+ momentum=momentum,
71
+ )
72
+ self.decoder_block2 = DecoderBlockRes(
73
+ in_channels=384,
74
+ out_channels=384,
75
+ stride=(2, 2),
76
+ activation=activation,
77
+ momentum=momentum,
78
+ )
79
+ self.decoder_block3 = DecoderBlockRes(
80
+ in_channels=384,
81
+ out_channels=256,
82
+ stride=(2, 2),
83
+ activation=activation,
84
+ momentum=momentum,
85
+ )
86
+ self.decoder_block4 = DecoderBlockRes(
87
+ in_channels=256,
88
+ out_channels=128,
89
+ stride=(2, 2),
90
+ activation=activation,
91
+ momentum=momentum,
92
+ )
93
+ self.decoder_block5 = DecoderBlockRes(
94
+ in_channels=128,
95
+ out_channels=64,
96
+ stride=(2, 2),
97
+ activation=activation,
98
+ momentum=momentum,
99
+ )
100
+ self.decoder_block6 = DecoderBlockRes(
101
+ in_channels=64,
102
+ out_channels=32,
103
+ stride=(2, 2),
104
+ activation=activation,
105
+ momentum=momentum,
106
+ )
107
+
108
+ self.after_conv_block1 = ConvBlockRes(
109
+ in_channels=32,
110
+ out_channels=32,
111
+ size=3,
112
+ activation=activation,
113
+ momentum=momentum,
114
+ )
115
+
116
+ self.after_conv2 = nn.Conv2d(
117
+ in_channels=32,
118
+ out_channels=1,
119
+ kernel_size=(1, 1),
120
+ stride=(1, 1),
121
+ padding=(0, 0),
122
+ bias=True,
123
+ )
124
+
125
+ self.init_weights()
126
+
127
+ def init_weights(self):
128
+ init_layer(self.after_conv2)
129
+
130
+ def forward(self, sp):
131
+ """
132
+ Args:
133
+ input: (batch_size, channels_num, segment_samples)
134
+
135
+ Outputs:
136
+ output_dict: {
137
+ 'wav': (batch_size, channels_num, segment_samples),
138
+ 'sp': (batch_size, channels_num, time_steps, freq_bins)}
139
+ """
140
+
141
+ # Batch normalization
142
+ x = sp
143
+
144
+ # Pad spectrogram to be evenly divided by downsample ratio.
145
+ origin_len = x.shape[2] # time_steps
146
+ pad_len = (
147
+ int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio
148
+ - origin_len
149
+ )
150
+ x = F.pad(x, pad=(0, 0, 0, pad_len))
151
+ x = x[..., 0 : x.shape[-1] - 1] # (bs, channels, T, F)
152
+
153
+ # UNet
154
+ (x1_pool, x1) = self.encoder_block1(x) # x1_pool: (bs, 32, T / 2, F / 2)
155
+ (x2_pool, x2) = self.encoder_block2(x1_pool) # x2_pool: (bs, 64, T / 4, F / 4)
156
+ (x3_pool, x3) = self.encoder_block3(x2_pool) # x3_pool: (bs, 128, T / 8, F / 8)
157
+ (x4_pool, x4) = self.encoder_block4(
158
+ x3_pool
159
+ ) # x4_pool: (bs, 256, T / 16, F / 16)
160
+ (x5_pool, x5) = self.encoder_block5(
161
+ x4_pool
162
+ ) # x5_pool: (bs, 512, T / 32, F / 32)
163
+ (x6_pool, x6) = self.encoder_block6(
164
+ x5_pool
165
+ ) # x6_pool: (bs, 1024, T / 64, F / 64)
166
+ x_center = self.conv_block7(x6_pool) # (bs, 2048, T / 64, F / 64)
167
+ x7 = self.decoder_block1(x_center, x6) # (bs, 1024, T / 32, F / 32)
168
+ x8 = self.decoder_block2(x7, x5) # (bs, 512, T / 16, F / 16)
169
+ x9 = self.decoder_block3(x8, x4) # (bs, 256, T / 8, F / 8)
170
+ x10 = self.decoder_block4(x9, x3) # (bs, 128, T / 4, F / 4)
171
+ x11 = self.decoder_block5(x10, x2) # (bs, 64, T / 2, F / 2)
172
+ x12 = self.decoder_block6(x11, x1) # (bs, 32, T, F)
173
+ x = self.after_conv_block1(x12) # (bs, 32, T, F)
174
+ x = self.after_conv2(x) # (bs, channels, T, F)
175
+
176
+ # Recover shape
177
+ x = F.pad(x, pad=(0, 1))
178
+ x = x[:, :, 0:origin_len, :]
179
+
180
+ output_dict = {"mel": x}
181
+ return output_dict
182
+
183
+
184
+ if __name__ == "__main__":
185
+ model = UNetResComplex_100Mb(channels=1)
186
+ print(model(torch.randn((1, 1, 101, 128)))["mel"].size())
voicefixer/restorer/modules.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+
7
+ class ConvBlockRes(nn.Module):
8
+ def __init__(self, in_channels, out_channels, size, activation, momentum):
9
+ super(ConvBlockRes, self).__init__()
10
+
11
+ self.activation = activation
12
+ if type(size) == type((3, 4)):
13
+ pad = size[0] // 2
14
+ size = size[0]
15
+ else:
16
+ pad = size // 2
17
+ size = size
18
+
19
+ self.conv1 = nn.Conv2d(
20
+ in_channels=in_channels,
21
+ out_channels=out_channels,
22
+ kernel_size=(size, size),
23
+ stride=(1, 1),
24
+ dilation=(1, 1),
25
+ padding=(pad, pad),
26
+ bias=False,
27
+ )
28
+
29
+ self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
30
+ # self.abn1 = InPlaceABN(num_features=in_channels, momentum=momentum, activation='leaky_relu')
31
+
32
+ self.conv2 = nn.Conv2d(
33
+ in_channels=out_channels,
34
+ out_channels=out_channels,
35
+ kernel_size=(size, size),
36
+ stride=(1, 1),
37
+ dilation=(1, 1),
38
+ padding=(pad, pad),
39
+ bias=False,
40
+ )
41
+
42
+ self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)
43
+
44
+ # self.abn2 = InPlaceABN(num_features=out_channels, momentum=momentum, activation='leaky_relu')
45
+
46
+ if in_channels != out_channels:
47
+ self.shortcut = nn.Conv2d(
48
+ in_channels=in_channels,
49
+ out_channels=out_channels,
50
+ kernel_size=(1, 1),
51
+ stride=(1, 1),
52
+ padding=(0, 0),
53
+ )
54
+ self.is_shortcut = True
55
+ else:
56
+ self.is_shortcut = False
57
+
58
+ self.init_weights()
59
+
60
+ def init_weights(self):
61
+ init_bn(self.bn1)
62
+ init_layer(self.conv1)
63
+ init_layer(self.conv2)
64
+
65
+ if self.is_shortcut:
66
+ init_layer(self.shortcut)
67
+
68
+ def forward(self, x):
69
+ origin = x
70
+ x = self.conv1(F.leaky_relu_(self.bn1(x), negative_slope=0.01))
71
+ x = self.conv2(F.leaky_relu_(self.bn2(x), negative_slope=0.01))
72
+
73
+ if self.is_shortcut:
74
+ return self.shortcut(origin) + x
75
+ else:
76
+ return origin + x
77
+
78
+
79
+ class EncoderBlockRes(nn.Module):
80
+ def __init__(self, in_channels, out_channels, downsample, activation, momentum):
81
+ super(EncoderBlockRes, self).__init__()
82
+ size = 3
83
+
84
+ self.conv_block1 = ConvBlockRes(
85
+ in_channels, out_channels, size, activation, momentum
86
+ )
87
+ self.conv_block2 = ConvBlockRes(
88
+ out_channels, out_channels, size, activation, momentum
89
+ )
90
+ self.conv_block3 = ConvBlockRes(
91
+ out_channels, out_channels, size, activation, momentum
92
+ )
93
+ self.conv_block4 = ConvBlockRes(
94
+ out_channels, out_channels, size, activation, momentum
95
+ )
96
+ self.downsample = downsample
97
+
98
+ def forward(self, x):
99
+ encoder = self.conv_block1(x)
100
+ encoder = self.conv_block2(encoder)
101
+ encoder = self.conv_block3(encoder)
102
+ encoder = self.conv_block4(encoder)
103
+ encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
104
+ return encoder_pool, encoder
105
+
106
+
107
+ class DecoderBlockRes(nn.Module):
108
+ def __init__(self, in_channels, out_channels, stride, activation, momentum):
109
+ super(DecoderBlockRes, self).__init__()
110
+ size = 3
111
+ self.activation = activation
112
+
113
+ self.conv1 = torch.nn.ConvTranspose2d(
114
+ in_channels=in_channels,
115
+ out_channels=out_channels,
116
+ kernel_size=(size, size),
117
+ stride=stride,
118
+ padding=(0, 0),
119
+ output_padding=(0, 0),
120
+ bias=False,
121
+ dilation=(1, 1),
122
+ )
123
+
124
+ self.bn1 = nn.BatchNorm2d(in_channels)
125
+ self.conv_block2 = ConvBlockRes(
126
+ out_channels * 2, out_channels, size, activation, momentum
127
+ )
128
+ self.conv_block3 = ConvBlockRes(
129
+ out_channels, out_channels, size, activation, momentum
130
+ )
131
+ self.conv_block4 = ConvBlockRes(
132
+ out_channels, out_channels, size, activation, momentum
133
+ )
134
+ self.conv_block5 = ConvBlockRes(
135
+ out_channels, out_channels, size, activation, momentum
136
+ )
137
+
138
+ def init_weights(self):
139
+ init_layer(self.conv1)
140
+
141
+ def prune(self, x, both=False):
142
+ """Prune the shape of x after transpose convolution."""
143
+ if both:
144
+ x = x[:, :, 0:-1, 0:-1]
145
+ else:
146
+ x = x[:, :, 0:-1, :]
147
+ return x
148
+
149
+ def forward(self, input_tensor, concat_tensor, both=False):
150
+ x = self.conv1(F.relu_(self.bn1(input_tensor)))
151
+ x = self.prune(x, both=both)
152
+ x = torch.cat((x, concat_tensor), dim=1)
153
+ x = self.conv_block2(x)
154
+ x = self.conv_block3(x)
155
+ x = self.conv_block4(x)
156
+ x = self.conv_block5(x)
157
+ return x
158
+
159
+
160
+ def init_layer(layer):
161
+ """Initialize a Linear or Convolutional layer."""
162
+ nn.init.xavier_uniform_(layer.weight)
163
+
164
+ if hasattr(layer, "bias"):
165
+ if layer.bias is not None:
166
+ layer.bias.data.fill_(0.0)
167
+
168
+
169
+ def init_bn(bn):
170
+ """Initialize a Batchnorm layer."""
171
+ bn.bias.data.fill_(0.0)
172
+ bn.weight.data.fill_(1.0)
173
+
174
+
175
+ def init_gru(rnn):
176
+ """Initialize a GRU layer."""
177
+
178
+ def _concat_init(tensor, init_funcs):
179
+ (length, fan_out) = tensor.shape
180
+ fan_in = length // len(init_funcs)
181
+
182
+ for (i, init_func) in enumerate(init_funcs):
183
+ init_func(tensor[i * fan_in : (i + 1) * fan_in, :])
184
+
185
+ def _inner_uniform(tensor):
186
+ fan_in = nn.init._calculate_correct_fan(tensor, "fan_in")
187
+ nn.init.uniform_(tensor, -math.sqrt(3 / fan_in), math.sqrt(3 / fan_in))
188
+
189
+ for i in range(rnn.num_layers):
190
+ _concat_init(
191
+ getattr(rnn, "weight_ih_l{}".format(i)),
192
+ [_inner_uniform, _inner_uniform, _inner_uniform],
193
+ )
194
+ torch.nn.init.constant_(getattr(rnn, "bias_ih_l{}".format(i)), 0)
195
+
196
+ _concat_init(
197
+ getattr(rnn, "weight_hh_l{}".format(i)),
198
+ [_inner_uniform, _inner_uniform, nn.init.orthogonal_],
199
+ )
200
+ torch.nn.init.constant_(getattr(rnn, "bias_hh_l{}".format(i)), 0)
201
+
202
+
203
+ from torch.cuda import init
204
+
205
+
206
+ def act(x, activation):
207
+ if activation == "relu":
208
+ return F.relu_(x)
209
+
210
+ elif activation == "leaky_relu":
211
+ return F.leaky_relu_(x, negative_slope=0.2)
212
+
213
+ elif activation == "swish":
214
+ return x * torch.sigmoid(x)
215
+
216
+ else:
217
+ raise Exception("Incorrect activation!")
voicefixer/tools/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+ """
4
+ @File : __init__.py.py
5
+ @Contact : [email protected]
6
+ @License : (C)Copyright 2020-2100
7
+
8
+ @Modify Time @Author @Version @Desciption
9
+ ------------ ------- -------- -----------
10
+ 9/14/21 12:28 AM Haohe Liu 1.0 None
11
+ """
voicefixer/tools/base.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+ import os
6
+ import torch.fft
7
+
8
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
9
+
10
+
11
+ def get_window(window_size, window_type, square_root_window=True):
12
+ """Return the window"""
13
+ window = {
14
+ "hamming": torch.hamming_window(window_size),
15
+ "hanning": torch.hann_window(window_size),
16
+ }[window_type]
17
+ if square_root_window:
18
+ window = torch.sqrt(window)
19
+ return window
20
+
21
+
22
+ def fft_point(dim):
23
+ assert dim > 0
24
+ num = math.log(dim, 2)
25
+ num_point = 2 ** (math.ceil(num))
26
+ return num_point
27
+
28
+
29
+ def pre_emphasis(signal, coefficient=0.97):
30
+ """Pre-emphasis original signal
31
+ y(n) = x(n) - a*x(n-1)
32
+ """
33
+ return np.append(signal[0], signal[1:] - coefficient * signal[:-1])
34
+
35
+
36
+ def de_emphasis(signal, coefficient=0.97):
37
+ """De-emphasis original signal
38
+ y(n) = x(n) + a*x(n-1)
39
+ """
40
+ length = signal.shape[0]
41
+ for i in range(1, length):
42
+ signal[i] = signal[i] + coefficient * signal[i - 1]
43
+ return signal
44
+
45
+
46
+ def seperate_magnitude(magnitude, phase):
47
+ real = torch.cos(phase) * magnitude
48
+ imagine = torch.sin(phase) * magnitude
49
+ expand_dim = len(list(real.size()))
50
+ return torch.stack((real, imagine), expand_dim)
51
+
52
+
53
+ def stft_single(
54
+ signal,
55
+ sample_rate=44100,
56
+ frame_length=46,
57
+ frame_shift=10,
58
+ window_type="hanning",
59
+ device=torch.device("cuda"),
60
+ square_root_window=True,
61
+ ):
62
+ """Compute the Short Time Fourier Transform.
63
+
64
+ Args:
65
+ signal: input speech signal,
66
+ sample_rate: waveform datas sample frequency (Hz)
67
+ frame_length: frame length in milliseconds
68
+ frame_shift: frame shift in milliseconds
69
+ window_type: type of window
70
+ square_root_window: square root window
71
+ Return:
72
+ fft: (n/2)+1 dim complex STFT restults
73
+ """
74
+ hop_length = int(
75
+ sample_rate * frame_shift / 1000
76
+ ) # The greater sample_rate, the greater hop_length
77
+ win_length = int(sample_rate * frame_length / 1000)
78
+ # num_point = fft_point(win_length)
79
+ num_point = win_length
80
+ window = get_window(num_point, window_type, square_root_window)
81
+ if "cuda" in str(device):
82
+ window = window.cuda(device)
83
+ feat = torch.stft(
84
+ signal,
85
+ n_fft=num_point,
86
+ hop_length=hop_length,
87
+ win_length=window.shape[0],
88
+ window=window,
89
+ )
90
+ real, imag = feat[..., 0], feat[..., 1]
91
+ return real.permute(0, 2, 1).unsqueeze(1), imag.permute(0, 2, 1).unsqueeze(1)
92
+
93
+
94
+ def istft(
95
+ real,
96
+ imag,
97
+ length,
98
+ sample_rate=44100,
99
+ frame_length=46,
100
+ frame_shift=10,
101
+ window_type="hanning",
102
+ preemphasis=0.0,
103
+ device=torch.device("cuda"),
104
+ square_root_window=True,
105
+ ):
106
+ """Convert frames to signal using overlap-and-add systhesis.
107
+ Args:
108
+ spectrum: magnitude spectrum [batchsize,x,y,2]
109
+ signal: wave signal to supply phase information
110
+ Return:
111
+ wav: synthesied output waveform
112
+ """
113
+ real = real.permute(0, 3, 2, 1)
114
+ imag = imag.permute(0, 3, 2, 1)
115
+ spectrum = torch.cat([real, imag], dim=-1)
116
+
117
+ hop_length = int(sample_rate * frame_shift / 1000)
118
+ win_length = int(sample_rate * frame_length / 1000)
119
+
120
+ # num_point = fft_point(win_length)
121
+ num_point = win_length
122
+ if "cuda" in str(device):
123
+ window = get_window(num_point, window_type, square_root_window).cuda(device)
124
+ else:
125
+ window = get_window(num_point, window_type, square_root_window)
126
+
127
+ wav = torch_istft(
128
+ spectrum,
129
+ num_point,
130
+ hop_length=hop_length,
131
+ win_length=window.shape[0],
132
+ window=window,
133
+ )
134
+ return wav[..., :length]
135
+
136
+
137
+ def torch_istft(
138
+ stft_matrix, # type: Tensor
139
+ n_fft, # type: int
140
+ hop_length=None, # type: Optional[int]
141
+ win_length=None, # type: Optional[int]
142
+ window=None, # type: Optional[Tensor]
143
+ center=True, # type: bool
144
+ pad_mode="reflect", # type: str
145
+ normalized=False, # type: bool
146
+ onesided=True, # type: bool
147
+ length=None, # type: Optional[int]
148
+ ):
149
+ # type: (...) -> Tensor
150
+
151
+ stft_matrix_dim = stft_matrix.dim()
152
+ assert 3 <= stft_matrix_dim <= 4, "Incorrect stft dimension: %d" % (stft_matrix_dim)
153
+
154
+ if stft_matrix_dim == 3:
155
+ # add a channel dimension
156
+ stft_matrix = stft_matrix.unsqueeze(0)
157
+
158
+ dtype = stft_matrix.dtype
159
+ device = stft_matrix.device
160
+ fft_size = stft_matrix.size(1)
161
+ assert (onesided and n_fft // 2 + 1 == fft_size) or (
162
+ not onesided and n_fft == fft_size
163
+ ), (
164
+ "one_sided implies that n_fft // 2 + 1 == fft_size and not one_sided implies n_fft == fft_size. "
165
+ + "Given values were onesided: %s, n_fft: %d, fft_size: %d"
166
+ % ("True" if onesided else False, n_fft, fft_size)
167
+ )
168
+
169
+ # use stft defaults for Optionals
170
+ if win_length is None:
171
+ win_length = n_fft
172
+
173
+ if hop_length is None:
174
+ hop_length = int(win_length // 4)
175
+
176
+ # There must be overlap
177
+ assert 0 < hop_length <= win_length
178
+ assert 0 < win_length <= n_fft
179
+
180
+ if window is None:
181
+ window = torch.ones(win_length, requires_grad=False, device=device, dtype=dtype)
182
+
183
+ assert window.dim() == 1 and window.size(0) == win_length
184
+
185
+ if win_length != n_fft:
186
+ # center window with pad left and right zeros
187
+ left = (n_fft - win_length) // 2
188
+ window = torch.nn.functional.pad(window, (left, n_fft - win_length - left))
189
+ assert window.size(0) == n_fft
190
+ # win_length and n_fft are synonymous from here on
191
+
192
+ stft_matrix = stft_matrix.transpose(1, 2) # size (channel, n_frames, fft_size, 2)
193
+ stft_matrix = torch.irfft(
194
+ stft_matrix, 1, normalized, onesided, signal_sizes=(n_fft,)
195
+ ) # size (channel, n_frames, n_fft)
196
+
197
+ assert stft_matrix.size(2) == n_fft
198
+ n_frames = stft_matrix.size(1)
199
+
200
+ ytmp = stft_matrix * window.view(1, 1, n_fft) # size (channel, n_frames, n_fft)
201
+ # each column of a channel is a frame which needs to be overlap added at the right place
202
+ ytmp = ytmp.transpose(1, 2) # size (channel, n_fft, n_frames)
203
+
204
+ eye = torch.eye(n_fft, requires_grad=False, device=device, dtype=dtype).unsqueeze(
205
+ 1
206
+ ) # size (n_fft, 1, n_fft)
207
+
208
+ # this does overlap add where the frames of ytmp are added such that the i'th frame of
209
+ # ytmp is added starting at i*hop_length in the output
210
+ y = torch.nn.functional.conv_transpose1d(
211
+ ytmp, eye, stride=hop_length, padding=0
212
+ ) # size (channel, 1, expected_signal_len)
213
+
214
+ # do the same for the window function
215
+ window_sq = (
216
+ window.pow(2).view(n_fft, 1).repeat((1, n_frames)).unsqueeze(0)
217
+ ) # size (1, n_fft, n_frames)
218
+ window_envelop = torch.nn.functional.conv_transpose1d(
219
+ window_sq, eye, stride=hop_length, padding=0
220
+ ) # size (1, 1, expected_signal_len)
221
+
222
+ expected_signal_len = n_fft + hop_length * (n_frames - 1)
223
+ assert y.size(2) == expected_signal_len
224
+ assert window_envelop.size(2) == expected_signal_len
225
+
226
+ half_n_fft = n_fft // 2
227
+ # we need to trim the front padding away if center
228
+ start = half_n_fft if center else 0
229
+ end = -half_n_fft if length is None else start + length
230
+
231
+ y = y[:, :, start:end]
232
+ window_envelop = window_envelop[:, :, start:end]
233
+
234
+ # check NOLA non-zero overlap condition
235
+ window_envelop_lowest = window_envelop.abs().min()
236
+ assert window_envelop_lowest > 1e-11, "window overlap add min: %f" % (
237
+ window_envelop_lowest
238
+ )
239
+
240
+ y = (y / window_envelop).squeeze(1) # size (channel, expected_signal_len)
241
+
242
+ if stft_matrix_dim == 3: # remove the channel dimension
243
+ y = y.squeeze(0)
244
+ return y
voicefixer/tools/io.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import pickle
3
+
4
+
5
+ def read_list(fname):
6
+ result = []
7
+ with open(fname, "r") as f:
8
+ for each in f.readlines():
9
+ each = each.strip("\n")
10
+ result.append(each)
11
+ return result
12
+
13
+
14
+ def write_list(list, fname):
15
+ with open(fname, "w") as f:
16
+ for word in list:
17
+ f.write(word)
18
+ f.write("\n")
19
+
20
+
21
+ def write_json(my_dict, fname):
22
+ # print("Save json file at "+fname)
23
+ json_str = json.dumps(my_dict)
24
+ with open(fname, "w") as json_file:
25
+ json_file.write(json_str)
26
+
27
+
28
+ def load_json(fname):
29
+ with open(fname, "r") as f:
30
+ data = json.load(f)
31
+ return data
32
+
33
+
34
+ def save_pickle(obj, fname):
35
+ # print("Save pickle at "+fname)
36
+ with open(fname, "wb") as f:
37
+ pickle.dump(obj, f)
38
+
39
+
40
+ def load_pickle(fname):
41
+ # print("Load pickle at "+fname)
42
+ with open(fname, "rb") as f:
43
+ res = pickle.load(f)
44
+ return res
voicefixer/tools/mel_scale.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+ from typing import Optional
4
+ import math
5
+
6
+ import warnings
7
+
8
+
9
+ class MelScale(torch.nn.Module):
10
+ r"""Turn a normal STFT into a mel frequency STFT, using a conversion
11
+ matrix. This uses triangular filter banks.
12
+
13
+ User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)).
14
+
15
+ Args:
16
+ n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
17
+ sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
18
+ f_min (float, optional): Minimum frequency. (Default: ``0.``)
19
+ f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
20
+ n_stft (int, optional): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`. (Default: ``201``)
21
+ norm (str or None, optional): If 'slaney', divide the triangular mel weights by the width of the mel band
22
+ (area normalization). (Default: ``None``)
23
+ mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
24
+
25
+ See also:
26
+ :py:func:`torchaudio.functional.melscale_fbanks` - The function used to
27
+ generate the filter banks.
28
+ """
29
+ __constants__ = ["n_mels", "sample_rate", "f_min", "f_max"]
30
+
31
+ def __init__(
32
+ self,
33
+ n_mels: int = 128,
34
+ sample_rate: int = 16000,
35
+ f_min: float = 0.0,
36
+ f_max: Optional[float] = None,
37
+ n_stft: int = 201,
38
+ norm: Optional[str] = None,
39
+ mel_scale: str = "htk",
40
+ ) -> None:
41
+ super(MelScale, self).__init__()
42
+ self.n_mels = n_mels
43
+ self.sample_rate = sample_rate
44
+ self.f_max = f_max if f_max is not None else float(sample_rate // 2)
45
+ self.f_min = f_min
46
+ self.norm = norm
47
+ self.mel_scale = mel_scale
48
+
49
+ assert f_min <= self.f_max, "Require f_min: {} < f_max: {}".format(
50
+ f_min, self.f_max
51
+ )
52
+ fb = melscale_fbanks(
53
+ n_stft,
54
+ self.f_min,
55
+ self.f_max,
56
+ self.n_mels,
57
+ self.sample_rate,
58
+ self.norm,
59
+ self.mel_scale,
60
+ )
61
+ self.register_buffer("fb", fb)
62
+
63
+ def forward(self, specgram: Tensor) -> Tensor:
64
+ r"""
65
+ Args:
66
+ specgram (Tensor): A spectrogram STFT of dimension (..., freq, time).
67
+
68
+ Returns:
69
+ Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
70
+ """
71
+
72
+ # (..., time, freq) dot (freq, n_mels) -> (..., n_mels, time)
73
+ mel_specgram = torch.matmul(specgram.transpose(-1, -2), self.fb).transpose(
74
+ -1, -2
75
+ )
76
+
77
+ return mel_specgram
78
+
79
+
80
+ def _hz_to_mel(freq: float, mel_scale: str = "htk") -> float:
81
+ r"""Convert Hz to Mels.
82
+
83
+ Args:
84
+ freqs (float): Frequencies in Hz
85
+ mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
86
+
87
+ Returns:
88
+ mels (float): Frequency in Mels
89
+ """
90
+
91
+ if mel_scale not in ["slaney", "htk"]:
92
+ raise ValueError('mel_scale should be one of "htk" or "slaney".')
93
+
94
+ if mel_scale == "htk":
95
+ return 2595.0 * math.log10(1.0 + (freq / 700.0))
96
+
97
+ # Fill in the linear part
98
+ f_min = 0.0
99
+ f_sp = 200.0 / 3
100
+
101
+ mels = (freq - f_min) / f_sp
102
+
103
+ # Fill in the log-scale part
104
+ min_log_hz = 1000.0
105
+ min_log_mel = (min_log_hz - f_min) / f_sp
106
+ logstep = math.log(6.4) / 27.0
107
+
108
+ if freq >= min_log_hz:
109
+ mels = min_log_mel + math.log(freq / min_log_hz) / logstep
110
+
111
+ return mels
112
+
113
+
114
+ def _mel_to_hz(mels: Tensor, mel_scale: str = "htk") -> Tensor:
115
+ """Convert mel bin numbers to frequencies.
116
+
117
+ Args:
118
+ mels (Tensor): Mel frequencies
119
+ mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
120
+
121
+ Returns:
122
+ freqs (Tensor): Mels converted in Hz
123
+ """
124
+
125
+ if mel_scale not in ["slaney", "htk"]:
126
+ raise ValueError('mel_scale should be one of "htk" or "slaney".')
127
+
128
+ if mel_scale == "htk":
129
+ return 700.0 * (10.0 ** (mels / 2595.0) - 1.0)
130
+
131
+ # Fill in the linear scale
132
+ f_min = 0.0
133
+ f_sp = 200.0 / 3
134
+ freqs = f_min + f_sp * mels
135
+
136
+ # And now the nonlinear scale
137
+ min_log_hz = 1000.0
138
+ min_log_mel = (min_log_hz - f_min) / f_sp
139
+ logstep = math.log(6.4) / 27.0
140
+
141
+ log_t = mels >= min_log_mel
142
+ freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))
143
+
144
+ return freqs
145
+
146
+
147
+ def _create_triangular_filterbank(
148
+ all_freqs: Tensor,
149
+ f_pts: Tensor,
150
+ ) -> Tensor:
151
+ """Create a triangular filter bank.
152
+
153
+ Args:
154
+ all_freqs (Tensor): STFT freq points of size (`n_freqs`).
155
+ f_pts (Tensor): Filter mid points of size (`n_filter`).
156
+
157
+ Returns:
158
+ fb (Tensor): The filter bank of size (`n_freqs`, `n_filter`).
159
+ """
160
+ # Adopted from Librosa
161
+ # calculate the difference between each filter mid point and each stft freq point in hertz
162
+ f_diff = f_pts[1:] - f_pts[:-1] # (n_filter + 1)
163
+ slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # (n_freqs, n_filter + 2)
164
+ # create overlapping triangles
165
+ zero = torch.zeros(1)
166
+ down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_filter)
167
+ up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_filter)
168
+ fb = torch.max(zero, torch.min(down_slopes, up_slopes))
169
+
170
+ return fb
171
+
172
+
173
+ def melscale_fbanks(
174
+ n_freqs: int,
175
+ f_min: float,
176
+ f_max: float,
177
+ n_mels: int,
178
+ sample_rate: int,
179
+ norm: Optional[str] = None,
180
+ mel_scale: str = "htk",
181
+ ) -> Tensor:
182
+ r"""Create a frequency bin conversion matrix.
183
+
184
+ Note:
185
+ For the sake of the numerical compatibility with librosa, not all the coefficients
186
+ in the resulting filter bank has magnitude of 1.
187
+
188
+ .. image:: https://download.pytorch.org/torchaudio/doc-assets/mel_fbanks.png
189
+ :alt: Visualization of generated filter bank
190
+
191
+ Args:
192
+ n_freqs (int): Number of frequencies to highlight/apply
193
+ f_min (float): Minimum frequency (Hz)
194
+ f_max (float): Maximum frequency (Hz)
195
+ n_mels (int): Number of mel filterbanks
196
+ sample_rate (int): Sample rate of the audio waveform
197
+ norm (str or None, optional): If 'slaney', divide the triangular mel weights by the width of the mel band
198
+ (area normalization). (Default: ``None``)
199
+ mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
200
+
201
+ Returns:
202
+ Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``)
203
+ meaning number of frequencies to highlight/apply to x the number of filterbanks.
204
+ Each column is a filterbank so that assuming there is a matrix A of
205
+ size (..., ``n_freqs``), the applied result would be
206
+ ``A * melscale_fbanks(A.size(-1), ...)``.
207
+
208
+ """
209
+
210
+ if norm is not None and norm != "slaney":
211
+ raise ValueError("norm must be one of None or 'slaney'")
212
+
213
+ # freq bins
214
+ all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
215
+
216
+ # calculate mel freq bins
217
+ m_min = _hz_to_mel(f_min, mel_scale=mel_scale)
218
+ m_max = _hz_to_mel(f_max, mel_scale=mel_scale)
219
+
220
+ m_pts = torch.linspace(m_min, m_max, n_mels + 2)
221
+ f_pts = _mel_to_hz(m_pts, mel_scale=mel_scale)
222
+
223
+ # create filterbank
224
+ fb = _create_triangular_filterbank(all_freqs, f_pts)
225
+
226
+ if norm is not None and norm == "slaney":
227
+ # Slaney-style mel is scaled to be approx constant energy per channel
228
+ enorm = 2.0 / (f_pts[2 : n_mels + 2] - f_pts[:n_mels])
229
+ fb *= enorm.unsqueeze(0)
230
+
231
+ if (fb.max(dim=0).values == 0.0).any():
232
+ warnings.warn(
233
+ "At least one mel filterbank has all zero values. "
234
+ f"The value for `n_mels` ({n_mels}) may be set too high. "
235
+ f"Or, the value for `n_freqs` ({n_freqs}) may be set too low."
236
+ )
237
+
238
+ return fb
voicefixer/tools/modules/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+ """
4
+ @File : __init__.py.py
5
+ @Contact : [email protected]
6
+ @License : (C)Copyright 2020-2100
7
+
8
+ @Modify Time @Author @Version @Desciption
9
+ ------------ ------- -------- -----------
10
+ 9/14/21 12:29 AM Haohe Liu 1.0 None
11
+ """
voicefixer/tools/modules/fDomainHelper.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchlibrosa.stft import STFT, ISTFT, magphase
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from voicefixer.tools.modules.pqmf import PQMF
6
+
7
+ class FDomainHelper(nn.Module):
8
+ def __init__(
9
+ self,
10
+ window_size=2048,
11
+ hop_size=441,
12
+ center=True,
13
+ pad_mode="reflect",
14
+ window="hann",
15
+ freeze_parameters=True,
16
+ subband=None,
17
+ root="/Users/admin/Documents/projects/",
18
+ ):
19
+ super(FDomainHelper, self).__init__()
20
+ self.subband = subband
21
+ # assert torchlibrosa.__version__ == "0.0.7", "Error: Found torchlibrosa version %s. Please install 0.0.7 version of torchlibrosa by: pip install torchlibrosa==0.0.7." % torchlibrosa.__version__
22
+ if self.subband is None:
23
+ self.stft = STFT(
24
+ n_fft=window_size,
25
+ hop_length=hop_size,
26
+ win_length=window_size,
27
+ window=window,
28
+ center=center,
29
+ pad_mode=pad_mode,
30
+ freeze_parameters=freeze_parameters,
31
+ )
32
+
33
+ self.istft = ISTFT(
34
+ n_fft=window_size,
35
+ hop_length=hop_size,
36
+ win_length=window_size,
37
+ window=window,
38
+ center=center,
39
+ pad_mode=pad_mode,
40
+ freeze_parameters=freeze_parameters,
41
+ )
42
+ else:
43
+ self.stft = STFT(
44
+ n_fft=window_size // self.subband,
45
+ hop_length=hop_size // self.subband,
46
+ win_length=window_size // self.subband,
47
+ window=window,
48
+ center=center,
49
+ pad_mode=pad_mode,
50
+ freeze_parameters=freeze_parameters,
51
+ )
52
+
53
+ self.istft = ISTFT(
54
+ n_fft=window_size // self.subband,
55
+ hop_length=hop_size // self.subband,
56
+ win_length=window_size // self.subband,
57
+ window=window,
58
+ center=center,
59
+ pad_mode=pad_mode,
60
+ freeze_parameters=freeze_parameters,
61
+ )
62
+
63
+ if subband is not None and root is not None:
64
+ self.qmf = PQMF(subband, 64, root)
65
+
66
+ def complex_spectrogram(self, input, eps=0.0):
67
+ # [batchsize, samples]
68
+ # return [batchsize, 2, t-steps, f-bins]
69
+ real, imag = self.stft(input)
70
+ return torch.cat([real, imag], dim=1)
71
+
72
+ def reverse_complex_spectrogram(self, input, eps=0.0, length=None):
73
+ # [batchsize, 2[real,imag], t-steps, f-bins]
74
+ wav = self.istft(input[:, 0:1, ...], input[:, 1:2, ...], length=length)
75
+ return wav
76
+
77
+ def spectrogram(self, input, eps=0.0):
78
+ (real, imag) = self.stft(input.float())
79
+ return torch.clamp(real**2 + imag**2, eps, np.inf) ** 0.5
80
+
81
+ def spectrogram_phase(self, input, eps=0.0):
82
+ (real, imag) = self.stft(input.float())
83
+ mag = torch.clamp(real**2 + imag**2, eps, np.inf) ** 0.5
84
+ cos = real / mag
85
+ sin = imag / mag
86
+ return mag, cos, sin
87
+
88
+ def wav_to_spectrogram_phase(self, input, eps=1e-8):
89
+ """Waveform to spectrogram.
90
+
91
+ Args:
92
+ input: (batch_size, channels_num, segment_samples)
93
+
94
+ Outputs:
95
+ output: (batch_size, channels_num, time_steps, freq_bins)
96
+ """
97
+ sp_list = []
98
+ cos_list = []
99
+ sin_list = []
100
+ channels_num = input.shape[1]
101
+ for channel in range(channels_num):
102
+ mag, cos, sin = self.spectrogram_phase(input[:, channel, :], eps=eps)
103
+ sp_list.append(mag)
104
+ cos_list.append(cos)
105
+ sin_list.append(sin)
106
+
107
+ sps = torch.cat(sp_list, dim=1)
108
+ coss = torch.cat(cos_list, dim=1)
109
+ sins = torch.cat(sin_list, dim=1)
110
+ return sps, coss, sins
111
+
112
+ def spectrogram_phase_to_wav(self, sps, coss, sins, length):
113
+ channels_num = sps.size()[1]
114
+ res = []
115
+ for i in range(channels_num):
116
+ res.append(
117
+ self.istft(
118
+ sps[:, i : i + 1, ...] * coss[:, i : i + 1, ...],
119
+ sps[:, i : i + 1, ...] * sins[:, i : i + 1, ...],
120
+ length,
121
+ )
122
+ )
123
+ res[-1] = res[-1].unsqueeze(1)
124
+ return torch.cat(res, dim=1)
125
+
126
+ def wav_to_spectrogram(self, input, eps=1e-8):
127
+ """Waveform to spectrogram.
128
+
129
+ Args:
130
+ input: (batch_size,channels_num, segment_samples)
131
+
132
+ Outputs:
133
+ output: (batch_size, channels_num, time_steps, freq_bins)
134
+ """
135
+ sp_list = []
136
+ channels_num = input.shape[1]
137
+ for channel in range(channels_num):
138
+ sp_list.append(self.spectrogram(input[:, channel, :], eps=eps))
139
+ output = torch.cat(sp_list, dim=1)
140
+ return output
141
+
142
+ def spectrogram_to_wav(self, input, spectrogram, length=None):
143
+ """Spectrogram to waveform.
144
+ Args:
145
+ input: (batch_size, segment_samples, channels_num)
146
+ spectrogram: (batch_size, channels_num, time_steps, freq_bins)
147
+
148
+ Outputs:
149
+ output: (batch_size, segment_samples, channels_num)
150
+ """
151
+ channels_num = input.shape[1]
152
+ wav_list = []
153
+ for channel in range(channels_num):
154
+ (real, imag) = self.stft(input[:, channel, :])
155
+ (_, cos, sin) = magphase(real, imag)
156
+ wav_list.append(
157
+ self.istft(
158
+ spectrogram[:, channel : channel + 1, :, :] * cos,
159
+ spectrogram[:, channel : channel + 1, :, :] * sin,
160
+ length,
161
+ )
162
+ )
163
+
164
+ output = torch.stack(wav_list, dim=1)
165
+ return output
166
+
167
+ # todo the following code is not bug free!
168
+ def wav_to_complex_spectrogram(self, input, eps=0.0):
169
+ # [batchsize , channels, samples]
170
+ # [batchsize, 2[real,imag]*channels, t-steps, f-bins]
171
+ res = []
172
+ channels_num = input.shape[1]
173
+ for channel in range(channels_num):
174
+ res.append(self.complex_spectrogram(input[:, channel, :], eps=eps))
175
+ return torch.cat(res, dim=1)
176
+
177
+ def complex_spectrogram_to_wav(self, input, eps=0.0, length=None):
178
+ # [batchsize, 2[real,imag]*channels, t-steps, f-bins]
179
+ # return [batchsize, channels, samples]
180
+ channels = input.size()[1] // 2
181
+ wavs = []
182
+ for i in range(channels):
183
+ wavs.append(
184
+ self.reverse_complex_spectrogram(
185
+ input[:, 2 * i : 2 * i + 2, ...], eps=eps, length=length
186
+ )
187
+ )
188
+ wavs[-1] = wavs[-1].unsqueeze(1)
189
+ return torch.cat(wavs, dim=1)
190
+
191
+ def wav_to_complex_subband_spectrogram(self, input, eps=0.0):
192
+ # [batchsize, channels, samples]
193
+ # [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins]
194
+ subwav = self.qmf.analysis(input) # [batchsize, subband*channels, samples]
195
+ subspec = self.wav_to_complex_spectrogram(subwav)
196
+ return subspec
197
+
198
+ def complex_subband_spectrogram_to_wav(self, input, eps=0.0):
199
+ # [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins]
200
+ # [batchsize, channels, samples]
201
+ subwav = self.complex_spectrogram_to_wav(input)
202
+ data = self.qmf.synthesis(subwav)
203
+ return data
204
+
205
+ def wav_to_mag_phase_subband_spectrogram(self, input, eps=1e-8):
206
+ """
207
+ :param input:
208
+ :param eps:
209
+ :return:
210
+ loss = torch.nn.L1Loss()
211
+ models = FDomainHelper(subband=4)
212
+ data = torch.randn((3,1, 44100*3))
213
+
214
+ sps, coss, sins = models.wav_to_mag_phase_subband_spectrogram(data)
215
+ wav = models.mag_phase_subband_spectrogram_to_wav(sps,coss,sins,44100*3//4)
216
+
217
+ print(loss(data,wav))
218
+ print(torch.max(torch.abs(data-wav)))
219
+
220
+ """
221
+ # [batchsize, channels, samples]
222
+ # [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins]
223
+ subwav = self.qmf.analysis(input) # [batchsize, subband*channels, samples]
224
+ sps, coss, sins = self.wav_to_spectrogram_phase(subwav, eps=eps)
225
+ return sps, coss, sins
226
+
227
+ def mag_phase_subband_spectrogram_to_wav(self, sps, coss, sins, length, eps=0.0):
228
+ # [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins]
229
+ # [batchsize, channels, samples]
230
+ subwav = self.spectrogram_phase_to_wav(
231
+ sps, coss, sins, length + self.qmf.pad_samples // self.qmf.N
232
+ )
233
+ data = self.qmf.synthesis(subwav)
234
+ return data
voicefixer/tools/modules/filters/f_2_64.mat ADDED
File without changes
voicefixer/tools/modules/filters/f_4_64.mat ADDED
File without changes
voicefixer/tools/modules/filters/f_8_64.mat ADDED
File without changes
voicefixer/tools/modules/filters/h_2_64.mat ADDED
File without changes
voicefixer/tools/modules/filters/h_4_64.mat ADDED
File without changes
voicefixer/tools/modules/filters/h_8_64.mat ADDED
File without changes
voicefixer/tools/modules/pqmf.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @File : subband_util.py
3
+ @Contact : [email protected]
4
+ @License : (C)Copyright 2020-2021
5
+ @Modify Time @Author @Version @Desciption
6
+ ------------ ------- -------- -----------
7
+ 2020/4/3 4:54 PM Haohe Liu 1.0 None
8
+ """
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import torch.nn as nn
13
+ import numpy as np
14
+ import os.path as op
15
+ from scipy.io import loadmat
16
+
17
+
18
+ def load_mat2numpy(fname=""):
19
+ if len(fname) == 0:
20
+ return None
21
+ else:
22
+ return loadmat(fname)
23
+
24
+
25
+ class PQMF(nn.Module):
26
+ def __init__(self, N, M, project_root):
27
+ super().__init__()
28
+ self.N = N # nsubband
29
+ self.M = M # nfilter
30
+ try:
31
+ assert (N, M) in [(8, 64), (4, 64), (2, 64)]
32
+ except:
33
+ print("Warning:", N, "subbandand ", M, " filter is not supported")
34
+ self.pad_samples = 64
35
+ self.name = str(N) + "_" + str(M) + ".mat"
36
+ self.ana_conv_filter = nn.Conv1d(
37
+ 1, out_channels=N, kernel_size=M, stride=N, bias=False
38
+ )
39
+ data = load_mat2numpy(
40
+ op.join(
41
+ project_root,
42
+ "arnold_workspace/restorer/tools/pytorch/modules/filters/f_"
43
+ + self.name,
44
+ )
45
+ )
46
+ data = data["f"].astype(np.float32) / N
47
+ data = np.flipud(data.T).T
48
+ data = np.reshape(data, (N, 1, M)).copy()
49
+ dict_new = self.ana_conv_filter.state_dict().copy()
50
+ dict_new["weight"] = torch.from_numpy(data)
51
+ self.ana_pad = nn.ConstantPad1d((M - N, 0), 0)
52
+ self.ana_conv_filter.load_state_dict(dict_new)
53
+
54
+ self.syn_pad = nn.ConstantPad1d((0, M // N - 1), 0)
55
+ self.syn_conv_filter = nn.Conv1d(
56
+ N, out_channels=N, kernel_size=M // N, stride=1, bias=False
57
+ )
58
+ gk = load_mat2numpy(
59
+ op.join(
60
+ project_root,
61
+ "arnold_workspace/restorer/tools/pytorch/modules/filters/h_"
62
+ + self.name,
63
+ )
64
+ )
65
+ gk = gk["h"].astype(np.float32)
66
+ gk = np.transpose(np.reshape(gk, (N, M // N, N)), (1, 0, 2)) * N
67
+ gk = np.transpose(gk[::-1, :, :], (2, 1, 0)).copy()
68
+ dict_new = self.syn_conv_filter.state_dict().copy()
69
+ dict_new["weight"] = torch.from_numpy(gk)
70
+ self.syn_conv_filter.load_state_dict(dict_new)
71
+
72
+ for param in self.parameters():
73
+ param.requires_grad = False
74
+
75
+ def __analysis_channel(self, inputs):
76
+ return self.ana_conv_filter(self.ana_pad(inputs))
77
+
78
+ def __systhesis_channel(self, inputs):
79
+ ret = self.syn_conv_filter(self.syn_pad(inputs)).permute(0, 2, 1)
80
+ return torch.reshape(ret, (ret.shape[0], 1, -1))
81
+
82
+ def analysis(self, inputs):
83
+ """
84
+ :param inputs: [batchsize,channel,raw_wav],value:[0,1]
85
+ :return:
86
+ """
87
+ inputs = F.pad(inputs, ((0, self.pad_samples)))
88
+ ret = None
89
+ for i in range(inputs.size()[1]): # channels
90
+ if ret is None:
91
+ ret = self.__analysis_channel(inputs[:, i : i + 1, :])
92
+ else:
93
+ ret = torch.cat(
94
+ (ret, self.__analysis_channel(inputs[:, i : i + 1, :])), dim=1
95
+ )
96
+ return ret
97
+
98
+ def synthesis(self, data):
99
+ """
100
+ :param data: [batchsize,self.N*K,raw_wav_sub],value:[0,1]
101
+ :return:
102
+ """
103
+ ret = None
104
+ # data = F.pad(data,((0,self.pad_samples//self.N)))
105
+ for i in range(data.size()[1]): # channels
106
+ if i % self.N == 0:
107
+ if ret is None:
108
+ ret = self.__systhesis_channel(data[:, i : i + self.N, :])
109
+ else:
110
+ new = self.__systhesis_channel(data[:, i : i + self.N, :])
111
+ ret = torch.cat((ret, new), dim=1)
112
+ ret = ret[..., : -self.pad_samples]
113
+ return ret
114
+
115
+ def forward(self, inputs):
116
+ return self.ana_conv_filter(self.ana_pad(inputs))
voicefixer/tools/path.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ def find_and_build(root, path):
5
+ path = os.path.join(root, path)
6
+ if not os.path.exists(path):
7
+ os.makedirs(path, exist_ok=True)
8
+ return path
9
+
10
+
11
+ def root_path(repo_name="voicefixer"):
12
+ path = os.path.abspath(__file__)
13
+ return path.split(repo_name)[0]
voicefixer/tools/pytorch_util.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+
6
+ def check_cuda_availability(cuda):
7
+ if cuda and not torch.cuda.is_available():
8
+ raise RuntimeError("Error: You set cuda=True but no cuda device found.")
9
+
10
+
11
+ def try_tensor_cuda(tensor, cuda):
12
+ if cuda and torch.cuda.is_available():
13
+ return tensor.cuda()
14
+ else:
15
+ return tensor.cpu()
16
+
17
+
18
+ def to_log(input):
19
+ assert torch.sum(input < 0) == 0, (
20
+ str(input) + " has negative values counts " + str(torch.sum(input < 0))
21
+ )
22
+ return torch.log10(torch.clip(input, min=1e-8))
23
+
24
+
25
+ def from_log(input):
26
+ input = torch.clip(input, min=-np.inf, max=5)
27
+ return 10**input
28
+
29
+
30
+ def move_data_to_device(x, device):
31
+ if "float" in str(x.dtype):
32
+ x = torch.Tensor(x)
33
+ elif "int" in str(x.dtype):
34
+ x = torch.LongTensor(x)
35
+ else:
36
+ return x
37
+ return x.to(device)
38
+
39
+
40
+ def tensor2numpy(tensor):
41
+ if "cuda" in str(tensor.device):
42
+ return tensor.detach().cpu().numpy()
43
+ else:
44
+ return tensor.detach().numpy()
45
+
46
+
47
+ def count_parameters(model):
48
+ for p in model.parameters():
49
+ if p.requires_grad:
50
+ print(p.shape)
51
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
52
+
53
+
54
+ def count_flops(model, audio_length):
55
+ multiply_adds = False
56
+ list_conv2d = []
57
+
58
+ def conv2d_hook(self, input, output):
59
+ batch_size, input_channels, input_height, input_width = input[0].size()
60
+ output_channels, output_height, output_width = output[0].size()
61
+
62
+ kernel_ops = (
63
+ self.kernel_size[0]
64
+ * self.kernel_size[1]
65
+ * (self.in_channels / self.groups)
66
+ * (2 if multiply_adds else 1)
67
+ )
68
+ bias_ops = 1 if self.bias is not None else 0
69
+
70
+ params = output_channels * (kernel_ops + bias_ops)
71
+ flops = batch_size * params * output_height * output_width
72
+
73
+ list_conv2d.append(flops)
74
+
75
+ list_conv1d = []
76
+
77
+ def conv1d_hook(self, input, output):
78
+ batch_size, input_channels, input_length = input[0].size()
79
+ output_channels, output_length = output[0].size()
80
+
81
+ kernel_ops = (
82
+ self.kernel_size[0]
83
+ * (self.in_channels / self.groups)
84
+ * (2 if multiply_adds else 1)
85
+ )
86
+ bias_ops = 1 if self.bias is not None else 0
87
+
88
+ params = output_channels * (kernel_ops + bias_ops)
89
+ flops = batch_size * params * output_length
90
+
91
+ list_conv1d.append(flops)
92
+
93
+ list_linear = []
94
+
95
+ def linear_hook(self, input, output):
96
+ batch_size = input[0].size(0) if input[0].dim() == 2 else 1
97
+
98
+ weight_ops = self.weight.nelement() * (2 if multiply_adds else 1)
99
+ bias_ops = self.bias.nelement()
100
+
101
+ flops = batch_size * (weight_ops + bias_ops)
102
+ list_linear.append(flops)
103
+
104
+ list_bn = []
105
+
106
+ def bn_hook(self, input, output):
107
+ list_bn.append(input[0].nelement())
108
+
109
+ list_relu = []
110
+
111
+ def relu_hook(self, input, output):
112
+ list_relu.append(input[0].nelement())
113
+
114
+ list_pooling2d = []
115
+
116
+ def pooling2d_hook(self, input, output):
117
+ batch_size, input_channels, input_height, input_width = input[0].size()
118
+ output_channels, output_height, output_width = output[0].size()
119
+
120
+ kernel_ops = self.kernel_size * self.kernel_size
121
+ bias_ops = 0
122
+ params = output_channels * (kernel_ops + bias_ops)
123
+ flops = batch_size * params * output_height * output_width
124
+
125
+ list_pooling2d.append(flops)
126
+
127
+ list_pooling1d = []
128
+
129
+ def pooling1d_hook(self, input, output):
130
+ batch_size, input_channels, input_length = input[0].size()
131
+ output_channels, output_length = output[0].size()
132
+
133
+ kernel_ops = self.kernel_size
134
+ bias_ops = 0
135
+ params = output_channels * (kernel_ops + bias_ops)
136
+ flops = batch_size * params * output_length
137
+
138
+ list_pooling2d.append(flops)
139
+
140
+ def foo(net):
141
+ childrens = list(net.children())
142
+ if not childrens:
143
+ if isinstance(net, nn.Conv2d):
144
+ net.register_forward_hook(conv2d_hook)
145
+ elif isinstance(net, nn.ConvTranspose2d):
146
+ net.register_forward_hook(conv2d_hook)
147
+ elif isinstance(net, nn.Conv1d):
148
+ net.register_forward_hook(conv1d_hook)
149
+ elif isinstance(net, nn.Linear):
150
+ net.register_forward_hook(linear_hook)
151
+ elif isinstance(net, nn.BatchNorm2d) or isinstance(net, nn.BatchNorm1d):
152
+ net.register_forward_hook(bn_hook)
153
+ elif isinstance(net, nn.ReLU):
154
+ net.register_forward_hook(relu_hook)
155
+ elif isinstance(net, nn.AvgPool2d) or isinstance(net, nn.MaxPool2d):
156
+ net.register_forward_hook(pooling2d_hook)
157
+ elif isinstance(net, nn.AvgPool1d) or isinstance(net, nn.MaxPool1d):
158
+ net.register_forward_hook(pooling1d_hook)
159
+ else:
160
+ print("Warning: flop of module {} is not counted!".format(net))
161
+ return
162
+ for c in childrens:
163
+ foo(c)
164
+
165
+ foo(model)
166
+
167
+ input = torch.rand(1, audio_length, 2)
168
+ out = model(input)
169
+
170
+ total_flops = (
171
+ sum(list_conv2d)
172
+ + sum(list_conv1d)
173
+ + sum(list_linear)
174
+ + sum(list_bn)
175
+ + sum(list_relu)
176
+ + sum(list_pooling2d)
177
+ + sum(list_pooling1d)
178
+ )
179
+
180
+ return total_flops
voicefixer/tools/random_.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+
4
+ RANDOM_RESOLUTION = 2**31
5
+
6
+
7
+ def random_torch(high, to_int=True):
8
+ if to_int:
9
+ return int((torch.rand(1)) * high) # do not use numpy.random.random
10
+ else:
11
+ return (torch.rand(1)) * high # do not use numpy.random.random
12
+
13
+
14
+ def shuffle_torch(list):
15
+ length = len(list)
16
+ res = []
17
+ order = torch.randperm(length)
18
+ for each in order:
19
+ res.append(list[each])
20
+ assert len(list) == len(res)
21
+ return res
22
+
23
+
24
+ def random_choose_list(list):
25
+ num = int(uniform_torch(0, len(list)))
26
+ return list[num]
27
+
28
+
29
+ def normal_torch(mean=0, segma=1):
30
+ return float(torch.normal(mean=mean, std=torch.Tensor([segma]))[0])
31
+
32
+
33
+ def uniform_torch(lower, upper):
34
+ if abs(lower - upper) < 1e-5:
35
+ return upper
36
+ return (upper - lower) * torch.rand(1) + lower
37
+
38
+
39
+ def random_key(keys: list, weights: list):
40
+ return random.choices(keys, weights=weights)[0]
41
+
42
+
43
+ def random_select(probs):
44
+ res = []
45
+ chance = random_torch(RANDOM_RESOLUTION)
46
+ threshold = None
47
+ for prob in probs:
48
+ # if(threshold is None):threshold=prob
49
+ # else:threshold*=prob
50
+ threshold = prob
51
+ res.append(chance < threshold * RANDOM_RESOLUTION)
52
+ return res, chance
voicefixer/tools/wav.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import wave
2
+ import os
3
+ import numpy as np
4
+ import scipy.signal as signal
5
+ import soundfile as sf
6
+ import librosa
7
+
8
+
9
+ def save_wave(frames: np.ndarray, fname, sample_rate=44100):
10
+ shape = list(frames.shape)
11
+ if len(shape) == 1:
12
+ frames = frames[..., None]
13
+ in_samples, in_channels = shape[-2], shape[-1]
14
+ if in_channels >= 3:
15
+ if len(shape) == 2:
16
+ frames = np.transpose(frames, (1, 0))
17
+ elif len(shape) == 3:
18
+ frames = np.transpose(frames, (0, 2, 1))
19
+ msg = (
20
+ "Warning: Save audio with "
21
+ + str(in_channels)
22
+ + " channels, save permute audio with shape "
23
+ + str(list(frames.shape))
24
+ + " please check if it's correct."
25
+ )
26
+ # print(msg)
27
+ if (
28
+ np.max(frames) <= 1
29
+ and frames.dtype == np.float32
30
+ or frames.dtype == np.float16
31
+ or frames.dtype == np.float64
32
+ ):
33
+ frames *= 2**15
34
+ frames = frames.astype(np.short)
35
+ if len(frames.shape) >= 3:
36
+ frames = frames[0, ...]
37
+ sf.write(fname, frames, samplerate=sample_rate)
38
+
39
+
40
+ def constrain_length(chunk, length):
41
+ frames_length = chunk.shape[0]
42
+ if frames_length == length:
43
+ return chunk
44
+ elif frames_length < length:
45
+ return np.pad(chunk, ((0, int(length - frames_length)), (0, 0)), "constant")
46
+ else:
47
+ return chunk[:length, ...]
48
+
49
+
50
+ def random_chunk_wav_file(fname, chunk_length):
51
+ """
52
+ fname: path to wav file
53
+ chunk_length: frame length in seconds
54
+ """
55
+ with wave.open(fname) as f:
56
+ params = f.getparams()
57
+ duration = params[3] / params[2]
58
+ sample_rate = params[2]
59
+ sample_length = params[3]
60
+ if duration < chunk_length or abs(duration - chunk_length) < 1e-4:
61
+ frames = read_wave(fname, sample_rate)
62
+ return frames, duration, sample_rate # [-1,1]
63
+ else:
64
+ # Random trunk
65
+ random_starts = np.random.randint(
66
+ 0, sample_length - sample_rate * chunk_length
67
+ )
68
+ random_end = random_starts + sample_rate * chunk_length
69
+ random_starts, random_end = (
70
+ random_starts / sample_rate,
71
+ random_end / sample_rate,
72
+ )
73
+ random_starts, random_end = random_starts / duration, random_end / duration
74
+ frames = read_wave(
75
+ fname, sample_rate, portion_start=random_starts, portion_end=random_end
76
+ )
77
+ frames = constrain_length(frames, length=int(chunk_length * sample_rate))
78
+ return frames, chunk_length, sample_rate
79
+
80
+
81
+ def random_chunk_wav_file_v2(fname, chunk_length, random_starts=None, random_end=None):
82
+ """
83
+ fname: path to wav file
84
+ chunk_length: frame length in seconds
85
+ """
86
+ with wave.open(fname) as f:
87
+ params = f.getparams()
88
+ duration = params[3] / params[2]
89
+ sample_rate = params[2]
90
+ sample_length = params[3]
91
+ if duration < chunk_length or abs(duration - chunk_length) < 1e-4:
92
+ frames = read_wave(fname, sample_rate)
93
+ return frames, duration, sample_rate # [-1,1]
94
+ else:
95
+ # Random trunk
96
+ if random_starts is None and random_end is None:
97
+ random_starts = np.random.randint(
98
+ 0, sample_length - sample_rate * chunk_length
99
+ )
100
+ random_end = random_starts + sample_rate * chunk_length
101
+ random_starts, random_end = (
102
+ random_starts / sample_rate,
103
+ random_end / sample_rate,
104
+ )
105
+ random_starts, random_end = (
106
+ random_starts / duration,
107
+ random_end / duration,
108
+ )
109
+ frames = read_wave(
110
+ fname, sample_rate, portion_start=random_starts, portion_end=random_end
111
+ )
112
+ frames = constrain_length(frames, length=int(chunk_length * sample_rate))
113
+ return frames, chunk_length, sample_rate, random_starts, random_end
114
+
115
+
116
+ def read_wave(
117
+ fname,
118
+ sample_rate,
119
+ portion_start=0,
120
+ portion_end=1,
121
+ ): # Whether you want raw bytes
122
+ """
123
+ :param fname: wav file path
124
+ :param sample_rate:
125
+ :param portion_start:
126
+ :param portion_end:
127
+ :return: [sample, channels]
128
+ """
129
+ # sr = get_sample_rate(fname)
130
+ # if(sr != sample_rate):
131
+ # print("Warning: Sample rate not match, may lead to unexpected behavior.")
132
+ if portion_end > 1 and portion_end < 1.1:
133
+ portion_end = 1
134
+ if portion_end != 1:
135
+ duration = get_duration(fname)
136
+ wav, _ = librosa.load(
137
+ fname,
138
+ sr=sample_rate,
139
+ offset=portion_start * duration,
140
+ duration=(portion_end - portion_start) * duration,
141
+ mono=False,
142
+ )
143
+ else:
144
+ wav, _ = librosa.load(fname, sr=sample_rate, mono=False)
145
+ if len(list(wav.shape)) == 1:
146
+ wav = wav[..., None]
147
+ else:
148
+ wav = np.transpose(wav, (1, 0))
149
+ return wav
150
+
151
+
152
+ def get_channels_sampwidth_and_sample_rate(fname):
153
+ with wave.open(fname) as f:
154
+ params = f.getparams()
155
+ return (
156
+ params[0],
157
+ params[1],
158
+ params[2],
159
+ ) # == (2,2,44100),(params[0],params[1],params[2])
160
+
161
+
162
+ def get_channels(fname):
163
+ with wave.open(fname) as f:
164
+ params = f.getparams()
165
+ return params[0]
166
+
167
+
168
+ def get_sample_rate(fname):
169
+ with wave.open(fname) as f:
170
+ params = f.getparams()
171
+ return params[2]
172
+
173
+
174
+ def get_duration(fname):
175
+ with wave.open(fname) as f:
176
+ params = f.getparams()
177
+ return params[3] / params[2]
178
+
179
+
180
+ def get_framesLength(fname):
181
+ with wave.open(fname) as f:
182
+ params = f.getparams()
183
+ return params[3]
184
+
185
+
186
+ def restore_wave(zxx):
187
+ _, w = signal.istft(zxx)
188
+ return w
189
+
190
+
191
+ def calculate_total_times(dir):
192
+ total = 0
193
+ for each in os.listdir(dir):
194
+ fname = os.path.join(dir, each)
195
+ try:
196
+ duration = get_duration(fname)
197
+ except:
198
+ print(fname)
199
+ total += duration
200
+ return total
201
+
202
+
203
+ def filter(pth):
204
+ global dic
205
+ temp = []
206
+ for each in os.listdir(pth):
207
+ temp.append(os.path.join(pth, each))
208
+ for each in temp:
209
+ sr = get_sample_rate(each)
210
+ if sr not in dic.keys():
211
+ dic[sr] = []
212
+ dic[sr].append(each)
213
+ for each in dic[16000]:
214
+ # print(each)
215
+ pass
216
+ print(dic.keys())
217
+ for each in list(dic.keys()):
218
+ print(each, len(dic[each]))
219
+
220
+
221
+ if __name__ == "__main__":
222
+ path = "/Users/admin/Desktop/p376_025.wav"
223
+ stereo = "/Users/admin/Desktop/vocals.wav"
224
+ path_16 = "/Users/admin/Desktop/SI869.WAV.wav"
225
+ import time
226
+
227
+ start = time.time()
228
+ for i in range(1000):
229
+ frames, duration, sample_rate = random_chunk_wav_file(stereo, chunk_length=3.0)
230
+ print(frames.shape, np.max(frames))
231
+ save_wave(frames, "stero.wav", sample_rate=44100)
232
+ frames, duration, sample_rate = random_chunk_wav_file(path, chunk_length=3.0)
233
+ print(frames.shape, np.max(frames))
234
+ save_wave(frames, "mono.wav", sample_rate=44100)
235
+ frames, duration, sample_rate = random_chunk_wav_file(path_16, chunk_length=3.0)
236
+ print(frames.shape, np.max(frames))
237
+ save_wave(frames, "16.wav", sample_rate=16000)
238
+ print(time.time() - start)
239
+ # frames = read_wave(stereo,sample_rate=44100)
240
+ print(frames.shape)
241
+
242
+ print(frames)
voicefixer/vocoder/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+ """
4
+ @File : __init__.py.py
5
+ @Contact : [email protected]
6
+ @License : (C)Copyright 2020-2100
7
+
8
+ @Modify Time @Author @Version @Desciption
9
+ ------------ ------- -------- -----------
10
+ 9/14/21 1:00 AM Haohe Liu 1.0 None
11
+ """
12
+
13
+ import os
14
+ from voicefixer.vocoder.config import Config
15
+ import urllib.request
16
+
17
+ if not os.path.exists(Config.ckpt):
18
+ os.makedirs(os.path.dirname(Config.ckpt), exist_ok=True)
19
+ print("Downloading the weight of neural vocoder: TFGAN")
20
+ urllib.request.urlretrieve(
21
+ "https://zenodo.org/record/5469951/files/model.ckpt-1490000_trimed.pt?download=1",
22
+ Config.ckpt,
23
+ )
24
+ print(
25
+ "Weights downloaded in: {} Size: {}".format(
26
+ Config.ckpt, os.path.getsize(Config.ckpt)
27
+ )
28
+ )
29
+ # cmd = "wget https://zenodo.org/record/5469951/files/model.ckpt-1490000_trimed.pt?download=1 -O " + Config.ckpt
30
+ # os.system(cmd)
voicefixer/vocoder/base.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from voicefixer.vocoder.model.generator import Generator
2
+ from voicefixer.tools.wav import read_wave, save_wave
3
+ from voicefixer.tools.pytorch_util import *
4
+ from voicefixer.vocoder.model.util import *
5
+ from voicefixer.vocoder.config import Config
6
+ import os
7
+ import numpy as np
8
+
9
+
10
+ class Vocoder(nn.Module):
11
+ def __init__(self, sample_rate):
12
+ super(Vocoder, self).__init__()
13
+ Config.refresh(sample_rate)
14
+ self.rate = sample_rate
15
+ if(not os.path.exists(Config.ckpt)):
16
+ raise RuntimeError("Error 1: The checkpoint for synthesis module / vocoder (model.ckpt-1490000_trimed) is not found in ~/.cache/voicefixer/synthesis_module/44100. \
17
+ By default the checkpoint should be download automatically by this program. Something bad may happened. Apologies for the inconvenience.\
18
+ But don't worry! Alternatively you can download it directly from Zenodo: https://zenodo.org/record/5600188/files/model.ckpt-1490000_trimed.pt?download=1")
19
+ self._load_pretrain(Config.ckpt)
20
+ self.weight_torch = Config.get_mel_weight_torch(percent=1.0)[
21
+ None, None, None, ...
22
+ ]
23
+
24
+ def _load_pretrain(self, pth):
25
+ self.model = Generator(Config.cin_channels)
26
+ checkpoint = load_checkpoint(pth, torch.device("cpu"))
27
+ load_try(checkpoint["generator"], self.model)
28
+ self.model.eval()
29
+ self.model.remove_weight_norm()
30
+ self.model.remove_weight_norm()
31
+ for p in self.model.parameters():
32
+ p.requires_grad = False
33
+
34
+ # def vocoder_mel_npy(self, mel, save_dir, sample_rate, gain):
35
+ # mel = mel / Config.get_mel_weight(percent=gain)[...,None]
36
+ # mel = normalize(amp_to_db(np.abs(mel)) - 20)
37
+ # mel = pre(np.transpose(mel, (1, 0)))
38
+ # with torch.no_grad():
39
+ # wav_re = self.model(mel) # torch.Size([1, 1, 104076])
40
+ # save_wave(tensor2numpy(wav_re)*2**15,save_dir,sample_rate=sample_rate)
41
+
42
+ def forward(self, mel, cuda=False):
43
+ """
44
+ :param non normalized mel spectrogram: [batchsize, 1, t-steps, n_mel]
45
+ :return: [batchsize, 1, samples]
46
+ """
47
+ assert mel.size()[-1] == 128
48
+ check_cuda_availability(cuda=cuda)
49
+ self.model = try_tensor_cuda(self.model, cuda=cuda)
50
+ mel = try_tensor_cuda(mel, cuda=cuda)
51
+ self.weight_torch = self.weight_torch.type_as(mel)
52
+ mel = mel / self.weight_torch
53
+ mel = tr_normalize(tr_amp_to_db(torch.abs(mel)) - 20.0)
54
+ mel = tr_pre(mel[:, 0, ...])
55
+ wav_re = self.model(mel)
56
+ return wav_re
57
+
58
+ def oracle(self, fpath, out_path, cuda=False):
59
+ check_cuda_availability(cuda=cuda)
60
+ self.model = try_tensor_cuda(self.model, cuda=cuda)
61
+ wav = read_wave(fpath, sample_rate=self.rate)[..., 0]
62
+ wav = wav / np.max(np.abs(wav))
63
+ stft = np.abs(
64
+ librosa.stft(
65
+ wav,
66
+ hop_length=Config.hop_length,
67
+ win_length=Config.win_size,
68
+ n_fft=Config.n_fft,
69
+ )
70
+ )
71
+ mel = linear_to_mel(stft)
72
+ mel = normalize(amp_to_db(np.abs(mel)) - 20)
73
+ mel = pre(np.transpose(mel, (1, 0)))
74
+ mel = try_tensor_cuda(mel, cuda=cuda)
75
+ with torch.no_grad():
76
+ wav_re = self.model(mel)
77
+ save_wave(tensor2numpy(wav_re * 2**15), out_path, sample_rate=self.rate)
78
+
79
+
80
+ if __name__ == "__main__":
81
+ model = Vocoder(sample_rate=44100)
82
+ print(model.device)
83
+ # model.load_pretrain(Config.ckpt)
84
+ # model.oracle(path="/Users/liuhaohe/Desktop/test.wav",
85
+ # sample_rate=44100,
86
+ # save_dir="/Users/liuhaohe/Desktop/test_vocoder.wav")
voicefixer/vocoder/config.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import os
4
+ from voicefixer.tools.path import root_path
5
+
6
+
7
+ class Config:
8
+ @classmethod
9
+ def refresh(cls, sr):
10
+ if sr == 44100:
11
+ Config.ckpt = os.path.join(
12
+ os.path.expanduser("~"),
13
+ ".cache/voicefixer/synthesis_module/44100/model.ckpt-1490000_trimed.pt",
14
+ )
15
+ Config.cond_channels = 512
16
+ Config.m_channels = 768
17
+ Config.resstack_depth = [8, 8, 8, 8]
18
+ Config.channels = 1024
19
+ Config.cin_channels = 128
20
+ Config.upsample_scales = [7, 7, 3, 3]
21
+ Config.num_mels = 128
22
+ Config.n_fft = 2048
23
+ Config.hop_length = 441
24
+ Config.sample_rate = 44100
25
+ Config.fmax = 22000
26
+ Config.mel_win = 128
27
+ Config.local_condition_dim = 128
28
+ else:
29
+ raise RuntimeError(
30
+ "Error: Vocoder currently only support 44100 samplerate."
31
+ )
32
+
33
+ ckpt = os.path.join(
34
+ os.path.expanduser("~"),
35
+ ".cache/voicefixer/synthesis_module/44100/model.ckpt-1490000_trimed.pt",
36
+ )
37
+ m_channels = 384
38
+ bits = 10
39
+ opt = "Ralamb"
40
+ cond_channels = 256
41
+ clip = 0.5
42
+ num_bands = 1
43
+ cin_channels = 128
44
+ upsample_scales = [7, 7, 3, 3]
45
+ filterbands = "test/filterbanks_4bands.dat"
46
+ ##For inference
47
+ tag = ""
48
+ min_db = -115
49
+ num_mels = 128
50
+ n_fft = 2048
51
+ hop_length = 441
52
+ win_size = None
53
+ sample_rate = 44100
54
+ frame_shift_ms = None
55
+
56
+ trim_fft_size = 512
57
+ trim_hop_size = 128
58
+ trim_top_db = 23
59
+
60
+ signal_normalization = True
61
+ allow_clipping_in_normalization = True
62
+ symmetric_mels = True
63
+ max_abs_value = 4.0
64
+
65
+ preemphasis = 0.85
66
+ min_level_db = -100
67
+ ref_level_db = 20
68
+ fmin = 50
69
+ fmax = 22000
70
+ power = 1.5
71
+ griffin_lim_iters = 60
72
+ rescale = False
73
+ rescaling_max = 0.95
74
+ trim_silence = False
75
+ clip_mels_length = True
76
+ max_mel_frames = 2000
77
+
78
+ mel_win = 128
79
+ batch_size = 24
80
+ g_learning_rate = 0.001
81
+ d_learning_rate = 0.001
82
+ warmup_steps = 100000
83
+ decay_learning_rate = 0.5
84
+ exponential_moving_average = True
85
+ ema_decay = 0.99
86
+
87
+ reset_opt = False
88
+ reset_g_opt = False
89
+ reset_d_opt = False
90
+
91
+ local_condition_dim = 128
92
+ lambda_update_G = 1
93
+ multiscale_D = 3
94
+
95
+ lambda_adv = 4.0
96
+ lambda_fm_loss = 0.0
97
+ lambda_sc_loss = 5.0
98
+ lambda_mag_loss = 5.0
99
+ lambda_mel_loss = 50.0
100
+ use_mle_loss = False
101
+ lambda_mle_loss = 5.0
102
+
103
+ lambda_freq_loss = 2.0
104
+ lambda_energy_loss = 100.0
105
+ lambda_t_loss = 200.0
106
+ lambda_phase_loss = 100.0
107
+ lambda_f0_loss = 1.0
108
+ use_elu = False
109
+ de_preem = False # train
110
+ up_org = False
111
+ use_one = True
112
+ use_small_D = False
113
+ use_condnet = True
114
+ use_depreem = False # inference
115
+ use_msd = False
116
+ model_type = "tfgan" # or bytewave, frame level vocoder using istft
117
+ use_hjcud = False
118
+ no_skip = False
119
+ out_channels = 1
120
+ use_postnet = False # wn in postnet
121
+ use_wn = False # wn in resstack
122
+ up_type = "transpose"
123
+ use_smooth = False
124
+ use_drop = False
125
+ use_shift_scale = False
126
+ use_gcnn = False
127
+ resstack_depth = [6, 6, 6, 6]
128
+ kernel_size = [3, 3, 3, 3]
129
+ channels = 512
130
+ use_f0_loss = False
131
+ use_sine = False
132
+ use_cond_rnn = False
133
+ use_rnn = False
134
+
135
+ f0_step = 120
136
+ use_lowfreq_loss = False
137
+ lambda_lowfreq_loss = 1.0
138
+ use_film = False
139
+ use_mb_mr_gan = False
140
+
141
+ use_mssl = False
142
+ use_ml_gan = False
143
+ use_mb_gan = True
144
+ use_mpd = False
145
+ use_spec_gan = True
146
+ use_rwd = False
147
+ use_mr_gan = True
148
+ use_pqmf_rwd = False
149
+ no_sine = False
150
+ use_frame_mask = False
151
+
152
+ lambda_var_loss = 0.0
153
+ discriminator_train_start_steps = 40000 # 80k
154
+ aux_d_train_start_steps = 40000 # 100k
155
+ rescale_out = 0.40
156
+ use_dist = True
157
+ dist_backend = "nccl"
158
+ dist_url = "tcp://localhost:12345"
159
+ world_size = 1
160
+
161
+ mel_weight_torch = torch.tensor(
162
+ [
163
+ 19.40951426,
164
+ 19.94047336,
165
+ 20.4859038,
166
+ 21.04629067,
167
+ 21.62194148,
168
+ 22.21335214,
169
+ 22.8210215,
170
+ 23.44529231,
171
+ 24.08660962,
172
+ 24.74541882,
173
+ 25.42234287,
174
+ 26.11770576,
175
+ 26.83212784,
176
+ 27.56615283,
177
+ 28.32007747,
178
+ 29.0947679,
179
+ 29.89060111,
180
+ 30.70832636,
181
+ 31.54828121,
182
+ 32.41121487,
183
+ 33.29780773,
184
+ 34.20865341,
185
+ 35.14437675,
186
+ 36.1056621,
187
+ 37.09332763,
188
+ 38.10795802,
189
+ 39.15039691,
190
+ 40.22119881,
191
+ 41.32154931,
192
+ 42.45172373,
193
+ 43.61293329,
194
+ 44.80609379,
195
+ 46.031602,
196
+ 47.29070223,
197
+ 48.58427549,
198
+ 49.91327905,
199
+ 51.27863232,
200
+ 52.68119708,
201
+ 54.1222372,
202
+ 55.60274206,
203
+ 57.12364703,
204
+ 58.68617876,
205
+ 60.29148652,
206
+ 61.94081306,
207
+ 63.63501986,
208
+ 65.37562658,
209
+ 67.16408954,
210
+ 69.00109084,
211
+ 70.88850318,
212
+ 72.82736101,
213
+ 74.81985537,
214
+ 76.86654792,
215
+ 78.96885475,
216
+ 81.12900906,
217
+ 83.34840929,
218
+ 85.62810662,
219
+ 87.97005418,
220
+ 90.37689804,
221
+ 92.84887686,
222
+ 95.38872881,
223
+ 97.99777002,
224
+ 100.67862715,
225
+ 103.43232942,
226
+ 106.26140638,
227
+ 109.16827015,
228
+ 112.15470471,
229
+ 115.22184756,
230
+ 118.37439245,
231
+ 121.6122689,
232
+ 124.93877158,
233
+ 128.35661454,
234
+ 131.86761321,
235
+ 135.47417938,
236
+ 139.18059494,
237
+ 142.98713744,
238
+ 146.89771854,
239
+ 150.91684347,
240
+ 155.0446638,
241
+ 159.28614648,
242
+ 163.64270198,
243
+ 168.12035831,
244
+ 172.71749158,
245
+ 177.44220154,
246
+ 182.29556933,
247
+ 187.28286676,
248
+ 192.40502126,
249
+ 197.6682721,
250
+ 203.07516896,
251
+ 208.63088733,
252
+ 214.33770931,
253
+ 220.19910108,
254
+ 226.22363072,
255
+ 232.41087124,
256
+ 238.76803591,
257
+ 245.30079083,
258
+ 252.01064464,
259
+ 258.90261676,
260
+ 265.98474,
261
+ 273.26010248,
262
+ 280.73496362,
263
+ 288.41440094,
264
+ 296.30489752,
265
+ 304.41180337,
266
+ 312.7377183,
267
+ 321.28877878,
268
+ 330.07870237,
269
+ 339.10812951,
270
+ 348.38276173,
271
+ 357.91393924,
272
+ 367.70513992,
273
+ 377.76413924,
274
+ 388.09467408,
275
+ 398.70920178,
276
+ 409.61813793,
277
+ 420.81980127,
278
+ 432.33215467,
279
+ 444.16083117,
280
+ 456.30919947,
281
+ 468.78589276,
282
+ 481.61325588,
283
+ 494.78824596,
284
+ 508.31969844,
285
+ 522.2238331,
286
+ 536.51163441,
287
+ 551.18859414,
288
+ 566.26142988,
289
+ 581.75006061,
290
+ 597.66210737,
291
+ ]
292
+ )
293
+
294
+ x_orig = np.linspace(1, mel_weight_torch.shape[0], num=mel_weight_torch.shape[0])
295
+
296
+ x_orig_torch = torch.linspace(
297
+ 1, mel_weight_torch.shape[0], steps=mel_weight_torch.shape[0]
298
+ )
299
+
300
+ @classmethod
301
+ def get_mel_weight(cls, percent=1, a=18.8927416350036, b=0.0269863588184314):
302
+ b = percent * b
303
+
304
+ def func(a, b, x):
305
+ return a * np.exp(b * x)
306
+
307
+ return func(a, b, Config.x_orig)
308
+
309
+ @classmethod
310
+ def get_mel_weight_torch(cls, percent=1, a=18.8927416350036, b=0.0269863588184314):
311
+ b = percent * b
312
+
313
+ def func(a, b, x):
314
+ return a * torch.exp(b * x)
315
+
316
+ return func(a, b, Config.x_orig_torch)
voicefixer/vocoder/model/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+ """
4
+ @File : __init__.py.py
5
+ @Contact : [email protected]
6
+ @License : (C)Copyright 2020-2100
7
+
8
+ @Modify Time @Author @Version @Desciption
9
+ ------------ ------- -------- -----------
10
+ 9/14/21 1:00 AM Haohe Liu 1.0 None
11
+ """
voicefixer/vocoder/model/generator.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from voicefixer.vocoder.model.modules import UpsampleNet, ResStack
5
+ from voicefixer.vocoder.config import Config
6
+ from voicefixer.vocoder.model.pqmf import PQMF
7
+ import os
8
+
9
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
10
+
11
+
12
+ class Generator(nn.Module):
13
+ def __init__(
14
+ self,
15
+ in_channels=128,
16
+ use_elu=False,
17
+ use_gcnn=False,
18
+ up_org=False,
19
+ group=1,
20
+ hp=None,
21
+ ):
22
+ super(Generator, self).__init__()
23
+ self.hp = hp
24
+ channels = Config.channels
25
+ self.upsample_scales = Config.upsample_scales
26
+ self.use_condnet = Config.use_condnet
27
+ self.out_channels = Config.out_channels
28
+ self.resstack_depth = Config.resstack_depth
29
+ self.use_postnet = Config.use_postnet
30
+ self.use_cond_rnn = Config.use_cond_rnn
31
+ if self.use_condnet:
32
+ cond_channels = Config.cond_channels
33
+ self.condnet = nn.Sequential(
34
+ nn.utils.weight_norm(
35
+ nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
36
+ ),
37
+ nn.ELU(),
38
+ nn.utils.weight_norm(
39
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
40
+ ),
41
+ nn.ELU(),
42
+ nn.utils.weight_norm(
43
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
44
+ ),
45
+ nn.ELU(),
46
+ nn.utils.weight_norm(
47
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
48
+ ),
49
+ nn.ELU(),
50
+ nn.utils.weight_norm(
51
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
52
+ ),
53
+ nn.ELU(),
54
+ )
55
+ in_channels = cond_channels
56
+ if self.use_cond_rnn:
57
+ self.rnn = nn.GRU(
58
+ cond_channels,
59
+ cond_channels // 2,
60
+ num_layers=1,
61
+ batch_first=True,
62
+ bidirectional=True,
63
+ )
64
+
65
+ if use_elu:
66
+ act = nn.ELU()
67
+ else:
68
+ act = nn.LeakyReLU(0.2, True)
69
+
70
+ kernel_size = Config.kernel_size
71
+
72
+ if self.out_channels == 1:
73
+ self.generator = nn.Sequential(
74
+ nn.ReflectionPad1d(3),
75
+ nn.utils.weight_norm(nn.Conv1d(in_channels, channels, kernel_size=7)),
76
+ act,
77
+ UpsampleNet(channels, channels // 2, self.upsample_scales[0], hp, 0),
78
+ ResStack(channels // 2, kernel_size[0], self.resstack_depth[0], hp),
79
+ act,
80
+ UpsampleNet(
81
+ channels // 2, channels // 4, self.upsample_scales[1], hp, 1
82
+ ),
83
+ ResStack(channels // 4, kernel_size[1], self.resstack_depth[1], hp),
84
+ act,
85
+ UpsampleNet(
86
+ channels // 4, channels // 8, self.upsample_scales[2], hp, 2
87
+ ),
88
+ ResStack(channels // 8, kernel_size[2], self.resstack_depth[2], hp),
89
+ act,
90
+ UpsampleNet(
91
+ channels // 8, channels // 16, self.upsample_scales[3], hp, 3
92
+ ),
93
+ ResStack(channels // 16, kernel_size[3], self.resstack_depth[3], hp),
94
+ act,
95
+ nn.ReflectionPad1d(3),
96
+ nn.utils.weight_norm(
97
+ nn.Conv1d(channels // 16, self.out_channels, kernel_size=7)
98
+ ),
99
+ nn.Tanh(),
100
+ )
101
+ else:
102
+ channels = Config.m_channels
103
+ self.generator = nn.Sequential(
104
+ nn.ReflectionPad1d(3),
105
+ nn.utils.weight_norm(nn.Conv1d(in_channels, channels, kernel_size=7)),
106
+ act,
107
+ UpsampleNet(channels, channels // 2, self.upsample_scales[0], hp),
108
+ ResStack(channels // 2, kernel_size[0], self.resstack_depth[0], hp),
109
+ act,
110
+ UpsampleNet(channels // 2, channels // 4, self.upsample_scales[1], hp),
111
+ ResStack(channels // 4, kernel_size[1], self.resstack_depth[1], hp),
112
+ act,
113
+ UpsampleNet(channels // 4, channels // 8, self.upsample_scales[3], hp),
114
+ ResStack(channels // 8, kernel_size[3], self.resstack_depth[2], hp),
115
+ act,
116
+ nn.ReflectionPad1d(3),
117
+ nn.utils.weight_norm(
118
+ nn.Conv1d(channels // 8, self.out_channels, kernel_size=7)
119
+ ),
120
+ nn.Tanh(),
121
+ )
122
+ if self.out_channels > 1:
123
+ self.pqmf = PQMF(4, 64)
124
+
125
+ self.num_params()
126
+
127
+ def forward(self, conditions, use_res=False, f0=None):
128
+ res = conditions
129
+ if self.use_condnet:
130
+ conditions = self.condnet(conditions)
131
+ if self.use_cond_rnn:
132
+ conditions, _ = self.rnn(conditions.transpose(1, 2))
133
+ conditions = conditions.transpose(1, 2)
134
+
135
+ wav = self.generator(conditions)
136
+ if self.out_channels > 1:
137
+ B = wav.size(0)
138
+ f_wav = (
139
+ self.pqmf.synthesis(wav)
140
+ .transpose(1, 2)
141
+ .reshape(B, 1, -1)
142
+ .clamp(-0.99, 0.99)
143
+ )
144
+ return f_wav, wav
145
+ return wav
146
+
147
+ def num_params(self):
148
+ parameters = filter(lambda p: p.requires_grad, self.parameters())
149
+ parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
150
+ return parameters
151
+ # print('Trainable Parameters: %.3f million' % parameters)
152
+
153
+ def remove_weight_norm(self):
154
+ def _remove_weight_norm(m):
155
+ try:
156
+ torch.nn.utils.remove_weight_norm(m)
157
+ except ValueError: # this module didn't have weight norm
158
+ return
159
+
160
+ self.apply(_remove_weight_norm)
161
+
162
+
163
+ if __name__ == "__main__":
164
+ model = Generator(128)
165
+ x = torch.randn(3, 128, 13)
166
+ print(x.shape)
167
+ y = model(x)
168
+ print(y.shape)
voicefixer/vocoder/model/modules.py ADDED
@@ -0,0 +1,947 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from voicefixer.vocoder.config import Config
7
+
8
+ # From xin wang of nii
9
+ class SineGen(torch.nn.Module):
10
+ """Definition of sine generator
11
+ SineGen(samp_rate, harmonic_num = 0,
12
+ sine_amp = 0.1, noise_std = 0.003,
13
+ voiced_threshold = 0,
14
+ flag_for_pulse=False)
15
+
16
+ samp_rate: sampling rate in Hz
17
+ harmonic_num: number of harmonic overtones (default 0)
18
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
19
+ noise_std: std of Gaussian noise (default 0.003)
20
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
21
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
22
+
23
+ Note: when flag_for_pulse is True, the first time step of a voiced
24
+ segment is always sin(np.pi) or cos(0)
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ samp_rate=24000,
30
+ harmonic_num=0,
31
+ sine_amp=0.1,
32
+ noise_std=0.003,
33
+ voiced_threshold=0,
34
+ flag_for_pulse=False,
35
+ ):
36
+ super(SineGen, self).__init__()
37
+ self.sine_amp = sine_amp
38
+ self.noise_std = noise_std
39
+ self.harmonic_num = harmonic_num
40
+ self.dim = self.harmonic_num + 1
41
+ self.sampling_rate = samp_rate
42
+ self.voiced_threshold = voiced_threshold
43
+ self.flag_for_pulse = flag_for_pulse
44
+
45
+ def _f02uv(self, f0):
46
+ # generate uv signal
47
+ uv = torch.ones_like(f0)
48
+ uv = uv * (f0 > self.voiced_threshold)
49
+ return uv
50
+
51
+ def _f02sine(self, f0_values):
52
+ """f0_values: (batchsize, length, dim)
53
+ where dim indicates fundamental tone and overtones
54
+ """
55
+ # convert to F0 in rad. The interger part n can be ignored
56
+ # because 2 * np.pi * n doesn't affect phase
57
+ rad_values = (f0_values / self.sampling_rate) % 1
58
+
59
+ # initial phase noise (no noise for fundamental component)
60
+ rand_ini = torch.rand(
61
+ f0_values.shape[0], f0_values.shape[2], device=f0_values.device
62
+ )
63
+ rand_ini[:, 0] = 0
64
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
65
+
66
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
67
+ if not self.flag_for_pulse:
68
+ # for normal case
69
+
70
+ # To prevent torch.cumsum numerical overflow,
71
+ # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
72
+ # Buffer tmp_over_one_idx indicates the time step to add -1.
73
+ # This will not change F0 of sine because (x-1) * 2*pi = x *2*pi
74
+ tmp_over_one = torch.cumsum(rad_values, 1) % 1
75
+ tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
76
+ cumsum_shift = torch.zeros_like(rad_values)
77
+ cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
78
+
79
+ sines = torch.sin(
80
+ torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi
81
+ )
82
+ else:
83
+ # If necessary, make sure that the first time step of every
84
+ # voiced segments is sin(pi) or cos(0)
85
+ # This is used for pulse-train generation
86
+
87
+ # identify the last time step in unvoiced segments
88
+ uv = self._f02uv(f0_values)
89
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
90
+ uv_1[:, -1, :] = 1
91
+ u_loc = (uv < 1) * (uv_1 > 0)
92
+
93
+ # get the instantanouse phase
94
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
95
+ # different batch needs to be processed differently
96
+ for idx in range(f0_values.shape[0]):
97
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
98
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
99
+ # stores the accumulation of i.phase within
100
+ # each voiced segments
101
+ tmp_cumsum[idx, :, :] = 0
102
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
103
+
104
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
105
+ # within the previous voiced segment.
106
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
107
+
108
+ # get the sines
109
+ sines = torch.cos(i_phase * 2 * np.pi)
110
+ return sines
111
+
112
+ def forward(self, f0):
113
+ """sine_tensor, uv = forward(f0)
114
+ input F0: tensor(batchsize=1, length, dim=1)
115
+ f0 for unvoiced steps should be 0
116
+ output sine_tensor: tensor(batchsize=1, length, dim)
117
+ output uv: tensor(batchsize=1, length, 1)
118
+ """
119
+
120
+ with torch.no_grad():
121
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
122
+ # fundamental component
123
+ f0_buf[:, :, 0] = f0[:, :, 0]
124
+ for idx in np.arange(self.harmonic_num):
125
+ # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
126
+ f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)
127
+
128
+ # generate sine waveforms
129
+ sine_waves = self._f02sine(f0_buf) * self.sine_amp
130
+
131
+ # generate uv signal
132
+ # uv = torch.ones(f0.shape)
133
+ # uv = uv * (f0 > self.voiced_threshold)
134
+ uv = self._f02uv(f0)
135
+
136
+ # noise: for unvoiced should be similar to sine_amp
137
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
138
+ # . for voiced regions is self.noise_std
139
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
140
+ noise = noise_amp * torch.randn_like(sine_waves)
141
+
142
+ # first: set the unvoiced part to 0 by uv
143
+ # then: additive noise
144
+ sine_waves = sine_waves * uv + noise
145
+ return sine_waves, uv, noise
146
+
147
+
148
+ class LowpassBlur(nn.Module):
149
+ """perform low pass filter after upsampling for anti-aliasing"""
150
+
151
+ def __init__(self, channels=128, filt_size=3, pad_type="reflect", pad_off=0):
152
+ super(LowpassBlur, self).__init__()
153
+ self.filt_size = filt_size
154
+ self.pad_off = pad_off
155
+ self.pad_sizes = [
156
+ int(1.0 * (filt_size - 1) / 2),
157
+ int(np.ceil(1.0 * (filt_size - 1) / 2)),
158
+ ]
159
+ self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
160
+ self.off = 0
161
+ self.channels = channels
162
+
163
+ if self.filt_size == 1:
164
+ a = np.array(
165
+ [
166
+ 1.0,
167
+ ]
168
+ )
169
+ elif self.filt_size == 2:
170
+ a = np.array([1.0, 1.0])
171
+ elif self.filt_size == 3:
172
+ a = np.array([1.0, 2.0, 1.0])
173
+ elif self.filt_size == 4:
174
+ a = np.array([1.0, 3.0, 3.0, 1.0])
175
+ elif self.filt_size == 5:
176
+ a = np.array([1.0, 4.0, 6.0, 4.0, 1.0])
177
+ elif self.filt_size == 6:
178
+ a = np.array([1.0, 5.0, 10.0, 10.0, 5.0, 1.0])
179
+ elif self.filt_size == 7:
180
+ a = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0])
181
+
182
+ filt = torch.Tensor(a)
183
+ filt = filt / torch.sum(filt)
184
+ self.register_buffer("filt", filt[None, None, :].repeat((self.channels, 1, 1)))
185
+
186
+ self.pad = get_pad_layer_1d(pad_type)(self.pad_sizes)
187
+
188
+ def forward(self, inp):
189
+ if self.filt_size == 1:
190
+ return inp
191
+ return F.conv1d(self.pad(inp), self.filt, groups=inp.shape[1])
192
+
193
+
194
+ def get_pad_layer_1d(pad_type):
195
+ if pad_type in ["refl", "reflect"]:
196
+ PadLayer = nn.ReflectionPad1d
197
+ elif pad_type in ["repl", "replicate"]:
198
+ PadLayer = nn.ReplicationPad1d
199
+ elif pad_type == "zero":
200
+ PadLayer = nn.ZeroPad1d
201
+ else:
202
+ print("Pad type [%s] not recognized" % pad_type)
203
+ return PadLayer
204
+
205
+
206
+ class MovingAverageSmooth(torch.nn.Conv1d):
207
+ def __init__(self, channels, window_len=3):
208
+ """Initialize Conv1d module."""
209
+ super(MovingAverageSmooth, self).__init__(
210
+ in_channels=channels,
211
+ out_channels=channels,
212
+ kernel_size=1,
213
+ groups=channels,
214
+ bias=False,
215
+ )
216
+
217
+ torch.nn.init.constant_(self.weight, 1.0 / window_len)
218
+ for p in self.parameters():
219
+ p.requires_grad = False
220
+
221
+ def forward(self, data):
222
+ return super(MovingAverageSmooth, self).forward(data)
223
+
224
+
225
+ class Conv1d(torch.nn.Conv1d):
226
+ """Conv1d module with customized initialization."""
227
+
228
+ def __init__(self, *args, **kwargs):
229
+ """Initialize Conv1d module."""
230
+ super(Conv1d, self).__init__(*args, **kwargs)
231
+
232
+ def reset_parameters(self):
233
+ """Reset parameters."""
234
+ torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu")
235
+ if self.bias is not None:
236
+ torch.nn.init.constant_(self.bias, 0.0)
237
+
238
+
239
+ class Stretch2d(torch.nn.Module):
240
+ """Stretch2d module."""
241
+
242
+ def __init__(self, x_scale, y_scale, mode="nearest"):
243
+ """Initialize Stretch2d module.
244
+ Args:
245
+ x_scale (int): X scaling factor (Time axis in spectrogram).
246
+ y_scale (int): Y scaling factor (Frequency axis in spectrogram).
247
+ mode (str): Interpolation mode.
248
+ """
249
+ super(Stretch2d, self).__init__()
250
+ self.x_scale = x_scale
251
+ self.y_scale = y_scale
252
+ self.mode = mode
253
+
254
+ def forward(self, x):
255
+ """Calculate forward propagation.
256
+ Args:
257
+ x (Tensor): Input tensor (B, C, F, T).
258
+ Returns:
259
+ Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale),
260
+ """
261
+ return F.interpolate(
262
+ x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode
263
+ )
264
+
265
+
266
+ class Conv2d(torch.nn.Conv2d):
267
+ """Conv2d module with customized initialization."""
268
+
269
+ def __init__(self, *args, **kwargs):
270
+ """Initialize Conv2d module."""
271
+ super(Conv2d, self).__init__(*args, **kwargs)
272
+
273
+ def reset_parameters(self):
274
+ """Reset parameters."""
275
+ self.weight.data.fill_(1.0 / np.prod(self.kernel_size))
276
+ if self.bias is not None:
277
+ torch.nn.init.constant_(self.bias, 0.0)
278
+
279
+
280
+ class UpsampleNetwork(torch.nn.Module):
281
+ """Upsampling network module."""
282
+
283
+ def __init__(
284
+ self,
285
+ upsample_scales,
286
+ nonlinear_activation=None,
287
+ nonlinear_activation_params={},
288
+ interpolate_mode="nearest",
289
+ freq_axis_kernel_size=1,
290
+ use_causal_conv=False,
291
+ ):
292
+ """Initialize upsampling network module.
293
+ Args:
294
+ upsample_scales (list): List of upsampling scales.
295
+ nonlinear_activation (str): Activation function name.
296
+ nonlinear_activation_params (dict): Arguments for specified activation function.
297
+ interpolate_mode (str): Interpolation mode.
298
+ freq_axis_kernel_size (int): Kernel size in the direction of frequency axis.
299
+ """
300
+ super(UpsampleNetwork, self).__init__()
301
+ self.use_causal_conv = use_causal_conv
302
+ self.up_layers = torch.nn.ModuleList()
303
+ for scale in upsample_scales:
304
+ # interpolation layer
305
+ stretch = Stretch2d(scale, 1, interpolate_mode)
306
+ self.up_layers += [stretch]
307
+
308
+ # conv layer
309
+ assert (
310
+ freq_axis_kernel_size - 1
311
+ ) % 2 == 0, "Not support even number freq axis kernel size."
312
+ freq_axis_padding = (freq_axis_kernel_size - 1) // 2
313
+ kernel_size = (freq_axis_kernel_size, scale * 2 + 1)
314
+ if use_causal_conv:
315
+ padding = (freq_axis_padding, scale * 2)
316
+ else:
317
+ padding = (freq_axis_padding, scale)
318
+ conv = Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False)
319
+ self.up_layers += [conv]
320
+
321
+ # nonlinear
322
+ if nonlinear_activation is not None:
323
+ nonlinear = getattr(torch.nn, nonlinear_activation)(
324
+ **nonlinear_activation_params
325
+ )
326
+ self.up_layers += [nonlinear]
327
+
328
+ def forward(self, c):
329
+ """Calculate forward propagation.
330
+ Args:
331
+ c : Input tensor (B, C, T).
332
+ Returns:
333
+ Tensor: Upsampled tensor (B, C, T'), where T' = T * prod(upsample_scales).
334
+ """
335
+ c = c.unsqueeze(1) # (B, 1, C, T)
336
+ for f in self.up_layers:
337
+ if self.use_causal_conv and isinstance(f, Conv2d):
338
+ c = f(c)[..., : c.size(-1)]
339
+ else:
340
+ c = f(c)
341
+ return c.squeeze(1) # (B, C, T')
342
+
343
+
344
+ class ConvInUpsampleNetwork(torch.nn.Module):
345
+ """Convolution + upsampling network module."""
346
+
347
+ def __init__(
348
+ self,
349
+ upsample_scales=[3, 4, 5, 5],
350
+ nonlinear_activation="ReLU",
351
+ nonlinear_activation_params={},
352
+ interpolate_mode="nearest",
353
+ freq_axis_kernel_size=1,
354
+ aux_channels=80,
355
+ aux_context_window=0,
356
+ use_causal_conv=False,
357
+ ):
358
+ """Initialize convolution + upsampling network module.
359
+ Args:
360
+ upsample_scales (list): List of upsampling scales.
361
+ nonlinear_activation (str): Activation function name.
362
+ nonlinear_activation_params (dict): Arguments for specified activation function.
363
+ mode (str): Interpolation mode.
364
+ freq_axis_kernel_size (int): Kernel size in the direction of frequency axis.
365
+ aux_channels (int): Number of channels of pre-convolutional layer.
366
+ aux_context_window (int): Context window size of the pre-convolutional layer.
367
+ use_causal_conv (bool): Whether to use causal structure.
368
+ """
369
+ super(ConvInUpsampleNetwork, self).__init__()
370
+ self.aux_context_window = aux_context_window
371
+ self.use_causal_conv = use_causal_conv and aux_context_window > 0
372
+ # To capture wide-context information in conditional features
373
+ kernel_size = (
374
+ aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1
375
+ )
376
+ # NOTE(kan-bayashi): Here do not use padding because the input is already padded
377
+ self.conv_in = Conv1d(
378
+ aux_channels, aux_channels, kernel_size=kernel_size, bias=False
379
+ )
380
+ self.upsample = UpsampleNetwork(
381
+ upsample_scales=upsample_scales,
382
+ nonlinear_activation=nonlinear_activation,
383
+ nonlinear_activation_params=nonlinear_activation_params,
384
+ interpolate_mode=interpolate_mode,
385
+ freq_axis_kernel_size=freq_axis_kernel_size,
386
+ use_causal_conv=use_causal_conv,
387
+ )
388
+
389
+ def forward(self, c):
390
+ """Calculate forward propagation.
391
+ Args:
392
+ c : Input tensor (B, C, T').
393
+ Returns:
394
+ Tensor: Upsampled tensor (B, C, T),
395
+ where T = (T' - aux_context_window * 2) * prod(upsample_scales).
396
+ Note:
397
+ The length of inputs considers the context window size.
398
+ """
399
+ c_ = self.conv_in(c)
400
+ c = c_[:, :, : -self.aux_context_window] if self.use_causal_conv else c_
401
+ return self.upsample(c)
402
+
403
+
404
+ class DownsampleNet(nn.Module):
405
+ def __init__(self, input_size, output_size, upsample_factor, hp=None, index=0):
406
+ super(DownsampleNet, self).__init__()
407
+ self.input_size = input_size
408
+ self.output_size = output_size
409
+ self.upsample_factor = upsample_factor
410
+ self.skip_conv = nn.Conv1d(input_size, output_size, kernel_size=1)
411
+ self.index = index
412
+ layer = nn.Conv1d(
413
+ input_size,
414
+ output_size,
415
+ kernel_size=upsample_factor * 2,
416
+ stride=upsample_factor,
417
+ padding=upsample_factor // 2 + upsample_factor % 2,
418
+ )
419
+
420
+ self.layer = nn.utils.weight_norm(layer)
421
+
422
+ def forward(self, inputs):
423
+ B, C, T = inputs.size()
424
+ res = inputs[:, :, :: self.upsample_factor]
425
+ skip = self.skip_conv(res)
426
+
427
+ outputs = self.layer(inputs)
428
+ outputs = outputs + skip
429
+
430
+ return outputs
431
+
432
+
433
+ class UpsampleNet(nn.Module):
434
+ def __init__(self, input_size, output_size, upsample_factor, hp=None, index=0):
435
+
436
+ super(UpsampleNet, self).__init__()
437
+ self.up_type = Config.up_type
438
+ self.use_smooth = Config.use_smooth
439
+ self.use_drop = Config.use_drop
440
+ self.input_size = input_size
441
+ self.output_size = output_size
442
+ self.upsample_factor = upsample_factor
443
+ self.skip_conv = nn.Conv1d(input_size, output_size, kernel_size=1)
444
+ self.index = index
445
+ if self.use_smooth:
446
+ window_lens = [5, 5, 4, 3]
447
+ self.window_len = window_lens[index]
448
+
449
+ if self.up_type != "pn" or self.index < 3:
450
+ # if self.up_type != "pn":
451
+ layer = nn.ConvTranspose1d(
452
+ input_size,
453
+ output_size,
454
+ upsample_factor * 2,
455
+ upsample_factor,
456
+ padding=upsample_factor // 2 + upsample_factor % 2,
457
+ output_padding=upsample_factor % 2,
458
+ )
459
+ self.layer = nn.utils.weight_norm(layer)
460
+ else:
461
+ self.layer = nn.Sequential(
462
+ nn.ReflectionPad1d(1),
463
+ nn.utils.weight_norm(
464
+ nn.Conv1d(input_size, output_size * upsample_factor, kernel_size=3)
465
+ ),
466
+ nn.LeakyReLU(),
467
+ nn.ReflectionPad1d(1),
468
+ nn.utils.weight_norm(
469
+ nn.Conv1d(
470
+ output_size * upsample_factor,
471
+ output_size * upsample_factor,
472
+ kernel_size=3,
473
+ )
474
+ ),
475
+ nn.LeakyReLU(),
476
+ nn.ReflectionPad1d(1),
477
+ nn.utils.weight_norm(
478
+ nn.Conv1d(
479
+ output_size * upsample_factor,
480
+ output_size * upsample_factor,
481
+ kernel_size=3,
482
+ )
483
+ ),
484
+ nn.LeakyReLU(),
485
+ )
486
+
487
+ if hp is not None:
488
+ self.org = Config.up_org
489
+ self.no_skip = Config.no_skip
490
+ else:
491
+ self.org = False
492
+ self.no_skip = True
493
+
494
+ if self.use_smooth:
495
+ self.mas = nn.Sequential(
496
+ # LowpassBlur(output_size, self.window_len),
497
+ MovingAverageSmooth(output_size, self.window_len),
498
+ # MovingAverageSmooth(output_size, self.window_len),
499
+ )
500
+
501
+ def forward(self, inputs):
502
+
503
+ if not self.org:
504
+ inputs = inputs + torch.sin(inputs)
505
+ B, C, T = inputs.size()
506
+ res = inputs.repeat(1, self.upsample_factor, 1).view(B, C, -1)
507
+ skip = self.skip_conv(res)
508
+ if self.up_type == "repeat":
509
+ return skip
510
+
511
+ outputs = self.layer(inputs)
512
+ if self.up_type == "pn" and self.index > 2:
513
+ B, c, l = outputs.size()
514
+ outputs = outputs.view(B, -1, l * self.upsample_factor)
515
+
516
+ if self.no_skip:
517
+ return outputs
518
+
519
+ if not self.org:
520
+ outputs = outputs + skip
521
+
522
+ if self.use_smooth:
523
+ outputs = self.mas(outputs)
524
+
525
+ if self.use_drop:
526
+ outputs = F.dropout(outputs, p=0.05)
527
+
528
+ return outputs
529
+
530
+
531
+ class ResStack(nn.Module):
532
+ def __init__(self, channel, kernel_size=3, resstack_depth=4, hp=None):
533
+ super(ResStack, self).__init__()
534
+
535
+ self.use_wn = Config.use_wn
536
+ self.use_shift_scale = Config.use_shift_scale
537
+ self.channel = channel
538
+
539
+ def get_padding(kernel_size, dilation=1):
540
+ return int((kernel_size * dilation - dilation) / 2)
541
+
542
+ if self.use_shift_scale:
543
+ self.scale_conv = nn.utils.weight_norm(
544
+ nn.Conv1d(
545
+ channel, 2 * channel, kernel_size=kernel_size, dilation=1, padding=1
546
+ )
547
+ )
548
+
549
+ if not self.use_wn:
550
+ self.layers = nn.ModuleList(
551
+ [
552
+ nn.Sequential(
553
+ nn.LeakyReLU(),
554
+ nn.utils.weight_norm(
555
+ nn.Conv1d(
556
+ channel,
557
+ channel,
558
+ kernel_size=kernel_size,
559
+ dilation=3 ** (i % 10),
560
+ padding=get_padding(kernel_size, 3 ** (i % 10)),
561
+ )
562
+ ),
563
+ nn.LeakyReLU(),
564
+ nn.utils.weight_norm(
565
+ nn.Conv1d(
566
+ channel,
567
+ channel,
568
+ kernel_size=kernel_size,
569
+ dilation=1,
570
+ padding=get_padding(kernel_size, 1),
571
+ )
572
+ ),
573
+ )
574
+ for i in range(resstack_depth)
575
+ ]
576
+ )
577
+ else:
578
+ self.wn = WaveNet(
579
+ in_channels=channel,
580
+ out_channels=channel,
581
+ cin_channels=-1,
582
+ num_layers=resstack_depth,
583
+ residual_channels=channel,
584
+ gate_channels=channel,
585
+ skip_channels=channel,
586
+ # kernel_size=5,
587
+ # dilation_rate=3,
588
+ causal=False,
589
+ use_downup=False,
590
+ )
591
+
592
+ def forward(self, x):
593
+ if not self.use_wn:
594
+ for layer in self.layers:
595
+ x = x + layer(x)
596
+ else:
597
+ x = self.wn(x)
598
+
599
+ if self.use_shift_scale:
600
+ m_s = self.scale_conv(x)
601
+ m_s = m_s[:, :, :-1]
602
+
603
+ m, s = torch.split(m_s, self.channel, dim=1)
604
+ s = F.softplus(s)
605
+
606
+ x = m + s * x[:, :, 1:] # key!!!
607
+ x = F.pad(x, pad=(1, 0), mode="constant", value=0)
608
+
609
+ return x
610
+
611
+
612
+ class WaveNet(nn.Module):
613
+ def __init__(
614
+ self,
615
+ in_channels=1,
616
+ out_channels=1,
617
+ num_layers=10,
618
+ residual_channels=64,
619
+ gate_channels=64,
620
+ skip_channels=64,
621
+ kernel_size=3,
622
+ dilation_rate=2,
623
+ cin_channels=80,
624
+ hp=None,
625
+ causal=False,
626
+ use_downup=False,
627
+ ):
628
+ super(WaveNet, self).__init__()
629
+
630
+ self.in_channels = in_channels
631
+ self.causal = causal
632
+ self.num_layers = num_layers
633
+ self.out_channels = out_channels
634
+ self.gate_channels = gate_channels
635
+ self.residual_channels = residual_channels
636
+ self.skip_channels = skip_channels
637
+ self.cin_channels = cin_channels
638
+ self.kernel_size = kernel_size
639
+ self.use_downup = use_downup
640
+
641
+ self.front_conv = nn.Sequential(
642
+ nn.Conv1d(
643
+ in_channels=self.in_channels,
644
+ out_channels=self.residual_channels,
645
+ kernel_size=3,
646
+ padding=1,
647
+ ),
648
+ nn.ReLU(),
649
+ )
650
+ if self.use_downup:
651
+ self.downup_conv = nn.Sequential(
652
+ nn.Conv1d(
653
+ in_channels=self.residual_channels,
654
+ out_channels=self.residual_channels,
655
+ kernel_size=3,
656
+ stride=2,
657
+ padding=1,
658
+ ),
659
+ nn.ReLU(),
660
+ nn.Conv1d(
661
+ in_channels=self.residual_channels,
662
+ out_channels=self.residual_channels,
663
+ kernel_size=3,
664
+ stride=2,
665
+ padding=1,
666
+ ),
667
+ nn.ReLU(),
668
+ UpsampleNet(self.residual_channels, self.residual_channels, 4, hp),
669
+ )
670
+
671
+ self.res_blocks = nn.ModuleList()
672
+ for n in range(self.num_layers):
673
+ self.res_blocks.append(
674
+ ResBlock(
675
+ self.residual_channels,
676
+ self.gate_channels,
677
+ self.skip_channels,
678
+ self.kernel_size,
679
+ dilation=dilation_rate**n,
680
+ cin_channels=self.cin_channels,
681
+ local_conditioning=(self.cin_channels > 0),
682
+ causal=self.causal,
683
+ mode="SAME",
684
+ )
685
+ )
686
+ self.final_conv = nn.Sequential(
687
+ nn.ReLU(),
688
+ Conv(self.skip_channels, self.skip_channels, 1, causal=self.causal),
689
+ nn.ReLU(),
690
+ Conv(self.skip_channels, self.out_channels, 1, causal=self.causal),
691
+ )
692
+
693
+ def forward(self, x, c=None):
694
+ return self.wavenet(x, c)
695
+
696
+ def wavenet(self, tensor, c=None):
697
+
698
+ h = self.front_conv(tensor)
699
+ if self.use_downup:
700
+ h = self.downup_conv(h)
701
+ skip = 0
702
+ for i, f in enumerate(self.res_blocks):
703
+ h, s = f(h, c)
704
+ skip += s
705
+ out = self.final_conv(skip)
706
+ return out
707
+
708
+ def receptive_field_size(self):
709
+ num_dir = 1 if self.causal else 2
710
+ dilations = [2 ** (i % self.num_layers) for i in range(self.num_layers)]
711
+ return (
712
+ num_dir * (self.kernel_size - 1) * sum(dilations)
713
+ + 1
714
+ + (self.front_channels - 1)
715
+ )
716
+
717
+ def remove_weight_norm(self):
718
+ for f in self.res_blocks:
719
+ f.remove_weight_norm()
720
+
721
+
722
+ class Conv(nn.Module):
723
+ def __init__(
724
+ self,
725
+ in_channels,
726
+ out_channels,
727
+ kernel_size,
728
+ dilation=1,
729
+ causal=False,
730
+ mode="SAME",
731
+ ):
732
+ super(Conv, self).__init__()
733
+
734
+ self.causal = causal
735
+ self.mode = mode
736
+ if self.causal and self.mode == "SAME":
737
+ self.padding = dilation * (kernel_size - 1)
738
+ elif self.mode == "SAME":
739
+ self.padding = dilation * (kernel_size - 1) // 2
740
+ else:
741
+ self.padding = 0
742
+ self.conv = nn.Conv1d(
743
+ in_channels,
744
+ out_channels,
745
+ kernel_size,
746
+ dilation=dilation,
747
+ padding=self.padding,
748
+ )
749
+ self.conv = nn.utils.weight_norm(self.conv)
750
+ nn.init.kaiming_normal_(self.conv.weight)
751
+
752
+ def forward(self, tensor):
753
+ out = self.conv(tensor)
754
+ if self.causal and self.padding is not 0:
755
+ out = out[:, :, : -self.padding]
756
+ return out
757
+
758
+ def remove_weight_norm(self):
759
+ nn.utils.remove_weight_norm(self.conv)
760
+
761
+
762
+ class ResBlock(nn.Module):
763
+ def __init__(
764
+ self,
765
+ in_channels,
766
+ out_channels,
767
+ skip_channels,
768
+ kernel_size,
769
+ dilation,
770
+ cin_channels=None,
771
+ local_conditioning=True,
772
+ causal=False,
773
+ mode="SAME",
774
+ ):
775
+ super(ResBlock, self).__init__()
776
+ self.causal = causal
777
+ self.local_conditioning = local_conditioning
778
+ self.cin_channels = cin_channels
779
+ self.mode = mode
780
+
781
+ self.filter_conv = Conv(
782
+ in_channels, out_channels, kernel_size, dilation, causal, mode
783
+ )
784
+ self.gate_conv = Conv(
785
+ in_channels, out_channels, kernel_size, dilation, causal, mode
786
+ )
787
+ self.res_conv = nn.Conv1d(out_channels, in_channels, kernel_size=1)
788
+ self.skip_conv = nn.Conv1d(out_channels, skip_channels, kernel_size=1)
789
+ self.res_conv = nn.utils.weight_norm(self.res_conv)
790
+ self.skip_conv = nn.utils.weight_norm(self.skip_conv)
791
+
792
+ if self.local_conditioning:
793
+ self.filter_conv_c = nn.Conv1d(cin_channels, out_channels, kernel_size=1)
794
+ self.gate_conv_c = nn.Conv1d(cin_channels, out_channels, kernel_size=1)
795
+ self.filter_conv_c = nn.utils.weight_norm(self.filter_conv_c)
796
+ self.gate_conv_c = nn.utils.weight_norm(self.gate_conv_c)
797
+
798
+ def forward(self, tensor, c=None):
799
+ h_filter = self.filter_conv(tensor)
800
+ h_gate = self.gate_conv(tensor)
801
+
802
+ if self.local_conditioning:
803
+ h_filter += self.filter_conv_c(c)
804
+ h_gate += self.gate_conv_c(c)
805
+
806
+ out = torch.tanh(h_filter) * torch.sigmoid(h_gate)
807
+
808
+ res = self.res_conv(out)
809
+ skip = self.skip_conv(out)
810
+ if self.mode == "SAME":
811
+ return (tensor + res) * math.sqrt(0.5), skip
812
+ else:
813
+ return (tensor[:, :, 1:] + res) * math.sqrt(0.5), skip
814
+
815
+ def remove_weight_norm(self):
816
+ self.filter_conv.remove_weight_norm()
817
+ self.gate_conv.remove_weight_norm()
818
+ nn.utils.remove_weight_norm(self.res_conv)
819
+ nn.utils.remove_weight_norm(self.skip_conv)
820
+ nn.utils.remove_weight_norm(self.filter_conv_c)
821
+ nn.utils.remove_weight_norm(self.gate_conv_c)
822
+
823
+
824
+ @torch.jit.script
825
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
826
+ n_channels_int = n_channels[0]
827
+ in_act = input_a + input_b
828
+ t_act = torch.tanh(in_act[:, :n_channels_int])
829
+ s_act = torch.sigmoid(in_act[:, n_channels_int:])
830
+ acts = t_act * s_act
831
+ return acts
832
+
833
+
834
+ @torch.jit.script
835
+ def fused_res_skip(tensor, res_skip, n_channels):
836
+ n_channels_int = n_channels[0]
837
+ res = res_skip[:, :n_channels_int]
838
+ skip = res_skip[:, n_channels_int:]
839
+ return (tensor + res), skip
840
+
841
+
842
+ class ResStack2D(nn.Module):
843
+ def __init__(self, channels=16, kernel_size=3, resstack_depth=4, hp=None):
844
+ super(ResStack2D, self).__init__()
845
+ channels = 16
846
+ kernel_size = 3
847
+ resstack_depth = 2
848
+ self.channels = channels
849
+
850
+ def get_padding(kernel_size, dilation=1):
851
+ return int((kernel_size * dilation - dilation) / 2)
852
+
853
+ self.layers = nn.ModuleList(
854
+ [
855
+ nn.Sequential(
856
+ nn.LeakyReLU(),
857
+ nn.utils.weight_norm(
858
+ nn.Conv2d(
859
+ 1,
860
+ self.channels,
861
+ kernel_size,
862
+ dilation=(1, 3 ** (i)),
863
+ padding=(1, get_padding(kernel_size, 3 ** (i))),
864
+ )
865
+ ),
866
+ nn.LeakyReLU(),
867
+ nn.utils.weight_norm(
868
+ nn.Conv2d(
869
+ self.channels,
870
+ self.channels,
871
+ kernel_size,
872
+ dilation=(1, 3 ** (i)),
873
+ padding=(1, get_padding(kernel_size, 3 ** (i))),
874
+ )
875
+ ),
876
+ nn.LeakyReLU(),
877
+ nn.utils.weight_norm(nn.Conv2d(self.channels, 1, kernel_size=1)),
878
+ )
879
+ for i in range(resstack_depth)
880
+ ]
881
+ )
882
+
883
+ def forward(self, tensor):
884
+ x = tensor.unsqueeze(1)
885
+ for layer in self.layers:
886
+ x = x + layer(x)
887
+ x = x.squeeze(1)
888
+
889
+ return x
890
+
891
+
892
+ class FiLM(nn.Module):
893
+ """
894
+ feature-wise linear modulation
895
+ """
896
+
897
+ def __init__(self, input_dim, attribute_dim):
898
+ super().__init__()
899
+ self.input_dim = input_dim
900
+ self.generator = nn.Conv1d(
901
+ attribute_dim, input_dim * 2, kernel_size=3, padding=1
902
+ )
903
+
904
+ def forward(self, x, c):
905
+ """
906
+ x: (B, input_dim, seq)
907
+ c: (B, attribute_dim, seq)
908
+ """
909
+ c = self.generator(c)
910
+ m, s = torch.split(c, self.input_dim, dim=1)
911
+
912
+ return x * s + m
913
+
914
+
915
+ class FiLMConv1d(nn.Module):
916
+ """
917
+ Conv1d with FiLMs in between
918
+ """
919
+
920
+ def __init__(self, in_size, out_size, attribute_dim, ins_norm=True, loop=1):
921
+ super().__init__()
922
+ self.loop = loop
923
+ self.mlps = nn.ModuleList(
924
+ [nn.Conv1d(in_size, out_size, kernel_size=3, padding=1)]
925
+ + [
926
+ nn.Conv1d(out_size, out_size, kernel_size=3, padding=1)
927
+ for i in range(loop - 1)
928
+ ]
929
+ )
930
+ self.films = nn.ModuleList([FiLM(out_size, attribute_dim) for i in range(loop)])
931
+ self.ins_norm = ins_norm
932
+ if self.ins_norm:
933
+ self.norm = nn.InstanceNorm1d(attribute_dim)
934
+
935
+ def forward(self, x, c):
936
+ """
937
+ x: (B, input_dim, seq)
938
+ c: (B, attribute_dim, seq)
939
+ """
940
+ if self.ins_norm:
941
+ c = self.norm(c)
942
+ for i in range(self.loop):
943
+ x = self.mlps[i](x)
944
+ x = F.relu(x)
945
+ x = self.films[i](x, c)
946
+
947
+ return x
voicefixer/vocoder/model/pqmf.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import scipy.io.wavfile
7
+
8
+
9
+ class PQMF(nn.Module):
10
+ def __init__(self, N, M, file_path="utils/pqmf_hk_4_64.dat"):
11
+ super().__init__()
12
+ self.N = N # nsubband
13
+ self.M = M # nfilter
14
+ self.ana_conv_filter = nn.Conv1d(
15
+ 1, out_channels=N, kernel_size=M, stride=N, bias=False
16
+ )
17
+ data = np.reshape(np.fromfile(file_path, dtype=np.float32), (N, M))
18
+ data = np.flipud(data.T).T
19
+ gk = data.copy()
20
+ data = np.reshape(data, (N, 1, M)).copy()
21
+ dict_new = self.ana_conv_filter.state_dict().copy()
22
+ dict_new["weight"] = torch.from_numpy(data)
23
+ self.ana_pad = nn.ConstantPad1d((M - N, 0), 0)
24
+ self.ana_conv_filter.load_state_dict(dict_new)
25
+
26
+ self.syn_pad = nn.ConstantPad1d((0, M // N - 1), 0)
27
+ self.syn_conv_filter = nn.Conv1d(
28
+ N, out_channels=N, kernel_size=M // N, stride=1, bias=False
29
+ )
30
+ gk = np.transpose(np.reshape(gk, (4, 16, 4)), (1, 0, 2)) * N
31
+ gk = np.transpose(gk[::-1, :, :], (2, 1, 0)).copy()
32
+ dict_new = self.syn_conv_filter.state_dict().copy()
33
+ dict_new["weight"] = torch.from_numpy(gk)
34
+ self.syn_conv_filter.load_state_dict(dict_new)
35
+
36
+ for param in self.parameters():
37
+ param.requires_grad = False
38
+
39
+ def analysis(self, inputs):
40
+ return self.ana_conv_filter(self.ana_pad(inputs))
41
+
42
+ def synthesis(self, inputs):
43
+ return self.syn_conv_filter(self.syn_pad(inputs))
44
+
45
+ def forward(self, inputs):
46
+ return self.ana_conv_filter(self.ana_pad(inputs))
47
+
48
+
49
+ if __name__ == "__main__":
50
+ a = PQMF(4, 64)
51
+ # x = np.load('data/train/audio/010000.npy')
52
+ x = np.zeros([8, 24000], np.float32)
53
+ x = np.reshape(x, (8, 1, -1))
54
+ x = torch.from_numpy(x)
55
+ b = a.analysis(x)
56
+ c = a.synthesis(b)
57
+ print(x.shape, b.shape, c.shape)
58
+ b = (b * 32768).numpy()
59
+ b = np.reshape(np.transpose(b, (0, 2, 1)), (-1, 1)).astype(np.int16)
60
+ # b.tofile('1.pcm')
61
+ # np.reshape(np.transpose(c.numpy()*32768, (0, 2, 1)), (-1,1)).astype(np.int16).tofile('2.pcm')
voicefixer/vocoder/model/res_msd.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
5
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
6
+
7
+ LRELU_SLOPE = 0.1
8
+
9
+
10
+ def init_weights(m, mean=0.0, std=0.01):
11
+ classname = m.__class__.__name__
12
+ if classname.find("Conv") != -1:
13
+ m.weight.data.normal_(mean, std)
14
+
15
+
16
+ def get_padding(kernel_size, dilation=1):
17
+ return int((kernel_size * dilation - dilation) / 2)
18
+
19
+
20
+ class ResStack(nn.Module):
21
+ def __init__(self, channels=384, kernel_size=3, resstack_depth=3, hp=None):
22
+ super(ResStack, self).__init__()
23
+ dilation = [2 * i + 1 for i in range(resstack_depth)] # [1, 3, 5]
24
+ self.convs1 = nn.ModuleList(
25
+ [
26
+ weight_norm(
27
+ Conv1d(
28
+ channels,
29
+ channels,
30
+ kernel_size,
31
+ 1,
32
+ dilation=dilation[i],
33
+ padding=get_padding(kernel_size, dilation[i]),
34
+ )
35
+ )
36
+ for i in range(resstack_depth)
37
+ ]
38
+ )
39
+ self.convs1.apply(init_weights)
40
+
41
+ self.convs2 = nn.ModuleList(
42
+ [
43
+ weight_norm(
44
+ Conv1d(
45
+ channels,
46
+ channels,
47
+ kernel_size,
48
+ 1,
49
+ dilation=1,
50
+ padding=get_padding(kernel_size, 1),
51
+ )
52
+ )
53
+ for i in range(resstack_depth)
54
+ ]
55
+ )
56
+ self.convs2.apply(init_weights)
57
+
58
+ def forward(self, x):
59
+ for c1, c2 in zip(self.convs1, self.convs2):
60
+ xt = F.leaky_relu(x, LRELU_SLOPE)
61
+ xt = c1(xt)
62
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
63
+ xt = c2(xt)
64
+ x = xt + x
65
+ return x
66
+
67
+ def remove_weight_norm(self):
68
+ for l in self.convs1:
69
+ remove_weight_norm(l)
70
+ for l in self.convs2:
71
+ remove_weight_norm(l)
voicefixer/vocoder/model/util.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from voicefixer.vocoder.config import Config
2
+ from voicefixer.tools.pytorch_util import try_tensor_cuda, check_cuda_availability
3
+ import torch
4
+ import librosa
5
+ import numpy as np
6
+
7
+
8
+ def tr_normalize(S):
9
+ if Config.allow_clipping_in_normalization:
10
+ if Config.symmetric_mels:
11
+ return torch.clip(
12
+ (2 * Config.max_abs_value) * ((S - Config.min_db) / (-Config.min_db))
13
+ - Config.max_abs_value,
14
+ -Config.max_abs_value,
15
+ Config.max_abs_value,
16
+ )
17
+ else:
18
+ return torch.clip(
19
+ Config.max_abs_value * ((S - Config.min_db) / (-Config.min_db)),
20
+ 0,
21
+ Config.max_abs_value,
22
+ )
23
+
24
+ assert S.max() <= 0 and S.min() - Config.min_db >= 0
25
+ if Config.symmetric_mels:
26
+ return (2 * Config.max_abs_value) * (
27
+ (S - Config.min_db) / (-Config.min_db)
28
+ ) - Config.max_abs_value
29
+ else:
30
+ return Config.max_abs_value * ((S - Config.min_db) / (-Config.min_db))
31
+
32
+
33
+ def tr_amp_to_db(x):
34
+ min_level = torch.exp(Config.min_level_db / 20 * torch.log(torch.tensor(10.0)))
35
+ min_level = min_level.type_as(x)
36
+ return 20 * torch.log10(torch.maximum(min_level, x))
37
+
38
+
39
+ def normalize(S):
40
+ if Config.allow_clipping_in_normalization:
41
+ if Config.symmetric_mels:
42
+ return np.clip(
43
+ (2 * Config.max_abs_value) * ((S - Config.min_db) / (-Config.min_db))
44
+ - Config.max_abs_value,
45
+ -Config.max_abs_value,
46
+ Config.max_abs_value,
47
+ )
48
+ else:
49
+ return np.clip(
50
+ Config.max_abs_value * ((S - Config.min_db) / (-Config.min_db)),
51
+ 0,
52
+ Config.max_abs_value,
53
+ )
54
+
55
+ assert S.max() <= 0 and S.min() - Config.min_db >= 0
56
+ if Config.symmetric_mels:
57
+ return (2 * Config.max_abs_value) * (
58
+ (S - Config.min_db) / (-Config.min_db)
59
+ ) - Config.max_abs_value
60
+ else:
61
+ return Config.max_abs_value * ((S - Config.min_db) / (-Config.min_db))
62
+
63
+
64
+ def amp_to_db(x):
65
+ min_level = np.exp(Config.min_level_db / 20 * np.log(10))
66
+ return 20 * np.log10(np.maximum(min_level, x))
67
+
68
+
69
+ def tr_pre(npy):
70
+ # conditions = torch.FloatTensor(npy).type_as(npy) # to(device)
71
+ conditions = npy.transpose(1, 2)
72
+ l = conditions.size(-1)
73
+ pad_tail = l % 2 + 4
74
+ zeros = (
75
+ torch.zeros([conditions.size()[0], Config.num_mels, pad_tail]).type_as(
76
+ conditions
77
+ )
78
+ + -4.0
79
+ )
80
+ return torch.cat([conditions, zeros], dim=-1)
81
+
82
+
83
+ def pre(npy):
84
+ conditions = npy
85
+ ## padding tail
86
+ if type(conditions) == np.ndarray:
87
+ conditions = torch.FloatTensor(conditions).unsqueeze(0)
88
+ else:
89
+ conditions = torch.FloatTensor(conditions.float()).unsqueeze(0)
90
+ conditions = conditions.transpose(1, 2)
91
+ l = conditions.size(-1)
92
+ pad_tail = l % 2 + 4
93
+ zeros = torch.zeros([1, Config.num_mels, pad_tail]) + -4.0
94
+ return torch.cat([conditions, zeros], dim=-1)
95
+
96
+
97
+ def load_try(state, model):
98
+ model_dict = model.state_dict()
99
+ try:
100
+ model_dict.update(state)
101
+ model.load_state_dict(model_dict)
102
+ except RuntimeError as e:
103
+ print(str(e))
104
+ model_dict = model.state_dict()
105
+ for k, v in state.items():
106
+ model_dict[k] = v
107
+ model.load_state_dict(model_dict)
108
+
109
+
110
+ def load_checkpoint(checkpoint_path, device):
111
+ checkpoint = torch.load(checkpoint_path, map_location=device)
112
+ return checkpoint
113
+
114
+
115
+ def build_mel_basis():
116
+ return librosa.filters.mel(
117
+ Config.sample_rate,
118
+ Config.n_fft,
119
+ htk=True,
120
+ n_mels=Config.num_mels,
121
+ fmin=0,
122
+ fmax=int(Config.sample_rate // 2),
123
+ )
124
+
125
+
126
+ def linear_to_mel(spectogram):
127
+ _mel_basis = build_mel_basis()
128
+ return np.dot(_mel_basis, spectogram)
129
+
130
+
131
+ if __name__ == "__main__":
132
+ data = torch.randn((3, 5, 100))
133
+ b = normalize(amp_to_db(data.numpy()))
134
+ a = tr_normalize(tr_amp_to_db(data)).numpy()
135
+ print(a - b)