Spaces:
Runtime error
Runtime error
Commit
·
5ec3488
1
Parent(s):
446d1e4
Upload 36 files
Browse files- main.py +37 -0
- packages.txt +2 -0
- requirements.txt +11 -0
- voicefixer/__init__.py +14 -0
- voicefixer/__main__.py +170 -0
- voicefixer/base.py +145 -0
- voicefixer/restorer/__init__.py +44 -0
- voicefixer/restorer/model.py +680 -0
- voicefixer/restorer/model_kqq_bn.py +186 -0
- voicefixer/restorer/modules.py +217 -0
- voicefixer/tools/__init__.py +11 -0
- voicefixer/tools/base.py +244 -0
- voicefixer/tools/io.py +44 -0
- voicefixer/tools/mel_scale.py +238 -0
- voicefixer/tools/modules/__init__.py +11 -0
- voicefixer/tools/modules/fDomainHelper.py +234 -0
- voicefixer/tools/modules/filters/f_2_64.mat +0 -0
- voicefixer/tools/modules/filters/f_4_64.mat +0 -0
- voicefixer/tools/modules/filters/f_8_64.mat +0 -0
- voicefixer/tools/modules/filters/h_2_64.mat +0 -0
- voicefixer/tools/modules/filters/h_4_64.mat +0 -0
- voicefixer/tools/modules/filters/h_8_64.mat +0 -0
- voicefixer/tools/modules/pqmf.py +116 -0
- voicefixer/tools/path.py +13 -0
- voicefixer/tools/pytorch_util.py +180 -0
- voicefixer/tools/random_.py +52 -0
- voicefixer/tools/wav.py +242 -0
- voicefixer/vocoder/__init__.py +30 -0
- voicefixer/vocoder/base.py +86 -0
- voicefixer/vocoder/config.py +316 -0
- voicefixer/vocoder/model/__init__.py +11 -0
- voicefixer/vocoder/model/generator.py +168 -0
- voicefixer/vocoder/model/modules.py +947 -0
- voicefixer/vocoder/model/pqmf.py +61 -0
- voicefixer/vocoder/model/res_msd.py +71 -0
- voicefixer/vocoder/model/util.py +135 -0
main.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from voicefixer.base import VoiceFixer
|
2 |
+
import streamlit as st
|
3 |
+
from audio_recorder_streamlit import audio_recorder
|
4 |
+
from io import BytesIO
|
5 |
+
import soundfile as sf
|
6 |
+
|
7 |
+
st.set_page_config(page_title="VoiceFixer app", page_icon=":notes:")
|
8 |
+
st.title("Voice Fixer App :notes:")
|
9 |
+
st.write(
|
10 |
+
"""
|
11 |
+
This app is a mix of [VoiceFixer Model](https://github.com/haoheliu/voicefixer), and a custom
|
12 |
+
Streamlit component that [records audio](https://github.com/Joooohan/audio-recorder-streamlit) Online.
|
13 |
+
Currently the app shows great results when removing background noises, but
|
14 |
+
speech improvements aren't as obvious.
|
15 |
+
""")
|
16 |
+
#Config files are on voicefixer/base and voicefixer/vocoder/config import
|
17 |
+
# They were uploaded on hugging face
|
18 |
+
voicefixer = VoiceFixer()
|
19 |
+
audio_bytes = audio_recorder(
|
20 |
+
pause_threshold= 1.5
|
21 |
+
)
|
22 |
+
try:
|
23 |
+
data, samplerate = sf.read(BytesIO(audio_bytes))
|
24 |
+
print(samplerate)
|
25 |
+
sf.write("original.wav",data,samplerate)
|
26 |
+
st.audio(audio_bytes, format = "audio/wav")
|
27 |
+
if data.shape[0]>=10000:
|
28 |
+
voicefixer.restore(input="original.wav", # low quality .wav/.flac file
|
29 |
+
output="enhanced_output.wav",
|
30 |
+
cuda=False, # GPU acceleration
|
31 |
+
mode=0)
|
32 |
+
st.write("The Audio without background noises and a little enhancement :ocean:")
|
33 |
+
st.audio("enhanced_output.wav")
|
34 |
+
|
35 |
+
else: st.warning("Recorded Audio is too short, try again :relieved:")#wink
|
36 |
+
except:
|
37 |
+
st.info("Try to record some audio :relieved:")
|
packages.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
ffmpeg
|
2 |
+
libsndfile1
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
audio_recorder_streamlit>=0.0.7
|
2 |
+
soundfile>=0.9.0
|
3 |
+
huggingface-hub>=0.11.1
|
4 |
+
librosa>=0.8.1,<0.9.0
|
5 |
+
torch>=1.7.0
|
6 |
+
matplotlib
|
7 |
+
progressbar
|
8 |
+
torchlibrosa==0.0.7
|
9 |
+
GitPython
|
10 |
+
streamlit>=1.12.
|
11 |
+
pyyaml
|
voicefixer/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@File : __init__.py.py
|
5 |
+
@Contact : [email protected]
|
6 |
+
@License : (C)Copyright 2020-2100
|
7 |
+
|
8 |
+
@Modify Time @Author @Version @Desciption
|
9 |
+
------------ ------- -------- -----------
|
10 |
+
9/14/21 12:31 AM Haohe Liu 1.0 None
|
11 |
+
"""
|
12 |
+
|
13 |
+
from voicefixer.vocoder.base import Vocoder
|
14 |
+
from voicefixer.base import VoiceFixer
|
voicefixer/__main__.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
from genericpath import exists
|
3 |
+
import os.path
|
4 |
+
import argparse
|
5 |
+
from voicefixer import VoiceFixer
|
6 |
+
import torch
|
7 |
+
import os
|
8 |
+
|
9 |
+
|
10 |
+
def writefile(infile, outfile, mode, append_mode, cuda, verbose=False):
|
11 |
+
if append_mode is True:
|
12 |
+
outbasename, outext = os.path.splitext(os.path.basename(outfile))
|
13 |
+
outfile = os.path.join(
|
14 |
+
os.path.dirname(outfile), "{}-mode{}{}".format(outbasename, mode, outext)
|
15 |
+
)
|
16 |
+
if verbose:
|
17 |
+
print("Processing {}, mode={}".format(infile, mode))
|
18 |
+
voicefixer.restore(input=infile, output=outfile, cuda=cuda, mode=int(mode))
|
19 |
+
|
20 |
+
def check_arguments(args):
|
21 |
+
process_file, process_folder = len(args.infile) != 0, len(args.infolder) != 0
|
22 |
+
# assert len(args.infile) == 0 and len(args.outfile) == 0 or process_file, \
|
23 |
+
# "Error: You should give the input and output file path at the same time. The input and output file path we receive is %s and %s" % (args.infile, args.outfile)
|
24 |
+
# assert len(args.infolder) == 0 and len(args.outfolder) == 0 or process_folder, \
|
25 |
+
# "Error: You should give the input and output folder path at the same time. The input and output folder path we receive is %s and %s" % (args.infolder, args.outfolder)
|
26 |
+
assert (
|
27 |
+
process_file or process_folder
|
28 |
+
), "Error: You need to specify a input file path (--infile) or a input folder path (--infolder) to proceed. For more information please run: voicefixer -h"
|
29 |
+
|
30 |
+
# if(args.cuda and not torch.cuda.is_available()):
|
31 |
+
# print("Warning: You set --cuda while no cuda device found on your machine. We will use CPU instead.")
|
32 |
+
if process_file:
|
33 |
+
assert os.path.exists(args.infile), (
|
34 |
+
"Error: The input file %s is not found." % args.infile
|
35 |
+
)
|
36 |
+
output_dirname = os.path.dirname(args.outfile)
|
37 |
+
if len(output_dirname) > 1:
|
38 |
+
os.makedirs(output_dirname, exist_ok=True)
|
39 |
+
if process_folder:
|
40 |
+
assert os.path.exists(args.infolder), (
|
41 |
+
"Error: The input folder %s is not found." % args.infile
|
42 |
+
)
|
43 |
+
output_dirname = args.outfolder
|
44 |
+
if len(output_dirname) > 1:
|
45 |
+
os.makedirs(args.outfolder, exist_ok=True)
|
46 |
+
|
47 |
+
return process_file, process_folder
|
48 |
+
|
49 |
+
|
50 |
+
if __name__ == "__main__":
|
51 |
+
parser = argparse.ArgumentParser(
|
52 |
+
description="VoiceFixer - restores degraded speech"
|
53 |
+
)
|
54 |
+
parser.add_argument(
|
55 |
+
"-i",
|
56 |
+
"--infile",
|
57 |
+
type=str,
|
58 |
+
default="",
|
59 |
+
help="An input file to be processed by VoiceFixer.",
|
60 |
+
)
|
61 |
+
parser.add_argument(
|
62 |
+
"-o",
|
63 |
+
"--outfile",
|
64 |
+
type=str,
|
65 |
+
default="outfile.wav",
|
66 |
+
help="An output file to store the result.",
|
67 |
+
)
|
68 |
+
|
69 |
+
parser.add_argument(
|
70 |
+
"-ifdr",
|
71 |
+
"--infolder",
|
72 |
+
type=str,
|
73 |
+
default="",
|
74 |
+
help="Input folder. Place all your wav file that need process in this folder.",
|
75 |
+
)
|
76 |
+
parser.add_argument(
|
77 |
+
"-ofdr",
|
78 |
+
"--outfolder",
|
79 |
+
type=str,
|
80 |
+
default="outfolder",
|
81 |
+
help="Output folder. The processed files will be stored in this folder.",
|
82 |
+
)
|
83 |
+
|
84 |
+
parser.add_argument(
|
85 |
+
"--mode", help="mode", choices=["0", "1", "2", "all"], default="0"
|
86 |
+
)
|
87 |
+
parser.add_argument('--disable-cuda', help='Set this flag if you do not want to use your gpu.', default=False, action="store_true")
|
88 |
+
parser.add_argument(
|
89 |
+
"--silent",
|
90 |
+
help="Set this flag if you do not want to see any message.",
|
91 |
+
default=False,
|
92 |
+
action="store_true",
|
93 |
+
)
|
94 |
+
|
95 |
+
args = parser.parse_args()
|
96 |
+
|
97 |
+
if torch.cuda.is_available() and not args.disable_cuda:
|
98 |
+
cuda = True
|
99 |
+
else:
|
100 |
+
cuda = False
|
101 |
+
|
102 |
+
process_file, process_folder = check_arguments(args)
|
103 |
+
|
104 |
+
if not args.silent:
|
105 |
+
print("Initializing VoiceFixer")
|
106 |
+
voicefixer = VoiceFixer()
|
107 |
+
|
108 |
+
if not args.silent:
|
109 |
+
print("Start processing the input file %s." % args.infile)
|
110 |
+
|
111 |
+
if process_file:
|
112 |
+
audioext = os.path.splitext(os.path.basename(args.infile))[-1]
|
113 |
+
if audioext != ".wav":
|
114 |
+
raise ValueError(
|
115 |
+
"Error: Error processing the input file. We only support the .wav format currently. Please convert your %s format to .wav. Thanks."
|
116 |
+
% audioext
|
117 |
+
)
|
118 |
+
if args.mode == "all":
|
119 |
+
for file_mode in range(3):
|
120 |
+
writefile(
|
121 |
+
args.infile,
|
122 |
+
args.outfile,
|
123 |
+
file_mode,
|
124 |
+
True,
|
125 |
+
cuda,
|
126 |
+
verbose=not args.silent,
|
127 |
+
)
|
128 |
+
else:
|
129 |
+
writefile(
|
130 |
+
args.infile,
|
131 |
+
args.outfile,
|
132 |
+
args.mode,
|
133 |
+
False,
|
134 |
+
cuda,
|
135 |
+
verbose=not args.silent,
|
136 |
+
)
|
137 |
+
|
138 |
+
if process_folder:
|
139 |
+
if not args.silent:
|
140 |
+
files = [
|
141 |
+
file
|
142 |
+
for file in os.listdir(args.infolder)
|
143 |
+
if (os.path.splitext(os.path.basename(file))[-1] == ".wav")
|
144 |
+
]
|
145 |
+
print(
|
146 |
+
"Found %s .wav files in the input folder %s. Start processing."
|
147 |
+
% (len(files), args.infolder)
|
148 |
+
)
|
149 |
+
for file in os.listdir(args.infolder):
|
150 |
+
outbasename, outext = os.path.splitext(os.path.basename(file))
|
151 |
+
in_file = os.path.join(args.infolder, file)
|
152 |
+
out_file = os.path.join(args.outfolder, file)
|
153 |
+
|
154 |
+
if args.mode == "all":
|
155 |
+
for file_mode in range(3):
|
156 |
+
writefile(
|
157 |
+
in_file,
|
158 |
+
out_file,
|
159 |
+
file_mode,
|
160 |
+
True,
|
161 |
+
cuda,
|
162 |
+
verbose=not args.silent,
|
163 |
+
)
|
164 |
+
else:
|
165 |
+
writefile(
|
166 |
+
in_file, out_file, args.mode, False, cuda, verbose=not args.silent
|
167 |
+
)
|
168 |
+
|
169 |
+
if not args.silent:
|
170 |
+
print("Done")
|
voicefixer/base.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import librosa.display
|
2 |
+
from voicefixer.tools.pytorch_util import *
|
3 |
+
from voicefixer.tools.wav import *
|
4 |
+
from voicefixer.restorer.model import VoiceFixer as voicefixer_fe
|
5 |
+
import os
|
6 |
+
from huggingface_hub import hf_hub_download
|
7 |
+
|
8 |
+
path_to_ckpt = hf_hub_download(repo_id="jlmarrugom/voice_fixer", filename="vf.ckpt")
|
9 |
+
|
10 |
+
|
11 |
+
EPS = 1e-8
|
12 |
+
|
13 |
+
|
14 |
+
class VoiceFixer(nn.Module):
|
15 |
+
def __init__(self):
|
16 |
+
super(VoiceFixer, self).__init__()
|
17 |
+
self._model = voicefixer_fe(channels=2, sample_rate=44100)
|
18 |
+
# print(os.path.join(os.path.expanduser('~'), ".cache/voicefixer/analysis_module/checkpoints/epoch=15_trimed_bn.ckpt"))
|
19 |
+
self.analysis_module_ckpt = path_to_ckpt #"models/vf.ckpt"
|
20 |
+
if(not os.path.exists(self.analysis_module_ckpt)):
|
21 |
+
raise RuntimeError("Error 0: The checkpoint for analysis module (vf.ckpt) is not found in ~/.cache/voicefixer/analysis_module/checkpoints. \
|
22 |
+
By default the checkpoint should be download automatically by this program. Something bad may happened.\
|
23 |
+
But don't worry! Alternatively you can download it directly from Zenodo: https://zenodo.org/record/5600188/files/vf.ckpt?download=1.")
|
24 |
+
self._model.load_state_dict(
|
25 |
+
torch.load(
|
26 |
+
self.analysis_module_ckpt
|
27 |
+
)
|
28 |
+
)
|
29 |
+
self._model.eval()
|
30 |
+
|
31 |
+
def _load_wav_energy(self, path, sample_rate, threshold=0.95):
|
32 |
+
wav_10k, _ = librosa.load(path, sr=sample_rate)
|
33 |
+
stft = np.log10(np.abs(librosa.stft(wav_10k)) + 1.0)
|
34 |
+
fbins = stft.shape[0]
|
35 |
+
e_stft = np.sum(stft, axis=1)
|
36 |
+
for i in range(e_stft.shape[0]):
|
37 |
+
e_stft[-i - 1] = np.sum(e_stft[: -i - 1])
|
38 |
+
total = e_stft[-1]
|
39 |
+
for i in range(e_stft.shape[0]):
|
40 |
+
if e_stft[i] < total * threshold:
|
41 |
+
continue
|
42 |
+
else:
|
43 |
+
break
|
44 |
+
return wav_10k, int((sample_rate // 2) * (i / fbins))
|
45 |
+
|
46 |
+
def _load_wav(self, path, sample_rate, threshold=0.95):
|
47 |
+
wav_10k, _ = librosa.load(path, sr=sample_rate)
|
48 |
+
return wav_10k
|
49 |
+
|
50 |
+
def _amp_to_original_f(self, mel_sp_est, mel_sp_target, cutoff=0.2):
|
51 |
+
freq_dim = mel_sp_target.size()[-1]
|
52 |
+
mel_sp_est_low, mel_sp_target_low = (
|
53 |
+
mel_sp_est[..., 5 : int(freq_dim * cutoff)],
|
54 |
+
mel_sp_target[..., 5 : int(freq_dim * cutoff)],
|
55 |
+
)
|
56 |
+
energy_est, energy_target = torch.mean(mel_sp_est_low, dim=(2, 3)), torch.mean(
|
57 |
+
mel_sp_target_low, dim=(2, 3)
|
58 |
+
)
|
59 |
+
amp_ratio = energy_target / energy_est
|
60 |
+
return mel_sp_est * amp_ratio[..., None, None], mel_sp_target
|
61 |
+
|
62 |
+
def _trim_center(self, est, ref):
|
63 |
+
diff = np.abs(est.shape[-1] - ref.shape[-1])
|
64 |
+
if est.shape[-1] == ref.shape[-1]:
|
65 |
+
return est, ref
|
66 |
+
elif est.shape[-1] > ref.shape[-1]:
|
67 |
+
min_len = min(est.shape[-1], ref.shape[-1])
|
68 |
+
est, ref = est[..., int(diff // 2) : -int(diff // 2)], ref
|
69 |
+
est, ref = est[..., :min_len], ref[..., :min_len]
|
70 |
+
return est, ref
|
71 |
+
else:
|
72 |
+
min_len = min(est.shape[-1], ref.shape[-1])
|
73 |
+
est, ref = est, ref[..., int(diff // 2) : -int(diff // 2)]
|
74 |
+
est, ref = est[..., :min_len], ref[..., :min_len]
|
75 |
+
return est, ref
|
76 |
+
|
77 |
+
def _pre(self, model, input, cuda):
|
78 |
+
input = input[None, None, ...]
|
79 |
+
input = torch.tensor(input)
|
80 |
+
input = try_tensor_cuda(input, cuda=cuda)
|
81 |
+
sp, _, _ = model.f_helper.wav_to_spectrogram_phase(input)
|
82 |
+
mel_orig = model.mel(sp.permute(0, 1, 3, 2)).permute(0, 1, 3, 2)
|
83 |
+
# return models.to_log(sp), models.to_log(mel_orig)
|
84 |
+
return sp, mel_orig
|
85 |
+
|
86 |
+
def remove_higher_frequency(self, wav, ratio=0.95):
|
87 |
+
stft = librosa.stft(wav)
|
88 |
+
real, img = np.real(stft), np.imag(stft)
|
89 |
+
mag = (real**2 + img**2) ** 0.5
|
90 |
+
cos, sin = real / (mag + EPS), img / (mag + EPS)
|
91 |
+
spec = np.abs(stft) # [1025,T]
|
92 |
+
feature = spec.copy()
|
93 |
+
feature = np.log10(feature + EPS)
|
94 |
+
feature[feature < 0] = 0
|
95 |
+
energy_level = np.sum(feature, axis=1)
|
96 |
+
threshold = np.sum(energy_level) * ratio
|
97 |
+
curent_level, i = energy_level[0], 0
|
98 |
+
while i < energy_level.shape[0] and curent_level < threshold:
|
99 |
+
curent_level += energy_level[i + 1, ...]
|
100 |
+
i += 1
|
101 |
+
spec[i:, ...] = np.zeros_like(spec[i:, ...])
|
102 |
+
stft = spec * cos + 1j * spec * sin
|
103 |
+
return librosa.istft(stft)
|
104 |
+
|
105 |
+
@torch.no_grad()
|
106 |
+
def restore_inmem(self, wav_10k, cuda=False, mode=0, your_vocoder_func=None):
|
107 |
+
check_cuda_availability(cuda=cuda)
|
108 |
+
self._model = try_tensor_cuda(self._model, cuda=cuda)
|
109 |
+
if mode == 0:
|
110 |
+
self._model.eval()
|
111 |
+
elif mode == 1:
|
112 |
+
self._model.eval()
|
113 |
+
elif mode == 2:
|
114 |
+
self._model.train() # More effective on seriously demaged speech
|
115 |
+
res = []
|
116 |
+
seg_length = 44100 * 30
|
117 |
+
break_point = seg_length
|
118 |
+
while break_point < wav_10k.shape[0] + seg_length:
|
119 |
+
segment = wav_10k[break_point - seg_length : break_point]
|
120 |
+
if mode == 1:
|
121 |
+
segment = self.remove_higher_frequency(segment)
|
122 |
+
sp, mel_noisy = self._pre(self._model, segment, cuda)
|
123 |
+
out_model = self._model(sp, mel_noisy)
|
124 |
+
denoised_mel = from_log(out_model["mel"])
|
125 |
+
if your_vocoder_func is None:
|
126 |
+
out = self._model.vocoder(denoised_mel, cuda=cuda)
|
127 |
+
else:
|
128 |
+
out = your_vocoder_func(denoised_mel)
|
129 |
+
# unify energy
|
130 |
+
if torch.max(torch.abs(out)) > 1.0:
|
131 |
+
out = out / torch.max(torch.abs(out))
|
132 |
+
print("Warning: Exceed energy limit,", input)
|
133 |
+
# frame alignment
|
134 |
+
out, _ = self._trim_center(out, segment)
|
135 |
+
res.append(out)
|
136 |
+
break_point += seg_length
|
137 |
+
out = torch.cat(res, -1)
|
138 |
+
return tensor2numpy(out.squeeze(0))
|
139 |
+
|
140 |
+
def restore(self, input, output, cuda=False, mode=0, your_vocoder_func=None):
|
141 |
+
wav_10k = self._load_wav(input, sample_rate=44100)
|
142 |
+
out_np_wav = self.restore_inmem(
|
143 |
+
wav_10k, cuda=cuda, mode=mode, your_vocoder_func=your_vocoder_func
|
144 |
+
)
|
145 |
+
save_wave(out_np_wav, fname=output, sample_rate=44100)
|
voicefixer/restorer/__init__.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@File : __init__.py.py
|
5 |
+
@Contact : [email protected]
|
6 |
+
@License : (C)Copyright 2020-2100
|
7 |
+
|
8 |
+
@Modify Time @Author @Version @Desciption
|
9 |
+
------------ ------- -------- -----------
|
10 |
+
9/14/21 12:31 AM Haohe Liu 1.0 None
|
11 |
+
"""
|
12 |
+
|
13 |
+
import os
|
14 |
+
import torch
|
15 |
+
import urllib.request
|
16 |
+
|
17 |
+
meta = {
|
18 |
+
"voicefixer_fe": {
|
19 |
+
"path": os.path.join(
|
20 |
+
os.path.expanduser("~"),
|
21 |
+
".cache/voicefixer/analysis_module/checkpoints/vf.ckpt",
|
22 |
+
),
|
23 |
+
"url": "https://zenodo.org/record/5600188/files/vf.ckpt?download=1",
|
24 |
+
},
|
25 |
+
}
|
26 |
+
|
27 |
+
if not os.path.exists(meta["voicefixer_fe"]["path"]):
|
28 |
+
os.makedirs(os.path.dirname(meta["voicefixer_fe"]["path"]), exist_ok=True)
|
29 |
+
print("Downloading the main structure of voicefixer")
|
30 |
+
|
31 |
+
urllib.request.urlretrieve(
|
32 |
+
meta["voicefixer_fe"]["url"], meta["voicefixer_fe"]["path"]
|
33 |
+
)
|
34 |
+
print(
|
35 |
+
"Weights downloaded in: {} Size: {}".format(
|
36 |
+
meta["voicefixer_fe"]["path"],
|
37 |
+
os.path.getsize(meta["voicefixer_fe"]["path"]),
|
38 |
+
)
|
39 |
+
)
|
40 |
+
|
41 |
+
# cmd = "wget "+ meta["voicefixer_fe"]['url'] + " -O " + meta["voicefixer_fe"]['path']
|
42 |
+
# os.system(cmd)
|
43 |
+
# temp = torch.load(meta["voicefixer_fe"]['path'])
|
44 |
+
# torch.save(temp['state_dict'], os.path.join(os.path.expanduser('~'), ".cache/voicefixer/analysis_module/checkpoints/vf.ckpt"))
|
voicefixer/restorer/model.py
ADDED
@@ -0,0 +1,680 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import pytorch_lightning as pl
|
2 |
+
|
3 |
+
import torch.utils
|
4 |
+
from voicefixer.tools.mel_scale import MelScale
|
5 |
+
import torch.utils.data
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import librosa.display
|
8 |
+
from voicefixer.vocoder.base import Vocoder
|
9 |
+
from voicefixer.tools.pytorch_util import *
|
10 |
+
from voicefixer.restorer.model_kqq_bn import UNetResComplex_100Mb
|
11 |
+
from voicefixer.tools.random_ import *
|
12 |
+
from voicefixer.tools.wav import *
|
13 |
+
from voicefixer.tools.modules.fDomainHelper import FDomainHelper
|
14 |
+
|
15 |
+
from voicefixer.tools.io import load_json, write_json
|
16 |
+
from matplotlib import cm
|
17 |
+
|
18 |
+
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
|
19 |
+
EPS = 1e-8
|
20 |
+
|
21 |
+
|
22 |
+
class BN_GRU(torch.nn.Module):
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
input_dim,
|
26 |
+
hidden_dim,
|
27 |
+
layer=1,
|
28 |
+
bidirectional=False,
|
29 |
+
batchnorm=True,
|
30 |
+
dropout=0.0,
|
31 |
+
):
|
32 |
+
super(BN_GRU, self).__init__()
|
33 |
+
self.batchnorm = batchnorm
|
34 |
+
if batchnorm:
|
35 |
+
self.bn = nn.BatchNorm2d(1)
|
36 |
+
self.gru = torch.nn.GRU(
|
37 |
+
input_size=input_dim,
|
38 |
+
hidden_size=hidden_dim,
|
39 |
+
num_layers=layer,
|
40 |
+
bidirectional=bidirectional,
|
41 |
+
dropout=dropout,
|
42 |
+
batch_first=True,
|
43 |
+
)
|
44 |
+
self.init_weights()
|
45 |
+
|
46 |
+
def init_weights(self):
|
47 |
+
for m in self.modules():
|
48 |
+
if type(m) in [nn.GRU, nn.LSTM, nn.RNN]:
|
49 |
+
for name, param in m.named_parameters():
|
50 |
+
if "weight_ih" in name:
|
51 |
+
torch.nn.init.xavier_uniform_(param.data)
|
52 |
+
elif "weight_hh" in name:
|
53 |
+
torch.nn.init.orthogonal_(param.data)
|
54 |
+
elif "bias" in name:
|
55 |
+
param.data.fill_(0)
|
56 |
+
|
57 |
+
def forward(self, inputs):
|
58 |
+
# (batch, 1, seq, feature)
|
59 |
+
if self.batchnorm:
|
60 |
+
inputs = self.bn(inputs)
|
61 |
+
out, _ = self.gru(inputs.squeeze(1))
|
62 |
+
return out.unsqueeze(1)
|
63 |
+
|
64 |
+
|
65 |
+
class Generator(nn.Module):
|
66 |
+
def __init__(self, n_mel, hidden, channels):
|
67 |
+
super(Generator, self).__init__()
|
68 |
+
# todo the currently running trail don't have dropout
|
69 |
+
self.denoiser = nn.Sequential(
|
70 |
+
nn.BatchNorm2d(1),
|
71 |
+
nn.Linear(n_mel, n_mel * 2),
|
72 |
+
nn.ReLU(inplace=True),
|
73 |
+
nn.BatchNorm2d(1),
|
74 |
+
nn.Linear(n_mel * 2, n_mel * 4),
|
75 |
+
nn.Dropout(0.5),
|
76 |
+
nn.ReLU(inplace=True),
|
77 |
+
BN_GRU(
|
78 |
+
input_dim=n_mel * 4,
|
79 |
+
hidden_dim=n_mel * 2,
|
80 |
+
bidirectional=True,
|
81 |
+
layer=2,
|
82 |
+
batchnorm=True,
|
83 |
+
),
|
84 |
+
BN_GRU(
|
85 |
+
input_dim=n_mel * 4,
|
86 |
+
hidden_dim=n_mel * 2,
|
87 |
+
bidirectional=True,
|
88 |
+
layer=2,
|
89 |
+
batchnorm=True,
|
90 |
+
),
|
91 |
+
nn.BatchNorm2d(1),
|
92 |
+
nn.ReLU(inplace=True),
|
93 |
+
nn.Linear(n_mel * 4, n_mel * 4),
|
94 |
+
nn.Dropout(0.5),
|
95 |
+
nn.BatchNorm2d(1),
|
96 |
+
nn.ReLU(inplace=True),
|
97 |
+
nn.Linear(n_mel * 4, n_mel),
|
98 |
+
nn.Sigmoid(),
|
99 |
+
)
|
100 |
+
|
101 |
+
self.unet = UNetResComplex_100Mb(channels=channels)
|
102 |
+
|
103 |
+
def forward(self, sp, mel_orig):
|
104 |
+
# Denoising
|
105 |
+
noisy = mel_orig.clone()
|
106 |
+
clean = self.denoiser(noisy) * noisy
|
107 |
+
x = to_log(clean.detach())
|
108 |
+
unet_in = torch.cat([to_log(mel_orig), x], dim=1)
|
109 |
+
# unet_in = lstm_out
|
110 |
+
unet_out = self.unet(unet_in)["mel"]
|
111 |
+
# masks
|
112 |
+
mel = unet_out + x
|
113 |
+
# todo mel and addition here are in log scales
|
114 |
+
return {
|
115 |
+
"mel": mel,
|
116 |
+
"lstm_out": unet_out,
|
117 |
+
"unet_out": unet_out,
|
118 |
+
"noisy": noisy,
|
119 |
+
"clean": clean,
|
120 |
+
}
|
121 |
+
|
122 |
+
|
123 |
+
class VoiceFixer(nn.Module):
|
124 |
+
def __init__(
|
125 |
+
self,
|
126 |
+
channels,
|
127 |
+
type_target="vocals",
|
128 |
+
nsrc=1,
|
129 |
+
loss="l1",
|
130 |
+
lr=0.002,
|
131 |
+
gamma=0.9,
|
132 |
+
batchsize=None,
|
133 |
+
frame_length=None,
|
134 |
+
sample_rate=None,
|
135 |
+
warm_up_steps=1000,
|
136 |
+
reduce_lr_steps=15000,
|
137 |
+
# datas
|
138 |
+
check_val_every_n_epoch=5,
|
139 |
+
):
|
140 |
+
super(VoiceFixer, self).__init__()
|
141 |
+
|
142 |
+
if sample_rate == 44100:
|
143 |
+
window_size = 2048
|
144 |
+
hop_size = 441
|
145 |
+
n_mel = 128
|
146 |
+
elif sample_rate == 24000:
|
147 |
+
window_size = 768
|
148 |
+
hop_size = 240
|
149 |
+
n_mel = 80
|
150 |
+
elif sample_rate == 16000:
|
151 |
+
window_size = 512
|
152 |
+
hop_size = 160
|
153 |
+
n_mel = 80
|
154 |
+
else:
|
155 |
+
raise ValueError(
|
156 |
+
"Error: Sample rate " + str(sample_rate) + " not supported"
|
157 |
+
)
|
158 |
+
|
159 |
+
center = (True,)
|
160 |
+
pad_mode = "reflect"
|
161 |
+
window = "hann"
|
162 |
+
freeze_parameters = True
|
163 |
+
|
164 |
+
# self.save_hyperparameters()
|
165 |
+
self.nsrc = nsrc
|
166 |
+
self.type_target = type_target
|
167 |
+
self.channels = channels
|
168 |
+
self.lr = lr
|
169 |
+
self.generated = None
|
170 |
+
self.gamma = gamma
|
171 |
+
self.sample_rate = sample_rate
|
172 |
+
self.sample_rate = sample_rate
|
173 |
+
self.batchsize = batchsize
|
174 |
+
self.frame_length = frame_length
|
175 |
+
# self.hparams['channels'] = 2
|
176 |
+
|
177 |
+
# self.am = AudioMetrics()
|
178 |
+
# self.im = ImgMetrics()
|
179 |
+
|
180 |
+
self.vocoder = Vocoder(sample_rate=44100)
|
181 |
+
|
182 |
+
self.valid = None
|
183 |
+
self.fake = None
|
184 |
+
|
185 |
+
self.train_step = 0
|
186 |
+
self.val_step = 0
|
187 |
+
self.val_result_save_dir = None
|
188 |
+
self.val_result_save_dir_step = None
|
189 |
+
self.downsample_ratio = 2**6 # This number equals 2^{#encoder_blcoks}
|
190 |
+
self.check_val_every_n_epoch = check_val_every_n_epoch
|
191 |
+
|
192 |
+
self.f_helper = FDomainHelper(
|
193 |
+
window_size=window_size,
|
194 |
+
hop_size=hop_size,
|
195 |
+
center=center,
|
196 |
+
pad_mode=pad_mode,
|
197 |
+
window=window,
|
198 |
+
freeze_parameters=freeze_parameters,
|
199 |
+
)
|
200 |
+
|
201 |
+
hidden = window_size // 2 + 1
|
202 |
+
|
203 |
+
self.mel = MelScale(n_mels=n_mel, sample_rate=sample_rate, n_stft=hidden)
|
204 |
+
|
205 |
+
# masking
|
206 |
+
self.generator = Generator(n_mel, hidden, channels)
|
207 |
+
|
208 |
+
self.lr_lambda = lambda step: self.get_lr_lambda(
|
209 |
+
step,
|
210 |
+
gamma=self.gamma,
|
211 |
+
warm_up_steps=warm_up_steps,
|
212 |
+
reduce_lr_steps=reduce_lr_steps,
|
213 |
+
)
|
214 |
+
|
215 |
+
self.lr_lambda_2 = lambda step: self.get_lr_lambda(
|
216 |
+
step, gamma=self.gamma, warm_up_steps=10, reduce_lr_steps=reduce_lr_steps
|
217 |
+
)
|
218 |
+
|
219 |
+
self.mel_weight_44k_128 = (
|
220 |
+
torch.tensor(
|
221 |
+
[
|
222 |
+
19.40951426,
|
223 |
+
19.94047336,
|
224 |
+
20.4859038,
|
225 |
+
21.04629067,
|
226 |
+
21.62194148,
|
227 |
+
22.21335214,
|
228 |
+
22.8210215,
|
229 |
+
23.44529231,
|
230 |
+
24.08660962,
|
231 |
+
24.74541882,
|
232 |
+
25.42234287,
|
233 |
+
26.11770576,
|
234 |
+
26.83212784,
|
235 |
+
27.56615283,
|
236 |
+
28.32007747,
|
237 |
+
29.0947679,
|
238 |
+
29.89060111,
|
239 |
+
30.70832636,
|
240 |
+
31.54828121,
|
241 |
+
32.41121487,
|
242 |
+
33.29780773,
|
243 |
+
34.20865341,
|
244 |
+
35.14437675,
|
245 |
+
36.1056621,
|
246 |
+
37.09332763,
|
247 |
+
38.10795802,
|
248 |
+
39.15039691,
|
249 |
+
40.22119881,
|
250 |
+
41.32154931,
|
251 |
+
42.45172373,
|
252 |
+
43.61293329,
|
253 |
+
44.80609379,
|
254 |
+
46.031602,
|
255 |
+
47.29070223,
|
256 |
+
48.58427549,
|
257 |
+
49.91327905,
|
258 |
+
51.27863232,
|
259 |
+
52.68119708,
|
260 |
+
54.1222372,
|
261 |
+
55.60274206,
|
262 |
+
57.12364703,
|
263 |
+
58.68617876,
|
264 |
+
60.29148652,
|
265 |
+
61.94081306,
|
266 |
+
63.63501986,
|
267 |
+
65.37562658,
|
268 |
+
67.16408954,
|
269 |
+
69.00109084,
|
270 |
+
70.88850318,
|
271 |
+
72.82736101,
|
272 |
+
74.81985537,
|
273 |
+
76.86654792,
|
274 |
+
78.96885475,
|
275 |
+
81.12900906,
|
276 |
+
83.34840929,
|
277 |
+
85.62810662,
|
278 |
+
87.97005418,
|
279 |
+
90.37689804,
|
280 |
+
92.84887686,
|
281 |
+
95.38872881,
|
282 |
+
97.99777002,
|
283 |
+
100.67862715,
|
284 |
+
103.43232942,
|
285 |
+
106.26140638,
|
286 |
+
109.16827015,
|
287 |
+
112.15470471,
|
288 |
+
115.22184756,
|
289 |
+
118.37439245,
|
290 |
+
121.6122689,
|
291 |
+
124.93877158,
|
292 |
+
128.35661454,
|
293 |
+
131.86761321,
|
294 |
+
135.47417938,
|
295 |
+
139.18059494,
|
296 |
+
142.98713744,
|
297 |
+
146.89771854,
|
298 |
+
150.91684347,
|
299 |
+
155.0446638,
|
300 |
+
159.28614648,
|
301 |
+
163.64270198,
|
302 |
+
168.12035831,
|
303 |
+
172.71749158,
|
304 |
+
177.44220154,
|
305 |
+
182.29556933,
|
306 |
+
187.28286676,
|
307 |
+
192.40502126,
|
308 |
+
197.6682721,
|
309 |
+
203.07516896,
|
310 |
+
208.63088733,
|
311 |
+
214.33770931,
|
312 |
+
220.19910108,
|
313 |
+
226.22363072,
|
314 |
+
232.41087124,
|
315 |
+
238.76803591,
|
316 |
+
245.30079083,
|
317 |
+
252.01064464,
|
318 |
+
258.90261676,
|
319 |
+
265.98474,
|
320 |
+
273.26010248,
|
321 |
+
280.73496362,
|
322 |
+
288.41440094,
|
323 |
+
296.30489752,
|
324 |
+
304.41180337,
|
325 |
+
312.7377183,
|
326 |
+
321.28877878,
|
327 |
+
330.07870237,
|
328 |
+
339.10812951,
|
329 |
+
348.38276173,
|
330 |
+
357.91393924,
|
331 |
+
367.70513992,
|
332 |
+
377.76413924,
|
333 |
+
388.09467408,
|
334 |
+
398.70920178,
|
335 |
+
409.61813793,
|
336 |
+
420.81980127,
|
337 |
+
432.33215467,
|
338 |
+
444.16083117,
|
339 |
+
456.30919947,
|
340 |
+
468.78589276,
|
341 |
+
481.61325588,
|
342 |
+
494.78824596,
|
343 |
+
508.31969844,
|
344 |
+
522.2238331,
|
345 |
+
536.51163441,
|
346 |
+
551.18859414,
|
347 |
+
566.26142988,
|
348 |
+
581.75006061,
|
349 |
+
597.66210737,
|
350 |
+
]
|
351 |
+
)
|
352 |
+
/ 19.40951426
|
353 |
+
)
|
354 |
+
self.mel_weight_44k_128 = self.mel_weight_44k_128[None, None, None, ...]
|
355 |
+
|
356 |
+
self.g_loss_weight = 0.01
|
357 |
+
self.d_loss_weight = 1
|
358 |
+
|
359 |
+
def get_vocoder(self):
|
360 |
+
return self.vocoder
|
361 |
+
|
362 |
+
def get_f_helper(self):
|
363 |
+
return self.f_helper
|
364 |
+
|
365 |
+
def get_lr_lambda(self, step, gamma, warm_up_steps, reduce_lr_steps):
|
366 |
+
r"""Get lr_lambda for LambdaLR. E.g.,
|
367 |
+
|
368 |
+
.. code-block: python
|
369 |
+
lr_lambda = lambda step: get_lr_lambda(step, warm_up_steps=1000, reduce_lr_steps=10000)
|
370 |
+
|
371 |
+
from torch.optim.lr_scheduler import LambdaLR
|
372 |
+
LambdaLR(optimizer, lr_lambda)
|
373 |
+
"""
|
374 |
+
if step <= warm_up_steps:
|
375 |
+
return step / warm_up_steps
|
376 |
+
else:
|
377 |
+
return gamma ** (step // reduce_lr_steps)
|
378 |
+
|
379 |
+
def init_weights(self, module: nn.Module):
|
380 |
+
for m in module.modules():
|
381 |
+
if type(m) in [nn.GRU, nn.LSTM, nn.RNN]:
|
382 |
+
for name, param in m.named_parameters():
|
383 |
+
if "weight_ih" in name:
|
384 |
+
torch.nn.init.xavier_uniform_(param.data)
|
385 |
+
elif "weight_hh" in name:
|
386 |
+
torch.nn.init.orthogonal_(param.data)
|
387 |
+
elif "bias" in name:
|
388 |
+
param.data.fill_(0)
|
389 |
+
|
390 |
+
def pre(self, input):
|
391 |
+
sp, _, _ = self.f_helper.wav_to_spectrogram_phase(input)
|
392 |
+
mel_orig = self.mel(sp.permute(0, 1, 3, 2)).permute(0, 1, 3, 2)
|
393 |
+
return sp, mel_orig
|
394 |
+
|
395 |
+
def forward(self, sp, mel_orig):
|
396 |
+
"""
|
397 |
+
Args:
|
398 |
+
input: (batch_size, channels_num, segment_samples)
|
399 |
+
|
400 |
+
Outputs:
|
401 |
+
output_dict: {
|
402 |
+
'wav': (batch_size, channels_num, segment_samples),
|
403 |
+
'sp': (batch_size, channels_num, time_steps, freq_bins)}
|
404 |
+
"""
|
405 |
+
return self.generator(sp, mel_orig)
|
406 |
+
|
407 |
+
def configure_optimizers(self):
|
408 |
+
optimizer_g = torch.optim.Adam(
|
409 |
+
[{"params": self.generator.parameters()}],
|
410 |
+
lr=self.lr,
|
411 |
+
amsgrad=True,
|
412 |
+
betas=(0.5, 0.999),
|
413 |
+
)
|
414 |
+
optimizer_d = torch.optim.Adam(
|
415 |
+
[{"params": self.discriminator.parameters()}],
|
416 |
+
lr=self.lr,
|
417 |
+
amsgrad=True,
|
418 |
+
betas=(0.5, 0.999),
|
419 |
+
)
|
420 |
+
|
421 |
+
scheduler_g = {
|
422 |
+
"scheduler": torch.optim.lr_scheduler.LambdaLR(optimizer_g, self.lr_lambda),
|
423 |
+
"interval": "step",
|
424 |
+
"frequency": 1,
|
425 |
+
}
|
426 |
+
scheduler_d = {
|
427 |
+
"scheduler": torch.optim.lr_scheduler.LambdaLR(optimizer_d, self.lr_lambda),
|
428 |
+
"interval": "step",
|
429 |
+
"frequency": 1,
|
430 |
+
}
|
431 |
+
return [optimizer_g, optimizer_d], [scheduler_g, scheduler_d]
|
432 |
+
|
433 |
+
def preprocess(self, batch, train=False, cutoff=None):
|
434 |
+
if train:
|
435 |
+
vocal = batch[self.type_target] # final target
|
436 |
+
noise = batch["noise_LR"] # augmented low resolution audio with noise
|
437 |
+
augLR = batch[
|
438 |
+
self.type_target + "_aug_LR"
|
439 |
+
] # # augment low resolution audio
|
440 |
+
LR = batch[self.type_target + "_LR"]
|
441 |
+
# embed()
|
442 |
+
vocal, LR, augLR, noise = (
|
443 |
+
vocal.float().permute(0, 2, 1),
|
444 |
+
LR.float().permute(0, 2, 1),
|
445 |
+
augLR.float().permute(0, 2, 1),
|
446 |
+
noise.float().permute(0, 2, 1),
|
447 |
+
)
|
448 |
+
# LR, noise = self.add_random_noise(LR, noise)
|
449 |
+
snr, scale = [], []
|
450 |
+
for i in range(vocal.size()[0]):
|
451 |
+
(
|
452 |
+
vocal[i, ...],
|
453 |
+
LR[i, ...],
|
454 |
+
augLR[i, ...],
|
455 |
+
noise[i, ...],
|
456 |
+
_snr,
|
457 |
+
_scale,
|
458 |
+
) = add_noise_and_scale_with_HQ_with_Aug(
|
459 |
+
vocal[i, ...],
|
460 |
+
LR[i, ...],
|
461 |
+
augLR[i, ...],
|
462 |
+
noise[i, ...],
|
463 |
+
snr_l=-5,
|
464 |
+
snr_h=45,
|
465 |
+
scale_lower=0.6,
|
466 |
+
scale_upper=1.0,
|
467 |
+
)
|
468 |
+
snr.append(_snr), scale.append(_scale)
|
469 |
+
# vocal, LR = self.amp_to_original_f(vocal, LR)
|
470 |
+
# noise = (noise * 0.0) + 1e-8 # todo
|
471 |
+
return vocal, augLR, LR, noise + augLR
|
472 |
+
else:
|
473 |
+
if cutoff is None:
|
474 |
+
LR_noisy = batch["noisy"]
|
475 |
+
LR = batch["vocals"]
|
476 |
+
vocals = batch["vocals"]
|
477 |
+
vocals, LR, LR_noisy = (
|
478 |
+
vocals.float().permute(0, 2, 1),
|
479 |
+
LR.float().permute(0, 2, 1),
|
480 |
+
LR_noisy.float().permute(0, 2, 1),
|
481 |
+
)
|
482 |
+
return vocals, LR, LR_noisy, batch["fname"][0]
|
483 |
+
else:
|
484 |
+
LR_noisy = batch["noisy" + "LR" + "_" + str(cutoff)]
|
485 |
+
LR = batch["vocals" + "LR" + "_" + str(cutoff)]
|
486 |
+
vocals = batch["vocals"]
|
487 |
+
vocals, LR, LR_noisy = (
|
488 |
+
vocals.float().permute(0, 2, 1),
|
489 |
+
LR.float().permute(0, 2, 1),
|
490 |
+
LR_noisy.float().permute(0, 2, 1),
|
491 |
+
)
|
492 |
+
return vocals, LR, LR_noisy, batch["fname"][0]
|
493 |
+
|
494 |
+
def training_step(self, batch, batch_nb, optimizer_idx):
|
495 |
+
# dict_keys(['vocals', 'vocals_aug', 'vocals_augLR', 'noise'])
|
496 |
+
config = load_json("temp_path.json")
|
497 |
+
if "g_loss_weight" not in config.keys():
|
498 |
+
config["g_loss_weight"] = self.g_loss_weight
|
499 |
+
config["d_loss_weight"] = self.d_loss_weight
|
500 |
+
write_json(config, "temp_path.json")
|
501 |
+
elif (
|
502 |
+
config["g_loss_weight"] != self.g_loss_weight
|
503 |
+
or config["d_loss_weight"] != self.d_loss_weight
|
504 |
+
):
|
505 |
+
print(
|
506 |
+
"Update d_loss weight, from",
|
507 |
+
self.d_loss_weight,
|
508 |
+
"to",
|
509 |
+
config["d_loss_weight"],
|
510 |
+
)
|
511 |
+
print(
|
512 |
+
"Update g_loss weight, from",
|
513 |
+
self.g_loss_weight,
|
514 |
+
"to",
|
515 |
+
config["g_loss_weight"],
|
516 |
+
)
|
517 |
+
self.g_loss_weight = config["g_loss_weight"]
|
518 |
+
self.d_loss_weight = config["d_loss_weight"]
|
519 |
+
|
520 |
+
if optimizer_idx == 0:
|
521 |
+
self.vocal, self.augLR, _, self.LR_noisy = self.preprocess(
|
522 |
+
batch, train=True
|
523 |
+
)
|
524 |
+
|
525 |
+
for i in range(self.vocal.size()[0]):
|
526 |
+
save_wave(
|
527 |
+
tensor2numpy(self.vocal[i, ...]),
|
528 |
+
str(i) + "vocal" + ".wav",
|
529 |
+
sample_rate=44100,
|
530 |
+
)
|
531 |
+
save_wave(
|
532 |
+
tensor2numpy(self.LR_noisy[i, ...]),
|
533 |
+
str(i) + "LR_noisy" + ".wav",
|
534 |
+
sample_rate=44100,
|
535 |
+
)
|
536 |
+
|
537 |
+
# all_mel_e2e in non-log scale
|
538 |
+
_, self.mel_target = self.pre(self.vocal)
|
539 |
+
self.sp_LR_target, self.mel_LR_target = self.pre(self.augLR)
|
540 |
+
self.sp_LR_target_noisy, self.mel_LR_target_noisy = self.pre(self.LR_noisy)
|
541 |
+
|
542 |
+
if self.valid is None or self.valid.size()[0] != self.mel_target.size()[0]:
|
543 |
+
self.valid = torch.ones(
|
544 |
+
self.mel_target.size()[0], 1, self.mel_target.size()[2], 1
|
545 |
+
)
|
546 |
+
self.valid = self.valid.type_as(self.mel_target)
|
547 |
+
if self.fake is None or self.fake.size()[0] != self.mel_target.size()[0]:
|
548 |
+
self.fake = torch.zeros(
|
549 |
+
self.mel_target.size()[0], 1, self.mel_target.size()[2], 1
|
550 |
+
)
|
551 |
+
self.fake = self.fake.type_as(self.mel_target)
|
552 |
+
|
553 |
+
self.generated = self(self.sp_LR_target_noisy, self.mel_LR_target_noisy)
|
554 |
+
|
555 |
+
denoise_loss = self.l1loss(self.generated["clean"], self.mel_LR_target)
|
556 |
+
targ_loss = self.l1loss(self.generated["mel"], to_log(self.mel_target))
|
557 |
+
|
558 |
+
self.log(
|
559 |
+
"targ-l",
|
560 |
+
targ_loss,
|
561 |
+
on_step=True,
|
562 |
+
on_epoch=False,
|
563 |
+
logger=True,
|
564 |
+
sync_dist=True,
|
565 |
+
prog_bar=True,
|
566 |
+
)
|
567 |
+
self.log(
|
568 |
+
"noise-l",
|
569 |
+
denoise_loss,
|
570 |
+
on_step=True,
|
571 |
+
on_epoch=False,
|
572 |
+
logger=True,
|
573 |
+
sync_dist=True,
|
574 |
+
prog_bar=True,
|
575 |
+
)
|
576 |
+
|
577 |
+
loss = targ_loss + denoise_loss
|
578 |
+
|
579 |
+
if self.train_step >= 18000:
|
580 |
+
g_loss = self.bce_loss(
|
581 |
+
self.discriminator(self.generated["mel"]), self.valid
|
582 |
+
)
|
583 |
+
self.log(
|
584 |
+
"g_l",
|
585 |
+
g_loss,
|
586 |
+
on_step=True,
|
587 |
+
on_epoch=False,
|
588 |
+
logger=True,
|
589 |
+
sync_dist=True,
|
590 |
+
prog_bar=True,
|
591 |
+
)
|
592 |
+
# print("g_loss", g_loss)
|
593 |
+
all_loss = loss + self.g_loss_weight * g_loss
|
594 |
+
self.log(
|
595 |
+
"all_loss",
|
596 |
+
all_loss,
|
597 |
+
on_step=True,
|
598 |
+
on_epoch=True,
|
599 |
+
logger=True,
|
600 |
+
sync_dist=True,
|
601 |
+
)
|
602 |
+
else:
|
603 |
+
all_loss = loss
|
604 |
+
self.train_step += 0.5
|
605 |
+
return {"loss": all_loss}
|
606 |
+
|
607 |
+
elif optimizer_idx == 1:
|
608 |
+
if self.train_step >= 16000:
|
609 |
+
self.generated = self(self.sp_LR_target_noisy, self.mel_LR_target_noisy)
|
610 |
+
self.train_step += 0.5
|
611 |
+
real_loss = self.bce_loss(
|
612 |
+
self.discriminator(to_log(self.mel_target)), self.valid
|
613 |
+
)
|
614 |
+
self.log(
|
615 |
+
"r_l",
|
616 |
+
real_loss,
|
617 |
+
on_step=True,
|
618 |
+
on_epoch=False,
|
619 |
+
logger=True,
|
620 |
+
sync_dist=True,
|
621 |
+
prog_bar=True,
|
622 |
+
)
|
623 |
+
fake_loss = self.bce_loss(
|
624 |
+
self.discriminator(self.generated["mel"].detach()), self.fake
|
625 |
+
)
|
626 |
+
self.log(
|
627 |
+
"d_l",
|
628 |
+
fake_loss,
|
629 |
+
on_step=True,
|
630 |
+
on_epoch=False,
|
631 |
+
logger=True,
|
632 |
+
sync_dist=True,
|
633 |
+
prog_bar=True,
|
634 |
+
)
|
635 |
+
d_loss = self.d_loss_weight * (real_loss + fake_loss) / 2
|
636 |
+
self.log(
|
637 |
+
"discriminator_loss",
|
638 |
+
d_loss,
|
639 |
+
on_step=True,
|
640 |
+
on_epoch=True,
|
641 |
+
logger=True,
|
642 |
+
sync_dist=True,
|
643 |
+
)
|
644 |
+
return {"loss": d_loss}
|
645 |
+
|
646 |
+
def draw_and_save(
|
647 |
+
self, mel: torch.Tensor, path, clip_max=None, clip_min=None, needlog=True
|
648 |
+
):
|
649 |
+
plt.figure(figsize=(15, 5))
|
650 |
+
if clip_min is None:
|
651 |
+
clip_max, clip_min = self.clip(mel)
|
652 |
+
mel = np.transpose(tensor2numpy(mel)[0, 0, ...], (1, 0))
|
653 |
+
# assert np.sum(mel < 0) == 0, str(np.sum(mel < 0)) + str(np.sum(mel < 0))
|
654 |
+
|
655 |
+
if needlog:
|
656 |
+
assert np.sum(mel < 0) == 0, str(np.sum(mel < 0)) + "-" + path
|
657 |
+
mel_log = np.log10(mel + EPS)
|
658 |
+
else:
|
659 |
+
mel_log = mel
|
660 |
+
|
661 |
+
# plt.imshow(mel)
|
662 |
+
librosa.display.specshow(
|
663 |
+
mel_log,
|
664 |
+
sr=44100,
|
665 |
+
x_axis="frames",
|
666 |
+
y_axis="mel",
|
667 |
+
cmap=cm.jet,
|
668 |
+
vmax=clip_max,
|
669 |
+
vmin=clip_min,
|
670 |
+
)
|
671 |
+
plt.colorbar()
|
672 |
+
plt.savefig(path)
|
673 |
+
plt.close()
|
674 |
+
|
675 |
+
def clip(self, *args):
|
676 |
+
val_max, val_min = [], []
|
677 |
+
for each in args:
|
678 |
+
val_max.append(torch.max(each))
|
679 |
+
val_min.append(torch.min(each))
|
680 |
+
return max(val_max), min(val_min)
|
voicefixer/restorer/model_kqq_bn.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from voicefixer.restorer.modules import *
|
2 |
+
|
3 |
+
from voicefixer.tools.pytorch_util import *
|
4 |
+
|
5 |
+
|
6 |
+
class UNetResComplex_100Mb(nn.Module):
|
7 |
+
def __init__(self, channels, nsrc=1):
|
8 |
+
super(UNetResComplex_100Mb, self).__init__()
|
9 |
+
activation = "relu"
|
10 |
+
momentum = 0.01
|
11 |
+
|
12 |
+
self.nsrc = nsrc
|
13 |
+
self.channels = channels
|
14 |
+
self.downsample_ratio = 2**6 # This number equals 2^{#encoder_blcoks}
|
15 |
+
|
16 |
+
self.encoder_block1 = EncoderBlockRes(
|
17 |
+
in_channels=channels * nsrc,
|
18 |
+
out_channels=32,
|
19 |
+
downsample=(2, 2),
|
20 |
+
activation=activation,
|
21 |
+
momentum=momentum,
|
22 |
+
)
|
23 |
+
self.encoder_block2 = EncoderBlockRes(
|
24 |
+
in_channels=32,
|
25 |
+
out_channels=64,
|
26 |
+
downsample=(2, 2),
|
27 |
+
activation=activation,
|
28 |
+
momentum=momentum,
|
29 |
+
)
|
30 |
+
self.encoder_block3 = EncoderBlockRes(
|
31 |
+
in_channels=64,
|
32 |
+
out_channels=128,
|
33 |
+
downsample=(2, 2),
|
34 |
+
activation=activation,
|
35 |
+
momentum=momentum,
|
36 |
+
)
|
37 |
+
self.encoder_block4 = EncoderBlockRes(
|
38 |
+
in_channels=128,
|
39 |
+
out_channels=256,
|
40 |
+
downsample=(2, 2),
|
41 |
+
activation=activation,
|
42 |
+
momentum=momentum,
|
43 |
+
)
|
44 |
+
self.encoder_block5 = EncoderBlockRes(
|
45 |
+
in_channels=256,
|
46 |
+
out_channels=384,
|
47 |
+
downsample=(2, 2),
|
48 |
+
activation=activation,
|
49 |
+
momentum=momentum,
|
50 |
+
)
|
51 |
+
self.encoder_block6 = EncoderBlockRes(
|
52 |
+
in_channels=384,
|
53 |
+
out_channels=384,
|
54 |
+
downsample=(2, 2),
|
55 |
+
activation=activation,
|
56 |
+
momentum=momentum,
|
57 |
+
)
|
58 |
+
self.conv_block7 = ConvBlockRes(
|
59 |
+
in_channels=384,
|
60 |
+
out_channels=384,
|
61 |
+
size=3,
|
62 |
+
activation=activation,
|
63 |
+
momentum=momentum,
|
64 |
+
)
|
65 |
+
self.decoder_block1 = DecoderBlockRes(
|
66 |
+
in_channels=384,
|
67 |
+
out_channels=384,
|
68 |
+
stride=(2, 2),
|
69 |
+
activation=activation,
|
70 |
+
momentum=momentum,
|
71 |
+
)
|
72 |
+
self.decoder_block2 = DecoderBlockRes(
|
73 |
+
in_channels=384,
|
74 |
+
out_channels=384,
|
75 |
+
stride=(2, 2),
|
76 |
+
activation=activation,
|
77 |
+
momentum=momentum,
|
78 |
+
)
|
79 |
+
self.decoder_block3 = DecoderBlockRes(
|
80 |
+
in_channels=384,
|
81 |
+
out_channels=256,
|
82 |
+
stride=(2, 2),
|
83 |
+
activation=activation,
|
84 |
+
momentum=momentum,
|
85 |
+
)
|
86 |
+
self.decoder_block4 = DecoderBlockRes(
|
87 |
+
in_channels=256,
|
88 |
+
out_channels=128,
|
89 |
+
stride=(2, 2),
|
90 |
+
activation=activation,
|
91 |
+
momentum=momentum,
|
92 |
+
)
|
93 |
+
self.decoder_block5 = DecoderBlockRes(
|
94 |
+
in_channels=128,
|
95 |
+
out_channels=64,
|
96 |
+
stride=(2, 2),
|
97 |
+
activation=activation,
|
98 |
+
momentum=momentum,
|
99 |
+
)
|
100 |
+
self.decoder_block6 = DecoderBlockRes(
|
101 |
+
in_channels=64,
|
102 |
+
out_channels=32,
|
103 |
+
stride=(2, 2),
|
104 |
+
activation=activation,
|
105 |
+
momentum=momentum,
|
106 |
+
)
|
107 |
+
|
108 |
+
self.after_conv_block1 = ConvBlockRes(
|
109 |
+
in_channels=32,
|
110 |
+
out_channels=32,
|
111 |
+
size=3,
|
112 |
+
activation=activation,
|
113 |
+
momentum=momentum,
|
114 |
+
)
|
115 |
+
|
116 |
+
self.after_conv2 = nn.Conv2d(
|
117 |
+
in_channels=32,
|
118 |
+
out_channels=1,
|
119 |
+
kernel_size=(1, 1),
|
120 |
+
stride=(1, 1),
|
121 |
+
padding=(0, 0),
|
122 |
+
bias=True,
|
123 |
+
)
|
124 |
+
|
125 |
+
self.init_weights()
|
126 |
+
|
127 |
+
def init_weights(self):
|
128 |
+
init_layer(self.after_conv2)
|
129 |
+
|
130 |
+
def forward(self, sp):
|
131 |
+
"""
|
132 |
+
Args:
|
133 |
+
input: (batch_size, channels_num, segment_samples)
|
134 |
+
|
135 |
+
Outputs:
|
136 |
+
output_dict: {
|
137 |
+
'wav': (batch_size, channels_num, segment_samples),
|
138 |
+
'sp': (batch_size, channels_num, time_steps, freq_bins)}
|
139 |
+
"""
|
140 |
+
|
141 |
+
# Batch normalization
|
142 |
+
x = sp
|
143 |
+
|
144 |
+
# Pad spectrogram to be evenly divided by downsample ratio.
|
145 |
+
origin_len = x.shape[2] # time_steps
|
146 |
+
pad_len = (
|
147 |
+
int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio
|
148 |
+
- origin_len
|
149 |
+
)
|
150 |
+
x = F.pad(x, pad=(0, 0, 0, pad_len))
|
151 |
+
x = x[..., 0 : x.shape[-1] - 1] # (bs, channels, T, F)
|
152 |
+
|
153 |
+
# UNet
|
154 |
+
(x1_pool, x1) = self.encoder_block1(x) # x1_pool: (bs, 32, T / 2, F / 2)
|
155 |
+
(x2_pool, x2) = self.encoder_block2(x1_pool) # x2_pool: (bs, 64, T / 4, F / 4)
|
156 |
+
(x3_pool, x3) = self.encoder_block3(x2_pool) # x3_pool: (bs, 128, T / 8, F / 8)
|
157 |
+
(x4_pool, x4) = self.encoder_block4(
|
158 |
+
x3_pool
|
159 |
+
) # x4_pool: (bs, 256, T / 16, F / 16)
|
160 |
+
(x5_pool, x5) = self.encoder_block5(
|
161 |
+
x4_pool
|
162 |
+
) # x5_pool: (bs, 512, T / 32, F / 32)
|
163 |
+
(x6_pool, x6) = self.encoder_block6(
|
164 |
+
x5_pool
|
165 |
+
) # x6_pool: (bs, 1024, T / 64, F / 64)
|
166 |
+
x_center = self.conv_block7(x6_pool) # (bs, 2048, T / 64, F / 64)
|
167 |
+
x7 = self.decoder_block1(x_center, x6) # (bs, 1024, T / 32, F / 32)
|
168 |
+
x8 = self.decoder_block2(x7, x5) # (bs, 512, T / 16, F / 16)
|
169 |
+
x9 = self.decoder_block3(x8, x4) # (bs, 256, T / 8, F / 8)
|
170 |
+
x10 = self.decoder_block4(x9, x3) # (bs, 128, T / 4, F / 4)
|
171 |
+
x11 = self.decoder_block5(x10, x2) # (bs, 64, T / 2, F / 2)
|
172 |
+
x12 = self.decoder_block6(x11, x1) # (bs, 32, T, F)
|
173 |
+
x = self.after_conv_block1(x12) # (bs, 32, T, F)
|
174 |
+
x = self.after_conv2(x) # (bs, channels, T, F)
|
175 |
+
|
176 |
+
# Recover shape
|
177 |
+
x = F.pad(x, pad=(0, 1))
|
178 |
+
x = x[:, :, 0:origin_len, :]
|
179 |
+
|
180 |
+
output_dict = {"mel": x}
|
181 |
+
return output_dict
|
182 |
+
|
183 |
+
|
184 |
+
if __name__ == "__main__":
|
185 |
+
model = UNetResComplex_100Mb(channels=1)
|
186 |
+
print(model(torch.randn((1, 1, 101, 128)))["mel"].size())
|
voicefixer/restorer/modules.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
|
6 |
+
|
7 |
+
class ConvBlockRes(nn.Module):
|
8 |
+
def __init__(self, in_channels, out_channels, size, activation, momentum):
|
9 |
+
super(ConvBlockRes, self).__init__()
|
10 |
+
|
11 |
+
self.activation = activation
|
12 |
+
if type(size) == type((3, 4)):
|
13 |
+
pad = size[0] // 2
|
14 |
+
size = size[0]
|
15 |
+
else:
|
16 |
+
pad = size // 2
|
17 |
+
size = size
|
18 |
+
|
19 |
+
self.conv1 = nn.Conv2d(
|
20 |
+
in_channels=in_channels,
|
21 |
+
out_channels=out_channels,
|
22 |
+
kernel_size=(size, size),
|
23 |
+
stride=(1, 1),
|
24 |
+
dilation=(1, 1),
|
25 |
+
padding=(pad, pad),
|
26 |
+
bias=False,
|
27 |
+
)
|
28 |
+
|
29 |
+
self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
|
30 |
+
# self.abn1 = InPlaceABN(num_features=in_channels, momentum=momentum, activation='leaky_relu')
|
31 |
+
|
32 |
+
self.conv2 = nn.Conv2d(
|
33 |
+
in_channels=out_channels,
|
34 |
+
out_channels=out_channels,
|
35 |
+
kernel_size=(size, size),
|
36 |
+
stride=(1, 1),
|
37 |
+
dilation=(1, 1),
|
38 |
+
padding=(pad, pad),
|
39 |
+
bias=False,
|
40 |
+
)
|
41 |
+
|
42 |
+
self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)
|
43 |
+
|
44 |
+
# self.abn2 = InPlaceABN(num_features=out_channels, momentum=momentum, activation='leaky_relu')
|
45 |
+
|
46 |
+
if in_channels != out_channels:
|
47 |
+
self.shortcut = nn.Conv2d(
|
48 |
+
in_channels=in_channels,
|
49 |
+
out_channels=out_channels,
|
50 |
+
kernel_size=(1, 1),
|
51 |
+
stride=(1, 1),
|
52 |
+
padding=(0, 0),
|
53 |
+
)
|
54 |
+
self.is_shortcut = True
|
55 |
+
else:
|
56 |
+
self.is_shortcut = False
|
57 |
+
|
58 |
+
self.init_weights()
|
59 |
+
|
60 |
+
def init_weights(self):
|
61 |
+
init_bn(self.bn1)
|
62 |
+
init_layer(self.conv1)
|
63 |
+
init_layer(self.conv2)
|
64 |
+
|
65 |
+
if self.is_shortcut:
|
66 |
+
init_layer(self.shortcut)
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
origin = x
|
70 |
+
x = self.conv1(F.leaky_relu_(self.bn1(x), negative_slope=0.01))
|
71 |
+
x = self.conv2(F.leaky_relu_(self.bn2(x), negative_slope=0.01))
|
72 |
+
|
73 |
+
if self.is_shortcut:
|
74 |
+
return self.shortcut(origin) + x
|
75 |
+
else:
|
76 |
+
return origin + x
|
77 |
+
|
78 |
+
|
79 |
+
class EncoderBlockRes(nn.Module):
|
80 |
+
def __init__(self, in_channels, out_channels, downsample, activation, momentum):
|
81 |
+
super(EncoderBlockRes, self).__init__()
|
82 |
+
size = 3
|
83 |
+
|
84 |
+
self.conv_block1 = ConvBlockRes(
|
85 |
+
in_channels, out_channels, size, activation, momentum
|
86 |
+
)
|
87 |
+
self.conv_block2 = ConvBlockRes(
|
88 |
+
out_channels, out_channels, size, activation, momentum
|
89 |
+
)
|
90 |
+
self.conv_block3 = ConvBlockRes(
|
91 |
+
out_channels, out_channels, size, activation, momentum
|
92 |
+
)
|
93 |
+
self.conv_block4 = ConvBlockRes(
|
94 |
+
out_channels, out_channels, size, activation, momentum
|
95 |
+
)
|
96 |
+
self.downsample = downsample
|
97 |
+
|
98 |
+
def forward(self, x):
|
99 |
+
encoder = self.conv_block1(x)
|
100 |
+
encoder = self.conv_block2(encoder)
|
101 |
+
encoder = self.conv_block3(encoder)
|
102 |
+
encoder = self.conv_block4(encoder)
|
103 |
+
encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
|
104 |
+
return encoder_pool, encoder
|
105 |
+
|
106 |
+
|
107 |
+
class DecoderBlockRes(nn.Module):
|
108 |
+
def __init__(self, in_channels, out_channels, stride, activation, momentum):
|
109 |
+
super(DecoderBlockRes, self).__init__()
|
110 |
+
size = 3
|
111 |
+
self.activation = activation
|
112 |
+
|
113 |
+
self.conv1 = torch.nn.ConvTranspose2d(
|
114 |
+
in_channels=in_channels,
|
115 |
+
out_channels=out_channels,
|
116 |
+
kernel_size=(size, size),
|
117 |
+
stride=stride,
|
118 |
+
padding=(0, 0),
|
119 |
+
output_padding=(0, 0),
|
120 |
+
bias=False,
|
121 |
+
dilation=(1, 1),
|
122 |
+
)
|
123 |
+
|
124 |
+
self.bn1 = nn.BatchNorm2d(in_channels)
|
125 |
+
self.conv_block2 = ConvBlockRes(
|
126 |
+
out_channels * 2, out_channels, size, activation, momentum
|
127 |
+
)
|
128 |
+
self.conv_block3 = ConvBlockRes(
|
129 |
+
out_channels, out_channels, size, activation, momentum
|
130 |
+
)
|
131 |
+
self.conv_block4 = ConvBlockRes(
|
132 |
+
out_channels, out_channels, size, activation, momentum
|
133 |
+
)
|
134 |
+
self.conv_block5 = ConvBlockRes(
|
135 |
+
out_channels, out_channels, size, activation, momentum
|
136 |
+
)
|
137 |
+
|
138 |
+
def init_weights(self):
|
139 |
+
init_layer(self.conv1)
|
140 |
+
|
141 |
+
def prune(self, x, both=False):
|
142 |
+
"""Prune the shape of x after transpose convolution."""
|
143 |
+
if both:
|
144 |
+
x = x[:, :, 0:-1, 0:-1]
|
145 |
+
else:
|
146 |
+
x = x[:, :, 0:-1, :]
|
147 |
+
return x
|
148 |
+
|
149 |
+
def forward(self, input_tensor, concat_tensor, both=False):
|
150 |
+
x = self.conv1(F.relu_(self.bn1(input_tensor)))
|
151 |
+
x = self.prune(x, both=both)
|
152 |
+
x = torch.cat((x, concat_tensor), dim=1)
|
153 |
+
x = self.conv_block2(x)
|
154 |
+
x = self.conv_block3(x)
|
155 |
+
x = self.conv_block4(x)
|
156 |
+
x = self.conv_block5(x)
|
157 |
+
return x
|
158 |
+
|
159 |
+
|
160 |
+
def init_layer(layer):
|
161 |
+
"""Initialize a Linear or Convolutional layer."""
|
162 |
+
nn.init.xavier_uniform_(layer.weight)
|
163 |
+
|
164 |
+
if hasattr(layer, "bias"):
|
165 |
+
if layer.bias is not None:
|
166 |
+
layer.bias.data.fill_(0.0)
|
167 |
+
|
168 |
+
|
169 |
+
def init_bn(bn):
|
170 |
+
"""Initialize a Batchnorm layer."""
|
171 |
+
bn.bias.data.fill_(0.0)
|
172 |
+
bn.weight.data.fill_(1.0)
|
173 |
+
|
174 |
+
|
175 |
+
def init_gru(rnn):
|
176 |
+
"""Initialize a GRU layer."""
|
177 |
+
|
178 |
+
def _concat_init(tensor, init_funcs):
|
179 |
+
(length, fan_out) = tensor.shape
|
180 |
+
fan_in = length // len(init_funcs)
|
181 |
+
|
182 |
+
for (i, init_func) in enumerate(init_funcs):
|
183 |
+
init_func(tensor[i * fan_in : (i + 1) * fan_in, :])
|
184 |
+
|
185 |
+
def _inner_uniform(tensor):
|
186 |
+
fan_in = nn.init._calculate_correct_fan(tensor, "fan_in")
|
187 |
+
nn.init.uniform_(tensor, -math.sqrt(3 / fan_in), math.sqrt(3 / fan_in))
|
188 |
+
|
189 |
+
for i in range(rnn.num_layers):
|
190 |
+
_concat_init(
|
191 |
+
getattr(rnn, "weight_ih_l{}".format(i)),
|
192 |
+
[_inner_uniform, _inner_uniform, _inner_uniform],
|
193 |
+
)
|
194 |
+
torch.nn.init.constant_(getattr(rnn, "bias_ih_l{}".format(i)), 0)
|
195 |
+
|
196 |
+
_concat_init(
|
197 |
+
getattr(rnn, "weight_hh_l{}".format(i)),
|
198 |
+
[_inner_uniform, _inner_uniform, nn.init.orthogonal_],
|
199 |
+
)
|
200 |
+
torch.nn.init.constant_(getattr(rnn, "bias_hh_l{}".format(i)), 0)
|
201 |
+
|
202 |
+
|
203 |
+
from torch.cuda import init
|
204 |
+
|
205 |
+
|
206 |
+
def act(x, activation):
|
207 |
+
if activation == "relu":
|
208 |
+
return F.relu_(x)
|
209 |
+
|
210 |
+
elif activation == "leaky_relu":
|
211 |
+
return F.leaky_relu_(x, negative_slope=0.2)
|
212 |
+
|
213 |
+
elif activation == "swish":
|
214 |
+
return x * torch.sigmoid(x)
|
215 |
+
|
216 |
+
else:
|
217 |
+
raise Exception("Incorrect activation!")
|
voicefixer/tools/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@File : __init__.py.py
|
5 |
+
@Contact : [email protected]
|
6 |
+
@License : (C)Copyright 2020-2100
|
7 |
+
|
8 |
+
@Modify Time @Author @Version @Desciption
|
9 |
+
------------ ------- -------- -----------
|
10 |
+
9/14/21 12:28 AM Haohe Liu 1.0 None
|
11 |
+
"""
|
voicefixer/tools/base.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import os
|
6 |
+
import torch.fft
|
7 |
+
|
8 |
+
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
|
9 |
+
|
10 |
+
|
11 |
+
def get_window(window_size, window_type, square_root_window=True):
|
12 |
+
"""Return the window"""
|
13 |
+
window = {
|
14 |
+
"hamming": torch.hamming_window(window_size),
|
15 |
+
"hanning": torch.hann_window(window_size),
|
16 |
+
}[window_type]
|
17 |
+
if square_root_window:
|
18 |
+
window = torch.sqrt(window)
|
19 |
+
return window
|
20 |
+
|
21 |
+
|
22 |
+
def fft_point(dim):
|
23 |
+
assert dim > 0
|
24 |
+
num = math.log(dim, 2)
|
25 |
+
num_point = 2 ** (math.ceil(num))
|
26 |
+
return num_point
|
27 |
+
|
28 |
+
|
29 |
+
def pre_emphasis(signal, coefficient=0.97):
|
30 |
+
"""Pre-emphasis original signal
|
31 |
+
y(n) = x(n) - a*x(n-1)
|
32 |
+
"""
|
33 |
+
return np.append(signal[0], signal[1:] - coefficient * signal[:-1])
|
34 |
+
|
35 |
+
|
36 |
+
def de_emphasis(signal, coefficient=0.97):
|
37 |
+
"""De-emphasis original signal
|
38 |
+
y(n) = x(n) + a*x(n-1)
|
39 |
+
"""
|
40 |
+
length = signal.shape[0]
|
41 |
+
for i in range(1, length):
|
42 |
+
signal[i] = signal[i] + coefficient * signal[i - 1]
|
43 |
+
return signal
|
44 |
+
|
45 |
+
|
46 |
+
def seperate_magnitude(magnitude, phase):
|
47 |
+
real = torch.cos(phase) * magnitude
|
48 |
+
imagine = torch.sin(phase) * magnitude
|
49 |
+
expand_dim = len(list(real.size()))
|
50 |
+
return torch.stack((real, imagine), expand_dim)
|
51 |
+
|
52 |
+
|
53 |
+
def stft_single(
|
54 |
+
signal,
|
55 |
+
sample_rate=44100,
|
56 |
+
frame_length=46,
|
57 |
+
frame_shift=10,
|
58 |
+
window_type="hanning",
|
59 |
+
device=torch.device("cuda"),
|
60 |
+
square_root_window=True,
|
61 |
+
):
|
62 |
+
"""Compute the Short Time Fourier Transform.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
signal: input speech signal,
|
66 |
+
sample_rate: waveform datas sample frequency (Hz)
|
67 |
+
frame_length: frame length in milliseconds
|
68 |
+
frame_shift: frame shift in milliseconds
|
69 |
+
window_type: type of window
|
70 |
+
square_root_window: square root window
|
71 |
+
Return:
|
72 |
+
fft: (n/2)+1 dim complex STFT restults
|
73 |
+
"""
|
74 |
+
hop_length = int(
|
75 |
+
sample_rate * frame_shift / 1000
|
76 |
+
) # The greater sample_rate, the greater hop_length
|
77 |
+
win_length = int(sample_rate * frame_length / 1000)
|
78 |
+
# num_point = fft_point(win_length)
|
79 |
+
num_point = win_length
|
80 |
+
window = get_window(num_point, window_type, square_root_window)
|
81 |
+
if "cuda" in str(device):
|
82 |
+
window = window.cuda(device)
|
83 |
+
feat = torch.stft(
|
84 |
+
signal,
|
85 |
+
n_fft=num_point,
|
86 |
+
hop_length=hop_length,
|
87 |
+
win_length=window.shape[0],
|
88 |
+
window=window,
|
89 |
+
)
|
90 |
+
real, imag = feat[..., 0], feat[..., 1]
|
91 |
+
return real.permute(0, 2, 1).unsqueeze(1), imag.permute(0, 2, 1).unsqueeze(1)
|
92 |
+
|
93 |
+
|
94 |
+
def istft(
|
95 |
+
real,
|
96 |
+
imag,
|
97 |
+
length,
|
98 |
+
sample_rate=44100,
|
99 |
+
frame_length=46,
|
100 |
+
frame_shift=10,
|
101 |
+
window_type="hanning",
|
102 |
+
preemphasis=0.0,
|
103 |
+
device=torch.device("cuda"),
|
104 |
+
square_root_window=True,
|
105 |
+
):
|
106 |
+
"""Convert frames to signal using overlap-and-add systhesis.
|
107 |
+
Args:
|
108 |
+
spectrum: magnitude spectrum [batchsize,x,y,2]
|
109 |
+
signal: wave signal to supply phase information
|
110 |
+
Return:
|
111 |
+
wav: synthesied output waveform
|
112 |
+
"""
|
113 |
+
real = real.permute(0, 3, 2, 1)
|
114 |
+
imag = imag.permute(0, 3, 2, 1)
|
115 |
+
spectrum = torch.cat([real, imag], dim=-1)
|
116 |
+
|
117 |
+
hop_length = int(sample_rate * frame_shift / 1000)
|
118 |
+
win_length = int(sample_rate * frame_length / 1000)
|
119 |
+
|
120 |
+
# num_point = fft_point(win_length)
|
121 |
+
num_point = win_length
|
122 |
+
if "cuda" in str(device):
|
123 |
+
window = get_window(num_point, window_type, square_root_window).cuda(device)
|
124 |
+
else:
|
125 |
+
window = get_window(num_point, window_type, square_root_window)
|
126 |
+
|
127 |
+
wav = torch_istft(
|
128 |
+
spectrum,
|
129 |
+
num_point,
|
130 |
+
hop_length=hop_length,
|
131 |
+
win_length=window.shape[0],
|
132 |
+
window=window,
|
133 |
+
)
|
134 |
+
return wav[..., :length]
|
135 |
+
|
136 |
+
|
137 |
+
def torch_istft(
|
138 |
+
stft_matrix, # type: Tensor
|
139 |
+
n_fft, # type: int
|
140 |
+
hop_length=None, # type: Optional[int]
|
141 |
+
win_length=None, # type: Optional[int]
|
142 |
+
window=None, # type: Optional[Tensor]
|
143 |
+
center=True, # type: bool
|
144 |
+
pad_mode="reflect", # type: str
|
145 |
+
normalized=False, # type: bool
|
146 |
+
onesided=True, # type: bool
|
147 |
+
length=None, # type: Optional[int]
|
148 |
+
):
|
149 |
+
# type: (...) -> Tensor
|
150 |
+
|
151 |
+
stft_matrix_dim = stft_matrix.dim()
|
152 |
+
assert 3 <= stft_matrix_dim <= 4, "Incorrect stft dimension: %d" % (stft_matrix_dim)
|
153 |
+
|
154 |
+
if stft_matrix_dim == 3:
|
155 |
+
# add a channel dimension
|
156 |
+
stft_matrix = stft_matrix.unsqueeze(0)
|
157 |
+
|
158 |
+
dtype = stft_matrix.dtype
|
159 |
+
device = stft_matrix.device
|
160 |
+
fft_size = stft_matrix.size(1)
|
161 |
+
assert (onesided and n_fft // 2 + 1 == fft_size) or (
|
162 |
+
not onesided and n_fft == fft_size
|
163 |
+
), (
|
164 |
+
"one_sided implies that n_fft // 2 + 1 == fft_size and not one_sided implies n_fft == fft_size. "
|
165 |
+
+ "Given values were onesided: %s, n_fft: %d, fft_size: %d"
|
166 |
+
% ("True" if onesided else False, n_fft, fft_size)
|
167 |
+
)
|
168 |
+
|
169 |
+
# use stft defaults for Optionals
|
170 |
+
if win_length is None:
|
171 |
+
win_length = n_fft
|
172 |
+
|
173 |
+
if hop_length is None:
|
174 |
+
hop_length = int(win_length // 4)
|
175 |
+
|
176 |
+
# There must be overlap
|
177 |
+
assert 0 < hop_length <= win_length
|
178 |
+
assert 0 < win_length <= n_fft
|
179 |
+
|
180 |
+
if window is None:
|
181 |
+
window = torch.ones(win_length, requires_grad=False, device=device, dtype=dtype)
|
182 |
+
|
183 |
+
assert window.dim() == 1 and window.size(0) == win_length
|
184 |
+
|
185 |
+
if win_length != n_fft:
|
186 |
+
# center window with pad left and right zeros
|
187 |
+
left = (n_fft - win_length) // 2
|
188 |
+
window = torch.nn.functional.pad(window, (left, n_fft - win_length - left))
|
189 |
+
assert window.size(0) == n_fft
|
190 |
+
# win_length and n_fft are synonymous from here on
|
191 |
+
|
192 |
+
stft_matrix = stft_matrix.transpose(1, 2) # size (channel, n_frames, fft_size, 2)
|
193 |
+
stft_matrix = torch.irfft(
|
194 |
+
stft_matrix, 1, normalized, onesided, signal_sizes=(n_fft,)
|
195 |
+
) # size (channel, n_frames, n_fft)
|
196 |
+
|
197 |
+
assert stft_matrix.size(2) == n_fft
|
198 |
+
n_frames = stft_matrix.size(1)
|
199 |
+
|
200 |
+
ytmp = stft_matrix * window.view(1, 1, n_fft) # size (channel, n_frames, n_fft)
|
201 |
+
# each column of a channel is a frame which needs to be overlap added at the right place
|
202 |
+
ytmp = ytmp.transpose(1, 2) # size (channel, n_fft, n_frames)
|
203 |
+
|
204 |
+
eye = torch.eye(n_fft, requires_grad=False, device=device, dtype=dtype).unsqueeze(
|
205 |
+
1
|
206 |
+
) # size (n_fft, 1, n_fft)
|
207 |
+
|
208 |
+
# this does overlap add where the frames of ytmp are added such that the i'th frame of
|
209 |
+
# ytmp is added starting at i*hop_length in the output
|
210 |
+
y = torch.nn.functional.conv_transpose1d(
|
211 |
+
ytmp, eye, stride=hop_length, padding=0
|
212 |
+
) # size (channel, 1, expected_signal_len)
|
213 |
+
|
214 |
+
# do the same for the window function
|
215 |
+
window_sq = (
|
216 |
+
window.pow(2).view(n_fft, 1).repeat((1, n_frames)).unsqueeze(0)
|
217 |
+
) # size (1, n_fft, n_frames)
|
218 |
+
window_envelop = torch.nn.functional.conv_transpose1d(
|
219 |
+
window_sq, eye, stride=hop_length, padding=0
|
220 |
+
) # size (1, 1, expected_signal_len)
|
221 |
+
|
222 |
+
expected_signal_len = n_fft + hop_length * (n_frames - 1)
|
223 |
+
assert y.size(2) == expected_signal_len
|
224 |
+
assert window_envelop.size(2) == expected_signal_len
|
225 |
+
|
226 |
+
half_n_fft = n_fft // 2
|
227 |
+
# we need to trim the front padding away if center
|
228 |
+
start = half_n_fft if center else 0
|
229 |
+
end = -half_n_fft if length is None else start + length
|
230 |
+
|
231 |
+
y = y[:, :, start:end]
|
232 |
+
window_envelop = window_envelop[:, :, start:end]
|
233 |
+
|
234 |
+
# check NOLA non-zero overlap condition
|
235 |
+
window_envelop_lowest = window_envelop.abs().min()
|
236 |
+
assert window_envelop_lowest > 1e-11, "window overlap add min: %f" % (
|
237 |
+
window_envelop_lowest
|
238 |
+
)
|
239 |
+
|
240 |
+
y = (y / window_envelop).squeeze(1) # size (channel, expected_signal_len)
|
241 |
+
|
242 |
+
if stft_matrix_dim == 3: # remove the channel dimension
|
243 |
+
y = y.squeeze(0)
|
244 |
+
return y
|
voicefixer/tools/io.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import pickle
|
3 |
+
|
4 |
+
|
5 |
+
def read_list(fname):
|
6 |
+
result = []
|
7 |
+
with open(fname, "r") as f:
|
8 |
+
for each in f.readlines():
|
9 |
+
each = each.strip("\n")
|
10 |
+
result.append(each)
|
11 |
+
return result
|
12 |
+
|
13 |
+
|
14 |
+
def write_list(list, fname):
|
15 |
+
with open(fname, "w") as f:
|
16 |
+
for word in list:
|
17 |
+
f.write(word)
|
18 |
+
f.write("\n")
|
19 |
+
|
20 |
+
|
21 |
+
def write_json(my_dict, fname):
|
22 |
+
# print("Save json file at "+fname)
|
23 |
+
json_str = json.dumps(my_dict)
|
24 |
+
with open(fname, "w") as json_file:
|
25 |
+
json_file.write(json_str)
|
26 |
+
|
27 |
+
|
28 |
+
def load_json(fname):
|
29 |
+
with open(fname, "r") as f:
|
30 |
+
data = json.load(f)
|
31 |
+
return data
|
32 |
+
|
33 |
+
|
34 |
+
def save_pickle(obj, fname):
|
35 |
+
# print("Save pickle at "+fname)
|
36 |
+
with open(fname, "wb") as f:
|
37 |
+
pickle.dump(obj, f)
|
38 |
+
|
39 |
+
|
40 |
+
def load_pickle(fname):
|
41 |
+
# print("Load pickle at "+fname)
|
42 |
+
with open(fname, "rb") as f:
|
43 |
+
res = pickle.load(f)
|
44 |
+
return res
|
voicefixer/tools/mel_scale.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import Tensor
|
3 |
+
from typing import Optional
|
4 |
+
import math
|
5 |
+
|
6 |
+
import warnings
|
7 |
+
|
8 |
+
|
9 |
+
class MelScale(torch.nn.Module):
|
10 |
+
r"""Turn a normal STFT into a mel frequency STFT, using a conversion
|
11 |
+
matrix. This uses triangular filter banks.
|
12 |
+
|
13 |
+
User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)).
|
14 |
+
|
15 |
+
Args:
|
16 |
+
n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
|
17 |
+
sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
|
18 |
+
f_min (float, optional): Minimum frequency. (Default: ``0.``)
|
19 |
+
f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
|
20 |
+
n_stft (int, optional): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`. (Default: ``201``)
|
21 |
+
norm (str or None, optional): If 'slaney', divide the triangular mel weights by the width of the mel band
|
22 |
+
(area normalization). (Default: ``None``)
|
23 |
+
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
|
24 |
+
|
25 |
+
See also:
|
26 |
+
:py:func:`torchaudio.functional.melscale_fbanks` - The function used to
|
27 |
+
generate the filter banks.
|
28 |
+
"""
|
29 |
+
__constants__ = ["n_mels", "sample_rate", "f_min", "f_max"]
|
30 |
+
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
n_mels: int = 128,
|
34 |
+
sample_rate: int = 16000,
|
35 |
+
f_min: float = 0.0,
|
36 |
+
f_max: Optional[float] = None,
|
37 |
+
n_stft: int = 201,
|
38 |
+
norm: Optional[str] = None,
|
39 |
+
mel_scale: str = "htk",
|
40 |
+
) -> None:
|
41 |
+
super(MelScale, self).__init__()
|
42 |
+
self.n_mels = n_mels
|
43 |
+
self.sample_rate = sample_rate
|
44 |
+
self.f_max = f_max if f_max is not None else float(sample_rate // 2)
|
45 |
+
self.f_min = f_min
|
46 |
+
self.norm = norm
|
47 |
+
self.mel_scale = mel_scale
|
48 |
+
|
49 |
+
assert f_min <= self.f_max, "Require f_min: {} < f_max: {}".format(
|
50 |
+
f_min, self.f_max
|
51 |
+
)
|
52 |
+
fb = melscale_fbanks(
|
53 |
+
n_stft,
|
54 |
+
self.f_min,
|
55 |
+
self.f_max,
|
56 |
+
self.n_mels,
|
57 |
+
self.sample_rate,
|
58 |
+
self.norm,
|
59 |
+
self.mel_scale,
|
60 |
+
)
|
61 |
+
self.register_buffer("fb", fb)
|
62 |
+
|
63 |
+
def forward(self, specgram: Tensor) -> Tensor:
|
64 |
+
r"""
|
65 |
+
Args:
|
66 |
+
specgram (Tensor): A spectrogram STFT of dimension (..., freq, time).
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
|
70 |
+
"""
|
71 |
+
|
72 |
+
# (..., time, freq) dot (freq, n_mels) -> (..., n_mels, time)
|
73 |
+
mel_specgram = torch.matmul(specgram.transpose(-1, -2), self.fb).transpose(
|
74 |
+
-1, -2
|
75 |
+
)
|
76 |
+
|
77 |
+
return mel_specgram
|
78 |
+
|
79 |
+
|
80 |
+
def _hz_to_mel(freq: float, mel_scale: str = "htk") -> float:
|
81 |
+
r"""Convert Hz to Mels.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
freqs (float): Frequencies in Hz
|
85 |
+
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
mels (float): Frequency in Mels
|
89 |
+
"""
|
90 |
+
|
91 |
+
if mel_scale not in ["slaney", "htk"]:
|
92 |
+
raise ValueError('mel_scale should be one of "htk" or "slaney".')
|
93 |
+
|
94 |
+
if mel_scale == "htk":
|
95 |
+
return 2595.0 * math.log10(1.0 + (freq / 700.0))
|
96 |
+
|
97 |
+
# Fill in the linear part
|
98 |
+
f_min = 0.0
|
99 |
+
f_sp = 200.0 / 3
|
100 |
+
|
101 |
+
mels = (freq - f_min) / f_sp
|
102 |
+
|
103 |
+
# Fill in the log-scale part
|
104 |
+
min_log_hz = 1000.0
|
105 |
+
min_log_mel = (min_log_hz - f_min) / f_sp
|
106 |
+
logstep = math.log(6.4) / 27.0
|
107 |
+
|
108 |
+
if freq >= min_log_hz:
|
109 |
+
mels = min_log_mel + math.log(freq / min_log_hz) / logstep
|
110 |
+
|
111 |
+
return mels
|
112 |
+
|
113 |
+
|
114 |
+
def _mel_to_hz(mels: Tensor, mel_scale: str = "htk") -> Tensor:
|
115 |
+
"""Convert mel bin numbers to frequencies.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
mels (Tensor): Mel frequencies
|
119 |
+
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
freqs (Tensor): Mels converted in Hz
|
123 |
+
"""
|
124 |
+
|
125 |
+
if mel_scale not in ["slaney", "htk"]:
|
126 |
+
raise ValueError('mel_scale should be one of "htk" or "slaney".')
|
127 |
+
|
128 |
+
if mel_scale == "htk":
|
129 |
+
return 700.0 * (10.0 ** (mels / 2595.0) - 1.0)
|
130 |
+
|
131 |
+
# Fill in the linear scale
|
132 |
+
f_min = 0.0
|
133 |
+
f_sp = 200.0 / 3
|
134 |
+
freqs = f_min + f_sp * mels
|
135 |
+
|
136 |
+
# And now the nonlinear scale
|
137 |
+
min_log_hz = 1000.0
|
138 |
+
min_log_mel = (min_log_hz - f_min) / f_sp
|
139 |
+
logstep = math.log(6.4) / 27.0
|
140 |
+
|
141 |
+
log_t = mels >= min_log_mel
|
142 |
+
freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))
|
143 |
+
|
144 |
+
return freqs
|
145 |
+
|
146 |
+
|
147 |
+
def _create_triangular_filterbank(
|
148 |
+
all_freqs: Tensor,
|
149 |
+
f_pts: Tensor,
|
150 |
+
) -> Tensor:
|
151 |
+
"""Create a triangular filter bank.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
all_freqs (Tensor): STFT freq points of size (`n_freqs`).
|
155 |
+
f_pts (Tensor): Filter mid points of size (`n_filter`).
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
fb (Tensor): The filter bank of size (`n_freqs`, `n_filter`).
|
159 |
+
"""
|
160 |
+
# Adopted from Librosa
|
161 |
+
# calculate the difference between each filter mid point and each stft freq point in hertz
|
162 |
+
f_diff = f_pts[1:] - f_pts[:-1] # (n_filter + 1)
|
163 |
+
slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # (n_freqs, n_filter + 2)
|
164 |
+
# create overlapping triangles
|
165 |
+
zero = torch.zeros(1)
|
166 |
+
down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_filter)
|
167 |
+
up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_filter)
|
168 |
+
fb = torch.max(zero, torch.min(down_slopes, up_slopes))
|
169 |
+
|
170 |
+
return fb
|
171 |
+
|
172 |
+
|
173 |
+
def melscale_fbanks(
|
174 |
+
n_freqs: int,
|
175 |
+
f_min: float,
|
176 |
+
f_max: float,
|
177 |
+
n_mels: int,
|
178 |
+
sample_rate: int,
|
179 |
+
norm: Optional[str] = None,
|
180 |
+
mel_scale: str = "htk",
|
181 |
+
) -> Tensor:
|
182 |
+
r"""Create a frequency bin conversion matrix.
|
183 |
+
|
184 |
+
Note:
|
185 |
+
For the sake of the numerical compatibility with librosa, not all the coefficients
|
186 |
+
in the resulting filter bank has magnitude of 1.
|
187 |
+
|
188 |
+
.. image:: https://download.pytorch.org/torchaudio/doc-assets/mel_fbanks.png
|
189 |
+
:alt: Visualization of generated filter bank
|
190 |
+
|
191 |
+
Args:
|
192 |
+
n_freqs (int): Number of frequencies to highlight/apply
|
193 |
+
f_min (float): Minimum frequency (Hz)
|
194 |
+
f_max (float): Maximum frequency (Hz)
|
195 |
+
n_mels (int): Number of mel filterbanks
|
196 |
+
sample_rate (int): Sample rate of the audio waveform
|
197 |
+
norm (str or None, optional): If 'slaney', divide the triangular mel weights by the width of the mel band
|
198 |
+
(area normalization). (Default: ``None``)
|
199 |
+
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
|
200 |
+
|
201 |
+
Returns:
|
202 |
+
Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``)
|
203 |
+
meaning number of frequencies to highlight/apply to x the number of filterbanks.
|
204 |
+
Each column is a filterbank so that assuming there is a matrix A of
|
205 |
+
size (..., ``n_freqs``), the applied result would be
|
206 |
+
``A * melscale_fbanks(A.size(-1), ...)``.
|
207 |
+
|
208 |
+
"""
|
209 |
+
|
210 |
+
if norm is not None and norm != "slaney":
|
211 |
+
raise ValueError("norm must be one of None or 'slaney'")
|
212 |
+
|
213 |
+
# freq bins
|
214 |
+
all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
|
215 |
+
|
216 |
+
# calculate mel freq bins
|
217 |
+
m_min = _hz_to_mel(f_min, mel_scale=mel_scale)
|
218 |
+
m_max = _hz_to_mel(f_max, mel_scale=mel_scale)
|
219 |
+
|
220 |
+
m_pts = torch.linspace(m_min, m_max, n_mels + 2)
|
221 |
+
f_pts = _mel_to_hz(m_pts, mel_scale=mel_scale)
|
222 |
+
|
223 |
+
# create filterbank
|
224 |
+
fb = _create_triangular_filterbank(all_freqs, f_pts)
|
225 |
+
|
226 |
+
if norm is not None and norm == "slaney":
|
227 |
+
# Slaney-style mel is scaled to be approx constant energy per channel
|
228 |
+
enorm = 2.0 / (f_pts[2 : n_mels + 2] - f_pts[:n_mels])
|
229 |
+
fb *= enorm.unsqueeze(0)
|
230 |
+
|
231 |
+
if (fb.max(dim=0).values == 0.0).any():
|
232 |
+
warnings.warn(
|
233 |
+
"At least one mel filterbank has all zero values. "
|
234 |
+
f"The value for `n_mels` ({n_mels}) may be set too high. "
|
235 |
+
f"Or, the value for `n_freqs` ({n_freqs}) may be set too low."
|
236 |
+
)
|
237 |
+
|
238 |
+
return fb
|
voicefixer/tools/modules/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@File : __init__.py.py
|
5 |
+
@Contact : [email protected]
|
6 |
+
@License : (C)Copyright 2020-2100
|
7 |
+
|
8 |
+
@Modify Time @Author @Version @Desciption
|
9 |
+
------------ ------- -------- -----------
|
10 |
+
9/14/21 12:29 AM Haohe Liu 1.0 None
|
11 |
+
"""
|
voicefixer/tools/modules/fDomainHelper.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchlibrosa.stft import STFT, ISTFT, magphase
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import numpy as np
|
5 |
+
from voicefixer.tools.modules.pqmf import PQMF
|
6 |
+
|
7 |
+
class FDomainHelper(nn.Module):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
window_size=2048,
|
11 |
+
hop_size=441,
|
12 |
+
center=True,
|
13 |
+
pad_mode="reflect",
|
14 |
+
window="hann",
|
15 |
+
freeze_parameters=True,
|
16 |
+
subband=None,
|
17 |
+
root="/Users/admin/Documents/projects/",
|
18 |
+
):
|
19 |
+
super(FDomainHelper, self).__init__()
|
20 |
+
self.subband = subband
|
21 |
+
# assert torchlibrosa.__version__ == "0.0.7", "Error: Found torchlibrosa version %s. Please install 0.0.7 version of torchlibrosa by: pip install torchlibrosa==0.0.7." % torchlibrosa.__version__
|
22 |
+
if self.subband is None:
|
23 |
+
self.stft = STFT(
|
24 |
+
n_fft=window_size,
|
25 |
+
hop_length=hop_size,
|
26 |
+
win_length=window_size,
|
27 |
+
window=window,
|
28 |
+
center=center,
|
29 |
+
pad_mode=pad_mode,
|
30 |
+
freeze_parameters=freeze_parameters,
|
31 |
+
)
|
32 |
+
|
33 |
+
self.istft = ISTFT(
|
34 |
+
n_fft=window_size,
|
35 |
+
hop_length=hop_size,
|
36 |
+
win_length=window_size,
|
37 |
+
window=window,
|
38 |
+
center=center,
|
39 |
+
pad_mode=pad_mode,
|
40 |
+
freeze_parameters=freeze_parameters,
|
41 |
+
)
|
42 |
+
else:
|
43 |
+
self.stft = STFT(
|
44 |
+
n_fft=window_size // self.subband,
|
45 |
+
hop_length=hop_size // self.subband,
|
46 |
+
win_length=window_size // self.subband,
|
47 |
+
window=window,
|
48 |
+
center=center,
|
49 |
+
pad_mode=pad_mode,
|
50 |
+
freeze_parameters=freeze_parameters,
|
51 |
+
)
|
52 |
+
|
53 |
+
self.istft = ISTFT(
|
54 |
+
n_fft=window_size // self.subband,
|
55 |
+
hop_length=hop_size // self.subband,
|
56 |
+
win_length=window_size // self.subband,
|
57 |
+
window=window,
|
58 |
+
center=center,
|
59 |
+
pad_mode=pad_mode,
|
60 |
+
freeze_parameters=freeze_parameters,
|
61 |
+
)
|
62 |
+
|
63 |
+
if subband is not None and root is not None:
|
64 |
+
self.qmf = PQMF(subband, 64, root)
|
65 |
+
|
66 |
+
def complex_spectrogram(self, input, eps=0.0):
|
67 |
+
# [batchsize, samples]
|
68 |
+
# return [batchsize, 2, t-steps, f-bins]
|
69 |
+
real, imag = self.stft(input)
|
70 |
+
return torch.cat([real, imag], dim=1)
|
71 |
+
|
72 |
+
def reverse_complex_spectrogram(self, input, eps=0.0, length=None):
|
73 |
+
# [batchsize, 2[real,imag], t-steps, f-bins]
|
74 |
+
wav = self.istft(input[:, 0:1, ...], input[:, 1:2, ...], length=length)
|
75 |
+
return wav
|
76 |
+
|
77 |
+
def spectrogram(self, input, eps=0.0):
|
78 |
+
(real, imag) = self.stft(input.float())
|
79 |
+
return torch.clamp(real**2 + imag**2, eps, np.inf) ** 0.5
|
80 |
+
|
81 |
+
def spectrogram_phase(self, input, eps=0.0):
|
82 |
+
(real, imag) = self.stft(input.float())
|
83 |
+
mag = torch.clamp(real**2 + imag**2, eps, np.inf) ** 0.5
|
84 |
+
cos = real / mag
|
85 |
+
sin = imag / mag
|
86 |
+
return mag, cos, sin
|
87 |
+
|
88 |
+
def wav_to_spectrogram_phase(self, input, eps=1e-8):
|
89 |
+
"""Waveform to spectrogram.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
input: (batch_size, channels_num, segment_samples)
|
93 |
+
|
94 |
+
Outputs:
|
95 |
+
output: (batch_size, channels_num, time_steps, freq_bins)
|
96 |
+
"""
|
97 |
+
sp_list = []
|
98 |
+
cos_list = []
|
99 |
+
sin_list = []
|
100 |
+
channels_num = input.shape[1]
|
101 |
+
for channel in range(channels_num):
|
102 |
+
mag, cos, sin = self.spectrogram_phase(input[:, channel, :], eps=eps)
|
103 |
+
sp_list.append(mag)
|
104 |
+
cos_list.append(cos)
|
105 |
+
sin_list.append(sin)
|
106 |
+
|
107 |
+
sps = torch.cat(sp_list, dim=1)
|
108 |
+
coss = torch.cat(cos_list, dim=1)
|
109 |
+
sins = torch.cat(sin_list, dim=1)
|
110 |
+
return sps, coss, sins
|
111 |
+
|
112 |
+
def spectrogram_phase_to_wav(self, sps, coss, sins, length):
|
113 |
+
channels_num = sps.size()[1]
|
114 |
+
res = []
|
115 |
+
for i in range(channels_num):
|
116 |
+
res.append(
|
117 |
+
self.istft(
|
118 |
+
sps[:, i : i + 1, ...] * coss[:, i : i + 1, ...],
|
119 |
+
sps[:, i : i + 1, ...] * sins[:, i : i + 1, ...],
|
120 |
+
length,
|
121 |
+
)
|
122 |
+
)
|
123 |
+
res[-1] = res[-1].unsqueeze(1)
|
124 |
+
return torch.cat(res, dim=1)
|
125 |
+
|
126 |
+
def wav_to_spectrogram(self, input, eps=1e-8):
|
127 |
+
"""Waveform to spectrogram.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
input: (batch_size,channels_num, segment_samples)
|
131 |
+
|
132 |
+
Outputs:
|
133 |
+
output: (batch_size, channels_num, time_steps, freq_bins)
|
134 |
+
"""
|
135 |
+
sp_list = []
|
136 |
+
channels_num = input.shape[1]
|
137 |
+
for channel in range(channels_num):
|
138 |
+
sp_list.append(self.spectrogram(input[:, channel, :], eps=eps))
|
139 |
+
output = torch.cat(sp_list, dim=1)
|
140 |
+
return output
|
141 |
+
|
142 |
+
def spectrogram_to_wav(self, input, spectrogram, length=None):
|
143 |
+
"""Spectrogram to waveform.
|
144 |
+
Args:
|
145 |
+
input: (batch_size, segment_samples, channels_num)
|
146 |
+
spectrogram: (batch_size, channels_num, time_steps, freq_bins)
|
147 |
+
|
148 |
+
Outputs:
|
149 |
+
output: (batch_size, segment_samples, channels_num)
|
150 |
+
"""
|
151 |
+
channels_num = input.shape[1]
|
152 |
+
wav_list = []
|
153 |
+
for channel in range(channels_num):
|
154 |
+
(real, imag) = self.stft(input[:, channel, :])
|
155 |
+
(_, cos, sin) = magphase(real, imag)
|
156 |
+
wav_list.append(
|
157 |
+
self.istft(
|
158 |
+
spectrogram[:, channel : channel + 1, :, :] * cos,
|
159 |
+
spectrogram[:, channel : channel + 1, :, :] * sin,
|
160 |
+
length,
|
161 |
+
)
|
162 |
+
)
|
163 |
+
|
164 |
+
output = torch.stack(wav_list, dim=1)
|
165 |
+
return output
|
166 |
+
|
167 |
+
# todo the following code is not bug free!
|
168 |
+
def wav_to_complex_spectrogram(self, input, eps=0.0):
|
169 |
+
# [batchsize , channels, samples]
|
170 |
+
# [batchsize, 2[real,imag]*channels, t-steps, f-bins]
|
171 |
+
res = []
|
172 |
+
channels_num = input.shape[1]
|
173 |
+
for channel in range(channels_num):
|
174 |
+
res.append(self.complex_spectrogram(input[:, channel, :], eps=eps))
|
175 |
+
return torch.cat(res, dim=1)
|
176 |
+
|
177 |
+
def complex_spectrogram_to_wav(self, input, eps=0.0, length=None):
|
178 |
+
# [batchsize, 2[real,imag]*channels, t-steps, f-bins]
|
179 |
+
# return [batchsize, channels, samples]
|
180 |
+
channels = input.size()[1] // 2
|
181 |
+
wavs = []
|
182 |
+
for i in range(channels):
|
183 |
+
wavs.append(
|
184 |
+
self.reverse_complex_spectrogram(
|
185 |
+
input[:, 2 * i : 2 * i + 2, ...], eps=eps, length=length
|
186 |
+
)
|
187 |
+
)
|
188 |
+
wavs[-1] = wavs[-1].unsqueeze(1)
|
189 |
+
return torch.cat(wavs, dim=1)
|
190 |
+
|
191 |
+
def wav_to_complex_subband_spectrogram(self, input, eps=0.0):
|
192 |
+
# [batchsize, channels, samples]
|
193 |
+
# [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins]
|
194 |
+
subwav = self.qmf.analysis(input) # [batchsize, subband*channels, samples]
|
195 |
+
subspec = self.wav_to_complex_spectrogram(subwav)
|
196 |
+
return subspec
|
197 |
+
|
198 |
+
def complex_subband_spectrogram_to_wav(self, input, eps=0.0):
|
199 |
+
# [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins]
|
200 |
+
# [batchsize, channels, samples]
|
201 |
+
subwav = self.complex_spectrogram_to_wav(input)
|
202 |
+
data = self.qmf.synthesis(subwav)
|
203 |
+
return data
|
204 |
+
|
205 |
+
def wav_to_mag_phase_subband_spectrogram(self, input, eps=1e-8):
|
206 |
+
"""
|
207 |
+
:param input:
|
208 |
+
:param eps:
|
209 |
+
:return:
|
210 |
+
loss = torch.nn.L1Loss()
|
211 |
+
models = FDomainHelper(subband=4)
|
212 |
+
data = torch.randn((3,1, 44100*3))
|
213 |
+
|
214 |
+
sps, coss, sins = models.wav_to_mag_phase_subband_spectrogram(data)
|
215 |
+
wav = models.mag_phase_subband_spectrogram_to_wav(sps,coss,sins,44100*3//4)
|
216 |
+
|
217 |
+
print(loss(data,wav))
|
218 |
+
print(torch.max(torch.abs(data-wav)))
|
219 |
+
|
220 |
+
"""
|
221 |
+
# [batchsize, channels, samples]
|
222 |
+
# [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins]
|
223 |
+
subwav = self.qmf.analysis(input) # [batchsize, subband*channels, samples]
|
224 |
+
sps, coss, sins = self.wav_to_spectrogram_phase(subwav, eps=eps)
|
225 |
+
return sps, coss, sins
|
226 |
+
|
227 |
+
def mag_phase_subband_spectrogram_to_wav(self, sps, coss, sins, length, eps=0.0):
|
228 |
+
# [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins]
|
229 |
+
# [batchsize, channels, samples]
|
230 |
+
subwav = self.spectrogram_phase_to_wav(
|
231 |
+
sps, coss, sins, length + self.qmf.pad_samples // self.qmf.N
|
232 |
+
)
|
233 |
+
data = self.qmf.synthesis(subwav)
|
234 |
+
return data
|
voicefixer/tools/modules/filters/f_2_64.mat
ADDED
File without changes
|
voicefixer/tools/modules/filters/f_4_64.mat
ADDED
File without changes
|
voicefixer/tools/modules/filters/f_8_64.mat
ADDED
File without changes
|
voicefixer/tools/modules/filters/h_2_64.mat
ADDED
File without changes
|
voicefixer/tools/modules/filters/h_4_64.mat
ADDED
File without changes
|
voicefixer/tools/modules/filters/h_8_64.mat
ADDED
File without changes
|
voicefixer/tools/modules/pqmf.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@File : subband_util.py
|
3 |
+
@Contact : [email protected]
|
4 |
+
@License : (C)Copyright 2020-2021
|
5 |
+
@Modify Time @Author @Version @Desciption
|
6 |
+
------------ ------- -------- -----------
|
7 |
+
2020/4/3 4:54 PM Haohe Liu 1.0 None
|
8 |
+
"""
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import torch.nn as nn
|
13 |
+
import numpy as np
|
14 |
+
import os.path as op
|
15 |
+
from scipy.io import loadmat
|
16 |
+
|
17 |
+
|
18 |
+
def load_mat2numpy(fname=""):
|
19 |
+
if len(fname) == 0:
|
20 |
+
return None
|
21 |
+
else:
|
22 |
+
return loadmat(fname)
|
23 |
+
|
24 |
+
|
25 |
+
class PQMF(nn.Module):
|
26 |
+
def __init__(self, N, M, project_root):
|
27 |
+
super().__init__()
|
28 |
+
self.N = N # nsubband
|
29 |
+
self.M = M # nfilter
|
30 |
+
try:
|
31 |
+
assert (N, M) in [(8, 64), (4, 64), (2, 64)]
|
32 |
+
except:
|
33 |
+
print("Warning:", N, "subbandand ", M, " filter is not supported")
|
34 |
+
self.pad_samples = 64
|
35 |
+
self.name = str(N) + "_" + str(M) + ".mat"
|
36 |
+
self.ana_conv_filter = nn.Conv1d(
|
37 |
+
1, out_channels=N, kernel_size=M, stride=N, bias=False
|
38 |
+
)
|
39 |
+
data = load_mat2numpy(
|
40 |
+
op.join(
|
41 |
+
project_root,
|
42 |
+
"arnold_workspace/restorer/tools/pytorch/modules/filters/f_"
|
43 |
+
+ self.name,
|
44 |
+
)
|
45 |
+
)
|
46 |
+
data = data["f"].astype(np.float32) / N
|
47 |
+
data = np.flipud(data.T).T
|
48 |
+
data = np.reshape(data, (N, 1, M)).copy()
|
49 |
+
dict_new = self.ana_conv_filter.state_dict().copy()
|
50 |
+
dict_new["weight"] = torch.from_numpy(data)
|
51 |
+
self.ana_pad = nn.ConstantPad1d((M - N, 0), 0)
|
52 |
+
self.ana_conv_filter.load_state_dict(dict_new)
|
53 |
+
|
54 |
+
self.syn_pad = nn.ConstantPad1d((0, M // N - 1), 0)
|
55 |
+
self.syn_conv_filter = nn.Conv1d(
|
56 |
+
N, out_channels=N, kernel_size=M // N, stride=1, bias=False
|
57 |
+
)
|
58 |
+
gk = load_mat2numpy(
|
59 |
+
op.join(
|
60 |
+
project_root,
|
61 |
+
"arnold_workspace/restorer/tools/pytorch/modules/filters/h_"
|
62 |
+
+ self.name,
|
63 |
+
)
|
64 |
+
)
|
65 |
+
gk = gk["h"].astype(np.float32)
|
66 |
+
gk = np.transpose(np.reshape(gk, (N, M // N, N)), (1, 0, 2)) * N
|
67 |
+
gk = np.transpose(gk[::-1, :, :], (2, 1, 0)).copy()
|
68 |
+
dict_new = self.syn_conv_filter.state_dict().copy()
|
69 |
+
dict_new["weight"] = torch.from_numpy(gk)
|
70 |
+
self.syn_conv_filter.load_state_dict(dict_new)
|
71 |
+
|
72 |
+
for param in self.parameters():
|
73 |
+
param.requires_grad = False
|
74 |
+
|
75 |
+
def __analysis_channel(self, inputs):
|
76 |
+
return self.ana_conv_filter(self.ana_pad(inputs))
|
77 |
+
|
78 |
+
def __systhesis_channel(self, inputs):
|
79 |
+
ret = self.syn_conv_filter(self.syn_pad(inputs)).permute(0, 2, 1)
|
80 |
+
return torch.reshape(ret, (ret.shape[0], 1, -1))
|
81 |
+
|
82 |
+
def analysis(self, inputs):
|
83 |
+
"""
|
84 |
+
:param inputs: [batchsize,channel,raw_wav],value:[0,1]
|
85 |
+
:return:
|
86 |
+
"""
|
87 |
+
inputs = F.pad(inputs, ((0, self.pad_samples)))
|
88 |
+
ret = None
|
89 |
+
for i in range(inputs.size()[1]): # channels
|
90 |
+
if ret is None:
|
91 |
+
ret = self.__analysis_channel(inputs[:, i : i + 1, :])
|
92 |
+
else:
|
93 |
+
ret = torch.cat(
|
94 |
+
(ret, self.__analysis_channel(inputs[:, i : i + 1, :])), dim=1
|
95 |
+
)
|
96 |
+
return ret
|
97 |
+
|
98 |
+
def synthesis(self, data):
|
99 |
+
"""
|
100 |
+
:param data: [batchsize,self.N*K,raw_wav_sub],value:[0,1]
|
101 |
+
:return:
|
102 |
+
"""
|
103 |
+
ret = None
|
104 |
+
# data = F.pad(data,((0,self.pad_samples//self.N)))
|
105 |
+
for i in range(data.size()[1]): # channels
|
106 |
+
if i % self.N == 0:
|
107 |
+
if ret is None:
|
108 |
+
ret = self.__systhesis_channel(data[:, i : i + self.N, :])
|
109 |
+
else:
|
110 |
+
new = self.__systhesis_channel(data[:, i : i + self.N, :])
|
111 |
+
ret = torch.cat((ret, new), dim=1)
|
112 |
+
ret = ret[..., : -self.pad_samples]
|
113 |
+
return ret
|
114 |
+
|
115 |
+
def forward(self, inputs):
|
116 |
+
return self.ana_conv_filter(self.ana_pad(inputs))
|
voicefixer/tools/path.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
|
4 |
+
def find_and_build(root, path):
|
5 |
+
path = os.path.join(root, path)
|
6 |
+
if not os.path.exists(path):
|
7 |
+
os.makedirs(path, exist_ok=True)
|
8 |
+
return path
|
9 |
+
|
10 |
+
|
11 |
+
def root_path(repo_name="voicefixer"):
|
12 |
+
path = os.path.abspath(__file__)
|
13 |
+
return path.split(repo_name)[0]
|
voicefixer/tools/pytorch_util.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
def check_cuda_availability(cuda):
|
7 |
+
if cuda and not torch.cuda.is_available():
|
8 |
+
raise RuntimeError("Error: You set cuda=True but no cuda device found.")
|
9 |
+
|
10 |
+
|
11 |
+
def try_tensor_cuda(tensor, cuda):
|
12 |
+
if cuda and torch.cuda.is_available():
|
13 |
+
return tensor.cuda()
|
14 |
+
else:
|
15 |
+
return tensor.cpu()
|
16 |
+
|
17 |
+
|
18 |
+
def to_log(input):
|
19 |
+
assert torch.sum(input < 0) == 0, (
|
20 |
+
str(input) + " has negative values counts " + str(torch.sum(input < 0))
|
21 |
+
)
|
22 |
+
return torch.log10(torch.clip(input, min=1e-8))
|
23 |
+
|
24 |
+
|
25 |
+
def from_log(input):
|
26 |
+
input = torch.clip(input, min=-np.inf, max=5)
|
27 |
+
return 10**input
|
28 |
+
|
29 |
+
|
30 |
+
def move_data_to_device(x, device):
|
31 |
+
if "float" in str(x.dtype):
|
32 |
+
x = torch.Tensor(x)
|
33 |
+
elif "int" in str(x.dtype):
|
34 |
+
x = torch.LongTensor(x)
|
35 |
+
else:
|
36 |
+
return x
|
37 |
+
return x.to(device)
|
38 |
+
|
39 |
+
|
40 |
+
def tensor2numpy(tensor):
|
41 |
+
if "cuda" in str(tensor.device):
|
42 |
+
return tensor.detach().cpu().numpy()
|
43 |
+
else:
|
44 |
+
return tensor.detach().numpy()
|
45 |
+
|
46 |
+
|
47 |
+
def count_parameters(model):
|
48 |
+
for p in model.parameters():
|
49 |
+
if p.requires_grad:
|
50 |
+
print(p.shape)
|
51 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
52 |
+
|
53 |
+
|
54 |
+
def count_flops(model, audio_length):
|
55 |
+
multiply_adds = False
|
56 |
+
list_conv2d = []
|
57 |
+
|
58 |
+
def conv2d_hook(self, input, output):
|
59 |
+
batch_size, input_channels, input_height, input_width = input[0].size()
|
60 |
+
output_channels, output_height, output_width = output[0].size()
|
61 |
+
|
62 |
+
kernel_ops = (
|
63 |
+
self.kernel_size[0]
|
64 |
+
* self.kernel_size[1]
|
65 |
+
* (self.in_channels / self.groups)
|
66 |
+
* (2 if multiply_adds else 1)
|
67 |
+
)
|
68 |
+
bias_ops = 1 if self.bias is not None else 0
|
69 |
+
|
70 |
+
params = output_channels * (kernel_ops + bias_ops)
|
71 |
+
flops = batch_size * params * output_height * output_width
|
72 |
+
|
73 |
+
list_conv2d.append(flops)
|
74 |
+
|
75 |
+
list_conv1d = []
|
76 |
+
|
77 |
+
def conv1d_hook(self, input, output):
|
78 |
+
batch_size, input_channels, input_length = input[0].size()
|
79 |
+
output_channels, output_length = output[0].size()
|
80 |
+
|
81 |
+
kernel_ops = (
|
82 |
+
self.kernel_size[0]
|
83 |
+
* (self.in_channels / self.groups)
|
84 |
+
* (2 if multiply_adds else 1)
|
85 |
+
)
|
86 |
+
bias_ops = 1 if self.bias is not None else 0
|
87 |
+
|
88 |
+
params = output_channels * (kernel_ops + bias_ops)
|
89 |
+
flops = batch_size * params * output_length
|
90 |
+
|
91 |
+
list_conv1d.append(flops)
|
92 |
+
|
93 |
+
list_linear = []
|
94 |
+
|
95 |
+
def linear_hook(self, input, output):
|
96 |
+
batch_size = input[0].size(0) if input[0].dim() == 2 else 1
|
97 |
+
|
98 |
+
weight_ops = self.weight.nelement() * (2 if multiply_adds else 1)
|
99 |
+
bias_ops = self.bias.nelement()
|
100 |
+
|
101 |
+
flops = batch_size * (weight_ops + bias_ops)
|
102 |
+
list_linear.append(flops)
|
103 |
+
|
104 |
+
list_bn = []
|
105 |
+
|
106 |
+
def bn_hook(self, input, output):
|
107 |
+
list_bn.append(input[0].nelement())
|
108 |
+
|
109 |
+
list_relu = []
|
110 |
+
|
111 |
+
def relu_hook(self, input, output):
|
112 |
+
list_relu.append(input[0].nelement())
|
113 |
+
|
114 |
+
list_pooling2d = []
|
115 |
+
|
116 |
+
def pooling2d_hook(self, input, output):
|
117 |
+
batch_size, input_channels, input_height, input_width = input[0].size()
|
118 |
+
output_channels, output_height, output_width = output[0].size()
|
119 |
+
|
120 |
+
kernel_ops = self.kernel_size * self.kernel_size
|
121 |
+
bias_ops = 0
|
122 |
+
params = output_channels * (kernel_ops + bias_ops)
|
123 |
+
flops = batch_size * params * output_height * output_width
|
124 |
+
|
125 |
+
list_pooling2d.append(flops)
|
126 |
+
|
127 |
+
list_pooling1d = []
|
128 |
+
|
129 |
+
def pooling1d_hook(self, input, output):
|
130 |
+
batch_size, input_channels, input_length = input[0].size()
|
131 |
+
output_channels, output_length = output[0].size()
|
132 |
+
|
133 |
+
kernel_ops = self.kernel_size
|
134 |
+
bias_ops = 0
|
135 |
+
params = output_channels * (kernel_ops + bias_ops)
|
136 |
+
flops = batch_size * params * output_length
|
137 |
+
|
138 |
+
list_pooling2d.append(flops)
|
139 |
+
|
140 |
+
def foo(net):
|
141 |
+
childrens = list(net.children())
|
142 |
+
if not childrens:
|
143 |
+
if isinstance(net, nn.Conv2d):
|
144 |
+
net.register_forward_hook(conv2d_hook)
|
145 |
+
elif isinstance(net, nn.ConvTranspose2d):
|
146 |
+
net.register_forward_hook(conv2d_hook)
|
147 |
+
elif isinstance(net, nn.Conv1d):
|
148 |
+
net.register_forward_hook(conv1d_hook)
|
149 |
+
elif isinstance(net, nn.Linear):
|
150 |
+
net.register_forward_hook(linear_hook)
|
151 |
+
elif isinstance(net, nn.BatchNorm2d) or isinstance(net, nn.BatchNorm1d):
|
152 |
+
net.register_forward_hook(bn_hook)
|
153 |
+
elif isinstance(net, nn.ReLU):
|
154 |
+
net.register_forward_hook(relu_hook)
|
155 |
+
elif isinstance(net, nn.AvgPool2d) or isinstance(net, nn.MaxPool2d):
|
156 |
+
net.register_forward_hook(pooling2d_hook)
|
157 |
+
elif isinstance(net, nn.AvgPool1d) or isinstance(net, nn.MaxPool1d):
|
158 |
+
net.register_forward_hook(pooling1d_hook)
|
159 |
+
else:
|
160 |
+
print("Warning: flop of module {} is not counted!".format(net))
|
161 |
+
return
|
162 |
+
for c in childrens:
|
163 |
+
foo(c)
|
164 |
+
|
165 |
+
foo(model)
|
166 |
+
|
167 |
+
input = torch.rand(1, audio_length, 2)
|
168 |
+
out = model(input)
|
169 |
+
|
170 |
+
total_flops = (
|
171 |
+
sum(list_conv2d)
|
172 |
+
+ sum(list_conv1d)
|
173 |
+
+ sum(list_linear)
|
174 |
+
+ sum(list_bn)
|
175 |
+
+ sum(list_relu)
|
176 |
+
+ sum(list_pooling2d)
|
177 |
+
+ sum(list_pooling1d)
|
178 |
+
)
|
179 |
+
|
180 |
+
return total_flops
|
voicefixer/tools/random_.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
|
4 |
+
RANDOM_RESOLUTION = 2**31
|
5 |
+
|
6 |
+
|
7 |
+
def random_torch(high, to_int=True):
|
8 |
+
if to_int:
|
9 |
+
return int((torch.rand(1)) * high) # do not use numpy.random.random
|
10 |
+
else:
|
11 |
+
return (torch.rand(1)) * high # do not use numpy.random.random
|
12 |
+
|
13 |
+
|
14 |
+
def shuffle_torch(list):
|
15 |
+
length = len(list)
|
16 |
+
res = []
|
17 |
+
order = torch.randperm(length)
|
18 |
+
for each in order:
|
19 |
+
res.append(list[each])
|
20 |
+
assert len(list) == len(res)
|
21 |
+
return res
|
22 |
+
|
23 |
+
|
24 |
+
def random_choose_list(list):
|
25 |
+
num = int(uniform_torch(0, len(list)))
|
26 |
+
return list[num]
|
27 |
+
|
28 |
+
|
29 |
+
def normal_torch(mean=0, segma=1):
|
30 |
+
return float(torch.normal(mean=mean, std=torch.Tensor([segma]))[0])
|
31 |
+
|
32 |
+
|
33 |
+
def uniform_torch(lower, upper):
|
34 |
+
if abs(lower - upper) < 1e-5:
|
35 |
+
return upper
|
36 |
+
return (upper - lower) * torch.rand(1) + lower
|
37 |
+
|
38 |
+
|
39 |
+
def random_key(keys: list, weights: list):
|
40 |
+
return random.choices(keys, weights=weights)[0]
|
41 |
+
|
42 |
+
|
43 |
+
def random_select(probs):
|
44 |
+
res = []
|
45 |
+
chance = random_torch(RANDOM_RESOLUTION)
|
46 |
+
threshold = None
|
47 |
+
for prob in probs:
|
48 |
+
# if(threshold is None):threshold=prob
|
49 |
+
# else:threshold*=prob
|
50 |
+
threshold = prob
|
51 |
+
res.append(chance < threshold * RANDOM_RESOLUTION)
|
52 |
+
return res, chance
|
voicefixer/tools/wav.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import wave
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
import scipy.signal as signal
|
5 |
+
import soundfile as sf
|
6 |
+
import librosa
|
7 |
+
|
8 |
+
|
9 |
+
def save_wave(frames: np.ndarray, fname, sample_rate=44100):
|
10 |
+
shape = list(frames.shape)
|
11 |
+
if len(shape) == 1:
|
12 |
+
frames = frames[..., None]
|
13 |
+
in_samples, in_channels = shape[-2], shape[-1]
|
14 |
+
if in_channels >= 3:
|
15 |
+
if len(shape) == 2:
|
16 |
+
frames = np.transpose(frames, (1, 0))
|
17 |
+
elif len(shape) == 3:
|
18 |
+
frames = np.transpose(frames, (0, 2, 1))
|
19 |
+
msg = (
|
20 |
+
"Warning: Save audio with "
|
21 |
+
+ str(in_channels)
|
22 |
+
+ " channels, save permute audio with shape "
|
23 |
+
+ str(list(frames.shape))
|
24 |
+
+ " please check if it's correct."
|
25 |
+
)
|
26 |
+
# print(msg)
|
27 |
+
if (
|
28 |
+
np.max(frames) <= 1
|
29 |
+
and frames.dtype == np.float32
|
30 |
+
or frames.dtype == np.float16
|
31 |
+
or frames.dtype == np.float64
|
32 |
+
):
|
33 |
+
frames *= 2**15
|
34 |
+
frames = frames.astype(np.short)
|
35 |
+
if len(frames.shape) >= 3:
|
36 |
+
frames = frames[0, ...]
|
37 |
+
sf.write(fname, frames, samplerate=sample_rate)
|
38 |
+
|
39 |
+
|
40 |
+
def constrain_length(chunk, length):
|
41 |
+
frames_length = chunk.shape[0]
|
42 |
+
if frames_length == length:
|
43 |
+
return chunk
|
44 |
+
elif frames_length < length:
|
45 |
+
return np.pad(chunk, ((0, int(length - frames_length)), (0, 0)), "constant")
|
46 |
+
else:
|
47 |
+
return chunk[:length, ...]
|
48 |
+
|
49 |
+
|
50 |
+
def random_chunk_wav_file(fname, chunk_length):
|
51 |
+
"""
|
52 |
+
fname: path to wav file
|
53 |
+
chunk_length: frame length in seconds
|
54 |
+
"""
|
55 |
+
with wave.open(fname) as f:
|
56 |
+
params = f.getparams()
|
57 |
+
duration = params[3] / params[2]
|
58 |
+
sample_rate = params[2]
|
59 |
+
sample_length = params[3]
|
60 |
+
if duration < chunk_length or abs(duration - chunk_length) < 1e-4:
|
61 |
+
frames = read_wave(fname, sample_rate)
|
62 |
+
return frames, duration, sample_rate # [-1,1]
|
63 |
+
else:
|
64 |
+
# Random trunk
|
65 |
+
random_starts = np.random.randint(
|
66 |
+
0, sample_length - sample_rate * chunk_length
|
67 |
+
)
|
68 |
+
random_end = random_starts + sample_rate * chunk_length
|
69 |
+
random_starts, random_end = (
|
70 |
+
random_starts / sample_rate,
|
71 |
+
random_end / sample_rate,
|
72 |
+
)
|
73 |
+
random_starts, random_end = random_starts / duration, random_end / duration
|
74 |
+
frames = read_wave(
|
75 |
+
fname, sample_rate, portion_start=random_starts, portion_end=random_end
|
76 |
+
)
|
77 |
+
frames = constrain_length(frames, length=int(chunk_length * sample_rate))
|
78 |
+
return frames, chunk_length, sample_rate
|
79 |
+
|
80 |
+
|
81 |
+
def random_chunk_wav_file_v2(fname, chunk_length, random_starts=None, random_end=None):
|
82 |
+
"""
|
83 |
+
fname: path to wav file
|
84 |
+
chunk_length: frame length in seconds
|
85 |
+
"""
|
86 |
+
with wave.open(fname) as f:
|
87 |
+
params = f.getparams()
|
88 |
+
duration = params[3] / params[2]
|
89 |
+
sample_rate = params[2]
|
90 |
+
sample_length = params[3]
|
91 |
+
if duration < chunk_length or abs(duration - chunk_length) < 1e-4:
|
92 |
+
frames = read_wave(fname, sample_rate)
|
93 |
+
return frames, duration, sample_rate # [-1,1]
|
94 |
+
else:
|
95 |
+
# Random trunk
|
96 |
+
if random_starts is None and random_end is None:
|
97 |
+
random_starts = np.random.randint(
|
98 |
+
0, sample_length - sample_rate * chunk_length
|
99 |
+
)
|
100 |
+
random_end = random_starts + sample_rate * chunk_length
|
101 |
+
random_starts, random_end = (
|
102 |
+
random_starts / sample_rate,
|
103 |
+
random_end / sample_rate,
|
104 |
+
)
|
105 |
+
random_starts, random_end = (
|
106 |
+
random_starts / duration,
|
107 |
+
random_end / duration,
|
108 |
+
)
|
109 |
+
frames = read_wave(
|
110 |
+
fname, sample_rate, portion_start=random_starts, portion_end=random_end
|
111 |
+
)
|
112 |
+
frames = constrain_length(frames, length=int(chunk_length * sample_rate))
|
113 |
+
return frames, chunk_length, sample_rate, random_starts, random_end
|
114 |
+
|
115 |
+
|
116 |
+
def read_wave(
|
117 |
+
fname,
|
118 |
+
sample_rate,
|
119 |
+
portion_start=0,
|
120 |
+
portion_end=1,
|
121 |
+
): # Whether you want raw bytes
|
122 |
+
"""
|
123 |
+
:param fname: wav file path
|
124 |
+
:param sample_rate:
|
125 |
+
:param portion_start:
|
126 |
+
:param portion_end:
|
127 |
+
:return: [sample, channels]
|
128 |
+
"""
|
129 |
+
# sr = get_sample_rate(fname)
|
130 |
+
# if(sr != sample_rate):
|
131 |
+
# print("Warning: Sample rate not match, may lead to unexpected behavior.")
|
132 |
+
if portion_end > 1 and portion_end < 1.1:
|
133 |
+
portion_end = 1
|
134 |
+
if portion_end != 1:
|
135 |
+
duration = get_duration(fname)
|
136 |
+
wav, _ = librosa.load(
|
137 |
+
fname,
|
138 |
+
sr=sample_rate,
|
139 |
+
offset=portion_start * duration,
|
140 |
+
duration=(portion_end - portion_start) * duration,
|
141 |
+
mono=False,
|
142 |
+
)
|
143 |
+
else:
|
144 |
+
wav, _ = librosa.load(fname, sr=sample_rate, mono=False)
|
145 |
+
if len(list(wav.shape)) == 1:
|
146 |
+
wav = wav[..., None]
|
147 |
+
else:
|
148 |
+
wav = np.transpose(wav, (1, 0))
|
149 |
+
return wav
|
150 |
+
|
151 |
+
|
152 |
+
def get_channels_sampwidth_and_sample_rate(fname):
|
153 |
+
with wave.open(fname) as f:
|
154 |
+
params = f.getparams()
|
155 |
+
return (
|
156 |
+
params[0],
|
157 |
+
params[1],
|
158 |
+
params[2],
|
159 |
+
) # == (2,2,44100),(params[0],params[1],params[2])
|
160 |
+
|
161 |
+
|
162 |
+
def get_channels(fname):
|
163 |
+
with wave.open(fname) as f:
|
164 |
+
params = f.getparams()
|
165 |
+
return params[0]
|
166 |
+
|
167 |
+
|
168 |
+
def get_sample_rate(fname):
|
169 |
+
with wave.open(fname) as f:
|
170 |
+
params = f.getparams()
|
171 |
+
return params[2]
|
172 |
+
|
173 |
+
|
174 |
+
def get_duration(fname):
|
175 |
+
with wave.open(fname) as f:
|
176 |
+
params = f.getparams()
|
177 |
+
return params[3] / params[2]
|
178 |
+
|
179 |
+
|
180 |
+
def get_framesLength(fname):
|
181 |
+
with wave.open(fname) as f:
|
182 |
+
params = f.getparams()
|
183 |
+
return params[3]
|
184 |
+
|
185 |
+
|
186 |
+
def restore_wave(zxx):
|
187 |
+
_, w = signal.istft(zxx)
|
188 |
+
return w
|
189 |
+
|
190 |
+
|
191 |
+
def calculate_total_times(dir):
|
192 |
+
total = 0
|
193 |
+
for each in os.listdir(dir):
|
194 |
+
fname = os.path.join(dir, each)
|
195 |
+
try:
|
196 |
+
duration = get_duration(fname)
|
197 |
+
except:
|
198 |
+
print(fname)
|
199 |
+
total += duration
|
200 |
+
return total
|
201 |
+
|
202 |
+
|
203 |
+
def filter(pth):
|
204 |
+
global dic
|
205 |
+
temp = []
|
206 |
+
for each in os.listdir(pth):
|
207 |
+
temp.append(os.path.join(pth, each))
|
208 |
+
for each in temp:
|
209 |
+
sr = get_sample_rate(each)
|
210 |
+
if sr not in dic.keys():
|
211 |
+
dic[sr] = []
|
212 |
+
dic[sr].append(each)
|
213 |
+
for each in dic[16000]:
|
214 |
+
# print(each)
|
215 |
+
pass
|
216 |
+
print(dic.keys())
|
217 |
+
for each in list(dic.keys()):
|
218 |
+
print(each, len(dic[each]))
|
219 |
+
|
220 |
+
|
221 |
+
if __name__ == "__main__":
|
222 |
+
path = "/Users/admin/Desktop/p376_025.wav"
|
223 |
+
stereo = "/Users/admin/Desktop/vocals.wav"
|
224 |
+
path_16 = "/Users/admin/Desktop/SI869.WAV.wav"
|
225 |
+
import time
|
226 |
+
|
227 |
+
start = time.time()
|
228 |
+
for i in range(1000):
|
229 |
+
frames, duration, sample_rate = random_chunk_wav_file(stereo, chunk_length=3.0)
|
230 |
+
print(frames.shape, np.max(frames))
|
231 |
+
save_wave(frames, "stero.wav", sample_rate=44100)
|
232 |
+
frames, duration, sample_rate = random_chunk_wav_file(path, chunk_length=3.0)
|
233 |
+
print(frames.shape, np.max(frames))
|
234 |
+
save_wave(frames, "mono.wav", sample_rate=44100)
|
235 |
+
frames, duration, sample_rate = random_chunk_wav_file(path_16, chunk_length=3.0)
|
236 |
+
print(frames.shape, np.max(frames))
|
237 |
+
save_wave(frames, "16.wav", sample_rate=16000)
|
238 |
+
print(time.time() - start)
|
239 |
+
# frames = read_wave(stereo,sample_rate=44100)
|
240 |
+
print(frames.shape)
|
241 |
+
|
242 |
+
print(frames)
|
voicefixer/vocoder/__init__.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@File : __init__.py.py
|
5 |
+
@Contact : [email protected]
|
6 |
+
@License : (C)Copyright 2020-2100
|
7 |
+
|
8 |
+
@Modify Time @Author @Version @Desciption
|
9 |
+
------------ ------- -------- -----------
|
10 |
+
9/14/21 1:00 AM Haohe Liu 1.0 None
|
11 |
+
"""
|
12 |
+
|
13 |
+
import os
|
14 |
+
from voicefixer.vocoder.config import Config
|
15 |
+
import urllib.request
|
16 |
+
|
17 |
+
if not os.path.exists(Config.ckpt):
|
18 |
+
os.makedirs(os.path.dirname(Config.ckpt), exist_ok=True)
|
19 |
+
print("Downloading the weight of neural vocoder: TFGAN")
|
20 |
+
urllib.request.urlretrieve(
|
21 |
+
"https://zenodo.org/record/5469951/files/model.ckpt-1490000_trimed.pt?download=1",
|
22 |
+
Config.ckpt,
|
23 |
+
)
|
24 |
+
print(
|
25 |
+
"Weights downloaded in: {} Size: {}".format(
|
26 |
+
Config.ckpt, os.path.getsize(Config.ckpt)
|
27 |
+
)
|
28 |
+
)
|
29 |
+
# cmd = "wget https://zenodo.org/record/5469951/files/model.ckpt-1490000_trimed.pt?download=1 -O " + Config.ckpt
|
30 |
+
# os.system(cmd)
|
voicefixer/vocoder/base.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from voicefixer.vocoder.model.generator import Generator
|
2 |
+
from voicefixer.tools.wav import read_wave, save_wave
|
3 |
+
from voicefixer.tools.pytorch_util import *
|
4 |
+
from voicefixer.vocoder.model.util import *
|
5 |
+
from voicefixer.vocoder.config import Config
|
6 |
+
import os
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
|
10 |
+
class Vocoder(nn.Module):
|
11 |
+
def __init__(self, sample_rate):
|
12 |
+
super(Vocoder, self).__init__()
|
13 |
+
Config.refresh(sample_rate)
|
14 |
+
self.rate = sample_rate
|
15 |
+
if(not os.path.exists(Config.ckpt)):
|
16 |
+
raise RuntimeError("Error 1: The checkpoint for synthesis module / vocoder (model.ckpt-1490000_trimed) is not found in ~/.cache/voicefixer/synthesis_module/44100. \
|
17 |
+
By default the checkpoint should be download automatically by this program. Something bad may happened. Apologies for the inconvenience.\
|
18 |
+
But don't worry! Alternatively you can download it directly from Zenodo: https://zenodo.org/record/5600188/files/model.ckpt-1490000_trimed.pt?download=1")
|
19 |
+
self._load_pretrain(Config.ckpt)
|
20 |
+
self.weight_torch = Config.get_mel_weight_torch(percent=1.0)[
|
21 |
+
None, None, None, ...
|
22 |
+
]
|
23 |
+
|
24 |
+
def _load_pretrain(self, pth):
|
25 |
+
self.model = Generator(Config.cin_channels)
|
26 |
+
checkpoint = load_checkpoint(pth, torch.device("cpu"))
|
27 |
+
load_try(checkpoint["generator"], self.model)
|
28 |
+
self.model.eval()
|
29 |
+
self.model.remove_weight_norm()
|
30 |
+
self.model.remove_weight_norm()
|
31 |
+
for p in self.model.parameters():
|
32 |
+
p.requires_grad = False
|
33 |
+
|
34 |
+
# def vocoder_mel_npy(self, mel, save_dir, sample_rate, gain):
|
35 |
+
# mel = mel / Config.get_mel_weight(percent=gain)[...,None]
|
36 |
+
# mel = normalize(amp_to_db(np.abs(mel)) - 20)
|
37 |
+
# mel = pre(np.transpose(mel, (1, 0)))
|
38 |
+
# with torch.no_grad():
|
39 |
+
# wav_re = self.model(mel) # torch.Size([1, 1, 104076])
|
40 |
+
# save_wave(tensor2numpy(wav_re)*2**15,save_dir,sample_rate=sample_rate)
|
41 |
+
|
42 |
+
def forward(self, mel, cuda=False):
|
43 |
+
"""
|
44 |
+
:param non normalized mel spectrogram: [batchsize, 1, t-steps, n_mel]
|
45 |
+
:return: [batchsize, 1, samples]
|
46 |
+
"""
|
47 |
+
assert mel.size()[-1] == 128
|
48 |
+
check_cuda_availability(cuda=cuda)
|
49 |
+
self.model = try_tensor_cuda(self.model, cuda=cuda)
|
50 |
+
mel = try_tensor_cuda(mel, cuda=cuda)
|
51 |
+
self.weight_torch = self.weight_torch.type_as(mel)
|
52 |
+
mel = mel / self.weight_torch
|
53 |
+
mel = tr_normalize(tr_amp_to_db(torch.abs(mel)) - 20.0)
|
54 |
+
mel = tr_pre(mel[:, 0, ...])
|
55 |
+
wav_re = self.model(mel)
|
56 |
+
return wav_re
|
57 |
+
|
58 |
+
def oracle(self, fpath, out_path, cuda=False):
|
59 |
+
check_cuda_availability(cuda=cuda)
|
60 |
+
self.model = try_tensor_cuda(self.model, cuda=cuda)
|
61 |
+
wav = read_wave(fpath, sample_rate=self.rate)[..., 0]
|
62 |
+
wav = wav / np.max(np.abs(wav))
|
63 |
+
stft = np.abs(
|
64 |
+
librosa.stft(
|
65 |
+
wav,
|
66 |
+
hop_length=Config.hop_length,
|
67 |
+
win_length=Config.win_size,
|
68 |
+
n_fft=Config.n_fft,
|
69 |
+
)
|
70 |
+
)
|
71 |
+
mel = linear_to_mel(stft)
|
72 |
+
mel = normalize(amp_to_db(np.abs(mel)) - 20)
|
73 |
+
mel = pre(np.transpose(mel, (1, 0)))
|
74 |
+
mel = try_tensor_cuda(mel, cuda=cuda)
|
75 |
+
with torch.no_grad():
|
76 |
+
wav_re = self.model(mel)
|
77 |
+
save_wave(tensor2numpy(wav_re * 2**15), out_path, sample_rate=self.rate)
|
78 |
+
|
79 |
+
|
80 |
+
if __name__ == "__main__":
|
81 |
+
model = Vocoder(sample_rate=44100)
|
82 |
+
print(model.device)
|
83 |
+
# model.load_pretrain(Config.ckpt)
|
84 |
+
# model.oracle(path="/Users/liuhaohe/Desktop/test.wav",
|
85 |
+
# sample_rate=44100,
|
86 |
+
# save_dir="/Users/liuhaohe/Desktop/test_vocoder.wav")
|
voicefixer/vocoder/config.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import os
|
4 |
+
from voicefixer.tools.path import root_path
|
5 |
+
|
6 |
+
|
7 |
+
class Config:
|
8 |
+
@classmethod
|
9 |
+
def refresh(cls, sr):
|
10 |
+
if sr == 44100:
|
11 |
+
Config.ckpt = os.path.join(
|
12 |
+
os.path.expanduser("~"),
|
13 |
+
".cache/voicefixer/synthesis_module/44100/model.ckpt-1490000_trimed.pt",
|
14 |
+
)
|
15 |
+
Config.cond_channels = 512
|
16 |
+
Config.m_channels = 768
|
17 |
+
Config.resstack_depth = [8, 8, 8, 8]
|
18 |
+
Config.channels = 1024
|
19 |
+
Config.cin_channels = 128
|
20 |
+
Config.upsample_scales = [7, 7, 3, 3]
|
21 |
+
Config.num_mels = 128
|
22 |
+
Config.n_fft = 2048
|
23 |
+
Config.hop_length = 441
|
24 |
+
Config.sample_rate = 44100
|
25 |
+
Config.fmax = 22000
|
26 |
+
Config.mel_win = 128
|
27 |
+
Config.local_condition_dim = 128
|
28 |
+
else:
|
29 |
+
raise RuntimeError(
|
30 |
+
"Error: Vocoder currently only support 44100 samplerate."
|
31 |
+
)
|
32 |
+
|
33 |
+
ckpt = os.path.join(
|
34 |
+
os.path.expanduser("~"),
|
35 |
+
".cache/voicefixer/synthesis_module/44100/model.ckpt-1490000_trimed.pt",
|
36 |
+
)
|
37 |
+
m_channels = 384
|
38 |
+
bits = 10
|
39 |
+
opt = "Ralamb"
|
40 |
+
cond_channels = 256
|
41 |
+
clip = 0.5
|
42 |
+
num_bands = 1
|
43 |
+
cin_channels = 128
|
44 |
+
upsample_scales = [7, 7, 3, 3]
|
45 |
+
filterbands = "test/filterbanks_4bands.dat"
|
46 |
+
##For inference
|
47 |
+
tag = ""
|
48 |
+
min_db = -115
|
49 |
+
num_mels = 128
|
50 |
+
n_fft = 2048
|
51 |
+
hop_length = 441
|
52 |
+
win_size = None
|
53 |
+
sample_rate = 44100
|
54 |
+
frame_shift_ms = None
|
55 |
+
|
56 |
+
trim_fft_size = 512
|
57 |
+
trim_hop_size = 128
|
58 |
+
trim_top_db = 23
|
59 |
+
|
60 |
+
signal_normalization = True
|
61 |
+
allow_clipping_in_normalization = True
|
62 |
+
symmetric_mels = True
|
63 |
+
max_abs_value = 4.0
|
64 |
+
|
65 |
+
preemphasis = 0.85
|
66 |
+
min_level_db = -100
|
67 |
+
ref_level_db = 20
|
68 |
+
fmin = 50
|
69 |
+
fmax = 22000
|
70 |
+
power = 1.5
|
71 |
+
griffin_lim_iters = 60
|
72 |
+
rescale = False
|
73 |
+
rescaling_max = 0.95
|
74 |
+
trim_silence = False
|
75 |
+
clip_mels_length = True
|
76 |
+
max_mel_frames = 2000
|
77 |
+
|
78 |
+
mel_win = 128
|
79 |
+
batch_size = 24
|
80 |
+
g_learning_rate = 0.001
|
81 |
+
d_learning_rate = 0.001
|
82 |
+
warmup_steps = 100000
|
83 |
+
decay_learning_rate = 0.5
|
84 |
+
exponential_moving_average = True
|
85 |
+
ema_decay = 0.99
|
86 |
+
|
87 |
+
reset_opt = False
|
88 |
+
reset_g_opt = False
|
89 |
+
reset_d_opt = False
|
90 |
+
|
91 |
+
local_condition_dim = 128
|
92 |
+
lambda_update_G = 1
|
93 |
+
multiscale_D = 3
|
94 |
+
|
95 |
+
lambda_adv = 4.0
|
96 |
+
lambda_fm_loss = 0.0
|
97 |
+
lambda_sc_loss = 5.0
|
98 |
+
lambda_mag_loss = 5.0
|
99 |
+
lambda_mel_loss = 50.0
|
100 |
+
use_mle_loss = False
|
101 |
+
lambda_mle_loss = 5.0
|
102 |
+
|
103 |
+
lambda_freq_loss = 2.0
|
104 |
+
lambda_energy_loss = 100.0
|
105 |
+
lambda_t_loss = 200.0
|
106 |
+
lambda_phase_loss = 100.0
|
107 |
+
lambda_f0_loss = 1.0
|
108 |
+
use_elu = False
|
109 |
+
de_preem = False # train
|
110 |
+
up_org = False
|
111 |
+
use_one = True
|
112 |
+
use_small_D = False
|
113 |
+
use_condnet = True
|
114 |
+
use_depreem = False # inference
|
115 |
+
use_msd = False
|
116 |
+
model_type = "tfgan" # or bytewave, frame level vocoder using istft
|
117 |
+
use_hjcud = False
|
118 |
+
no_skip = False
|
119 |
+
out_channels = 1
|
120 |
+
use_postnet = False # wn in postnet
|
121 |
+
use_wn = False # wn in resstack
|
122 |
+
up_type = "transpose"
|
123 |
+
use_smooth = False
|
124 |
+
use_drop = False
|
125 |
+
use_shift_scale = False
|
126 |
+
use_gcnn = False
|
127 |
+
resstack_depth = [6, 6, 6, 6]
|
128 |
+
kernel_size = [3, 3, 3, 3]
|
129 |
+
channels = 512
|
130 |
+
use_f0_loss = False
|
131 |
+
use_sine = False
|
132 |
+
use_cond_rnn = False
|
133 |
+
use_rnn = False
|
134 |
+
|
135 |
+
f0_step = 120
|
136 |
+
use_lowfreq_loss = False
|
137 |
+
lambda_lowfreq_loss = 1.0
|
138 |
+
use_film = False
|
139 |
+
use_mb_mr_gan = False
|
140 |
+
|
141 |
+
use_mssl = False
|
142 |
+
use_ml_gan = False
|
143 |
+
use_mb_gan = True
|
144 |
+
use_mpd = False
|
145 |
+
use_spec_gan = True
|
146 |
+
use_rwd = False
|
147 |
+
use_mr_gan = True
|
148 |
+
use_pqmf_rwd = False
|
149 |
+
no_sine = False
|
150 |
+
use_frame_mask = False
|
151 |
+
|
152 |
+
lambda_var_loss = 0.0
|
153 |
+
discriminator_train_start_steps = 40000 # 80k
|
154 |
+
aux_d_train_start_steps = 40000 # 100k
|
155 |
+
rescale_out = 0.40
|
156 |
+
use_dist = True
|
157 |
+
dist_backend = "nccl"
|
158 |
+
dist_url = "tcp://localhost:12345"
|
159 |
+
world_size = 1
|
160 |
+
|
161 |
+
mel_weight_torch = torch.tensor(
|
162 |
+
[
|
163 |
+
19.40951426,
|
164 |
+
19.94047336,
|
165 |
+
20.4859038,
|
166 |
+
21.04629067,
|
167 |
+
21.62194148,
|
168 |
+
22.21335214,
|
169 |
+
22.8210215,
|
170 |
+
23.44529231,
|
171 |
+
24.08660962,
|
172 |
+
24.74541882,
|
173 |
+
25.42234287,
|
174 |
+
26.11770576,
|
175 |
+
26.83212784,
|
176 |
+
27.56615283,
|
177 |
+
28.32007747,
|
178 |
+
29.0947679,
|
179 |
+
29.89060111,
|
180 |
+
30.70832636,
|
181 |
+
31.54828121,
|
182 |
+
32.41121487,
|
183 |
+
33.29780773,
|
184 |
+
34.20865341,
|
185 |
+
35.14437675,
|
186 |
+
36.1056621,
|
187 |
+
37.09332763,
|
188 |
+
38.10795802,
|
189 |
+
39.15039691,
|
190 |
+
40.22119881,
|
191 |
+
41.32154931,
|
192 |
+
42.45172373,
|
193 |
+
43.61293329,
|
194 |
+
44.80609379,
|
195 |
+
46.031602,
|
196 |
+
47.29070223,
|
197 |
+
48.58427549,
|
198 |
+
49.91327905,
|
199 |
+
51.27863232,
|
200 |
+
52.68119708,
|
201 |
+
54.1222372,
|
202 |
+
55.60274206,
|
203 |
+
57.12364703,
|
204 |
+
58.68617876,
|
205 |
+
60.29148652,
|
206 |
+
61.94081306,
|
207 |
+
63.63501986,
|
208 |
+
65.37562658,
|
209 |
+
67.16408954,
|
210 |
+
69.00109084,
|
211 |
+
70.88850318,
|
212 |
+
72.82736101,
|
213 |
+
74.81985537,
|
214 |
+
76.86654792,
|
215 |
+
78.96885475,
|
216 |
+
81.12900906,
|
217 |
+
83.34840929,
|
218 |
+
85.62810662,
|
219 |
+
87.97005418,
|
220 |
+
90.37689804,
|
221 |
+
92.84887686,
|
222 |
+
95.38872881,
|
223 |
+
97.99777002,
|
224 |
+
100.67862715,
|
225 |
+
103.43232942,
|
226 |
+
106.26140638,
|
227 |
+
109.16827015,
|
228 |
+
112.15470471,
|
229 |
+
115.22184756,
|
230 |
+
118.37439245,
|
231 |
+
121.6122689,
|
232 |
+
124.93877158,
|
233 |
+
128.35661454,
|
234 |
+
131.86761321,
|
235 |
+
135.47417938,
|
236 |
+
139.18059494,
|
237 |
+
142.98713744,
|
238 |
+
146.89771854,
|
239 |
+
150.91684347,
|
240 |
+
155.0446638,
|
241 |
+
159.28614648,
|
242 |
+
163.64270198,
|
243 |
+
168.12035831,
|
244 |
+
172.71749158,
|
245 |
+
177.44220154,
|
246 |
+
182.29556933,
|
247 |
+
187.28286676,
|
248 |
+
192.40502126,
|
249 |
+
197.6682721,
|
250 |
+
203.07516896,
|
251 |
+
208.63088733,
|
252 |
+
214.33770931,
|
253 |
+
220.19910108,
|
254 |
+
226.22363072,
|
255 |
+
232.41087124,
|
256 |
+
238.76803591,
|
257 |
+
245.30079083,
|
258 |
+
252.01064464,
|
259 |
+
258.90261676,
|
260 |
+
265.98474,
|
261 |
+
273.26010248,
|
262 |
+
280.73496362,
|
263 |
+
288.41440094,
|
264 |
+
296.30489752,
|
265 |
+
304.41180337,
|
266 |
+
312.7377183,
|
267 |
+
321.28877878,
|
268 |
+
330.07870237,
|
269 |
+
339.10812951,
|
270 |
+
348.38276173,
|
271 |
+
357.91393924,
|
272 |
+
367.70513992,
|
273 |
+
377.76413924,
|
274 |
+
388.09467408,
|
275 |
+
398.70920178,
|
276 |
+
409.61813793,
|
277 |
+
420.81980127,
|
278 |
+
432.33215467,
|
279 |
+
444.16083117,
|
280 |
+
456.30919947,
|
281 |
+
468.78589276,
|
282 |
+
481.61325588,
|
283 |
+
494.78824596,
|
284 |
+
508.31969844,
|
285 |
+
522.2238331,
|
286 |
+
536.51163441,
|
287 |
+
551.18859414,
|
288 |
+
566.26142988,
|
289 |
+
581.75006061,
|
290 |
+
597.66210737,
|
291 |
+
]
|
292 |
+
)
|
293 |
+
|
294 |
+
x_orig = np.linspace(1, mel_weight_torch.shape[0], num=mel_weight_torch.shape[0])
|
295 |
+
|
296 |
+
x_orig_torch = torch.linspace(
|
297 |
+
1, mel_weight_torch.shape[0], steps=mel_weight_torch.shape[0]
|
298 |
+
)
|
299 |
+
|
300 |
+
@classmethod
|
301 |
+
def get_mel_weight(cls, percent=1, a=18.8927416350036, b=0.0269863588184314):
|
302 |
+
b = percent * b
|
303 |
+
|
304 |
+
def func(a, b, x):
|
305 |
+
return a * np.exp(b * x)
|
306 |
+
|
307 |
+
return func(a, b, Config.x_orig)
|
308 |
+
|
309 |
+
@classmethod
|
310 |
+
def get_mel_weight_torch(cls, percent=1, a=18.8927416350036, b=0.0269863588184314):
|
311 |
+
b = percent * b
|
312 |
+
|
313 |
+
def func(a, b, x):
|
314 |
+
return a * torch.exp(b * x)
|
315 |
+
|
316 |
+
return func(a, b, Config.x_orig_torch)
|
voicefixer/vocoder/model/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@File : __init__.py.py
|
5 |
+
@Contact : [email protected]
|
6 |
+
@License : (C)Copyright 2020-2100
|
7 |
+
|
8 |
+
@Modify Time @Author @Version @Desciption
|
9 |
+
------------ ------- -------- -----------
|
10 |
+
9/14/21 1:00 AM Haohe Liu 1.0 None
|
11 |
+
"""
|
voicefixer/vocoder/model/generator.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
from voicefixer.vocoder.model.modules import UpsampleNet, ResStack
|
5 |
+
from voicefixer.vocoder.config import Config
|
6 |
+
from voicefixer.vocoder.model.pqmf import PQMF
|
7 |
+
import os
|
8 |
+
|
9 |
+
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
|
10 |
+
|
11 |
+
|
12 |
+
class Generator(nn.Module):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
in_channels=128,
|
16 |
+
use_elu=False,
|
17 |
+
use_gcnn=False,
|
18 |
+
up_org=False,
|
19 |
+
group=1,
|
20 |
+
hp=None,
|
21 |
+
):
|
22 |
+
super(Generator, self).__init__()
|
23 |
+
self.hp = hp
|
24 |
+
channels = Config.channels
|
25 |
+
self.upsample_scales = Config.upsample_scales
|
26 |
+
self.use_condnet = Config.use_condnet
|
27 |
+
self.out_channels = Config.out_channels
|
28 |
+
self.resstack_depth = Config.resstack_depth
|
29 |
+
self.use_postnet = Config.use_postnet
|
30 |
+
self.use_cond_rnn = Config.use_cond_rnn
|
31 |
+
if self.use_condnet:
|
32 |
+
cond_channels = Config.cond_channels
|
33 |
+
self.condnet = nn.Sequential(
|
34 |
+
nn.utils.weight_norm(
|
35 |
+
nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
|
36 |
+
),
|
37 |
+
nn.ELU(),
|
38 |
+
nn.utils.weight_norm(
|
39 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
40 |
+
),
|
41 |
+
nn.ELU(),
|
42 |
+
nn.utils.weight_norm(
|
43 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
44 |
+
),
|
45 |
+
nn.ELU(),
|
46 |
+
nn.utils.weight_norm(
|
47 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
48 |
+
),
|
49 |
+
nn.ELU(),
|
50 |
+
nn.utils.weight_norm(
|
51 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
52 |
+
),
|
53 |
+
nn.ELU(),
|
54 |
+
)
|
55 |
+
in_channels = cond_channels
|
56 |
+
if self.use_cond_rnn:
|
57 |
+
self.rnn = nn.GRU(
|
58 |
+
cond_channels,
|
59 |
+
cond_channels // 2,
|
60 |
+
num_layers=1,
|
61 |
+
batch_first=True,
|
62 |
+
bidirectional=True,
|
63 |
+
)
|
64 |
+
|
65 |
+
if use_elu:
|
66 |
+
act = nn.ELU()
|
67 |
+
else:
|
68 |
+
act = nn.LeakyReLU(0.2, True)
|
69 |
+
|
70 |
+
kernel_size = Config.kernel_size
|
71 |
+
|
72 |
+
if self.out_channels == 1:
|
73 |
+
self.generator = nn.Sequential(
|
74 |
+
nn.ReflectionPad1d(3),
|
75 |
+
nn.utils.weight_norm(nn.Conv1d(in_channels, channels, kernel_size=7)),
|
76 |
+
act,
|
77 |
+
UpsampleNet(channels, channels // 2, self.upsample_scales[0], hp, 0),
|
78 |
+
ResStack(channels // 2, kernel_size[0], self.resstack_depth[0], hp),
|
79 |
+
act,
|
80 |
+
UpsampleNet(
|
81 |
+
channels // 2, channels // 4, self.upsample_scales[1], hp, 1
|
82 |
+
),
|
83 |
+
ResStack(channels // 4, kernel_size[1], self.resstack_depth[1], hp),
|
84 |
+
act,
|
85 |
+
UpsampleNet(
|
86 |
+
channels // 4, channels // 8, self.upsample_scales[2], hp, 2
|
87 |
+
),
|
88 |
+
ResStack(channels // 8, kernel_size[2], self.resstack_depth[2], hp),
|
89 |
+
act,
|
90 |
+
UpsampleNet(
|
91 |
+
channels // 8, channels // 16, self.upsample_scales[3], hp, 3
|
92 |
+
),
|
93 |
+
ResStack(channels // 16, kernel_size[3], self.resstack_depth[3], hp),
|
94 |
+
act,
|
95 |
+
nn.ReflectionPad1d(3),
|
96 |
+
nn.utils.weight_norm(
|
97 |
+
nn.Conv1d(channels // 16, self.out_channels, kernel_size=7)
|
98 |
+
),
|
99 |
+
nn.Tanh(),
|
100 |
+
)
|
101 |
+
else:
|
102 |
+
channels = Config.m_channels
|
103 |
+
self.generator = nn.Sequential(
|
104 |
+
nn.ReflectionPad1d(3),
|
105 |
+
nn.utils.weight_norm(nn.Conv1d(in_channels, channels, kernel_size=7)),
|
106 |
+
act,
|
107 |
+
UpsampleNet(channels, channels // 2, self.upsample_scales[0], hp),
|
108 |
+
ResStack(channels // 2, kernel_size[0], self.resstack_depth[0], hp),
|
109 |
+
act,
|
110 |
+
UpsampleNet(channels // 2, channels // 4, self.upsample_scales[1], hp),
|
111 |
+
ResStack(channels // 4, kernel_size[1], self.resstack_depth[1], hp),
|
112 |
+
act,
|
113 |
+
UpsampleNet(channels // 4, channels // 8, self.upsample_scales[3], hp),
|
114 |
+
ResStack(channels // 8, kernel_size[3], self.resstack_depth[2], hp),
|
115 |
+
act,
|
116 |
+
nn.ReflectionPad1d(3),
|
117 |
+
nn.utils.weight_norm(
|
118 |
+
nn.Conv1d(channels // 8, self.out_channels, kernel_size=7)
|
119 |
+
),
|
120 |
+
nn.Tanh(),
|
121 |
+
)
|
122 |
+
if self.out_channels > 1:
|
123 |
+
self.pqmf = PQMF(4, 64)
|
124 |
+
|
125 |
+
self.num_params()
|
126 |
+
|
127 |
+
def forward(self, conditions, use_res=False, f0=None):
|
128 |
+
res = conditions
|
129 |
+
if self.use_condnet:
|
130 |
+
conditions = self.condnet(conditions)
|
131 |
+
if self.use_cond_rnn:
|
132 |
+
conditions, _ = self.rnn(conditions.transpose(1, 2))
|
133 |
+
conditions = conditions.transpose(1, 2)
|
134 |
+
|
135 |
+
wav = self.generator(conditions)
|
136 |
+
if self.out_channels > 1:
|
137 |
+
B = wav.size(0)
|
138 |
+
f_wav = (
|
139 |
+
self.pqmf.synthesis(wav)
|
140 |
+
.transpose(1, 2)
|
141 |
+
.reshape(B, 1, -1)
|
142 |
+
.clamp(-0.99, 0.99)
|
143 |
+
)
|
144 |
+
return f_wav, wav
|
145 |
+
return wav
|
146 |
+
|
147 |
+
def num_params(self):
|
148 |
+
parameters = filter(lambda p: p.requires_grad, self.parameters())
|
149 |
+
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
|
150 |
+
return parameters
|
151 |
+
# print('Trainable Parameters: %.3f million' % parameters)
|
152 |
+
|
153 |
+
def remove_weight_norm(self):
|
154 |
+
def _remove_weight_norm(m):
|
155 |
+
try:
|
156 |
+
torch.nn.utils.remove_weight_norm(m)
|
157 |
+
except ValueError: # this module didn't have weight norm
|
158 |
+
return
|
159 |
+
|
160 |
+
self.apply(_remove_weight_norm)
|
161 |
+
|
162 |
+
|
163 |
+
if __name__ == "__main__":
|
164 |
+
model = Generator(128)
|
165 |
+
x = torch.randn(3, 128, 13)
|
166 |
+
print(x.shape)
|
167 |
+
y = model(x)
|
168 |
+
print(y.shape)
|
voicefixer/vocoder/model/modules.py
ADDED
@@ -0,0 +1,947 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from voicefixer.vocoder.config import Config
|
7 |
+
|
8 |
+
# From xin wang of nii
|
9 |
+
class SineGen(torch.nn.Module):
|
10 |
+
"""Definition of sine generator
|
11 |
+
SineGen(samp_rate, harmonic_num = 0,
|
12 |
+
sine_amp = 0.1, noise_std = 0.003,
|
13 |
+
voiced_threshold = 0,
|
14 |
+
flag_for_pulse=False)
|
15 |
+
|
16 |
+
samp_rate: sampling rate in Hz
|
17 |
+
harmonic_num: number of harmonic overtones (default 0)
|
18 |
+
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
19 |
+
noise_std: std of Gaussian noise (default 0.003)
|
20 |
+
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
21 |
+
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
22 |
+
|
23 |
+
Note: when flag_for_pulse is True, the first time step of a voiced
|
24 |
+
segment is always sin(np.pi) or cos(0)
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
samp_rate=24000,
|
30 |
+
harmonic_num=0,
|
31 |
+
sine_amp=0.1,
|
32 |
+
noise_std=0.003,
|
33 |
+
voiced_threshold=0,
|
34 |
+
flag_for_pulse=False,
|
35 |
+
):
|
36 |
+
super(SineGen, self).__init__()
|
37 |
+
self.sine_amp = sine_amp
|
38 |
+
self.noise_std = noise_std
|
39 |
+
self.harmonic_num = harmonic_num
|
40 |
+
self.dim = self.harmonic_num + 1
|
41 |
+
self.sampling_rate = samp_rate
|
42 |
+
self.voiced_threshold = voiced_threshold
|
43 |
+
self.flag_for_pulse = flag_for_pulse
|
44 |
+
|
45 |
+
def _f02uv(self, f0):
|
46 |
+
# generate uv signal
|
47 |
+
uv = torch.ones_like(f0)
|
48 |
+
uv = uv * (f0 > self.voiced_threshold)
|
49 |
+
return uv
|
50 |
+
|
51 |
+
def _f02sine(self, f0_values):
|
52 |
+
"""f0_values: (batchsize, length, dim)
|
53 |
+
where dim indicates fundamental tone and overtones
|
54 |
+
"""
|
55 |
+
# convert to F0 in rad. The interger part n can be ignored
|
56 |
+
# because 2 * np.pi * n doesn't affect phase
|
57 |
+
rad_values = (f0_values / self.sampling_rate) % 1
|
58 |
+
|
59 |
+
# initial phase noise (no noise for fundamental component)
|
60 |
+
rand_ini = torch.rand(
|
61 |
+
f0_values.shape[0], f0_values.shape[2], device=f0_values.device
|
62 |
+
)
|
63 |
+
rand_ini[:, 0] = 0
|
64 |
+
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
65 |
+
|
66 |
+
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
|
67 |
+
if not self.flag_for_pulse:
|
68 |
+
# for normal case
|
69 |
+
|
70 |
+
# To prevent torch.cumsum numerical overflow,
|
71 |
+
# it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
|
72 |
+
# Buffer tmp_over_one_idx indicates the time step to add -1.
|
73 |
+
# This will not change F0 of sine because (x-1) * 2*pi = x *2*pi
|
74 |
+
tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
75 |
+
tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
|
76 |
+
cumsum_shift = torch.zeros_like(rad_values)
|
77 |
+
cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
78 |
+
|
79 |
+
sines = torch.sin(
|
80 |
+
torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi
|
81 |
+
)
|
82 |
+
else:
|
83 |
+
# If necessary, make sure that the first time step of every
|
84 |
+
# voiced segments is sin(pi) or cos(0)
|
85 |
+
# This is used for pulse-train generation
|
86 |
+
|
87 |
+
# identify the last time step in unvoiced segments
|
88 |
+
uv = self._f02uv(f0_values)
|
89 |
+
uv_1 = torch.roll(uv, shifts=-1, dims=1)
|
90 |
+
uv_1[:, -1, :] = 1
|
91 |
+
u_loc = (uv < 1) * (uv_1 > 0)
|
92 |
+
|
93 |
+
# get the instantanouse phase
|
94 |
+
tmp_cumsum = torch.cumsum(rad_values, dim=1)
|
95 |
+
# different batch needs to be processed differently
|
96 |
+
for idx in range(f0_values.shape[0]):
|
97 |
+
temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
|
98 |
+
temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
|
99 |
+
# stores the accumulation of i.phase within
|
100 |
+
# each voiced segments
|
101 |
+
tmp_cumsum[idx, :, :] = 0
|
102 |
+
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
|
103 |
+
|
104 |
+
# rad_values - tmp_cumsum: remove the accumulation of i.phase
|
105 |
+
# within the previous voiced segment.
|
106 |
+
i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
|
107 |
+
|
108 |
+
# get the sines
|
109 |
+
sines = torch.cos(i_phase * 2 * np.pi)
|
110 |
+
return sines
|
111 |
+
|
112 |
+
def forward(self, f0):
|
113 |
+
"""sine_tensor, uv = forward(f0)
|
114 |
+
input F0: tensor(batchsize=1, length, dim=1)
|
115 |
+
f0 for unvoiced steps should be 0
|
116 |
+
output sine_tensor: tensor(batchsize=1, length, dim)
|
117 |
+
output uv: tensor(batchsize=1, length, 1)
|
118 |
+
"""
|
119 |
+
|
120 |
+
with torch.no_grad():
|
121 |
+
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
|
122 |
+
# fundamental component
|
123 |
+
f0_buf[:, :, 0] = f0[:, :, 0]
|
124 |
+
for idx in np.arange(self.harmonic_num):
|
125 |
+
# idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
|
126 |
+
f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)
|
127 |
+
|
128 |
+
# generate sine waveforms
|
129 |
+
sine_waves = self._f02sine(f0_buf) * self.sine_amp
|
130 |
+
|
131 |
+
# generate uv signal
|
132 |
+
# uv = torch.ones(f0.shape)
|
133 |
+
# uv = uv * (f0 > self.voiced_threshold)
|
134 |
+
uv = self._f02uv(f0)
|
135 |
+
|
136 |
+
# noise: for unvoiced should be similar to sine_amp
|
137 |
+
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
138 |
+
# . for voiced regions is self.noise_std
|
139 |
+
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
140 |
+
noise = noise_amp * torch.randn_like(sine_waves)
|
141 |
+
|
142 |
+
# first: set the unvoiced part to 0 by uv
|
143 |
+
# then: additive noise
|
144 |
+
sine_waves = sine_waves * uv + noise
|
145 |
+
return sine_waves, uv, noise
|
146 |
+
|
147 |
+
|
148 |
+
class LowpassBlur(nn.Module):
|
149 |
+
"""perform low pass filter after upsampling for anti-aliasing"""
|
150 |
+
|
151 |
+
def __init__(self, channels=128, filt_size=3, pad_type="reflect", pad_off=0):
|
152 |
+
super(LowpassBlur, self).__init__()
|
153 |
+
self.filt_size = filt_size
|
154 |
+
self.pad_off = pad_off
|
155 |
+
self.pad_sizes = [
|
156 |
+
int(1.0 * (filt_size - 1) / 2),
|
157 |
+
int(np.ceil(1.0 * (filt_size - 1) / 2)),
|
158 |
+
]
|
159 |
+
self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
|
160 |
+
self.off = 0
|
161 |
+
self.channels = channels
|
162 |
+
|
163 |
+
if self.filt_size == 1:
|
164 |
+
a = np.array(
|
165 |
+
[
|
166 |
+
1.0,
|
167 |
+
]
|
168 |
+
)
|
169 |
+
elif self.filt_size == 2:
|
170 |
+
a = np.array([1.0, 1.0])
|
171 |
+
elif self.filt_size == 3:
|
172 |
+
a = np.array([1.0, 2.0, 1.0])
|
173 |
+
elif self.filt_size == 4:
|
174 |
+
a = np.array([1.0, 3.0, 3.0, 1.0])
|
175 |
+
elif self.filt_size == 5:
|
176 |
+
a = np.array([1.0, 4.0, 6.0, 4.0, 1.0])
|
177 |
+
elif self.filt_size == 6:
|
178 |
+
a = np.array([1.0, 5.0, 10.0, 10.0, 5.0, 1.0])
|
179 |
+
elif self.filt_size == 7:
|
180 |
+
a = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0])
|
181 |
+
|
182 |
+
filt = torch.Tensor(a)
|
183 |
+
filt = filt / torch.sum(filt)
|
184 |
+
self.register_buffer("filt", filt[None, None, :].repeat((self.channels, 1, 1)))
|
185 |
+
|
186 |
+
self.pad = get_pad_layer_1d(pad_type)(self.pad_sizes)
|
187 |
+
|
188 |
+
def forward(self, inp):
|
189 |
+
if self.filt_size == 1:
|
190 |
+
return inp
|
191 |
+
return F.conv1d(self.pad(inp), self.filt, groups=inp.shape[1])
|
192 |
+
|
193 |
+
|
194 |
+
def get_pad_layer_1d(pad_type):
|
195 |
+
if pad_type in ["refl", "reflect"]:
|
196 |
+
PadLayer = nn.ReflectionPad1d
|
197 |
+
elif pad_type in ["repl", "replicate"]:
|
198 |
+
PadLayer = nn.ReplicationPad1d
|
199 |
+
elif pad_type == "zero":
|
200 |
+
PadLayer = nn.ZeroPad1d
|
201 |
+
else:
|
202 |
+
print("Pad type [%s] not recognized" % pad_type)
|
203 |
+
return PadLayer
|
204 |
+
|
205 |
+
|
206 |
+
class MovingAverageSmooth(torch.nn.Conv1d):
|
207 |
+
def __init__(self, channels, window_len=3):
|
208 |
+
"""Initialize Conv1d module."""
|
209 |
+
super(MovingAverageSmooth, self).__init__(
|
210 |
+
in_channels=channels,
|
211 |
+
out_channels=channels,
|
212 |
+
kernel_size=1,
|
213 |
+
groups=channels,
|
214 |
+
bias=False,
|
215 |
+
)
|
216 |
+
|
217 |
+
torch.nn.init.constant_(self.weight, 1.0 / window_len)
|
218 |
+
for p in self.parameters():
|
219 |
+
p.requires_grad = False
|
220 |
+
|
221 |
+
def forward(self, data):
|
222 |
+
return super(MovingAverageSmooth, self).forward(data)
|
223 |
+
|
224 |
+
|
225 |
+
class Conv1d(torch.nn.Conv1d):
|
226 |
+
"""Conv1d module with customized initialization."""
|
227 |
+
|
228 |
+
def __init__(self, *args, **kwargs):
|
229 |
+
"""Initialize Conv1d module."""
|
230 |
+
super(Conv1d, self).__init__(*args, **kwargs)
|
231 |
+
|
232 |
+
def reset_parameters(self):
|
233 |
+
"""Reset parameters."""
|
234 |
+
torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu")
|
235 |
+
if self.bias is not None:
|
236 |
+
torch.nn.init.constant_(self.bias, 0.0)
|
237 |
+
|
238 |
+
|
239 |
+
class Stretch2d(torch.nn.Module):
|
240 |
+
"""Stretch2d module."""
|
241 |
+
|
242 |
+
def __init__(self, x_scale, y_scale, mode="nearest"):
|
243 |
+
"""Initialize Stretch2d module.
|
244 |
+
Args:
|
245 |
+
x_scale (int): X scaling factor (Time axis in spectrogram).
|
246 |
+
y_scale (int): Y scaling factor (Frequency axis in spectrogram).
|
247 |
+
mode (str): Interpolation mode.
|
248 |
+
"""
|
249 |
+
super(Stretch2d, self).__init__()
|
250 |
+
self.x_scale = x_scale
|
251 |
+
self.y_scale = y_scale
|
252 |
+
self.mode = mode
|
253 |
+
|
254 |
+
def forward(self, x):
|
255 |
+
"""Calculate forward propagation.
|
256 |
+
Args:
|
257 |
+
x (Tensor): Input tensor (B, C, F, T).
|
258 |
+
Returns:
|
259 |
+
Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale),
|
260 |
+
"""
|
261 |
+
return F.interpolate(
|
262 |
+
x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode
|
263 |
+
)
|
264 |
+
|
265 |
+
|
266 |
+
class Conv2d(torch.nn.Conv2d):
|
267 |
+
"""Conv2d module with customized initialization."""
|
268 |
+
|
269 |
+
def __init__(self, *args, **kwargs):
|
270 |
+
"""Initialize Conv2d module."""
|
271 |
+
super(Conv2d, self).__init__(*args, **kwargs)
|
272 |
+
|
273 |
+
def reset_parameters(self):
|
274 |
+
"""Reset parameters."""
|
275 |
+
self.weight.data.fill_(1.0 / np.prod(self.kernel_size))
|
276 |
+
if self.bias is not None:
|
277 |
+
torch.nn.init.constant_(self.bias, 0.0)
|
278 |
+
|
279 |
+
|
280 |
+
class UpsampleNetwork(torch.nn.Module):
|
281 |
+
"""Upsampling network module."""
|
282 |
+
|
283 |
+
def __init__(
|
284 |
+
self,
|
285 |
+
upsample_scales,
|
286 |
+
nonlinear_activation=None,
|
287 |
+
nonlinear_activation_params={},
|
288 |
+
interpolate_mode="nearest",
|
289 |
+
freq_axis_kernel_size=1,
|
290 |
+
use_causal_conv=False,
|
291 |
+
):
|
292 |
+
"""Initialize upsampling network module.
|
293 |
+
Args:
|
294 |
+
upsample_scales (list): List of upsampling scales.
|
295 |
+
nonlinear_activation (str): Activation function name.
|
296 |
+
nonlinear_activation_params (dict): Arguments for specified activation function.
|
297 |
+
interpolate_mode (str): Interpolation mode.
|
298 |
+
freq_axis_kernel_size (int): Kernel size in the direction of frequency axis.
|
299 |
+
"""
|
300 |
+
super(UpsampleNetwork, self).__init__()
|
301 |
+
self.use_causal_conv = use_causal_conv
|
302 |
+
self.up_layers = torch.nn.ModuleList()
|
303 |
+
for scale in upsample_scales:
|
304 |
+
# interpolation layer
|
305 |
+
stretch = Stretch2d(scale, 1, interpolate_mode)
|
306 |
+
self.up_layers += [stretch]
|
307 |
+
|
308 |
+
# conv layer
|
309 |
+
assert (
|
310 |
+
freq_axis_kernel_size - 1
|
311 |
+
) % 2 == 0, "Not support even number freq axis kernel size."
|
312 |
+
freq_axis_padding = (freq_axis_kernel_size - 1) // 2
|
313 |
+
kernel_size = (freq_axis_kernel_size, scale * 2 + 1)
|
314 |
+
if use_causal_conv:
|
315 |
+
padding = (freq_axis_padding, scale * 2)
|
316 |
+
else:
|
317 |
+
padding = (freq_axis_padding, scale)
|
318 |
+
conv = Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False)
|
319 |
+
self.up_layers += [conv]
|
320 |
+
|
321 |
+
# nonlinear
|
322 |
+
if nonlinear_activation is not None:
|
323 |
+
nonlinear = getattr(torch.nn, nonlinear_activation)(
|
324 |
+
**nonlinear_activation_params
|
325 |
+
)
|
326 |
+
self.up_layers += [nonlinear]
|
327 |
+
|
328 |
+
def forward(self, c):
|
329 |
+
"""Calculate forward propagation.
|
330 |
+
Args:
|
331 |
+
c : Input tensor (B, C, T).
|
332 |
+
Returns:
|
333 |
+
Tensor: Upsampled tensor (B, C, T'), where T' = T * prod(upsample_scales).
|
334 |
+
"""
|
335 |
+
c = c.unsqueeze(1) # (B, 1, C, T)
|
336 |
+
for f in self.up_layers:
|
337 |
+
if self.use_causal_conv and isinstance(f, Conv2d):
|
338 |
+
c = f(c)[..., : c.size(-1)]
|
339 |
+
else:
|
340 |
+
c = f(c)
|
341 |
+
return c.squeeze(1) # (B, C, T')
|
342 |
+
|
343 |
+
|
344 |
+
class ConvInUpsampleNetwork(torch.nn.Module):
|
345 |
+
"""Convolution + upsampling network module."""
|
346 |
+
|
347 |
+
def __init__(
|
348 |
+
self,
|
349 |
+
upsample_scales=[3, 4, 5, 5],
|
350 |
+
nonlinear_activation="ReLU",
|
351 |
+
nonlinear_activation_params={},
|
352 |
+
interpolate_mode="nearest",
|
353 |
+
freq_axis_kernel_size=1,
|
354 |
+
aux_channels=80,
|
355 |
+
aux_context_window=0,
|
356 |
+
use_causal_conv=False,
|
357 |
+
):
|
358 |
+
"""Initialize convolution + upsampling network module.
|
359 |
+
Args:
|
360 |
+
upsample_scales (list): List of upsampling scales.
|
361 |
+
nonlinear_activation (str): Activation function name.
|
362 |
+
nonlinear_activation_params (dict): Arguments for specified activation function.
|
363 |
+
mode (str): Interpolation mode.
|
364 |
+
freq_axis_kernel_size (int): Kernel size in the direction of frequency axis.
|
365 |
+
aux_channels (int): Number of channels of pre-convolutional layer.
|
366 |
+
aux_context_window (int): Context window size of the pre-convolutional layer.
|
367 |
+
use_causal_conv (bool): Whether to use causal structure.
|
368 |
+
"""
|
369 |
+
super(ConvInUpsampleNetwork, self).__init__()
|
370 |
+
self.aux_context_window = aux_context_window
|
371 |
+
self.use_causal_conv = use_causal_conv and aux_context_window > 0
|
372 |
+
# To capture wide-context information in conditional features
|
373 |
+
kernel_size = (
|
374 |
+
aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1
|
375 |
+
)
|
376 |
+
# NOTE(kan-bayashi): Here do not use padding because the input is already padded
|
377 |
+
self.conv_in = Conv1d(
|
378 |
+
aux_channels, aux_channels, kernel_size=kernel_size, bias=False
|
379 |
+
)
|
380 |
+
self.upsample = UpsampleNetwork(
|
381 |
+
upsample_scales=upsample_scales,
|
382 |
+
nonlinear_activation=nonlinear_activation,
|
383 |
+
nonlinear_activation_params=nonlinear_activation_params,
|
384 |
+
interpolate_mode=interpolate_mode,
|
385 |
+
freq_axis_kernel_size=freq_axis_kernel_size,
|
386 |
+
use_causal_conv=use_causal_conv,
|
387 |
+
)
|
388 |
+
|
389 |
+
def forward(self, c):
|
390 |
+
"""Calculate forward propagation.
|
391 |
+
Args:
|
392 |
+
c : Input tensor (B, C, T').
|
393 |
+
Returns:
|
394 |
+
Tensor: Upsampled tensor (B, C, T),
|
395 |
+
where T = (T' - aux_context_window * 2) * prod(upsample_scales).
|
396 |
+
Note:
|
397 |
+
The length of inputs considers the context window size.
|
398 |
+
"""
|
399 |
+
c_ = self.conv_in(c)
|
400 |
+
c = c_[:, :, : -self.aux_context_window] if self.use_causal_conv else c_
|
401 |
+
return self.upsample(c)
|
402 |
+
|
403 |
+
|
404 |
+
class DownsampleNet(nn.Module):
|
405 |
+
def __init__(self, input_size, output_size, upsample_factor, hp=None, index=0):
|
406 |
+
super(DownsampleNet, self).__init__()
|
407 |
+
self.input_size = input_size
|
408 |
+
self.output_size = output_size
|
409 |
+
self.upsample_factor = upsample_factor
|
410 |
+
self.skip_conv = nn.Conv1d(input_size, output_size, kernel_size=1)
|
411 |
+
self.index = index
|
412 |
+
layer = nn.Conv1d(
|
413 |
+
input_size,
|
414 |
+
output_size,
|
415 |
+
kernel_size=upsample_factor * 2,
|
416 |
+
stride=upsample_factor,
|
417 |
+
padding=upsample_factor // 2 + upsample_factor % 2,
|
418 |
+
)
|
419 |
+
|
420 |
+
self.layer = nn.utils.weight_norm(layer)
|
421 |
+
|
422 |
+
def forward(self, inputs):
|
423 |
+
B, C, T = inputs.size()
|
424 |
+
res = inputs[:, :, :: self.upsample_factor]
|
425 |
+
skip = self.skip_conv(res)
|
426 |
+
|
427 |
+
outputs = self.layer(inputs)
|
428 |
+
outputs = outputs + skip
|
429 |
+
|
430 |
+
return outputs
|
431 |
+
|
432 |
+
|
433 |
+
class UpsampleNet(nn.Module):
|
434 |
+
def __init__(self, input_size, output_size, upsample_factor, hp=None, index=0):
|
435 |
+
|
436 |
+
super(UpsampleNet, self).__init__()
|
437 |
+
self.up_type = Config.up_type
|
438 |
+
self.use_smooth = Config.use_smooth
|
439 |
+
self.use_drop = Config.use_drop
|
440 |
+
self.input_size = input_size
|
441 |
+
self.output_size = output_size
|
442 |
+
self.upsample_factor = upsample_factor
|
443 |
+
self.skip_conv = nn.Conv1d(input_size, output_size, kernel_size=1)
|
444 |
+
self.index = index
|
445 |
+
if self.use_smooth:
|
446 |
+
window_lens = [5, 5, 4, 3]
|
447 |
+
self.window_len = window_lens[index]
|
448 |
+
|
449 |
+
if self.up_type != "pn" or self.index < 3:
|
450 |
+
# if self.up_type != "pn":
|
451 |
+
layer = nn.ConvTranspose1d(
|
452 |
+
input_size,
|
453 |
+
output_size,
|
454 |
+
upsample_factor * 2,
|
455 |
+
upsample_factor,
|
456 |
+
padding=upsample_factor // 2 + upsample_factor % 2,
|
457 |
+
output_padding=upsample_factor % 2,
|
458 |
+
)
|
459 |
+
self.layer = nn.utils.weight_norm(layer)
|
460 |
+
else:
|
461 |
+
self.layer = nn.Sequential(
|
462 |
+
nn.ReflectionPad1d(1),
|
463 |
+
nn.utils.weight_norm(
|
464 |
+
nn.Conv1d(input_size, output_size * upsample_factor, kernel_size=3)
|
465 |
+
),
|
466 |
+
nn.LeakyReLU(),
|
467 |
+
nn.ReflectionPad1d(1),
|
468 |
+
nn.utils.weight_norm(
|
469 |
+
nn.Conv1d(
|
470 |
+
output_size * upsample_factor,
|
471 |
+
output_size * upsample_factor,
|
472 |
+
kernel_size=3,
|
473 |
+
)
|
474 |
+
),
|
475 |
+
nn.LeakyReLU(),
|
476 |
+
nn.ReflectionPad1d(1),
|
477 |
+
nn.utils.weight_norm(
|
478 |
+
nn.Conv1d(
|
479 |
+
output_size * upsample_factor,
|
480 |
+
output_size * upsample_factor,
|
481 |
+
kernel_size=3,
|
482 |
+
)
|
483 |
+
),
|
484 |
+
nn.LeakyReLU(),
|
485 |
+
)
|
486 |
+
|
487 |
+
if hp is not None:
|
488 |
+
self.org = Config.up_org
|
489 |
+
self.no_skip = Config.no_skip
|
490 |
+
else:
|
491 |
+
self.org = False
|
492 |
+
self.no_skip = True
|
493 |
+
|
494 |
+
if self.use_smooth:
|
495 |
+
self.mas = nn.Sequential(
|
496 |
+
# LowpassBlur(output_size, self.window_len),
|
497 |
+
MovingAverageSmooth(output_size, self.window_len),
|
498 |
+
# MovingAverageSmooth(output_size, self.window_len),
|
499 |
+
)
|
500 |
+
|
501 |
+
def forward(self, inputs):
|
502 |
+
|
503 |
+
if not self.org:
|
504 |
+
inputs = inputs + torch.sin(inputs)
|
505 |
+
B, C, T = inputs.size()
|
506 |
+
res = inputs.repeat(1, self.upsample_factor, 1).view(B, C, -1)
|
507 |
+
skip = self.skip_conv(res)
|
508 |
+
if self.up_type == "repeat":
|
509 |
+
return skip
|
510 |
+
|
511 |
+
outputs = self.layer(inputs)
|
512 |
+
if self.up_type == "pn" and self.index > 2:
|
513 |
+
B, c, l = outputs.size()
|
514 |
+
outputs = outputs.view(B, -1, l * self.upsample_factor)
|
515 |
+
|
516 |
+
if self.no_skip:
|
517 |
+
return outputs
|
518 |
+
|
519 |
+
if not self.org:
|
520 |
+
outputs = outputs + skip
|
521 |
+
|
522 |
+
if self.use_smooth:
|
523 |
+
outputs = self.mas(outputs)
|
524 |
+
|
525 |
+
if self.use_drop:
|
526 |
+
outputs = F.dropout(outputs, p=0.05)
|
527 |
+
|
528 |
+
return outputs
|
529 |
+
|
530 |
+
|
531 |
+
class ResStack(nn.Module):
|
532 |
+
def __init__(self, channel, kernel_size=3, resstack_depth=4, hp=None):
|
533 |
+
super(ResStack, self).__init__()
|
534 |
+
|
535 |
+
self.use_wn = Config.use_wn
|
536 |
+
self.use_shift_scale = Config.use_shift_scale
|
537 |
+
self.channel = channel
|
538 |
+
|
539 |
+
def get_padding(kernel_size, dilation=1):
|
540 |
+
return int((kernel_size * dilation - dilation) / 2)
|
541 |
+
|
542 |
+
if self.use_shift_scale:
|
543 |
+
self.scale_conv = nn.utils.weight_norm(
|
544 |
+
nn.Conv1d(
|
545 |
+
channel, 2 * channel, kernel_size=kernel_size, dilation=1, padding=1
|
546 |
+
)
|
547 |
+
)
|
548 |
+
|
549 |
+
if not self.use_wn:
|
550 |
+
self.layers = nn.ModuleList(
|
551 |
+
[
|
552 |
+
nn.Sequential(
|
553 |
+
nn.LeakyReLU(),
|
554 |
+
nn.utils.weight_norm(
|
555 |
+
nn.Conv1d(
|
556 |
+
channel,
|
557 |
+
channel,
|
558 |
+
kernel_size=kernel_size,
|
559 |
+
dilation=3 ** (i % 10),
|
560 |
+
padding=get_padding(kernel_size, 3 ** (i % 10)),
|
561 |
+
)
|
562 |
+
),
|
563 |
+
nn.LeakyReLU(),
|
564 |
+
nn.utils.weight_norm(
|
565 |
+
nn.Conv1d(
|
566 |
+
channel,
|
567 |
+
channel,
|
568 |
+
kernel_size=kernel_size,
|
569 |
+
dilation=1,
|
570 |
+
padding=get_padding(kernel_size, 1),
|
571 |
+
)
|
572 |
+
),
|
573 |
+
)
|
574 |
+
for i in range(resstack_depth)
|
575 |
+
]
|
576 |
+
)
|
577 |
+
else:
|
578 |
+
self.wn = WaveNet(
|
579 |
+
in_channels=channel,
|
580 |
+
out_channels=channel,
|
581 |
+
cin_channels=-1,
|
582 |
+
num_layers=resstack_depth,
|
583 |
+
residual_channels=channel,
|
584 |
+
gate_channels=channel,
|
585 |
+
skip_channels=channel,
|
586 |
+
# kernel_size=5,
|
587 |
+
# dilation_rate=3,
|
588 |
+
causal=False,
|
589 |
+
use_downup=False,
|
590 |
+
)
|
591 |
+
|
592 |
+
def forward(self, x):
|
593 |
+
if not self.use_wn:
|
594 |
+
for layer in self.layers:
|
595 |
+
x = x + layer(x)
|
596 |
+
else:
|
597 |
+
x = self.wn(x)
|
598 |
+
|
599 |
+
if self.use_shift_scale:
|
600 |
+
m_s = self.scale_conv(x)
|
601 |
+
m_s = m_s[:, :, :-1]
|
602 |
+
|
603 |
+
m, s = torch.split(m_s, self.channel, dim=1)
|
604 |
+
s = F.softplus(s)
|
605 |
+
|
606 |
+
x = m + s * x[:, :, 1:] # key!!!
|
607 |
+
x = F.pad(x, pad=(1, 0), mode="constant", value=0)
|
608 |
+
|
609 |
+
return x
|
610 |
+
|
611 |
+
|
612 |
+
class WaveNet(nn.Module):
|
613 |
+
def __init__(
|
614 |
+
self,
|
615 |
+
in_channels=1,
|
616 |
+
out_channels=1,
|
617 |
+
num_layers=10,
|
618 |
+
residual_channels=64,
|
619 |
+
gate_channels=64,
|
620 |
+
skip_channels=64,
|
621 |
+
kernel_size=3,
|
622 |
+
dilation_rate=2,
|
623 |
+
cin_channels=80,
|
624 |
+
hp=None,
|
625 |
+
causal=False,
|
626 |
+
use_downup=False,
|
627 |
+
):
|
628 |
+
super(WaveNet, self).__init__()
|
629 |
+
|
630 |
+
self.in_channels = in_channels
|
631 |
+
self.causal = causal
|
632 |
+
self.num_layers = num_layers
|
633 |
+
self.out_channels = out_channels
|
634 |
+
self.gate_channels = gate_channels
|
635 |
+
self.residual_channels = residual_channels
|
636 |
+
self.skip_channels = skip_channels
|
637 |
+
self.cin_channels = cin_channels
|
638 |
+
self.kernel_size = kernel_size
|
639 |
+
self.use_downup = use_downup
|
640 |
+
|
641 |
+
self.front_conv = nn.Sequential(
|
642 |
+
nn.Conv1d(
|
643 |
+
in_channels=self.in_channels,
|
644 |
+
out_channels=self.residual_channels,
|
645 |
+
kernel_size=3,
|
646 |
+
padding=1,
|
647 |
+
),
|
648 |
+
nn.ReLU(),
|
649 |
+
)
|
650 |
+
if self.use_downup:
|
651 |
+
self.downup_conv = nn.Sequential(
|
652 |
+
nn.Conv1d(
|
653 |
+
in_channels=self.residual_channels,
|
654 |
+
out_channels=self.residual_channels,
|
655 |
+
kernel_size=3,
|
656 |
+
stride=2,
|
657 |
+
padding=1,
|
658 |
+
),
|
659 |
+
nn.ReLU(),
|
660 |
+
nn.Conv1d(
|
661 |
+
in_channels=self.residual_channels,
|
662 |
+
out_channels=self.residual_channels,
|
663 |
+
kernel_size=3,
|
664 |
+
stride=2,
|
665 |
+
padding=1,
|
666 |
+
),
|
667 |
+
nn.ReLU(),
|
668 |
+
UpsampleNet(self.residual_channels, self.residual_channels, 4, hp),
|
669 |
+
)
|
670 |
+
|
671 |
+
self.res_blocks = nn.ModuleList()
|
672 |
+
for n in range(self.num_layers):
|
673 |
+
self.res_blocks.append(
|
674 |
+
ResBlock(
|
675 |
+
self.residual_channels,
|
676 |
+
self.gate_channels,
|
677 |
+
self.skip_channels,
|
678 |
+
self.kernel_size,
|
679 |
+
dilation=dilation_rate**n,
|
680 |
+
cin_channels=self.cin_channels,
|
681 |
+
local_conditioning=(self.cin_channels > 0),
|
682 |
+
causal=self.causal,
|
683 |
+
mode="SAME",
|
684 |
+
)
|
685 |
+
)
|
686 |
+
self.final_conv = nn.Sequential(
|
687 |
+
nn.ReLU(),
|
688 |
+
Conv(self.skip_channels, self.skip_channels, 1, causal=self.causal),
|
689 |
+
nn.ReLU(),
|
690 |
+
Conv(self.skip_channels, self.out_channels, 1, causal=self.causal),
|
691 |
+
)
|
692 |
+
|
693 |
+
def forward(self, x, c=None):
|
694 |
+
return self.wavenet(x, c)
|
695 |
+
|
696 |
+
def wavenet(self, tensor, c=None):
|
697 |
+
|
698 |
+
h = self.front_conv(tensor)
|
699 |
+
if self.use_downup:
|
700 |
+
h = self.downup_conv(h)
|
701 |
+
skip = 0
|
702 |
+
for i, f in enumerate(self.res_blocks):
|
703 |
+
h, s = f(h, c)
|
704 |
+
skip += s
|
705 |
+
out = self.final_conv(skip)
|
706 |
+
return out
|
707 |
+
|
708 |
+
def receptive_field_size(self):
|
709 |
+
num_dir = 1 if self.causal else 2
|
710 |
+
dilations = [2 ** (i % self.num_layers) for i in range(self.num_layers)]
|
711 |
+
return (
|
712 |
+
num_dir * (self.kernel_size - 1) * sum(dilations)
|
713 |
+
+ 1
|
714 |
+
+ (self.front_channels - 1)
|
715 |
+
)
|
716 |
+
|
717 |
+
def remove_weight_norm(self):
|
718 |
+
for f in self.res_blocks:
|
719 |
+
f.remove_weight_norm()
|
720 |
+
|
721 |
+
|
722 |
+
class Conv(nn.Module):
|
723 |
+
def __init__(
|
724 |
+
self,
|
725 |
+
in_channels,
|
726 |
+
out_channels,
|
727 |
+
kernel_size,
|
728 |
+
dilation=1,
|
729 |
+
causal=False,
|
730 |
+
mode="SAME",
|
731 |
+
):
|
732 |
+
super(Conv, self).__init__()
|
733 |
+
|
734 |
+
self.causal = causal
|
735 |
+
self.mode = mode
|
736 |
+
if self.causal and self.mode == "SAME":
|
737 |
+
self.padding = dilation * (kernel_size - 1)
|
738 |
+
elif self.mode == "SAME":
|
739 |
+
self.padding = dilation * (kernel_size - 1) // 2
|
740 |
+
else:
|
741 |
+
self.padding = 0
|
742 |
+
self.conv = nn.Conv1d(
|
743 |
+
in_channels,
|
744 |
+
out_channels,
|
745 |
+
kernel_size,
|
746 |
+
dilation=dilation,
|
747 |
+
padding=self.padding,
|
748 |
+
)
|
749 |
+
self.conv = nn.utils.weight_norm(self.conv)
|
750 |
+
nn.init.kaiming_normal_(self.conv.weight)
|
751 |
+
|
752 |
+
def forward(self, tensor):
|
753 |
+
out = self.conv(tensor)
|
754 |
+
if self.causal and self.padding is not 0:
|
755 |
+
out = out[:, :, : -self.padding]
|
756 |
+
return out
|
757 |
+
|
758 |
+
def remove_weight_norm(self):
|
759 |
+
nn.utils.remove_weight_norm(self.conv)
|
760 |
+
|
761 |
+
|
762 |
+
class ResBlock(nn.Module):
|
763 |
+
def __init__(
|
764 |
+
self,
|
765 |
+
in_channels,
|
766 |
+
out_channels,
|
767 |
+
skip_channels,
|
768 |
+
kernel_size,
|
769 |
+
dilation,
|
770 |
+
cin_channels=None,
|
771 |
+
local_conditioning=True,
|
772 |
+
causal=False,
|
773 |
+
mode="SAME",
|
774 |
+
):
|
775 |
+
super(ResBlock, self).__init__()
|
776 |
+
self.causal = causal
|
777 |
+
self.local_conditioning = local_conditioning
|
778 |
+
self.cin_channels = cin_channels
|
779 |
+
self.mode = mode
|
780 |
+
|
781 |
+
self.filter_conv = Conv(
|
782 |
+
in_channels, out_channels, kernel_size, dilation, causal, mode
|
783 |
+
)
|
784 |
+
self.gate_conv = Conv(
|
785 |
+
in_channels, out_channels, kernel_size, dilation, causal, mode
|
786 |
+
)
|
787 |
+
self.res_conv = nn.Conv1d(out_channels, in_channels, kernel_size=1)
|
788 |
+
self.skip_conv = nn.Conv1d(out_channels, skip_channels, kernel_size=1)
|
789 |
+
self.res_conv = nn.utils.weight_norm(self.res_conv)
|
790 |
+
self.skip_conv = nn.utils.weight_norm(self.skip_conv)
|
791 |
+
|
792 |
+
if self.local_conditioning:
|
793 |
+
self.filter_conv_c = nn.Conv1d(cin_channels, out_channels, kernel_size=1)
|
794 |
+
self.gate_conv_c = nn.Conv1d(cin_channels, out_channels, kernel_size=1)
|
795 |
+
self.filter_conv_c = nn.utils.weight_norm(self.filter_conv_c)
|
796 |
+
self.gate_conv_c = nn.utils.weight_norm(self.gate_conv_c)
|
797 |
+
|
798 |
+
def forward(self, tensor, c=None):
|
799 |
+
h_filter = self.filter_conv(tensor)
|
800 |
+
h_gate = self.gate_conv(tensor)
|
801 |
+
|
802 |
+
if self.local_conditioning:
|
803 |
+
h_filter += self.filter_conv_c(c)
|
804 |
+
h_gate += self.gate_conv_c(c)
|
805 |
+
|
806 |
+
out = torch.tanh(h_filter) * torch.sigmoid(h_gate)
|
807 |
+
|
808 |
+
res = self.res_conv(out)
|
809 |
+
skip = self.skip_conv(out)
|
810 |
+
if self.mode == "SAME":
|
811 |
+
return (tensor + res) * math.sqrt(0.5), skip
|
812 |
+
else:
|
813 |
+
return (tensor[:, :, 1:] + res) * math.sqrt(0.5), skip
|
814 |
+
|
815 |
+
def remove_weight_norm(self):
|
816 |
+
self.filter_conv.remove_weight_norm()
|
817 |
+
self.gate_conv.remove_weight_norm()
|
818 |
+
nn.utils.remove_weight_norm(self.res_conv)
|
819 |
+
nn.utils.remove_weight_norm(self.skip_conv)
|
820 |
+
nn.utils.remove_weight_norm(self.filter_conv_c)
|
821 |
+
nn.utils.remove_weight_norm(self.gate_conv_c)
|
822 |
+
|
823 |
+
|
824 |
+
@torch.jit.script
|
825 |
+
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
826 |
+
n_channels_int = n_channels[0]
|
827 |
+
in_act = input_a + input_b
|
828 |
+
t_act = torch.tanh(in_act[:, :n_channels_int])
|
829 |
+
s_act = torch.sigmoid(in_act[:, n_channels_int:])
|
830 |
+
acts = t_act * s_act
|
831 |
+
return acts
|
832 |
+
|
833 |
+
|
834 |
+
@torch.jit.script
|
835 |
+
def fused_res_skip(tensor, res_skip, n_channels):
|
836 |
+
n_channels_int = n_channels[0]
|
837 |
+
res = res_skip[:, :n_channels_int]
|
838 |
+
skip = res_skip[:, n_channels_int:]
|
839 |
+
return (tensor + res), skip
|
840 |
+
|
841 |
+
|
842 |
+
class ResStack2D(nn.Module):
|
843 |
+
def __init__(self, channels=16, kernel_size=3, resstack_depth=4, hp=None):
|
844 |
+
super(ResStack2D, self).__init__()
|
845 |
+
channels = 16
|
846 |
+
kernel_size = 3
|
847 |
+
resstack_depth = 2
|
848 |
+
self.channels = channels
|
849 |
+
|
850 |
+
def get_padding(kernel_size, dilation=1):
|
851 |
+
return int((kernel_size * dilation - dilation) / 2)
|
852 |
+
|
853 |
+
self.layers = nn.ModuleList(
|
854 |
+
[
|
855 |
+
nn.Sequential(
|
856 |
+
nn.LeakyReLU(),
|
857 |
+
nn.utils.weight_norm(
|
858 |
+
nn.Conv2d(
|
859 |
+
1,
|
860 |
+
self.channels,
|
861 |
+
kernel_size,
|
862 |
+
dilation=(1, 3 ** (i)),
|
863 |
+
padding=(1, get_padding(kernel_size, 3 ** (i))),
|
864 |
+
)
|
865 |
+
),
|
866 |
+
nn.LeakyReLU(),
|
867 |
+
nn.utils.weight_norm(
|
868 |
+
nn.Conv2d(
|
869 |
+
self.channels,
|
870 |
+
self.channels,
|
871 |
+
kernel_size,
|
872 |
+
dilation=(1, 3 ** (i)),
|
873 |
+
padding=(1, get_padding(kernel_size, 3 ** (i))),
|
874 |
+
)
|
875 |
+
),
|
876 |
+
nn.LeakyReLU(),
|
877 |
+
nn.utils.weight_norm(nn.Conv2d(self.channels, 1, kernel_size=1)),
|
878 |
+
)
|
879 |
+
for i in range(resstack_depth)
|
880 |
+
]
|
881 |
+
)
|
882 |
+
|
883 |
+
def forward(self, tensor):
|
884 |
+
x = tensor.unsqueeze(1)
|
885 |
+
for layer in self.layers:
|
886 |
+
x = x + layer(x)
|
887 |
+
x = x.squeeze(1)
|
888 |
+
|
889 |
+
return x
|
890 |
+
|
891 |
+
|
892 |
+
class FiLM(nn.Module):
|
893 |
+
"""
|
894 |
+
feature-wise linear modulation
|
895 |
+
"""
|
896 |
+
|
897 |
+
def __init__(self, input_dim, attribute_dim):
|
898 |
+
super().__init__()
|
899 |
+
self.input_dim = input_dim
|
900 |
+
self.generator = nn.Conv1d(
|
901 |
+
attribute_dim, input_dim * 2, kernel_size=3, padding=1
|
902 |
+
)
|
903 |
+
|
904 |
+
def forward(self, x, c):
|
905 |
+
"""
|
906 |
+
x: (B, input_dim, seq)
|
907 |
+
c: (B, attribute_dim, seq)
|
908 |
+
"""
|
909 |
+
c = self.generator(c)
|
910 |
+
m, s = torch.split(c, self.input_dim, dim=1)
|
911 |
+
|
912 |
+
return x * s + m
|
913 |
+
|
914 |
+
|
915 |
+
class FiLMConv1d(nn.Module):
|
916 |
+
"""
|
917 |
+
Conv1d with FiLMs in between
|
918 |
+
"""
|
919 |
+
|
920 |
+
def __init__(self, in_size, out_size, attribute_dim, ins_norm=True, loop=1):
|
921 |
+
super().__init__()
|
922 |
+
self.loop = loop
|
923 |
+
self.mlps = nn.ModuleList(
|
924 |
+
[nn.Conv1d(in_size, out_size, kernel_size=3, padding=1)]
|
925 |
+
+ [
|
926 |
+
nn.Conv1d(out_size, out_size, kernel_size=3, padding=1)
|
927 |
+
for i in range(loop - 1)
|
928 |
+
]
|
929 |
+
)
|
930 |
+
self.films = nn.ModuleList([FiLM(out_size, attribute_dim) for i in range(loop)])
|
931 |
+
self.ins_norm = ins_norm
|
932 |
+
if self.ins_norm:
|
933 |
+
self.norm = nn.InstanceNorm1d(attribute_dim)
|
934 |
+
|
935 |
+
def forward(self, x, c):
|
936 |
+
"""
|
937 |
+
x: (B, input_dim, seq)
|
938 |
+
c: (B, attribute_dim, seq)
|
939 |
+
"""
|
940 |
+
if self.ins_norm:
|
941 |
+
c = self.norm(c)
|
942 |
+
for i in range(self.loop):
|
943 |
+
x = self.mlps[i](x)
|
944 |
+
x = F.relu(x)
|
945 |
+
x = self.films[i](x, c)
|
946 |
+
|
947 |
+
return x
|
voicefixer/vocoder/model/pqmf.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import numpy as np
|
6 |
+
import scipy.io.wavfile
|
7 |
+
|
8 |
+
|
9 |
+
class PQMF(nn.Module):
|
10 |
+
def __init__(self, N, M, file_path="utils/pqmf_hk_4_64.dat"):
|
11 |
+
super().__init__()
|
12 |
+
self.N = N # nsubband
|
13 |
+
self.M = M # nfilter
|
14 |
+
self.ana_conv_filter = nn.Conv1d(
|
15 |
+
1, out_channels=N, kernel_size=M, stride=N, bias=False
|
16 |
+
)
|
17 |
+
data = np.reshape(np.fromfile(file_path, dtype=np.float32), (N, M))
|
18 |
+
data = np.flipud(data.T).T
|
19 |
+
gk = data.copy()
|
20 |
+
data = np.reshape(data, (N, 1, M)).copy()
|
21 |
+
dict_new = self.ana_conv_filter.state_dict().copy()
|
22 |
+
dict_new["weight"] = torch.from_numpy(data)
|
23 |
+
self.ana_pad = nn.ConstantPad1d((M - N, 0), 0)
|
24 |
+
self.ana_conv_filter.load_state_dict(dict_new)
|
25 |
+
|
26 |
+
self.syn_pad = nn.ConstantPad1d((0, M // N - 1), 0)
|
27 |
+
self.syn_conv_filter = nn.Conv1d(
|
28 |
+
N, out_channels=N, kernel_size=M // N, stride=1, bias=False
|
29 |
+
)
|
30 |
+
gk = np.transpose(np.reshape(gk, (4, 16, 4)), (1, 0, 2)) * N
|
31 |
+
gk = np.transpose(gk[::-1, :, :], (2, 1, 0)).copy()
|
32 |
+
dict_new = self.syn_conv_filter.state_dict().copy()
|
33 |
+
dict_new["weight"] = torch.from_numpy(gk)
|
34 |
+
self.syn_conv_filter.load_state_dict(dict_new)
|
35 |
+
|
36 |
+
for param in self.parameters():
|
37 |
+
param.requires_grad = False
|
38 |
+
|
39 |
+
def analysis(self, inputs):
|
40 |
+
return self.ana_conv_filter(self.ana_pad(inputs))
|
41 |
+
|
42 |
+
def synthesis(self, inputs):
|
43 |
+
return self.syn_conv_filter(self.syn_pad(inputs))
|
44 |
+
|
45 |
+
def forward(self, inputs):
|
46 |
+
return self.ana_conv_filter(self.ana_pad(inputs))
|
47 |
+
|
48 |
+
|
49 |
+
if __name__ == "__main__":
|
50 |
+
a = PQMF(4, 64)
|
51 |
+
# x = np.load('data/train/audio/010000.npy')
|
52 |
+
x = np.zeros([8, 24000], np.float32)
|
53 |
+
x = np.reshape(x, (8, 1, -1))
|
54 |
+
x = torch.from_numpy(x)
|
55 |
+
b = a.analysis(x)
|
56 |
+
c = a.synthesis(b)
|
57 |
+
print(x.shape, b.shape, c.shape)
|
58 |
+
b = (b * 32768).numpy()
|
59 |
+
b = np.reshape(np.transpose(b, (0, 2, 1)), (-1, 1)).astype(np.int16)
|
60 |
+
# b.tofile('1.pcm')
|
61 |
+
# np.reshape(np.transpose(c.numpy()*32768, (0, 2, 1)), (-1,1)).astype(np.int16).tofile('2.pcm')
|
voicefixer/vocoder/model/res_msd.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
5 |
+
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
6 |
+
|
7 |
+
LRELU_SLOPE = 0.1
|
8 |
+
|
9 |
+
|
10 |
+
def init_weights(m, mean=0.0, std=0.01):
|
11 |
+
classname = m.__class__.__name__
|
12 |
+
if classname.find("Conv") != -1:
|
13 |
+
m.weight.data.normal_(mean, std)
|
14 |
+
|
15 |
+
|
16 |
+
def get_padding(kernel_size, dilation=1):
|
17 |
+
return int((kernel_size * dilation - dilation) / 2)
|
18 |
+
|
19 |
+
|
20 |
+
class ResStack(nn.Module):
|
21 |
+
def __init__(self, channels=384, kernel_size=3, resstack_depth=3, hp=None):
|
22 |
+
super(ResStack, self).__init__()
|
23 |
+
dilation = [2 * i + 1 for i in range(resstack_depth)] # [1, 3, 5]
|
24 |
+
self.convs1 = nn.ModuleList(
|
25 |
+
[
|
26 |
+
weight_norm(
|
27 |
+
Conv1d(
|
28 |
+
channels,
|
29 |
+
channels,
|
30 |
+
kernel_size,
|
31 |
+
1,
|
32 |
+
dilation=dilation[i],
|
33 |
+
padding=get_padding(kernel_size, dilation[i]),
|
34 |
+
)
|
35 |
+
)
|
36 |
+
for i in range(resstack_depth)
|
37 |
+
]
|
38 |
+
)
|
39 |
+
self.convs1.apply(init_weights)
|
40 |
+
|
41 |
+
self.convs2 = nn.ModuleList(
|
42 |
+
[
|
43 |
+
weight_norm(
|
44 |
+
Conv1d(
|
45 |
+
channels,
|
46 |
+
channels,
|
47 |
+
kernel_size,
|
48 |
+
1,
|
49 |
+
dilation=1,
|
50 |
+
padding=get_padding(kernel_size, 1),
|
51 |
+
)
|
52 |
+
)
|
53 |
+
for i in range(resstack_depth)
|
54 |
+
]
|
55 |
+
)
|
56 |
+
self.convs2.apply(init_weights)
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
60 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
61 |
+
xt = c1(xt)
|
62 |
+
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
63 |
+
xt = c2(xt)
|
64 |
+
x = xt + x
|
65 |
+
return x
|
66 |
+
|
67 |
+
def remove_weight_norm(self):
|
68 |
+
for l in self.convs1:
|
69 |
+
remove_weight_norm(l)
|
70 |
+
for l in self.convs2:
|
71 |
+
remove_weight_norm(l)
|
voicefixer/vocoder/model/util.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from voicefixer.vocoder.config import Config
|
2 |
+
from voicefixer.tools.pytorch_util import try_tensor_cuda, check_cuda_availability
|
3 |
+
import torch
|
4 |
+
import librosa
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
def tr_normalize(S):
|
9 |
+
if Config.allow_clipping_in_normalization:
|
10 |
+
if Config.symmetric_mels:
|
11 |
+
return torch.clip(
|
12 |
+
(2 * Config.max_abs_value) * ((S - Config.min_db) / (-Config.min_db))
|
13 |
+
- Config.max_abs_value,
|
14 |
+
-Config.max_abs_value,
|
15 |
+
Config.max_abs_value,
|
16 |
+
)
|
17 |
+
else:
|
18 |
+
return torch.clip(
|
19 |
+
Config.max_abs_value * ((S - Config.min_db) / (-Config.min_db)),
|
20 |
+
0,
|
21 |
+
Config.max_abs_value,
|
22 |
+
)
|
23 |
+
|
24 |
+
assert S.max() <= 0 and S.min() - Config.min_db >= 0
|
25 |
+
if Config.symmetric_mels:
|
26 |
+
return (2 * Config.max_abs_value) * (
|
27 |
+
(S - Config.min_db) / (-Config.min_db)
|
28 |
+
) - Config.max_abs_value
|
29 |
+
else:
|
30 |
+
return Config.max_abs_value * ((S - Config.min_db) / (-Config.min_db))
|
31 |
+
|
32 |
+
|
33 |
+
def tr_amp_to_db(x):
|
34 |
+
min_level = torch.exp(Config.min_level_db / 20 * torch.log(torch.tensor(10.0)))
|
35 |
+
min_level = min_level.type_as(x)
|
36 |
+
return 20 * torch.log10(torch.maximum(min_level, x))
|
37 |
+
|
38 |
+
|
39 |
+
def normalize(S):
|
40 |
+
if Config.allow_clipping_in_normalization:
|
41 |
+
if Config.symmetric_mels:
|
42 |
+
return np.clip(
|
43 |
+
(2 * Config.max_abs_value) * ((S - Config.min_db) / (-Config.min_db))
|
44 |
+
- Config.max_abs_value,
|
45 |
+
-Config.max_abs_value,
|
46 |
+
Config.max_abs_value,
|
47 |
+
)
|
48 |
+
else:
|
49 |
+
return np.clip(
|
50 |
+
Config.max_abs_value * ((S - Config.min_db) / (-Config.min_db)),
|
51 |
+
0,
|
52 |
+
Config.max_abs_value,
|
53 |
+
)
|
54 |
+
|
55 |
+
assert S.max() <= 0 and S.min() - Config.min_db >= 0
|
56 |
+
if Config.symmetric_mels:
|
57 |
+
return (2 * Config.max_abs_value) * (
|
58 |
+
(S - Config.min_db) / (-Config.min_db)
|
59 |
+
) - Config.max_abs_value
|
60 |
+
else:
|
61 |
+
return Config.max_abs_value * ((S - Config.min_db) / (-Config.min_db))
|
62 |
+
|
63 |
+
|
64 |
+
def amp_to_db(x):
|
65 |
+
min_level = np.exp(Config.min_level_db / 20 * np.log(10))
|
66 |
+
return 20 * np.log10(np.maximum(min_level, x))
|
67 |
+
|
68 |
+
|
69 |
+
def tr_pre(npy):
|
70 |
+
# conditions = torch.FloatTensor(npy).type_as(npy) # to(device)
|
71 |
+
conditions = npy.transpose(1, 2)
|
72 |
+
l = conditions.size(-1)
|
73 |
+
pad_tail = l % 2 + 4
|
74 |
+
zeros = (
|
75 |
+
torch.zeros([conditions.size()[0], Config.num_mels, pad_tail]).type_as(
|
76 |
+
conditions
|
77 |
+
)
|
78 |
+
+ -4.0
|
79 |
+
)
|
80 |
+
return torch.cat([conditions, zeros], dim=-1)
|
81 |
+
|
82 |
+
|
83 |
+
def pre(npy):
|
84 |
+
conditions = npy
|
85 |
+
## padding tail
|
86 |
+
if type(conditions) == np.ndarray:
|
87 |
+
conditions = torch.FloatTensor(conditions).unsqueeze(0)
|
88 |
+
else:
|
89 |
+
conditions = torch.FloatTensor(conditions.float()).unsqueeze(0)
|
90 |
+
conditions = conditions.transpose(1, 2)
|
91 |
+
l = conditions.size(-1)
|
92 |
+
pad_tail = l % 2 + 4
|
93 |
+
zeros = torch.zeros([1, Config.num_mels, pad_tail]) + -4.0
|
94 |
+
return torch.cat([conditions, zeros], dim=-1)
|
95 |
+
|
96 |
+
|
97 |
+
def load_try(state, model):
|
98 |
+
model_dict = model.state_dict()
|
99 |
+
try:
|
100 |
+
model_dict.update(state)
|
101 |
+
model.load_state_dict(model_dict)
|
102 |
+
except RuntimeError as e:
|
103 |
+
print(str(e))
|
104 |
+
model_dict = model.state_dict()
|
105 |
+
for k, v in state.items():
|
106 |
+
model_dict[k] = v
|
107 |
+
model.load_state_dict(model_dict)
|
108 |
+
|
109 |
+
|
110 |
+
def load_checkpoint(checkpoint_path, device):
|
111 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
112 |
+
return checkpoint
|
113 |
+
|
114 |
+
|
115 |
+
def build_mel_basis():
|
116 |
+
return librosa.filters.mel(
|
117 |
+
Config.sample_rate,
|
118 |
+
Config.n_fft,
|
119 |
+
htk=True,
|
120 |
+
n_mels=Config.num_mels,
|
121 |
+
fmin=0,
|
122 |
+
fmax=int(Config.sample_rate // 2),
|
123 |
+
)
|
124 |
+
|
125 |
+
|
126 |
+
def linear_to_mel(spectogram):
|
127 |
+
_mel_basis = build_mel_basis()
|
128 |
+
return np.dot(_mel_basis, spectogram)
|
129 |
+
|
130 |
+
|
131 |
+
if __name__ == "__main__":
|
132 |
+
data = torch.randn((3, 5, 100))
|
133 |
+
b = normalize(amp_to_db(data.numpy()))
|
134 |
+
a = tr_normalize(tr_amp_to_db(data)).numpy()
|
135 |
+
print(a - b)
|