Spaces:
Build error
Build error
initial commit
Browse files- .gitattributes +1 -0
- .gitignore +1 -0
- Arg_Parser.py +30 -0
- Checkpoint/S_200000.pt +3 -0
- Datasets.py +146 -0
- Hyper_Parameters.yaml +45 -0
- Inference.py +168 -0
- Modules/Diffusion.py +403 -0
- Modules/Layer.py +317 -0
- Modules/Modules.py +265 -0
- Pattern_Generator.py +64 -0
- README.md +4 -4
- YAML/Genre_Info.yaml +1 -0
- YAML/Log_Energy_Info.yaml +3 -0
- YAML/Log_F0_Info.yaml +3 -0
- YAML/Mel_Range_Info.yaml +3 -0
- YAML/Singer_Info.yaml +1 -0
- YAML/Spectrogram_Range_Info.yaml +3 -0
- YAML/Token.yaml +71 -0
- app.py +81 -0
- meldataset.py +230 -0
- requirements.txt +7 -0
- vocoder.pts +3 -0
.gitattributes
CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*.pts filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
*.pyc
|
Arg_Parser.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import Namespace
|
2 |
+
|
3 |
+
def Recursive_Parse(args_dict):
|
4 |
+
parsed_dict = {}
|
5 |
+
for key, value in args_dict.items():
|
6 |
+
if isinstance(value, dict):
|
7 |
+
value = Recursive_Parse(value)
|
8 |
+
parsed_dict[key]= value
|
9 |
+
|
10 |
+
args = Namespace()
|
11 |
+
args.__dict__ = parsed_dict
|
12 |
+
return args
|
13 |
+
|
14 |
+
def To_Non_Recursive_Dict(
|
15 |
+
args: Namespace
|
16 |
+
):
|
17 |
+
parsed_dict = {}
|
18 |
+
for key, value in args.__dict__.items():
|
19 |
+
if isinstance(value, Namespace):
|
20 |
+
value_dict = To_Non_Recursive_Dict(value)
|
21 |
+
for sub_key, sub_value in value_dict.items():
|
22 |
+
parsed_dict[f'{key}.{sub_key}'] = sub_value
|
23 |
+
else:
|
24 |
+
parsed_dict[key] = value
|
25 |
+
|
26 |
+
return parsed_dict
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
|
Checkpoint/S_200000.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6482992a43b8a98554e7ef9e487a381c2717c5828d564e6dfc6cac16a0e16092
|
3 |
+
size 682529563
|
Datasets.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import Namespace
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import pickle, os, logging
|
5 |
+
from typing import Dict, List, Optional
|
6 |
+
import hgtk
|
7 |
+
|
8 |
+
from Pattern_Generator import Convert_Feature_Based_Music, Expand_by_Duration
|
9 |
+
|
10 |
+
def Decompose(syllable: str):
|
11 |
+
onset, nucleus, coda = hgtk.letter.decompose(syllable)
|
12 |
+
coda += '_'
|
13 |
+
|
14 |
+
return onset, nucleus, coda
|
15 |
+
|
16 |
+
def Lyric_to_Token(lyric: List[str], token_dict: Dict[str, int]):
|
17 |
+
return [
|
18 |
+
token_dict[letter]
|
19 |
+
for letter in list(lyric)
|
20 |
+
]
|
21 |
+
|
22 |
+
def Token_Stack(tokens: List[List[int]], token_dict: Dict[str, int], max_length: Optional[int]= None):
|
23 |
+
max_token_length = max_length or max([len(token) for token in tokens])
|
24 |
+
tokens = np.stack(
|
25 |
+
[np.pad(token[:max_token_length], [0, max_token_length - len(token[:max_token_length])], constant_values= token_dict['<X>']) for token in tokens],
|
26 |
+
axis= 0
|
27 |
+
)
|
28 |
+
return tokens
|
29 |
+
|
30 |
+
def Note_Stack(notes: List[List[int]], max_length: Optional[int]= None):
|
31 |
+
max_note_length = max_length or max([len(note) for note in notes])
|
32 |
+
notes = np.stack(
|
33 |
+
[np.pad(note[:max_note_length], [0, max_note_length - len(note[:max_note_length])], constant_values= 0) for note in notes],
|
34 |
+
axis= 0
|
35 |
+
)
|
36 |
+
return notes
|
37 |
+
|
38 |
+
def Duration_Stack(durations: List[List[int]], max_length: Optional[int]= None):
|
39 |
+
max_duration_length = max_length or max([len(duration) for duration in durations])
|
40 |
+
durations = np.stack(
|
41 |
+
[np.pad(duration[:max_duration_length], [0, max_duration_length - len(duration[:max_duration_length])], constant_values= 0) for duration in durations],
|
42 |
+
axis= 0
|
43 |
+
)
|
44 |
+
return durations
|
45 |
+
|
46 |
+
def Feature_Stack(features: List[np.array], max_length: Optional[int]= None):
|
47 |
+
max_feature_length = max_length or max([feature.shape[0] for feature in features])
|
48 |
+
features = np.stack(
|
49 |
+
[np.pad(feature, [[0, max_feature_length - feature.shape[0]], [0, 0]], constant_values= -1.0) for feature in features],
|
50 |
+
axis= 0
|
51 |
+
)
|
52 |
+
return features
|
53 |
+
|
54 |
+
def Log_F0_Stack(log_f0s: List[np.array], max_length: int= None):
|
55 |
+
max_log_f0_length = max_length or max([len(log_f0) for log_f0 in log_f0s])
|
56 |
+
log_f0s = np.stack(
|
57 |
+
[np.pad(log_f0, [0, max_log_f0_length - len(log_f0)], constant_values= 0.0) for log_f0 in log_f0s],
|
58 |
+
axis= 0
|
59 |
+
)
|
60 |
+
return log_f0s
|
61 |
+
|
62 |
+
class Inference_Dataset(torch.utils.data.Dataset):
|
63 |
+
def __init__(
|
64 |
+
self,
|
65 |
+
token_dict: Dict[str, int],
|
66 |
+
singer_info_dict: Dict[str, int],
|
67 |
+
genre_info_dict: Dict[str, int],
|
68 |
+
durations: List[List[float]],
|
69 |
+
lyrics: List[List[str]],
|
70 |
+
notes: List[List[int]],
|
71 |
+
singers: List[str],
|
72 |
+
genres: List[str],
|
73 |
+
sample_rate: int,
|
74 |
+
frame_shift: int,
|
75 |
+
equality_duration: bool= False,
|
76 |
+
consonant_duration: int= 3
|
77 |
+
):
|
78 |
+
super().__init__()
|
79 |
+
self.token_dict = token_dict
|
80 |
+
self.singer_info_dict = singer_info_dict
|
81 |
+
self.genre_info_dict = genre_info_dict
|
82 |
+
self.equality_duration = equality_duration
|
83 |
+
self.consonant_duration = consonant_duration
|
84 |
+
|
85 |
+
self.patterns = []
|
86 |
+
for index, (duration, lyric, note, singer, genre) in enumerate(zip(durations, lyrics, notes, singers, genres)):
|
87 |
+
if not singer in self.singer_info_dict.keys():
|
88 |
+
logging.warn('The singer \'{}\' is incorrect. The pattern \'{}\' is ignoired.'.format(singer, index))
|
89 |
+
continue
|
90 |
+
if not genre in self.genre_info_dict.keys():
|
91 |
+
logging.warn('The genre \'{}\' is incorrect. The pattern \'{}\' is ignoired.'.format(genre, index))
|
92 |
+
continue
|
93 |
+
|
94 |
+
music = [x for x in zip(duration, lyric, note)]
|
95 |
+
singer_label = singer
|
96 |
+
text = lyric
|
97 |
+
|
98 |
+
lyric, note, duration = Convert_Feature_Based_Music(
|
99 |
+
music= music,
|
100 |
+
sample_rate= sample_rate,
|
101 |
+
frame_shift= frame_shift,
|
102 |
+
consonant_duration= consonant_duration,
|
103 |
+
equality_duration= equality_duration
|
104 |
+
)
|
105 |
+
lyric_expand, note_expand, duration_expand = Expand_by_Duration(lyric, note, duration)
|
106 |
+
|
107 |
+
singer = self.singer_info_dict[singer]
|
108 |
+
genre = self.genre_info_dict[genre]
|
109 |
+
|
110 |
+
self.patterns.append((lyric_expand, note_expand, duration_expand, singer, genre, singer_label, text))
|
111 |
+
|
112 |
+
def __getitem__(self, idx):
|
113 |
+
lyric, note, duration, singer, genre, singer_label, text = self.patterns[idx]
|
114 |
+
|
115 |
+
return Lyric_to_Token(lyric, self.token_dict), note, duration, singer, genre, singer_label, text
|
116 |
+
|
117 |
+
def __len__(self):
|
118 |
+
return len(self.patterns)
|
119 |
+
|
120 |
+
class Inference_Collater:
|
121 |
+
def __init__(self,
|
122 |
+
token_dict: Dict[str, int]
|
123 |
+
):
|
124 |
+
self.token_dict = token_dict
|
125 |
+
|
126 |
+
def __call__(self, batch):
|
127 |
+
tokens, notes, durations, singers, genres, singer_labels, lyrics = zip(*batch)
|
128 |
+
|
129 |
+
lengths = np.array([len(token) for token in tokens])
|
130 |
+
|
131 |
+
max_length = max(lengths)
|
132 |
+
|
133 |
+
tokens = Token_Stack(tokens, self.token_dict, max_length)
|
134 |
+
notes = Note_Stack(notes, max_length)
|
135 |
+
durations = Duration_Stack(durations, max_length)
|
136 |
+
|
137 |
+
tokens = torch.LongTensor(tokens) # [Batch, Time]
|
138 |
+
notes = torch.LongTensor(notes) # [Batch, Time]
|
139 |
+
durations = torch.LongTensor(durations) # [Batch, Time]
|
140 |
+
lengths = torch.LongTensor(lengths) # [Batch]
|
141 |
+
singers = torch.LongTensor(singers) # [Batch]
|
142 |
+
genres = torch.LongTensor(genres) # [Batch]
|
143 |
+
|
144 |
+
lyrics = [''.join([(x if x != '<X>' else ' ') for x in lyric]) for lyric in lyrics]
|
145 |
+
|
146 |
+
return tokens, notes, durations, lengths, singers, genres, singer_labels, lyrics
|
Hyper_Parameters.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Sound:
|
2 |
+
N_FFT: 2048
|
3 |
+
Mel_Dim: 80
|
4 |
+
Frame_Length: 1024
|
5 |
+
Frame_Shift: 256
|
6 |
+
Sample_Rate: 22050
|
7 |
+
Mel_F_Min: 0
|
8 |
+
Mel_F_Max: 8000
|
9 |
+
|
10 |
+
Feature_Type: 'Mel' #'Spectrogram', 'Mel'
|
11 |
+
|
12 |
+
Tokens: 77
|
13 |
+
Notes: 128
|
14 |
+
Durations: 5000
|
15 |
+
Genres: 1
|
16 |
+
Singers: 1
|
17 |
+
Duration:
|
18 |
+
Equality: false
|
19 |
+
Consonant_Duration: 3 # This is only used when Equality is False.
|
20 |
+
|
21 |
+
Encoder:
|
22 |
+
Size: 384
|
23 |
+
ConvFFT:
|
24 |
+
Stack: 6
|
25 |
+
Head: 2
|
26 |
+
Dropout_Rate: 0.1
|
27 |
+
Conv:
|
28 |
+
Stack: 2
|
29 |
+
Kernel_Size: 5
|
30 |
+
FFN:
|
31 |
+
Kernel_Size: 17
|
32 |
+
|
33 |
+
Diffusion:
|
34 |
+
Max_Step: 100
|
35 |
+
Size: 256
|
36 |
+
Kernel_Size: 5
|
37 |
+
Stack: 20
|
38 |
+
|
39 |
+
Token_Path: './YAML/Token.yaml'
|
40 |
+
Spectrogram_Range_Info_Path: './YAML/Spectrogram_Range_Info.yaml'
|
41 |
+
Mel_Range_Info_Path: './YAML/Mel_Range_Info.yaml'
|
42 |
+
Log_F0_Info_Path: './YAML/Log_F0_Info.yaml'
|
43 |
+
Log_Energy_Info_Path: './YAML/Log_Energy_Info.yaml'
|
44 |
+
Singer_Info_Path: './YAML/Singer_Info.yaml'
|
45 |
+
Genre_Info_Path: './YAML/Genre_Info.yaml'
|
Inference.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import logging, yaml, os, sys, argparse, math
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
from tqdm import tqdm
|
6 |
+
from librosa import griffinlim
|
7 |
+
|
8 |
+
from Modules.Modules import DiffSinger
|
9 |
+
from Datasets import Inference_Dataset as Dataset, Inference_Collater as Collater
|
10 |
+
from meldataset import spectral_de_normalize_torch
|
11 |
+
from Arg_Parser import Recursive_Parse
|
12 |
+
|
13 |
+
import matplotlib as mpl
|
14 |
+
# 유니코드 깨짐현상 해결
|
15 |
+
mpl.rcParams['axes.unicode_minus'] = False
|
16 |
+
# 나눔고딕 폰트 적용
|
17 |
+
plt.rcParams["font.family"] = 'NanumGothic'
|
18 |
+
|
19 |
+
logging.basicConfig(
|
20 |
+
level=logging.INFO, stream=sys.stdout,
|
21 |
+
format= '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s'
|
22 |
+
)
|
23 |
+
|
24 |
+
class Inferencer:
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
hp_path: str,
|
28 |
+
checkpoint_path: str,
|
29 |
+
batch_size= 1
|
30 |
+
):
|
31 |
+
self.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
32 |
+
|
33 |
+
self.hp = Recursive_Parse(yaml.load(
|
34 |
+
open(hp_path, encoding='utf-8'),
|
35 |
+
Loader=yaml.Loader
|
36 |
+
))
|
37 |
+
|
38 |
+
self.model = DiffSinger(self.hp).to(self.device)
|
39 |
+
if self.hp.Feature_Type == 'Mel':
|
40 |
+
self.vocoder = torch.jit.load('vocoder.pts', map_location='cpu').to(self.device)
|
41 |
+
|
42 |
+
if self.hp.Feature_Type == 'Spectrogram':
|
43 |
+
self.feature_range_info_dict = yaml.load(open(self.hp.Spectrogram_Range_Info_Path), Loader=yaml.Loader)
|
44 |
+
if self.hp.Feature_Type == 'Mel':
|
45 |
+
self.feature_range_info_dict = yaml.load(open(self.hp.Mel_Range_Info_Path), Loader=yaml.Loader)
|
46 |
+
self.index_singer_dict = {
|
47 |
+
value: key
|
48 |
+
for key, value in yaml.load(open(self.hp.Singer_Info_Path), Loader=yaml.Loader).items()
|
49 |
+
}
|
50 |
+
|
51 |
+
if self.hp.Feature_Type == 'Spectrogram':
|
52 |
+
self.feature_size = self.hp.Sound.N_FFT // 2 + 1
|
53 |
+
elif self.hp.Feature_Type == 'Mel':
|
54 |
+
self.feature_size = self.hp.Sound.Mel_Dim
|
55 |
+
else:
|
56 |
+
raise ValueError('Unknown feature type: {}'.format(self.hp.Feature_Type))
|
57 |
+
|
58 |
+
self.Load_Checkpoint(checkpoint_path)
|
59 |
+
self.batch_size = batch_size
|
60 |
+
|
61 |
+
def Dataset_Generate(self, message_times_list, lyrics, notes, singers, genres):
|
62 |
+
token_dict = yaml.load(open(self.hp.Token_Path), Loader=yaml.Loader)
|
63 |
+
singer_info_dict = yaml.load(open(self.hp.Singer_Info_Path), Loader=yaml.Loader)
|
64 |
+
genre_info_dict = yaml.load(open(self.hp.Genre_Info_Path), Loader=yaml.Loader)
|
65 |
+
|
66 |
+
return torch.utils.data.DataLoader(
|
67 |
+
dataset= Dataset(
|
68 |
+
token_dict= token_dict,
|
69 |
+
singer_info_dict= singer_info_dict,
|
70 |
+
genre_info_dict= genre_info_dict,
|
71 |
+
durations= message_times_list,
|
72 |
+
lyrics= lyrics,
|
73 |
+
notes= notes,
|
74 |
+
singers= singers,
|
75 |
+
genres= genres,
|
76 |
+
sample_rate= self.hp.Sound.Sample_Rate,
|
77 |
+
frame_shift= self.hp.Sound.Frame_Shift,
|
78 |
+
equality_duration= self.hp.Duration.Equality,
|
79 |
+
consonant_duration= self.hp.Duration.Consonant_Duration
|
80 |
+
),
|
81 |
+
shuffle= False,
|
82 |
+
collate_fn= Collater(
|
83 |
+
token_dict= token_dict
|
84 |
+
),
|
85 |
+
batch_size= self.batch_size,
|
86 |
+
num_workers= 0,
|
87 |
+
pin_memory= True
|
88 |
+
)
|
89 |
+
|
90 |
+
def Load_Checkpoint(self, path):
|
91 |
+
state_dict = torch.load(path, map_location= 'cpu')
|
92 |
+
self.model.load_state_dict(state_dict['Model']['DiffSVS'])
|
93 |
+
self.steps = state_dict['Steps']
|
94 |
+
|
95 |
+
self.model.eval()
|
96 |
+
|
97 |
+
logging.info('Checkpoint loaded at {} steps.'.format(self.steps))
|
98 |
+
|
99 |
+
@torch.inference_mode()
|
100 |
+
def Inference_Step(self, tokens, notes, durations, lengths, singers, genres, singer_labels, ddim_steps):
|
101 |
+
tokens = tokens.to(self.device, non_blocking=True)
|
102 |
+
notes = notes.to(self.device, non_blocking=True)
|
103 |
+
durations = durations.to(self.device, non_blocking=True)
|
104 |
+
lengths = lengths.to(self.device, non_blocking=True)
|
105 |
+
singers = singers.to(self.device, non_blocking=True)
|
106 |
+
genres = genres.to(self.device, non_blocking=True)
|
107 |
+
|
108 |
+
linear_predictions, diffusion_predictions, _, _ = self.model(
|
109 |
+
tokens= tokens,
|
110 |
+
notes= notes,
|
111 |
+
durations= durations,
|
112 |
+
lengths= lengths,
|
113 |
+
genres= genres,
|
114 |
+
singers= singers,
|
115 |
+
ddim_steps= ddim_steps
|
116 |
+
)
|
117 |
+
linear_predictions = linear_predictions.clamp(-1.0, 1.0)
|
118 |
+
diffusion_predictions = diffusion_predictions.clamp(-1.0, 1.0)
|
119 |
+
|
120 |
+
linear_prediction_list, diffusion_prediction_list = [], []
|
121 |
+
for linear_prediction, diffusion_prediction, singer in zip(linear_predictions, diffusion_predictions, singer_labels):
|
122 |
+
feature_max = self.feature_range_info_dict[singer]['Max']
|
123 |
+
feature_min = self.feature_range_info_dict[singer]['Min']
|
124 |
+
linear_prediction_list.append((linear_prediction + 1.0) / 2.0 * (feature_max - feature_min) + feature_min)
|
125 |
+
diffusion_prediction_list.append((diffusion_prediction + 1.0) / 2.0 * (feature_max - feature_min) + feature_min)
|
126 |
+
linear_predictions = torch.stack(linear_prediction_list, dim= 0)
|
127 |
+
diffusion_predictions = torch.stack(diffusion_prediction_list, dim= 0)
|
128 |
+
|
129 |
+
if self.hp.Feature_Type == 'Mel':
|
130 |
+
audios = self.vocoder(diffusion_predictions)
|
131 |
+
if audios.ndim == 1: # This is temporal because of the vocoder problem.
|
132 |
+
audios = audios.unsqueeze(0)
|
133 |
+
audios = [
|
134 |
+
audio[:min(length * self.hp.Sound.Frame_Shift, audio.size(0))].cpu().numpy()
|
135 |
+
for audio, length in zip(audios, lengths)
|
136 |
+
]
|
137 |
+
elif self.hp.Feature_Type == 'Spectrogram':
|
138 |
+
audios = []
|
139 |
+
for prediction, length in zip(
|
140 |
+
diffusion_predictions,
|
141 |
+
lengths
|
142 |
+
):
|
143 |
+
prediction = spectral_de_normalize_torch(prediction).cpu().numpy()
|
144 |
+
audio = griffinlim(prediction)[:min(prediction.size(1), length) * self.hp.Sound.Frame_Shift]
|
145 |
+
audio = (audio / np.abs(audio).max() * 32767.5).astype(np.int16)
|
146 |
+
audios.append(audio)
|
147 |
+
|
148 |
+
return audios
|
149 |
+
|
150 |
+
def Inference_Epoch(self, message_times_list, lyrics, notes, singers, genres, ddim_steps= None, use_tqdm= True):
|
151 |
+
dataloader = self.Dataset_Generate(
|
152 |
+
message_times_list= message_times_list,
|
153 |
+
lyrics= lyrics,
|
154 |
+
notes= notes,
|
155 |
+
singers= singers,
|
156 |
+
genres= genres
|
157 |
+
)
|
158 |
+
if use_tqdm:
|
159 |
+
dataloader = tqdm(
|
160 |
+
dataloader,
|
161 |
+
desc='[Inference]',
|
162 |
+
total= math.ceil(len(dataloader.dataset) / self.batch_size)
|
163 |
+
)
|
164 |
+
audios = []
|
165 |
+
for tokens, notes, durations, lengths, singers, genres, singer_labels, lyrics in dataloader:
|
166 |
+
audios.extend(self.Inference_Step(tokens, notes, durations, lengths, singers, genres, singer_labels, ddim_steps))
|
167 |
+
|
168 |
+
return audios
|
Modules/Diffusion.py
ADDED
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import math
|
3 |
+
from argparse import Namespace
|
4 |
+
from typing import Optional, List, Dict, Union
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
from .Layer import Conv1d, Lambda
|
8 |
+
|
9 |
+
class Diffusion(torch.nn.Module):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
hyper_parameters: Namespace
|
13 |
+
):
|
14 |
+
super().__init__()
|
15 |
+
self.hp = hyper_parameters
|
16 |
+
|
17 |
+
if self.hp.Feature_Type == 'Mel':
|
18 |
+
self.feature_size = self.hp.Sound.Mel_Dim
|
19 |
+
elif self.hp.Feature_Type == 'Spectrogram':
|
20 |
+
self.feature_size = self.hp.Sound.N_FFT // 2 + 1
|
21 |
+
|
22 |
+
self.denoiser = Denoiser(
|
23 |
+
hyper_parameters= self.hp
|
24 |
+
)
|
25 |
+
|
26 |
+
self.timesteps = self.hp.Diffusion.Max_Step
|
27 |
+
betas = torch.linspace(1e-4, 0.06, self.timesteps)
|
28 |
+
alphas = 1.0 - betas
|
29 |
+
alphas_cumprod = torch.cumprod(alphas, axis= 0)
|
30 |
+
alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), alphas_cumprod[:-1]])
|
31 |
+
|
32 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
33 |
+
self.register_buffer('alphas_cumprod', alphas_cumprod) # [Diffusion_t]
|
34 |
+
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) # [Diffusion_t]
|
35 |
+
self.register_buffer('sqrt_alphas_cumprod', alphas_cumprod.sqrt())
|
36 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', (1.0 - alphas_cumprod).sqrt())
|
37 |
+
self.register_buffer('sqrt_recip_alphas_cumprod', (1.0 / alphas_cumprod).sqrt())
|
38 |
+
self.register_buffer('sqrt_recipm1_alphas_cumprod', (1.0 / alphas_cumprod - 1.0).sqrt())
|
39 |
+
|
40 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
41 |
+
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
42 |
+
|
43 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
44 |
+
self.register_buffer('posterior_log_variance', torch.maximum(posterior_variance, torch.tensor([1e-20])).log())
|
45 |
+
self.register_buffer('posterior_mean_coef1', betas * alphas_cumprod_prev.sqrt() / (1.0 - alphas_cumprod))
|
46 |
+
self.register_buffer('posterior_mean_coef2', (1.0 - alphas_cumprod_prev) * alphas.sqrt() / (1.0 - alphas_cumprod))
|
47 |
+
|
48 |
+
def forward(
|
49 |
+
self,
|
50 |
+
encodings: torch.Tensor,
|
51 |
+
features: torch.Tensor= None
|
52 |
+
):
|
53 |
+
'''
|
54 |
+
encodings: [Batch, Enc_d, Enc_t]
|
55 |
+
features: [Batch, Feature_d, Feature_t]
|
56 |
+
feature_lengths: [Batch]
|
57 |
+
'''
|
58 |
+
if not features is None: # train
|
59 |
+
diffusion_steps = torch.randint(
|
60 |
+
low= 0,
|
61 |
+
high= self.timesteps,
|
62 |
+
size= (encodings.size(0),),
|
63 |
+
dtype= torch.long,
|
64 |
+
device= encodings.device
|
65 |
+
) # random single step
|
66 |
+
|
67 |
+
noises, epsilons = self.Get_Noise_Epsilon_for_Train(
|
68 |
+
features= features,
|
69 |
+
encodings= encodings,
|
70 |
+
diffusion_steps= diffusion_steps,
|
71 |
+
)
|
72 |
+
return None, noises, epsilons
|
73 |
+
else: # inference
|
74 |
+
features = self.Sampling(
|
75 |
+
encodings= encodings,
|
76 |
+
)
|
77 |
+
return features, None, None
|
78 |
+
|
79 |
+
def Sampling(
|
80 |
+
self,
|
81 |
+
encodings: torch.Tensor,
|
82 |
+
):
|
83 |
+
features = torch.randn(
|
84 |
+
size= (encodings.size(0), self.feature_size, encodings.size(2)),
|
85 |
+
device= encodings.device
|
86 |
+
)
|
87 |
+
for diffusion_step in reversed(range(self.timesteps)):
|
88 |
+
features = self.P_Sampling(
|
89 |
+
features= features,
|
90 |
+
encodings= encodings,
|
91 |
+
diffusion_steps= torch.full(
|
92 |
+
size= (encodings.size(0), ),
|
93 |
+
fill_value= diffusion_step,
|
94 |
+
dtype= torch.long,
|
95 |
+
device= encodings.device
|
96 |
+
),
|
97 |
+
)
|
98 |
+
|
99 |
+
return features
|
100 |
+
|
101 |
+
def P_Sampling(
|
102 |
+
self,
|
103 |
+
features: torch.Tensor,
|
104 |
+
encodings: torch.Tensor,
|
105 |
+
diffusion_steps: torch.Tensor,
|
106 |
+
):
|
107 |
+
posterior_means, posterior_log_variances = self.Get_Posterior(
|
108 |
+
features= features,
|
109 |
+
encodings= encodings,
|
110 |
+
diffusion_steps= diffusion_steps,
|
111 |
+
)
|
112 |
+
|
113 |
+
noises = torch.randn_like(features) # [Batch, Feature_d, Feature_d]
|
114 |
+
masks = (diffusion_steps > 0).float().unsqueeze(1).unsqueeze(1) #[Batch, 1, 1]
|
115 |
+
|
116 |
+
return posterior_means + masks * (0.5 * posterior_log_variances).exp() * noises
|
117 |
+
|
118 |
+
def Get_Posterior(
|
119 |
+
self,
|
120 |
+
features: torch.Tensor,
|
121 |
+
encodings: torch.Tensor,
|
122 |
+
diffusion_steps: torch.Tensor
|
123 |
+
):
|
124 |
+
noised_predictions = self.denoiser(
|
125 |
+
features= features,
|
126 |
+
encodings= encodings,
|
127 |
+
diffusion_steps= diffusion_steps
|
128 |
+
)
|
129 |
+
|
130 |
+
epsilons = \
|
131 |
+
features * self.sqrt_recip_alphas_cumprod[diffusion_steps][:, None, None] - \
|
132 |
+
noised_predictions * self.sqrt_recipm1_alphas_cumprod[diffusion_steps][:, None, None]
|
133 |
+
epsilons.clamp_(-1.0, 1.0) # clipped
|
134 |
+
|
135 |
+
posterior_means = \
|
136 |
+
epsilons * self.posterior_mean_coef1[diffusion_steps][:, None, None] + \
|
137 |
+
features * self.posterior_mean_coef2[diffusion_steps][:, None, None]
|
138 |
+
posterior_log_variances = \
|
139 |
+
self.posterior_log_variance[diffusion_steps][:, None, None]
|
140 |
+
|
141 |
+
return posterior_means, posterior_log_variances
|
142 |
+
|
143 |
+
def Get_Noise_Epsilon_for_Train(
|
144 |
+
self,
|
145 |
+
features: torch.Tensor,
|
146 |
+
encodings: torch.Tensor,
|
147 |
+
diffusion_steps: torch.Tensor,
|
148 |
+
):
|
149 |
+
noises = torch.randn_like(features)
|
150 |
+
|
151 |
+
noised_features = \
|
152 |
+
features * self.sqrt_alphas_cumprod[diffusion_steps][:, None, None] + \
|
153 |
+
noises * self.sqrt_one_minus_alphas_cumprod[diffusion_steps][:, None, None]
|
154 |
+
|
155 |
+
epsilons = self.denoiser(
|
156 |
+
features= noised_features,
|
157 |
+
encodings= encodings,
|
158 |
+
diffusion_steps= diffusion_steps
|
159 |
+
)
|
160 |
+
|
161 |
+
return noises, epsilons
|
162 |
+
|
163 |
+
def DDIM(
|
164 |
+
self,
|
165 |
+
encodings: torch.Tensor,
|
166 |
+
ddim_steps: int,
|
167 |
+
eta: float= 0.0,
|
168 |
+
temperature: float= 1.0,
|
169 |
+
use_tqdm: bool= False
|
170 |
+
):
|
171 |
+
ddim_timesteps = self.Get_DDIM_Steps(
|
172 |
+
ddim_steps= ddim_steps
|
173 |
+
)
|
174 |
+
sigmas, alphas, alphas_prev = self.Get_DDIM_Sampling_Parameters(
|
175 |
+
ddim_timesteps= ddim_timesteps,
|
176 |
+
eta= eta
|
177 |
+
)
|
178 |
+
sqrt_one_minus_alphas = (1. - alphas).sqrt()
|
179 |
+
|
180 |
+
features = torch.randn(
|
181 |
+
size= (encodings.size(0), self.feature_size, encodings.size(2)),
|
182 |
+
device= encodings.device
|
183 |
+
)
|
184 |
+
|
185 |
+
setp_range = reversed(range(ddim_steps))
|
186 |
+
if use_tqdm:
|
187 |
+
tqdm(
|
188 |
+
setp_range,
|
189 |
+
desc= '[Diffusion]',
|
190 |
+
total= ddim_steps
|
191 |
+
)
|
192 |
+
|
193 |
+
for diffusion_steps in setp_range:
|
194 |
+
noised_predictions = self.denoiser(
|
195 |
+
features= features,
|
196 |
+
encodings= encodings,
|
197 |
+
diffusion_steps= torch.full(
|
198 |
+
size= (encodings.size(0), ),
|
199 |
+
fill_value= diffusion_steps,
|
200 |
+
dtype= torch.long,
|
201 |
+
device= encodings.device
|
202 |
+
)
|
203 |
+
)
|
204 |
+
|
205 |
+
feature_starts = (features - sqrt_one_minus_alphas[diffusion_steps] * noised_predictions) / alphas[diffusion_steps].sqrt()
|
206 |
+
direction_pointings = (1.0 - alphas_prev[diffusion_steps] - sigmas[diffusion_steps].pow(2.0)) * noised_predictions
|
207 |
+
noises = sigmas[diffusion_steps] * torch.randn_like(features) * temperature
|
208 |
+
|
209 |
+
features = alphas_prev[diffusion_steps].sqrt() * feature_starts + direction_pointings + noises
|
210 |
+
|
211 |
+
return features
|
212 |
+
|
213 |
+
# https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py
|
214 |
+
def Get_DDIM_Steps(
|
215 |
+
self,
|
216 |
+
ddim_steps: int,
|
217 |
+
ddim_discr_method: str= 'uniform'
|
218 |
+
):
|
219 |
+
if ddim_discr_method == 'uniform':
|
220 |
+
ddim_timesteps = torch.arange(0, self.timesteps, self.timesteps // ddim_steps).long()
|
221 |
+
elif ddim_discr_method == 'quad':
|
222 |
+
ddim_timesteps = torch.linspace(0, (torch.tensor(self.timesteps) * 0.8).sqrt(), ddim_steps).pow(2.0).long()
|
223 |
+
else:
|
224 |
+
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
|
225 |
+
|
226 |
+
ddim_timesteps[-1] = self.timesteps - 1
|
227 |
+
|
228 |
+
return ddim_timesteps
|
229 |
+
|
230 |
+
def Get_DDIM_Sampling_Parameters(self, ddim_timesteps, eta):
|
231 |
+
alphas = self.alphas_cumprod[ddim_timesteps]
|
232 |
+
alphas_prev = self.alphas_cumprod_prev[ddim_timesteps]
|
233 |
+
sigmas = eta * ((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)).sqrt()
|
234 |
+
|
235 |
+
return sigmas, alphas, alphas_prev
|
236 |
+
|
237 |
+
class Denoiser(torch.nn.Module):
|
238 |
+
def __init__(
|
239 |
+
self,
|
240 |
+
hyper_parameters: Namespace
|
241 |
+
):
|
242 |
+
super().__init__()
|
243 |
+
self.hp = hyper_parameters
|
244 |
+
|
245 |
+
if self.hp.Feature_Type == 'Mel':
|
246 |
+
feature_size = self.hp.Sound.Mel_Dim
|
247 |
+
elif self.hp.Feature_Type == 'Spectrogram':
|
248 |
+
feature_size = self.hp.Sound.N_FFT // 2 + 1
|
249 |
+
|
250 |
+
self.prenet = torch.nn.Sequential(
|
251 |
+
Conv1d(
|
252 |
+
in_channels= feature_size,
|
253 |
+
out_channels= self.hp.Diffusion.Size,
|
254 |
+
kernel_size= 1,
|
255 |
+
w_init_gain= 'relu'
|
256 |
+
),
|
257 |
+
torch.nn.Mish()
|
258 |
+
)
|
259 |
+
|
260 |
+
self.step_ffn = torch.nn.Sequential(
|
261 |
+
Diffusion_Embedding(
|
262 |
+
channels= self.hp.Diffusion.Size
|
263 |
+
),
|
264 |
+
Lambda(lambda x: x.unsqueeze(2)),
|
265 |
+
Conv1d(
|
266 |
+
in_channels= self.hp.Diffusion.Size,
|
267 |
+
out_channels= self.hp.Diffusion.Size * 4,
|
268 |
+
kernel_size= 1,
|
269 |
+
w_init_gain= 'relu'
|
270 |
+
),
|
271 |
+
torch.nn.Mish(),
|
272 |
+
Conv1d(
|
273 |
+
in_channels= self.hp.Diffusion.Size * 4,
|
274 |
+
out_channels= self.hp.Diffusion.Size,
|
275 |
+
kernel_size= 1,
|
276 |
+
w_init_gain= 'linear'
|
277 |
+
)
|
278 |
+
)
|
279 |
+
|
280 |
+
self.residual_blocks = torch.nn.ModuleList([
|
281 |
+
Residual_Block(
|
282 |
+
in_channels= self.hp.Diffusion.Size,
|
283 |
+
kernel_size= self.hp.Diffusion.Kernel_Size,
|
284 |
+
condition_channels= self.hp.Encoder.Size + feature_size
|
285 |
+
)
|
286 |
+
for _ in range(self.hp.Diffusion.Stack)
|
287 |
+
])
|
288 |
+
|
289 |
+
self.projection = torch.nn.Sequential(
|
290 |
+
Conv1d(
|
291 |
+
in_channels= self.hp.Diffusion.Size,
|
292 |
+
out_channels= self.hp.Diffusion.Size,
|
293 |
+
kernel_size= 1,
|
294 |
+
w_init_gain= 'relu'
|
295 |
+
),
|
296 |
+
torch.nn.ReLU(),
|
297 |
+
Conv1d(
|
298 |
+
in_channels= self.hp.Diffusion.Size,
|
299 |
+
out_channels= feature_size,
|
300 |
+
kernel_size= 1
|
301 |
+
),
|
302 |
+
)
|
303 |
+
torch.nn.init.zeros_(self.projection[-1].weight) # This is key factor....
|
304 |
+
|
305 |
+
def forward(
|
306 |
+
self,
|
307 |
+
features: torch.Tensor,
|
308 |
+
encodings: torch.Tensor,
|
309 |
+
diffusion_steps: torch.Tensor
|
310 |
+
):
|
311 |
+
'''
|
312 |
+
features: [Batch, Feature_d, Feature_t]
|
313 |
+
encodings: [Batch, Enc_d, Feature_t]
|
314 |
+
diffusion_steps: [Batch]
|
315 |
+
'''
|
316 |
+
x = self.prenet(features)
|
317 |
+
|
318 |
+
diffusion_steps = self.step_ffn(diffusion_steps) # [Batch, Res_d, 1]
|
319 |
+
|
320 |
+
skips_list = []
|
321 |
+
for residual_block in self.residual_blocks:
|
322 |
+
x, skips = residual_block(
|
323 |
+
x= x,
|
324 |
+
conditions= encodings,
|
325 |
+
diffusion_steps= diffusion_steps
|
326 |
+
)
|
327 |
+
skips_list.append(skips)
|
328 |
+
|
329 |
+
x = torch.stack(skips_list, dim= 0).sum(dim= 0) / math.sqrt(self.hp.Diffusion.Stack)
|
330 |
+
x = self.projection(x)
|
331 |
+
|
332 |
+
return x
|
333 |
+
|
334 |
+
class Diffusion_Embedding(torch.nn.Module):
|
335 |
+
def __init__(
|
336 |
+
self,
|
337 |
+
channels: int
|
338 |
+
):
|
339 |
+
super().__init__()
|
340 |
+
self.channels = channels
|
341 |
+
|
342 |
+
def forward(self, x: torch.Tensor):
|
343 |
+
half_channels = self.channels // 2 # sine and cosine
|
344 |
+
embeddings = math.log(10000.0) / (half_channels - 1)
|
345 |
+
embeddings = torch.exp(torch.arange(half_channels, device= x.device) * -embeddings)
|
346 |
+
embeddings = x.unsqueeze(1) * embeddings.unsqueeze(0)
|
347 |
+
embeddings = torch.cat([embeddings.sin(), embeddings.cos()], dim= -1)
|
348 |
+
|
349 |
+
return embeddings
|
350 |
+
|
351 |
+
class Residual_Block(torch.nn.Module):
|
352 |
+
def __init__(
|
353 |
+
self,
|
354 |
+
in_channels: int,
|
355 |
+
kernel_size: int,
|
356 |
+
condition_channels: int
|
357 |
+
):
|
358 |
+
super().__init__()
|
359 |
+
self.in_channels = in_channels
|
360 |
+
|
361 |
+
self.condition = Conv1d(
|
362 |
+
in_channels= condition_channels,
|
363 |
+
out_channels= in_channels * 2,
|
364 |
+
kernel_size= 1
|
365 |
+
)
|
366 |
+
self.diffusion_step = Conv1d(
|
367 |
+
in_channels= in_channels,
|
368 |
+
out_channels= in_channels,
|
369 |
+
kernel_size= 1
|
370 |
+
)
|
371 |
+
|
372 |
+
self.conv = Conv1d(
|
373 |
+
in_channels= in_channels,
|
374 |
+
out_channels= in_channels * 2,
|
375 |
+
kernel_size= kernel_size,
|
376 |
+
padding= kernel_size // 2
|
377 |
+
)
|
378 |
+
|
379 |
+
self.projection = Conv1d(
|
380 |
+
in_channels= in_channels,
|
381 |
+
out_channels= in_channels * 2,
|
382 |
+
kernel_size= 1
|
383 |
+
)
|
384 |
+
|
385 |
+
def forward(
|
386 |
+
self,
|
387 |
+
x: torch.Tensor,
|
388 |
+
conditions: torch.Tensor,
|
389 |
+
diffusion_steps: torch.Tensor
|
390 |
+
):
|
391 |
+
residuals = x
|
392 |
+
|
393 |
+
conditions = self.condition(conditions)
|
394 |
+
diffusion_steps = self.diffusion_step(diffusion_steps)
|
395 |
+
|
396 |
+
x = self.conv(x + diffusion_steps) + conditions
|
397 |
+
x_a, x_b = x.chunk(chunks= 2, dim= 1)
|
398 |
+
x = x_a.sigmoid() * x_b.tanh()
|
399 |
+
|
400 |
+
x = self.projection(x)
|
401 |
+
x, skips = x.chunk(chunks= 2, dim= 1)
|
402 |
+
|
403 |
+
return (x + residuals) / math.sqrt(2.0), skips
|
Modules/Layer.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
class Conv1d(torch.nn.Conv1d):
|
4 |
+
def __init__(self, w_init_gain= 'linear', *args, **kwargs):
|
5 |
+
self.w_init_gain = w_init_gain
|
6 |
+
super().__init__(*args, **kwargs)
|
7 |
+
|
8 |
+
def reset_parameters(self):
|
9 |
+
if self.w_init_gain in ['zero']:
|
10 |
+
torch.nn.init.zeros_(self.weight)
|
11 |
+
elif self.w_init_gain is None:
|
12 |
+
pass
|
13 |
+
elif self.w_init_gain in ['relu', 'leaky_relu']:
|
14 |
+
torch.nn.init.kaiming_uniform_(self.weight, nonlinearity= self.w_init_gain)
|
15 |
+
elif self.w_init_gain == 'glu':
|
16 |
+
assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.'
|
17 |
+
torch.nn.init.kaiming_uniform_(self.weight[:self.out_channels // 2], nonlinearity= 'linear')
|
18 |
+
torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid'))
|
19 |
+
elif self.w_init_gain == 'gate':
|
20 |
+
assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.'
|
21 |
+
torch.nn.init.xavier_uniform_(self.weight[:self.out_channels // 2], gain= torch.nn.init.calculate_gain('tanh'))
|
22 |
+
torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid'))
|
23 |
+
else:
|
24 |
+
torch.nn.init.xavier_uniform_(self.weight, gain= torch.nn.init.calculate_gain(self.w_init_gain))
|
25 |
+
if not self.bias is None:
|
26 |
+
torch.nn.init.zeros_(self.bias)
|
27 |
+
|
28 |
+
class ConvTranspose1d(torch.nn.ConvTranspose1d):
|
29 |
+
def __init__(self, w_init_gain= 'linear', *args, **kwargs):
|
30 |
+
self.w_init_gain = w_init_gain
|
31 |
+
super().__init__(*args, **kwargs)
|
32 |
+
|
33 |
+
def reset_parameters(self):
|
34 |
+
if self.w_init_gain in ['zero']:
|
35 |
+
torch.nn.init.zeros_(self.weight)
|
36 |
+
elif self.w_init_gain in ['relu', 'leaky_relu']:
|
37 |
+
torch.nn.init.kaiming_uniform_(self.weight, nonlinearity= self.w_init_gain)
|
38 |
+
elif self.w_init_gain == 'glu':
|
39 |
+
assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.'
|
40 |
+
torch.nn.init.kaiming_uniform_(self.weight[:self.out_channels // 2], nonlinearity= 'linear')
|
41 |
+
torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid'))
|
42 |
+
elif self.w_init_gain == 'gate':
|
43 |
+
assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.'
|
44 |
+
torch.nn.init.xavier_uniform_(self.weight[:self.out_channels // 2], gain= torch.nn.init.calculate_gain('tanh'))
|
45 |
+
torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid'))
|
46 |
+
else:
|
47 |
+
torch.nn.init.xavier_uniform_(self.weight, gain= torch.nn.init.calculate_gain(self.w_init_gain))
|
48 |
+
if not self.bias is None:
|
49 |
+
torch.nn.init.zeros_(self.bias)
|
50 |
+
|
51 |
+
class Conv2d(torch.nn.Conv2d):
|
52 |
+
def __init__(self, w_init_gain= 'linear', *args, **kwargs):
|
53 |
+
self.w_init_gain = w_init_gain
|
54 |
+
super().__init__(*args, **kwargs)
|
55 |
+
|
56 |
+
def reset_parameters(self):
|
57 |
+
if self.w_init_gain in ['zero']:
|
58 |
+
torch.nn.init.zeros_(self.weight)
|
59 |
+
elif self.w_init_gain in ['relu', 'leaky_relu']:
|
60 |
+
torch.nn.init.kaiming_uniform_(self.weight, nonlinearity= self.w_init_gain)
|
61 |
+
elif self.w_init_gain == 'glu':
|
62 |
+
assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.'
|
63 |
+
torch.nn.init.kaiming_uniform_(self.weight[:self.out_channels // 2], nonlinearity= 'linear')
|
64 |
+
torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid'))
|
65 |
+
elif self.w_init_gain == 'gate':
|
66 |
+
assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.'
|
67 |
+
torch.nn.init.xavier_uniform_(self.weight[:self.out_channels // 2], gain= torch.nn.init.calculate_gain('tanh'))
|
68 |
+
torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid'))
|
69 |
+
else:
|
70 |
+
torch.nn.init.xavier_uniform_(self.weight, gain= torch.nn.init.calculate_gain(self.w_init_gain))
|
71 |
+
if not self.bias is None:
|
72 |
+
torch.nn.init.zeros_(self.bias)
|
73 |
+
|
74 |
+
class ConvTranspose2d(torch.nn.ConvTranspose2d):
|
75 |
+
def __init__(self, w_init_gain= 'linear', *args, **kwargs):
|
76 |
+
self.w_init_gain = w_init_gain
|
77 |
+
super().__init__(*args, **kwargs)
|
78 |
+
|
79 |
+
def reset_parameters(self):
|
80 |
+
if self.w_init_gain in ['zero']:
|
81 |
+
torch.nn.init.zeros_(self.weight)
|
82 |
+
elif self.w_init_gain in ['relu', 'leaky_relu']:
|
83 |
+
torch.nn.init.kaiming_uniform_(self.weight, nonlinearity= self.w_init_gain)
|
84 |
+
elif self.w_init_gain == 'glu':
|
85 |
+
assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.'
|
86 |
+
torch.nn.init.kaiming_uniform_(self.weight[:self.out_channels // 2], nonlinearity= 'linear')
|
87 |
+
torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid'))
|
88 |
+
elif self.w_init_gain == 'gate':
|
89 |
+
assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.'
|
90 |
+
torch.nn.init.xavier_uniform_(self.weight[:self.out_channels // 2], gain= torch.nn.init.calculate_gain('tanh'))
|
91 |
+
torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid'))
|
92 |
+
else:
|
93 |
+
torch.nn.init.xavier_uniform_(self.weight, gain= torch.nn.init.calculate_gain(self.w_init_gain))
|
94 |
+
if not self.bias is None:
|
95 |
+
torch.nn.init.zeros_(self.bias)
|
96 |
+
|
97 |
+
class Linear(torch.nn.Linear):
|
98 |
+
def __init__(self, w_init_gain= 'linear', *args, **kwargs):
|
99 |
+
self.w_init_gain = w_init_gain
|
100 |
+
super().__init__(*args, **kwargs)
|
101 |
+
|
102 |
+
def reset_parameters(self):
|
103 |
+
if self.w_init_gain in ['zero']:
|
104 |
+
torch.nn.init.zeros_(self.weight)
|
105 |
+
elif self.w_init_gain in ['relu', 'leaky_relu']:
|
106 |
+
torch.nn.init.kaiming_uniform_(self.weight, nonlinearity= self.w_init_gain)
|
107 |
+
elif self.w_init_gain == 'glu':
|
108 |
+
assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.'
|
109 |
+
torch.nn.init.kaiming_uniform_(self.weight[:self.out_channels // 2], nonlinearity= 'linear')
|
110 |
+
torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid'))
|
111 |
+
else:
|
112 |
+
torch.nn.init.xavier_uniform_(self.weight, gain= torch.nn.init.calculate_gain(self.w_init_gain))
|
113 |
+
if not self.bias is None:
|
114 |
+
torch.nn.init.zeros_(self.bias)
|
115 |
+
|
116 |
+
class Lambda(torch.nn.Module):
|
117 |
+
def __init__(self, lambd):
|
118 |
+
super().__init__()
|
119 |
+
self.lambd = lambd
|
120 |
+
|
121 |
+
def forward(self, x):
|
122 |
+
return self.lambd(x)
|
123 |
+
|
124 |
+
class Residual(torch.nn.Module):
|
125 |
+
def __init__(self, module):
|
126 |
+
super().__init__()
|
127 |
+
self.module = module
|
128 |
+
|
129 |
+
def forward(self, *args, **kwargs):
|
130 |
+
return self.module(*args, **kwargs)
|
131 |
+
|
132 |
+
class LayerNorm(torch.nn.Module):
|
133 |
+
def __init__(self, num_features: int, eps: float= 1e-5):
|
134 |
+
super().__init__()
|
135 |
+
|
136 |
+
self.eps = eps
|
137 |
+
self.gamma = torch.nn.Parameter(torch.ones(num_features))
|
138 |
+
self.beta = torch.nn.Parameter(torch.zeros(num_features))
|
139 |
+
|
140 |
+
|
141 |
+
def forward(self, inputs: torch.Tensor):
|
142 |
+
means = inputs.mean(dim= 1, keepdim= True)
|
143 |
+
variances = (inputs - means).pow(2.0).mean(dim= 1, keepdim= True)
|
144 |
+
|
145 |
+
x = (inputs - means) * (variances + self.eps).rsqrt()
|
146 |
+
|
147 |
+
shape = [1, -1] + [1] * (x.ndim - 2)
|
148 |
+
|
149 |
+
return x * self.gamma.view(*shape) + self.beta.view(*shape)
|
150 |
+
|
151 |
+
class LightweightConv1d(torch.nn.Module):
|
152 |
+
'''
|
153 |
+
Args:
|
154 |
+
input_size: # of channels of the input and output
|
155 |
+
kernel_size: convolution channels
|
156 |
+
padding: padding
|
157 |
+
num_heads: number of heads used. The weight is of shape
|
158 |
+
`(num_heads, 1, kernel_size)`
|
159 |
+
weight_softmax: normalize the weight with softmax before the convolution
|
160 |
+
|
161 |
+
Shape:
|
162 |
+
Input: BxCxT, i.e. (batch_size, input_size, timesteps)
|
163 |
+
Output: BxCxT, i.e. (batch_size, input_size, timesteps)
|
164 |
+
|
165 |
+
Attributes:
|
166 |
+
weight: the learnable weights of the module of shape
|
167 |
+
`(num_heads, 1, kernel_size)`
|
168 |
+
bias: the learnable bias of the module of shape `(input_size)`
|
169 |
+
'''
|
170 |
+
|
171 |
+
def __init__(
|
172 |
+
self,
|
173 |
+
input_size,
|
174 |
+
kernel_size=1,
|
175 |
+
padding=0,
|
176 |
+
num_heads=1,
|
177 |
+
weight_softmax=False,
|
178 |
+
bias=False,
|
179 |
+
weight_dropout=0.0,
|
180 |
+
w_init_gain= 'linear'
|
181 |
+
):
|
182 |
+
super().__init__()
|
183 |
+
self.input_size = input_size
|
184 |
+
self.kernel_size = kernel_size
|
185 |
+
self.num_heads = num_heads
|
186 |
+
self.padding = padding
|
187 |
+
self.weight_softmax = weight_softmax
|
188 |
+
self.weight = torch.nn.Parameter(torch.Tensor(num_heads, 1, kernel_size))
|
189 |
+
self.w_init_gain = w_init_gain
|
190 |
+
|
191 |
+
if bias:
|
192 |
+
self.bias = torch.nn.Parameter(torch.Tensor(input_size))
|
193 |
+
else:
|
194 |
+
self.bias = None
|
195 |
+
self.weight_dropout_module = FairseqDropout(
|
196 |
+
weight_dropout, module_name=self.__class__.__name__
|
197 |
+
)
|
198 |
+
self.reset_parameters()
|
199 |
+
|
200 |
+
def reset_parameters(self):
|
201 |
+
if self.w_init_gain in ['relu', 'leaky_relu']:
|
202 |
+
torch.nn.init.kaiming_uniform_(self.weight, nonlinearity= self.w_init_gain)
|
203 |
+
elif self.w_init_gain == 'glu':
|
204 |
+
assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.'
|
205 |
+
torch.nn.init.kaiming_uniform_(self.weight[:self.out_channels // 2], nonlinearity= 'linear')
|
206 |
+
torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid'))
|
207 |
+
else:
|
208 |
+
torch.nn.init.xavier_uniform_(self.weight, gain= torch.nn.init.calculate_gain(self.w_init_gain))
|
209 |
+
if not self.bias is None:
|
210 |
+
torch.nn.init.zeros_(self.bias)
|
211 |
+
|
212 |
+
def forward(self, input):
|
213 |
+
"""
|
214 |
+
input size: B x C x T
|
215 |
+
output size: B x C x T
|
216 |
+
"""
|
217 |
+
B, C, T = input.size()
|
218 |
+
H = self.num_heads
|
219 |
+
|
220 |
+
weight = self.weight
|
221 |
+
if self.weight_softmax:
|
222 |
+
weight = weight.softmax(dim=-1)
|
223 |
+
|
224 |
+
weight = self.weight_dropout_module(weight)
|
225 |
+
# Merge every C/H entries into the batch dimension (C = self.input_size)
|
226 |
+
# B x C x T -> (B * C/H) x H x T
|
227 |
+
# One can also expand the weight to C x 1 x K by a factor of C/H
|
228 |
+
# and do not reshape the input instead, which is slow though
|
229 |
+
input = input.view(-1, H, T)
|
230 |
+
output = torch.nn.functional.conv1d(input, weight, padding=self.padding, groups=self.num_heads)
|
231 |
+
output = output.view(B, C, T)
|
232 |
+
if self.bias is not None:
|
233 |
+
output = output + self.bias.view(1, -1, 1)
|
234 |
+
|
235 |
+
return output
|
236 |
+
|
237 |
+
class FairseqDropout(torch.nn.Module):
|
238 |
+
def __init__(self, p, module_name=None):
|
239 |
+
super().__init__()
|
240 |
+
self.p = p
|
241 |
+
self.module_name = module_name
|
242 |
+
self.apply_during_inference = False
|
243 |
+
|
244 |
+
def forward(self, x, inplace: bool = False):
|
245 |
+
if self.training or self.apply_during_inference:
|
246 |
+
return torch.nn.functional.dropout(x, p=self.p, training=True, inplace=inplace)
|
247 |
+
else:
|
248 |
+
return x
|
249 |
+
|
250 |
+
class LinearAttention(torch.nn.Module):
|
251 |
+
def __init__(
|
252 |
+
self,
|
253 |
+
channels: int,
|
254 |
+
calc_channels: int,
|
255 |
+
num_heads: int,
|
256 |
+
dropout_rate: float= 0.1,
|
257 |
+
use_scale: bool= True,
|
258 |
+
use_residual: bool= True,
|
259 |
+
use_norm: bool= True
|
260 |
+
):
|
261 |
+
super().__init__()
|
262 |
+
assert calc_channels % num_heads == 0
|
263 |
+
self.calc_channels = calc_channels
|
264 |
+
self.num_heads = num_heads
|
265 |
+
self.use_scale = use_scale
|
266 |
+
self.use_residual = use_residual
|
267 |
+
self.use_norm = use_norm
|
268 |
+
|
269 |
+
self.prenet = Conv1d(
|
270 |
+
in_channels= channels,
|
271 |
+
out_channels= calc_channels * 3,
|
272 |
+
kernel_size= 1,
|
273 |
+
bias=False,
|
274 |
+
w_init_gain= 'linear'
|
275 |
+
)
|
276 |
+
self.projection = Conv1d(
|
277 |
+
in_channels= calc_channels,
|
278 |
+
out_channels= channels,
|
279 |
+
kernel_size= 1,
|
280 |
+
w_init_gain= 'linear'
|
281 |
+
)
|
282 |
+
self.dropout = torch.nn.Dropout(p= dropout_rate)
|
283 |
+
|
284 |
+
if use_scale:
|
285 |
+
self.scale = torch.nn.Parameter(torch.zeros(1))
|
286 |
+
|
287 |
+
if use_norm:
|
288 |
+
self.norm = LayerNorm(num_features= channels)
|
289 |
+
|
290 |
+
def forward(self, x: torch.Tensor, *args, **kwargs):
|
291 |
+
'''
|
292 |
+
x: [Batch, Enc_d, Enc_t]
|
293 |
+
'''
|
294 |
+
residuals = x
|
295 |
+
|
296 |
+
x = self.prenet(x) # [Batch, Calc_d * 3, Enc_t]
|
297 |
+
x = x.view(x.size(0), self.num_heads, x.size(1) // self.num_heads, x.size(2)) # [Batch, Head, Calc_d // Head * 3, Enc_t]
|
298 |
+
queries, keys, values = x.chunk(chunks= 3, dim= 2) # [Batch, Head, Calc_d // Head, Enc_t] * 3
|
299 |
+
keys = (keys + 1e-5).softmax(dim= 3)
|
300 |
+
|
301 |
+
contexts = keys @ values.permute(0, 1, 3, 2) # [Batch, Head, Calc_d // Head, Calc_d // Head]
|
302 |
+
contexts = contexts.permute(0, 1, 3, 2) @ queries # [Batch, Head, Calc_d // Head, Enc_t]
|
303 |
+
contexts = contexts.view(contexts.size(0), contexts.size(1) * contexts.size(2), contexts.size(3)) # [Batch, Calc_d, Enc_t]
|
304 |
+
contexts = self.projection(contexts) # [Batch, Enc_d, Enc_t]
|
305 |
+
|
306 |
+
if self.use_scale:
|
307 |
+
contexts = self.scale * contexts
|
308 |
+
|
309 |
+
contexts = self.dropout(contexts)
|
310 |
+
|
311 |
+
if self.use_residual:
|
312 |
+
contexts = contexts + residuals
|
313 |
+
|
314 |
+
if self.use_norm:
|
315 |
+
contexts = self.norm(contexts)
|
316 |
+
|
317 |
+
return contexts
|
Modules/Modules.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import Namespace
|
2 |
+
import torch
|
3 |
+
import math
|
4 |
+
from typing import Union
|
5 |
+
|
6 |
+
from .Layer import Conv1d, LayerNorm, LinearAttention
|
7 |
+
from .Diffusion import Diffusion
|
8 |
+
|
9 |
+
class DiffSinger(torch.nn.Module):
|
10 |
+
def __init__(self, hyper_parameters: Namespace):
|
11 |
+
super().__init__()
|
12 |
+
self.hp = hyper_parameters
|
13 |
+
|
14 |
+
self.encoder = Encoder(self.hp)
|
15 |
+
self.diffusion = Diffusion(self.hp)
|
16 |
+
|
17 |
+
def forward(
|
18 |
+
self,
|
19 |
+
tokens: torch.LongTensor,
|
20 |
+
notes: torch.LongTensor,
|
21 |
+
durations: torch.LongTensor,
|
22 |
+
lengths: torch.LongTensor,
|
23 |
+
genres: torch.LongTensor,
|
24 |
+
singers: torch.LongTensor,
|
25 |
+
features: Union[torch.FloatTensor, None]= None,
|
26 |
+
ddim_steps: Union[int, None]= None
|
27 |
+
):
|
28 |
+
encodings, linear_predictions = self.encoder(
|
29 |
+
tokens= tokens,
|
30 |
+
notes= notes,
|
31 |
+
durations= durations,
|
32 |
+
lengths= lengths,
|
33 |
+
genres= genres,
|
34 |
+
singers= singers
|
35 |
+
) # [Batch, Enc_d, Feature_t]
|
36 |
+
|
37 |
+
encodings = torch.cat([encodings, linear_predictions], dim= 1) # [Batch, Enc_d + Feature_d, Feature_t]
|
38 |
+
|
39 |
+
if not features is None or ddim_steps is None or ddim_steps == self.hp.Diffusion.Max_Step:
|
40 |
+
diffusion_predictions, noises, epsilons = self.diffusion(
|
41 |
+
encodings= encodings,
|
42 |
+
features= features,
|
43 |
+
)
|
44 |
+
else:
|
45 |
+
noises, epsilons = None, None
|
46 |
+
diffusion_predictions = self.diffusion.DDIM(
|
47 |
+
encodings= encodings,
|
48 |
+
ddim_steps= ddim_steps
|
49 |
+
)
|
50 |
+
|
51 |
+
return linear_predictions, diffusion_predictions, noises, epsilons
|
52 |
+
|
53 |
+
|
54 |
+
class Encoder(torch.nn.Module):
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
hyper_parameters: Namespace
|
58 |
+
):
|
59 |
+
super().__init__()
|
60 |
+
self.hp = hyper_parameters
|
61 |
+
|
62 |
+
if self.hp.Feature_Type == 'Mel':
|
63 |
+
self.feature_size = self.hp.Sound.Mel_Dim
|
64 |
+
elif self.hp.Feature_Type == 'Spectrogram':
|
65 |
+
self.feature_size = self.hp.Sound.N_FFT // 2 + 1
|
66 |
+
|
67 |
+
self.token_embedding = torch.nn.Embedding(
|
68 |
+
num_embeddings= self.hp.Tokens,
|
69 |
+
embedding_dim= self.hp.Encoder.Size
|
70 |
+
)
|
71 |
+
self.note_embedding = torch.nn.Embedding(
|
72 |
+
num_embeddings= self.hp.Notes,
|
73 |
+
embedding_dim= self.hp.Encoder.Size
|
74 |
+
)
|
75 |
+
self.duration_embedding = Duration_Positional_Encoding(
|
76 |
+
num_embeddings= self.hp.Durations,
|
77 |
+
embedding_dim= self.hp.Encoder.Size
|
78 |
+
)
|
79 |
+
self.genre_embedding = torch.nn.Embedding(
|
80 |
+
num_embeddings= self.hp.Genres,
|
81 |
+
embedding_dim= self.hp.Encoder.Size,
|
82 |
+
)
|
83 |
+
self.singer_embedding = torch.nn.Embedding(
|
84 |
+
num_embeddings= self.hp.Singers,
|
85 |
+
embedding_dim= self.hp.Encoder.Size,
|
86 |
+
)
|
87 |
+
torch.nn.init.xavier_uniform_(self.token_embedding.weight)
|
88 |
+
torch.nn.init.xavier_uniform_(self.note_embedding.weight)
|
89 |
+
torch.nn.init.xavier_uniform_(self.genre_embedding.weight)
|
90 |
+
torch.nn.init.xavier_uniform_(self.singer_embedding.weight)
|
91 |
+
|
92 |
+
self.fft_blocks = torch.nn.ModuleList([
|
93 |
+
FFT_Block(
|
94 |
+
channels= self.hp.Encoder.Size,
|
95 |
+
num_head= self.hp.Encoder.ConvFFT.Head,
|
96 |
+
ffn_kernel_size= self.hp.Encoder.ConvFFT.FFN.Kernel_Size,
|
97 |
+
dropout_rate= self.hp.Encoder.ConvFFT.Dropout_Rate
|
98 |
+
)
|
99 |
+
for _ in range(self.hp.Encoder.ConvFFT.Stack)
|
100 |
+
])
|
101 |
+
|
102 |
+
self.linear_projection = Conv1d(
|
103 |
+
in_channels= self.hp.Encoder.Size,
|
104 |
+
out_channels= self.feature_size,
|
105 |
+
kernel_size= 1,
|
106 |
+
bias= True,
|
107 |
+
w_init_gain= 'linear'
|
108 |
+
)
|
109 |
+
|
110 |
+
def forward(
|
111 |
+
self,
|
112 |
+
tokens: torch.Tensor,
|
113 |
+
notes: torch.Tensor,
|
114 |
+
durations: torch.Tensor,
|
115 |
+
lengths: torch.Tensor,
|
116 |
+
genres: torch.Tensor,
|
117 |
+
singers: torch.Tensor
|
118 |
+
):
|
119 |
+
x = \
|
120 |
+
self.token_embedding(tokens) + \
|
121 |
+
self.note_embedding(notes) + \
|
122 |
+
self.duration_embedding(durations) + \
|
123 |
+
self.genre_embedding(genres).unsqueeze(1) + \
|
124 |
+
self.singer_embedding(singers).unsqueeze(1)
|
125 |
+
x = x.permute(0, 2, 1) # [Batch, Enc_d, Enc_t]
|
126 |
+
|
127 |
+
for block in self.fft_blocks:
|
128 |
+
x = block(x, lengths) # [Batch, Enc_d, Enc_t]
|
129 |
+
|
130 |
+
linear_predictions = self.linear_projection(x) # [Batch, Feature_d, Enc_t]
|
131 |
+
|
132 |
+
return x, linear_predictions
|
133 |
+
|
134 |
+
class FFT_Block(torch.nn.Module):
|
135 |
+
def __init__(
|
136 |
+
self,
|
137 |
+
channels: int,
|
138 |
+
num_head: int,
|
139 |
+
ffn_kernel_size: int,
|
140 |
+
dropout_rate: float= 0.1,
|
141 |
+
) -> None:
|
142 |
+
super().__init__()
|
143 |
+
|
144 |
+
self.attention = LinearAttention(
|
145 |
+
channels= channels,
|
146 |
+
calc_channels= channels,
|
147 |
+
num_heads= num_head,
|
148 |
+
dropout_rate= dropout_rate
|
149 |
+
)
|
150 |
+
|
151 |
+
self.ffn = FFN(
|
152 |
+
channels= channels,
|
153 |
+
kernel_size= ffn_kernel_size,
|
154 |
+
dropout_rate= dropout_rate
|
155 |
+
)
|
156 |
+
|
157 |
+
def forward(
|
158 |
+
self,
|
159 |
+
x: torch.Tensor,
|
160 |
+
lengths: torch.Tensor
|
161 |
+
) -> torch.Tensor:
|
162 |
+
'''
|
163 |
+
x: [Batch, Dim, Time]
|
164 |
+
'''
|
165 |
+
masks = (~Mask_Generate(lengths= lengths, max_length= torch.ones_like(x[0, 0]).sum())).unsqueeze(1).float() # float mask
|
166 |
+
|
167 |
+
# Attention + Dropout + LayerNorm
|
168 |
+
x = self.attention(x)
|
169 |
+
|
170 |
+
# FFN + Dropout + LayerNorm
|
171 |
+
x = self.ffn(x, masks)
|
172 |
+
|
173 |
+
return x * masks
|
174 |
+
|
175 |
+
class FFN(torch.nn.Module):
|
176 |
+
def __init__(
|
177 |
+
self,
|
178 |
+
channels: int,
|
179 |
+
kernel_size: int,
|
180 |
+
dropout_rate: float= 0.1,
|
181 |
+
) -> None:
|
182 |
+
super().__init__()
|
183 |
+
self.conv_0 = Conv1d(
|
184 |
+
in_channels= channels,
|
185 |
+
out_channels= channels,
|
186 |
+
kernel_size= kernel_size,
|
187 |
+
padding= (kernel_size - 1) // 2,
|
188 |
+
w_init_gain= 'relu'
|
189 |
+
)
|
190 |
+
self.relu = torch.nn.ReLU()
|
191 |
+
self.dropout = torch.nn.Dropout(p= dropout_rate)
|
192 |
+
self.conv_1 = Conv1d(
|
193 |
+
in_channels= channels,
|
194 |
+
out_channels= channels,
|
195 |
+
kernel_size= kernel_size,
|
196 |
+
padding= (kernel_size - 1) // 2,
|
197 |
+
w_init_gain= 'linear'
|
198 |
+
)
|
199 |
+
self.norm = LayerNorm(
|
200 |
+
num_features= channels,
|
201 |
+
)
|
202 |
+
|
203 |
+
def forward(
|
204 |
+
self,
|
205 |
+
x: torch.Tensor,
|
206 |
+
masks: torch.Tensor
|
207 |
+
) -> torch.Tensor:
|
208 |
+
'''
|
209 |
+
x: [Batch, Dim, Time]
|
210 |
+
'''
|
211 |
+
residuals = x
|
212 |
+
|
213 |
+
x = self.conv_0(x * masks)
|
214 |
+
x = self.relu(x)
|
215 |
+
x = self.dropout(x)
|
216 |
+
x = self.conv_1(x * masks)
|
217 |
+
x = self.dropout(x)
|
218 |
+
x = self.norm(x + residuals)
|
219 |
+
|
220 |
+
return x * masks
|
221 |
+
|
222 |
+
# https://pytorch.org/tutorials/beginner/transformer_tutorial.html
|
223 |
+
# https://github.com/soobinseo/Transformer-TTS/blob/master/network.py
|
224 |
+
class Duration_Positional_Encoding(torch.nn.Embedding):
|
225 |
+
def __init__(
|
226 |
+
self,
|
227 |
+
num_embeddings: int,
|
228 |
+
embedding_dim: int,
|
229 |
+
):
|
230 |
+
positional_embedding = torch.zeros(num_embeddings, embedding_dim)
|
231 |
+
position = torch.arange(0, num_embeddings, dtype=torch.float).unsqueeze(1)
|
232 |
+
div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim))
|
233 |
+
positional_embedding[:, 0::2] = torch.sin(position * div_term)
|
234 |
+
positional_embedding[:, 1::2] = torch.cos(position * div_term)
|
235 |
+
super().__init__(
|
236 |
+
num_embeddings= num_embeddings,
|
237 |
+
embedding_dim= embedding_dim,
|
238 |
+
_weight= positional_embedding
|
239 |
+
)
|
240 |
+
self.weight.requires_grad = False
|
241 |
+
|
242 |
+
self.alpha = torch.nn.Parameter(
|
243 |
+
data= torch.ones(1) * 0.01,
|
244 |
+
requires_grad= True
|
245 |
+
)
|
246 |
+
|
247 |
+
def forward(self, durations):
|
248 |
+
'''
|
249 |
+
durations: [Batch, Length]
|
250 |
+
'''
|
251 |
+
return self.alpha * super().forward(durations) # [Batch, Dim, Length]
|
252 |
+
|
253 |
+
@torch.jit.script
|
254 |
+
def get_pe(x: torch.Tensor, pe: torch.Tensor):
|
255 |
+
pe = pe.repeat(1, 1, math.ceil(x.size(2) / pe.size(2)))
|
256 |
+
return pe[:, :, :x.size(2)]
|
257 |
+
|
258 |
+
def Mask_Generate(lengths: torch.Tensor, max_length: Union[torch.Tensor, int, None]= None):
|
259 |
+
'''
|
260 |
+
lengths: [Batch]
|
261 |
+
max_lengths: an int value. If None, max_lengths == max(lengths)
|
262 |
+
'''
|
263 |
+
max_length = max_length or torch.max(lengths)
|
264 |
+
sequence = torch.arange(max_length)[None, :].to(lengths.device)
|
265 |
+
return sequence >= lengths[:, None] # [Batch, Time]
|
Pattern_Generator.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import mido, os, pickle, yaml, argparse, math, librosa, hgtk, logging
|
3 |
+
from tqdm import tqdm
|
4 |
+
from pysptk.sptk import rapt
|
5 |
+
from typing import List, Tuple
|
6 |
+
from argparse import Namespace # for type
|
7 |
+
import torch
|
8 |
+
from typing import Dict
|
9 |
+
|
10 |
+
from meldataset import mel_spectrogram, spectrogram, spec_energy
|
11 |
+
from Arg_Parser import Recursive_Parse
|
12 |
+
|
13 |
+
def Convert_Feature_Based_Music(
|
14 |
+
music: List[Tuple[float, str, int]],
|
15 |
+
sample_rate: int,
|
16 |
+
frame_shift: int,
|
17 |
+
consonant_duration: int= 3,
|
18 |
+
equality_duration: bool= False
|
19 |
+
):
|
20 |
+
previous_used = 0
|
21 |
+
lyrics = []
|
22 |
+
notes = []
|
23 |
+
durations = []
|
24 |
+
for message_time, lyric, note in music:
|
25 |
+
duration = round(message_time * sample_rate) + previous_used
|
26 |
+
previous_used = duration % frame_shift
|
27 |
+
duration = duration // frame_shift
|
28 |
+
|
29 |
+
if lyric == '<X>':
|
30 |
+
lyrics.append(lyric)
|
31 |
+
notes.append(note)
|
32 |
+
durations.append(duration)
|
33 |
+
else:
|
34 |
+
lyrics.extend(Decompose(lyric))
|
35 |
+
notes.extend([note] * 3)
|
36 |
+
if equality_duration or duration < consonant_duration * 3:
|
37 |
+
split_duration = [duration // 3] * 3
|
38 |
+
split_duration[1] += duration % 3
|
39 |
+
durations.extend(split_duration)
|
40 |
+
else:
|
41 |
+
durations.extend([
|
42 |
+
consonant_duration, # onset
|
43 |
+
duration - consonant_duration * 2, # nucleus
|
44 |
+
consonant_duration # coda
|
45 |
+
])
|
46 |
+
|
47 |
+
return lyrics, notes, durations
|
48 |
+
|
49 |
+
def Expand_by_Duration(
|
50 |
+
lyrics: List[str],
|
51 |
+
notes: List[int],
|
52 |
+
durations: List[int],
|
53 |
+
):
|
54 |
+
lyrics = sum([[lyric] * duration for lyric, duration in zip(lyrics, durations)], [])
|
55 |
+
notes = sum([*[[note] * duration for note, duration in zip(notes, durations)]], [])
|
56 |
+
durations = [index for duration in durations for index in range(duration)]
|
57 |
+
|
58 |
+
return lyrics, notes, durations
|
59 |
+
|
60 |
+
def Decompose(syllable: str):
|
61 |
+
onset, nucleus, coda = hgtk.letter.decompose(syllable)
|
62 |
+
coda += '_'
|
63 |
+
|
64 |
+
return onset, nucleus, coda
|
README.md
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.17.0
|
8 |
app_file: app.py
|
|
|
1 |
---
|
2 |
+
title: Diffsvs
|
3 |
+
emoji: 🐢
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: blue
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.17.0
|
8 |
app_file: app.py
|
YAML/Genre_Info.yaml
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Children: 0
|
YAML/Log_Energy_Info.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
CSD:
|
2 |
+
Mean: 3.540642499923706
|
3 |
+
Std: 2.1372854709625244
|
YAML/Log_F0_Info.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
CSD:
|
2 |
+
Mean: 5.851496696472168
|
3 |
+
Std: 0.2526451647281647
|
YAML/Mel_Range_Info.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
CSD:
|
2 |
+
Max: 2.6226840019226074
|
3 |
+
Min: -11.512925148010254
|
YAML/Singer_Info.yaml
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
CSD: 0
|
YAML/Spectrogram_Range_Info.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
CSD:
|
2 |
+
Max: 5.292316913604736
|
3 |
+
Min: -10.36163330078125
|
YAML/Token.yaml
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<E>: 1
|
2 |
+
<S>: 0
|
3 |
+
<X>: 2
|
4 |
+
_: 3
|
5 |
+
"\u3131": 4
|
6 |
+
"\u3131_": 5
|
7 |
+
"\u3132": 6
|
8 |
+
"\u3132_": 7
|
9 |
+
"\u3133_": 8
|
10 |
+
"\u3134": 9
|
11 |
+
"\u3134_": 10
|
12 |
+
"\u3135_": 11
|
13 |
+
"\u3136_": 12
|
14 |
+
"\u3137": 13
|
15 |
+
"\u3137_": 14
|
16 |
+
"\u3138": 15
|
17 |
+
"\u3139": 16
|
18 |
+
"\u3139_": 17
|
19 |
+
"\u313A_": 18
|
20 |
+
"\u313B_": 19
|
21 |
+
"\u313C_": 20
|
22 |
+
"\u313D_": 21
|
23 |
+
"\u313E_": 22
|
24 |
+
"\u313F_": 23
|
25 |
+
"\u3140_": 24
|
26 |
+
"\u3141": 25
|
27 |
+
"\u3141_": 26
|
28 |
+
"\u3142": 27
|
29 |
+
"\u3142_": 28
|
30 |
+
"\u3143": 29
|
31 |
+
"\u3144_": 30
|
32 |
+
"\u3145": 31
|
33 |
+
"\u3145_": 32
|
34 |
+
"\u3146": 33
|
35 |
+
"\u3146_": 34
|
36 |
+
"\u3147": 35
|
37 |
+
"\u3147_": 36
|
38 |
+
"\u3148": 37
|
39 |
+
"\u3148_": 38
|
40 |
+
"\u3149": 39
|
41 |
+
"\u314A": 40
|
42 |
+
"\u314A_": 41
|
43 |
+
"\u314B": 42
|
44 |
+
"\u314B_": 43
|
45 |
+
"\u314C": 44
|
46 |
+
"\u314C_": 45
|
47 |
+
"\u314D": 46
|
48 |
+
"\u314D_": 47
|
49 |
+
"\u314E": 48
|
50 |
+
"\u314E_": 49
|
51 |
+
"\u314F": 50
|
52 |
+
"\u3150": 51
|
53 |
+
"\u3151": 52
|
54 |
+
"\u3152": 53
|
55 |
+
"\u3153": 54
|
56 |
+
"\u3154": 55
|
57 |
+
"\u3155": 56
|
58 |
+
"\u3156": 57
|
59 |
+
"\u3157": 58
|
60 |
+
"\u3158": 59
|
61 |
+
"\u3159": 60
|
62 |
+
"\u315A": 61
|
63 |
+
"\u315B": 62
|
64 |
+
"\u315C": 63
|
65 |
+
"\u315D": 64
|
66 |
+
"\u315E": 65
|
67 |
+
"\u315F": 66
|
68 |
+
"\u3160": 67
|
69 |
+
"\u3161": 68
|
70 |
+
"\u3162": 69
|
71 |
+
"\u3163": 70
|
app.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
from Inference import Inferencer
|
4 |
+
|
5 |
+
def app_diffsingerkr():
|
6 |
+
if not 'diffsingerkr_duration' in st.session_state.keys():
|
7 |
+
st.session_state.diffsingerkr_duration = ''
|
8 |
+
if not 'diffsingerkr_lyric' in st.session_state.keys():
|
9 |
+
st.session_state.diffsingerkr_lyric = ''
|
10 |
+
if not 'diffsingerkr_note' in st.session_state.keys():
|
11 |
+
st.session_state.diffsingerkr_note = ''
|
12 |
+
if not 'inferencer' in st.session_state.keys():
|
13 |
+
st.session_state.inferencer = Inferencer(
|
14 |
+
hp_path= 'Hyper_Parameters.yaml',
|
15 |
+
checkpoint_path= 'Checkpoint/S_200000.pt',
|
16 |
+
batch_size= 1
|
17 |
+
)
|
18 |
+
|
19 |
+
st.title('DiffSinger-KR')
|
20 |
+
st.markdown('* This code is an implementation of DiffSinger for Korean.')
|
21 |
+
st.markdown('* When music score which is note, duration, and lyric information are entered, singing voices are synthesized accordingly.')
|
22 |
+
st.markdown('* Due to the range of the trained dataset, the supported notes are between 65 and 89.')
|
23 |
+
st.markdown('* Please refer to the [here](https://github.com/CODEJIN/DiffSingerKR) for the source code for training the model.')
|
24 |
+
|
25 |
+
st.markdown('''---''')
|
26 |
+
status_indicator = st.empty()
|
27 |
+
status_indicator.header('Insert the music!')
|
28 |
+
st.markdown('''---''')
|
29 |
+
example1_col, example2_col, example3_col, _ = st.columns(4)
|
30 |
+
if example1_col.button('Example 1'):
|
31 |
+
st.session_state.diffsingerkr_duration = '0.52,0.17,0.35,0.35,0.35,0.35,0.70,0.35,0.35,0.70,0.35,0.35,0.70,0.52,0.17,0.35,0.35,0.35,0.35,0.70,0.35,0.35,0.35,0.35,1.39'
|
32 |
+
st.session_state.diffsingerkr_lyric = '떴,다,떴,다,비,행,기,날,아,라,날,아,라,높,이,높,이,날,아,라,우,리,비,행,기'
|
33 |
+
st.session_state.diffsingerkr_note = '76,74,72,74,76,76,76,74,74,74,76,79,79,76,74,72,74,76,76,76,74,74,76,74,72'
|
34 |
+
st.experimental_rerun()
|
35 |
+
if example2_col.button('Example 2'):
|
36 |
+
st.session_state.diffsingerkr_duration = '0.53,0.52,0.50,0.57,0.58,0.46,0.48,0.50,0.37,0.13,0.43,0.21,0.57,0.43,0.49,1.44,0.26,0.49,0.14,0.13,0.57,0.26,0.06,0.15,0.63,0.26,0.51,0.20,0.48,0.72,0.22'
|
37 |
+
st.session_state.diffsingerkr_lyric = '만,나,고,<X>,난,외,로,움,을,<X>,알,았,어,내,겐,<X>,관,심,조,<X>,차,<X>,없,<X>,다,는,걸,<X>,알,면,서'
|
38 |
+
st.session_state.diffsingerkr_note = '76,78,79,0,71,74,72,71,72,0,71,69,69,71,74,0,79,78,79,0,71,0,74,0,74,72,72,0,71,71,69'
|
39 |
+
st.experimental_rerun()
|
40 |
+
if example3_col.button('Example 3'):
|
41 |
+
st.session_state.diffsingerkr_duration = '0.33,0.16,0.33,0.49,0.33,0.16,0.81,0.33,0.16,0.16,0.33,0.16,0.49,0.16,0.82,0.33,0.16,0.33,0.49,0.33,0.16,0.33,0.49,0.33,0.33,0.16,0.33,1.47,0.33,0.16,0.33,0.49,0.33,0.16,0.81,0.33,0.16,0.16,0.33,0.16,0.49,0.16,0.82,0.33,0.16,0.33,0.16,0.33,0.49,0.16,0.33,0.33,0.33,0.33,0.16,0.33,0.82'
|
42 |
+
st.session_state.diffsingerkr_lyric = '마,음,울,적,한,날,에,<X>,거,리,를,걸,어,보,고,향,기,로,운,칵,테,일,에,취,해,도,보,고,한,편,의,시,가,있,는,<X>,전,시,회,장,도,가,고,밤,새,도,<X>,록,그,리,움,에,편,질,쓰,고,파'
|
43 |
+
st.session_state.diffsingerkr_note = '80,80,80,87,85,84,82,0,84,84,84,85,84,79,79,77,77,77,80,80,78,77,75,77,80,79,80,82,80,80,80,87,85,84,82,0,84,84,84,85,84,79,79,77,77,77,79,80,80,77,75,75,77,80,79,82,80'
|
44 |
+
st.experimental_rerun()
|
45 |
+
st.markdown('''---''')
|
46 |
+
duration = st.text_input('Duration', value= st.session_state.diffsingerkr_duration)
|
47 |
+
lyric = st.text_input('Lyric', value= st.session_state.diffsingerkr_lyric)
|
48 |
+
note = st.text_input('Note', value= st.session_state.diffsingerkr_note)
|
49 |
+
singer = 'CSD'
|
50 |
+
genre = 'Children'
|
51 |
+
key_adjustment = st.select_slider(
|
52 |
+
label= 'Key adjustment',
|
53 |
+
options= [x for x in range(-6, 7)],
|
54 |
+
value= 0
|
55 |
+
)
|
56 |
+
|
57 |
+
if st.button("Generate!"):
|
58 |
+
if duration != '' and lyric != '' and note != '':
|
59 |
+
status_indicator.header('Generating...')
|
60 |
+
audio = st.session_state.inferencer.Inference_Epoch(
|
61 |
+
message_times_list= [[float(x) for x in duration.strip().split(',')]],
|
62 |
+
lyrics= [[x for x in lyric.strip().split(',')]],
|
63 |
+
notes= [[
|
64 |
+
(int(x) + key_adjustment if int(x) != 0 else int(x))
|
65 |
+
for x in note.strip().split(',')
|
66 |
+
]],
|
67 |
+
singers= [singer],
|
68 |
+
genres= [genre]
|
69 |
+
)[0]
|
70 |
+
|
71 |
+
st.audio(
|
72 |
+
audio,
|
73 |
+
format="audio/wav",
|
74 |
+
start_time=0,
|
75 |
+
sample_rate= st.session_state.inferencer.hp.Sound.Sample_Rate
|
76 |
+
)
|
77 |
+
|
78 |
+
status_indicator.header('Done.')
|
79 |
+
|
80 |
+
if __name__ == '__main__':
|
81 |
+
app_diffsingerkr()
|
meldataset.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###############################################################################
|
2 |
+
# MIT License
|
3 |
+
#
|
4 |
+
# Copyright (c) 2020 Jungil Kong
|
5 |
+
#
|
6 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
7 |
+
# of this software and associated documentation files (the "Software"), to deal
|
8 |
+
# in the Software without restriction, including without limitation the rights
|
9 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
10 |
+
# copies of the Software, and to permit persons to whom the Software is
|
11 |
+
# furnished to do so, subject to the following conditions:
|
12 |
+
#
|
13 |
+
# The above copyright notice and this permission notice shall be included in all
|
14 |
+
# copies or substantial portions of the Software.
|
15 |
+
#
|
16 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
17 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
18 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
19 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
20 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
21 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
22 |
+
# SOFTWARE.
|
23 |
+
###############################################################################
|
24 |
+
|
25 |
+
import math
|
26 |
+
import os
|
27 |
+
import random
|
28 |
+
import torch
|
29 |
+
import torch.utils.data
|
30 |
+
import numpy as np
|
31 |
+
from librosa.util import normalize
|
32 |
+
from scipy.io.wavfile import read
|
33 |
+
from librosa.filters import mel as librosa_mel_fn
|
34 |
+
|
35 |
+
MAX_WAV_VALUE = 32768.0
|
36 |
+
|
37 |
+
|
38 |
+
def load_wav(full_path):
|
39 |
+
sampling_rate, data = read(full_path)
|
40 |
+
return data, sampling_rate
|
41 |
+
|
42 |
+
|
43 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
44 |
+
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
45 |
+
|
46 |
+
|
47 |
+
def dynamic_range_decompression(x, C=1):
|
48 |
+
return np.exp(x) / C
|
49 |
+
|
50 |
+
|
51 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
52 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
53 |
+
|
54 |
+
|
55 |
+
def dynamic_range_decompression_torch(x, C=1):
|
56 |
+
return torch.exp(x) / C
|
57 |
+
|
58 |
+
|
59 |
+
def spectral_normalize_torch(magnitudes):
|
60 |
+
output = dynamic_range_compression_torch(magnitudes)
|
61 |
+
return output
|
62 |
+
|
63 |
+
|
64 |
+
def spectral_de_normalize_torch(magnitudes):
|
65 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
66 |
+
return output
|
67 |
+
|
68 |
+
|
69 |
+
mel_basis = {}
|
70 |
+
hann_window = {}
|
71 |
+
|
72 |
+
|
73 |
+
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
74 |
+
if torch.min(y) < -1.:
|
75 |
+
print('min value is ', torch.min(y))
|
76 |
+
if torch.max(y) > 1.:
|
77 |
+
print('max value is ', torch.max(y))
|
78 |
+
|
79 |
+
global mel_basis, hann_window
|
80 |
+
if fmax not in mel_basis:
|
81 |
+
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
|
82 |
+
mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
83 |
+
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
84 |
+
|
85 |
+
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
|
86 |
+
y = y.squeeze(1)
|
87 |
+
|
88 |
+
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
|
89 |
+
center=center, pad_mode='reflect', normalized=False, onesided=True)
|
90 |
+
|
91 |
+
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
|
92 |
+
|
93 |
+
spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
|
94 |
+
spec = spectral_normalize_torch(spec)
|
95 |
+
|
96 |
+
return spec
|
97 |
+
|
98 |
+
def spectrogram(y, n_fft, hop_size, win_size, center=False):
|
99 |
+
if torch.min(y) < -1.:
|
100 |
+
print('min value is ', torch.min(y))
|
101 |
+
if torch.max(y) > 1.:
|
102 |
+
print('max value is ', torch.max(y))
|
103 |
+
|
104 |
+
global hann_window
|
105 |
+
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
106 |
+
|
107 |
+
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
|
108 |
+
y = y.squeeze(1)
|
109 |
+
|
110 |
+
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
|
111 |
+
center=center, pad_mode='reflect', normalized=False, onesided=True)
|
112 |
+
|
113 |
+
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
|
114 |
+
spec = spectral_normalize_torch(spec)
|
115 |
+
|
116 |
+
return spec
|
117 |
+
|
118 |
+
def spec_energy(y, n_fft, hop_size, win_size, center=False):
|
119 |
+
if torch.min(y) < -1.:
|
120 |
+
print('min value is ', torch.min(y))
|
121 |
+
if torch.max(y) > 1.:
|
122 |
+
print('max value is ', torch.max(y))
|
123 |
+
|
124 |
+
global hann_window
|
125 |
+
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
126 |
+
|
127 |
+
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
|
128 |
+
y = y.squeeze(1)
|
129 |
+
|
130 |
+
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
|
131 |
+
center=center, pad_mode='reflect', normalized=False, onesided=True)
|
132 |
+
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
|
133 |
+
energy = torch.norm(spec, dim= 1)
|
134 |
+
|
135 |
+
return energy
|
136 |
+
|
137 |
+
def get_dataset_filelist(a):
|
138 |
+
with open(a.input_training_file, 'r', encoding='utf-8') as fi:
|
139 |
+
training_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav')
|
140 |
+
for x in fi.read().split('\n') if len(x) > 0]
|
141 |
+
|
142 |
+
with open(a.input_validation_file, 'r', encoding='utf-8') as fi:
|
143 |
+
validation_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav')
|
144 |
+
for x in fi.read().split('\n') if len(x) > 0]
|
145 |
+
return training_files, validation_files
|
146 |
+
|
147 |
+
|
148 |
+
class MelDataset(torch.utils.data.Dataset):
|
149 |
+
def __init__(self, training_files, segment_size, n_fft, num_mels,
|
150 |
+
hop_size, win_size, sampling_rate, fmin, fmax, split=True, shuffle=True, n_cache_reuse=1,
|
151 |
+
device=None, fmax_loss=None, fine_tuning=False, base_mels_path=None):
|
152 |
+
self.audio_files = training_files
|
153 |
+
random.seed(1234)
|
154 |
+
if shuffle:
|
155 |
+
random.shuffle(self.audio_files)
|
156 |
+
self.segment_size = segment_size
|
157 |
+
self.sampling_rate = sampling_rate
|
158 |
+
self.split = split
|
159 |
+
self.n_fft = n_fft
|
160 |
+
self.num_mels = num_mels
|
161 |
+
self.hop_size = hop_size
|
162 |
+
self.win_size = win_size
|
163 |
+
self.fmin = fmin
|
164 |
+
self.fmax = fmax
|
165 |
+
self.fmax_loss = fmax_loss
|
166 |
+
self.cached_wav = None
|
167 |
+
self.n_cache_reuse = n_cache_reuse
|
168 |
+
self._cache_ref_count = 0
|
169 |
+
self.device = device
|
170 |
+
self.fine_tuning = fine_tuning
|
171 |
+
self.base_mels_path = base_mels_path
|
172 |
+
|
173 |
+
def __getitem__(self, index):
|
174 |
+
filename = self.audio_files[index]
|
175 |
+
if self._cache_ref_count == 0:
|
176 |
+
audio, sampling_rate = load_wav(filename)
|
177 |
+
audio = audio / MAX_WAV_VALUE
|
178 |
+
if not self.fine_tuning:
|
179 |
+
audio = normalize(audio) * 0.95
|
180 |
+
self.cached_wav = audio
|
181 |
+
if sampling_rate != self.sampling_rate:
|
182 |
+
raise ValueError("{} SR doesn't match target {} SR".format(
|
183 |
+
sampling_rate, self.sampling_rate))
|
184 |
+
self._cache_ref_count = self.n_cache_reuse
|
185 |
+
else:
|
186 |
+
audio = self.cached_wav
|
187 |
+
self._cache_ref_count -= 1
|
188 |
+
|
189 |
+
audio = torch.FloatTensor(audio)
|
190 |
+
audio = audio.unsqueeze(0)
|
191 |
+
|
192 |
+
if not self.fine_tuning:
|
193 |
+
if self.split:
|
194 |
+
if audio.size(1) >= self.segment_size:
|
195 |
+
max_audio_start = audio.size(1) - self.segment_size
|
196 |
+
audio_start = random.randint(0, max_audio_start)
|
197 |
+
audio = audio[:, audio_start:audio_start+self.segment_size]
|
198 |
+
else:
|
199 |
+
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
|
200 |
+
|
201 |
+
mel = mel_spectrogram(audio, self.n_fft, self.num_mels,
|
202 |
+
self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax,
|
203 |
+
center=False)
|
204 |
+
else:
|
205 |
+
mel = np.load(
|
206 |
+
os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + '.npy'))
|
207 |
+
mel = torch.from_numpy(mel)
|
208 |
+
|
209 |
+
if len(mel.shape) < 3:
|
210 |
+
mel = mel.unsqueeze(0)
|
211 |
+
|
212 |
+
if self.split:
|
213 |
+
frames_per_seg = math.ceil(self.segment_size / self.hop_size)
|
214 |
+
|
215 |
+
if audio.size(1) >= self.segment_size:
|
216 |
+
mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
|
217 |
+
mel = mel[:, :, mel_start:mel_start + frames_per_seg]
|
218 |
+
audio = audio[:, mel_start * self.hop_size:(mel_start + frames_per_seg) * self.hop_size]
|
219 |
+
else:
|
220 |
+
mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), 'constant')
|
221 |
+
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
|
222 |
+
|
223 |
+
mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels,
|
224 |
+
self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss,
|
225 |
+
center=False)
|
226 |
+
|
227 |
+
return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
|
228 |
+
|
229 |
+
def __len__(self):
|
230 |
+
return len(self.audio_files)
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
torch
|
3 |
+
librosa
|
4 |
+
mido
|
5 |
+
hgtk
|
6 |
+
pysptk
|
7 |
+
matplotlib
|
vocoder.pts
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5b47a5d03d744861f94ee973294317f738ccc6dc6d27bafa5d8db5ed18f95566
|
3 |
+
size 55884400
|