Respair commited on
Commit
302641c
·
verified ·
1 Parent(s): 9c3ea77

Upload Sana/models.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. Sana/models.py +1011 -0
Sana/models.py ADDED
@@ -0,0 +1,1011 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+
4
+ import copy
5
+ import math
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
+
13
+ from torchaudio.models import Conformer
14
+
15
+ from Utils.ASR.models import ASRCNN
16
+ from Utils.JDC.model import JDCNet
17
+
18
+
19
+ from transformers import AutoModelForSequenceClassification, PreTrainedModel, AutoConfig, AutoModel, AutoTokenizer
20
+
21
+ from Modules.KotoDama_sampler import KotoDama_Prompt, KotoDama_Text
22
+ from Modules.diffusion.sampler import KDiffusion, LogNormalDistribution
23
+ from Modules.diffusion.modules import Transformer1d, StyleTransformer1d
24
+ from Modules.diffusion.diffusion import AudioDiffusionConditional
25
+ from Modules.diffusion.audio_diffusion_pytorch import DiffusionModel, UNetV0, VDiffusion, VSampler, DiffusionUpsampler
26
+
27
+ from Modules.discriminators import MultiPeriodDiscriminator, MultiResSpecDiscriminator, WavLMDiscriminator
28
+
29
+ from munch import Munch
30
+ import yaml
31
+
32
+ # from hflayers import Hopfield, HopfieldPooling, HopfieldLayer
33
+ # from hflayers.auxiliary.data import BitPatternSet
34
+
35
+ # Import auxiliary modules.
36
+ from distutils.version import LooseVersion
37
+ from typing import List, Tuple
38
+
39
+ import math
40
+ # from liger_kernel.ops.layer_norm import LigerLayerNormFunction
41
+ # from liger_kernel.transformers.experimental.embedding import nn.Embedding
42
+
43
+ import torch
44
+
45
+ from xlstm import (
46
+ xLSTMBlockStack,
47
+ xLSTMBlockStackConfig,
48
+ mLSTMBlockConfig,
49
+ mLSTMLayerConfig,
50
+ sLSTMBlockConfig,
51
+ sLSTMLayerConfig,
52
+ FeedForwardConfig,
53
+ )
54
+
55
+
56
+
57
+ class LearnedDownSample(nn.Module):
58
+ def __init__(self, layer_type, dim_in):
59
+ super().__init__()
60
+ self.layer_type = layer_type
61
+
62
+ if self.layer_type == 'none':
63
+ self.conv = nn.Identity()
64
+ elif self.layer_type == 'timepreserve':
65
+ self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, padding=(1, 0)))
66
+ elif self.layer_type == 'half':
67
+ self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, padding=1))
68
+ else:
69
+ raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
70
+
71
+ def forward(self, x):
72
+ return self.conv(x)
73
+
74
+ class LearnedUpSample(nn.Module):
75
+ def __init__(self, layer_type, dim_in):
76
+ super().__init__()
77
+ self.layer_type = layer_type
78
+
79
+ if self.layer_type == 'none':
80
+ self.conv = nn.Identity()
81
+ elif self.layer_type == 'timepreserve':
82
+ self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, output_padding=(1, 0), padding=(1, 0))
83
+ elif self.layer_type == 'half':
84
+ self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, output_padding=1, padding=1)
85
+ else:
86
+ raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
87
+
88
+
89
+ def forward(self, x):
90
+ return self.conv(x)
91
+
92
+ class DownSample(nn.Module):
93
+ def __init__(self, layer_type):
94
+ super().__init__()
95
+ self.layer_type = layer_type
96
+
97
+ def forward(self, x):
98
+ if self.layer_type == 'none':
99
+ return x
100
+ elif self.layer_type == 'timepreserve':
101
+ return F.avg_pool2d(x, (2, 1))
102
+ elif self.layer_type == 'half':
103
+ if x.shape[-1] % 2 != 0:
104
+ x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
105
+ return F.avg_pool2d(x, 2)
106
+ else:
107
+ raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
108
+
109
+
110
+ class UpSample(nn.Module):
111
+ def __init__(self, layer_type):
112
+ super().__init__()
113
+ self.layer_type = layer_type
114
+
115
+ def forward(self, x):
116
+ if self.layer_type == 'none':
117
+ return x
118
+ elif self.layer_type == 'timepreserve':
119
+ return F.interpolate(x, scale_factor=(2, 1), mode='nearest')
120
+ elif self.layer_type == 'half':
121
+ return F.interpolate(x, scale_factor=2, mode='nearest')
122
+ else:
123
+ raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
124
+
125
+
126
+ class ResBlk(nn.Module):
127
+ def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
128
+ normalize=False, downsample='none'):
129
+ super().__init__()
130
+ self.actv = actv
131
+ self.normalize = normalize
132
+ self.downsample = DownSample(downsample)
133
+ self.downsample_res = LearnedDownSample(downsample, dim_in)
134
+ self.learned_sc = dim_in != dim_out
135
+ self._build_weights(dim_in, dim_out)
136
+
137
+ def _build_weights(self, dim_in, dim_out):
138
+ self.conv1 = spectral_norm(nn.Conv2d(dim_in, dim_in, 3, 1, 1))
139
+ self.conv2 = spectral_norm(nn.Conv2d(dim_in, dim_out, 3, 1, 1))
140
+ if self.normalize:
141
+ self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
142
+ self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
143
+ if self.learned_sc:
144
+ self.conv1x1 = spectral_norm(nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False))
145
+
146
+ def _shortcut(self, x):
147
+ if self.learned_sc:
148
+ x = self.conv1x1(x)
149
+ if self.downsample:
150
+ x = self.downsample(x)
151
+ return x
152
+
153
+ def _residual(self, x):
154
+ if self.normalize:
155
+ x = self.norm1(x)
156
+ x = self.actv(x)
157
+ x = self.conv1(x)
158
+ x = self.downsample_res(x)
159
+ if self.normalize:
160
+ x = self.norm2(x)
161
+ x = self.actv(x)
162
+ x = self.conv2(x)
163
+ return x
164
+
165
+ def forward(self, x):
166
+ x = self._shortcut(x) + self._residual(x)
167
+ return x / math.sqrt(2) # unit variance
168
+
169
+ # class StyleEncoder(nn.Module):
170
+ # def __init__(self, dim_in=48, style_dim=48, max_conv_dim=384):
171
+ # super().__init__()
172
+ # blocks = []
173
+ # blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
174
+
175
+ # repeat_num = 4
176
+ # for _ in range(repeat_num):
177
+ # dim_out = min(dim_in*2, max_conv_dim)
178
+ # blocks += [ResBlk(dim_in, dim_out, downsample='half')]
179
+ # dim_in = dim_out
180
+
181
+ # blocks += [nn.LeakyReLU(0.2)]
182
+ # blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
183
+ # blocks += [nn.AdaptiveAvgPool2d(1)]
184
+ # blocks += [nn.LeakyReLU(0.2)]
185
+ # self.shared = nn.Sequential(*blocks)
186
+
187
+ # self.unshared = nn.Linear(dim_out, style_dim)
188
+
189
+ # def forward(self, x):
190
+ # h = self.shared(x)
191
+ # h = h.view(h.size(0), -1)
192
+ # s = self.unshared(h)
193
+
194
+ # return s
195
+
196
+
197
+ class StyleEncoder(nn.Module):
198
+ def __init__(self, mel_dim=80, hidden_dim=512, style_dim=128, num_heads=8, num_layers=6):
199
+ super().__init__()
200
+
201
+ self.mel_proj = nn.Conv1d(mel_dim, hidden_dim, kernel_size=3, padding=1)
202
+
203
+
204
+ self.conformer_pre = Conformer(
205
+ input_dim=hidden_dim,
206
+ num_heads=num_heads,
207
+ ffn_dim=hidden_dim * 2,
208
+ num_layers=1,
209
+ depthwise_conv_kernel_size=31,
210
+ use_group_norm=True,
211
+ )
212
+ self.conformer_body = Conformer(
213
+ input_dim=hidden_dim,
214
+ num_heads=num_heads,
215
+ ffn_dim=hidden_dim * 2,
216
+ num_layers=num_layers - 1,
217
+ depthwise_conv_kernel_size=15,
218
+ use_group_norm=True,
219
+ )
220
+
221
+ self.out = nn.Linear(hidden_dim, style_dim)
222
+
223
+ def forward(self, mel):
224
+
225
+ mel = self.mel_proj(mel)
226
+
227
+ mel_len = mel.size(-1) # length of mel
228
+
229
+ batch_size = mel.size(0)
230
+ input_lengths = torch.full((batch_size,), mel_len, device=mel.device)
231
+
232
+ x, output_lengths = self.conformer_pre(mel.transpose(-1, -2), input_lengths)
233
+ x, output_lengths = self.conformer_body(x, input_lengths)
234
+
235
+ x = x.transpose(-1, -2)
236
+
237
+ s = self.out(x.mean(dim=-1))
238
+
239
+ return s
240
+
241
+
242
+ class LinearNorm(torch.nn.Module):
243
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
244
+ super(LinearNorm, self).__init__()
245
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
246
+
247
+ torch.nn.init.xavier_uniform_(
248
+ self.linear_layer.weight,
249
+ gain=torch.nn.init.calculate_gain(w_init_gain))
250
+
251
+ def forward(self, x):
252
+ return self.linear_layer(x)
253
+
254
+ class Discriminator2d(nn.Module):
255
+ def __init__(self, dim_in=48, num_domains=1, max_conv_dim=384, repeat_num=4):
256
+ super().__init__()
257
+ blocks = []
258
+ blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
259
+
260
+ for lid in range(repeat_num):
261
+ dim_out = min(dim_in*2, max_conv_dim)
262
+ blocks += [ResBlk(dim_in, dim_out, downsample='half')]
263
+ dim_in = dim_out
264
+
265
+ blocks += [nn.LeakyReLU(0.2)]
266
+ blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
267
+ blocks += [nn.LeakyReLU(0.2)]
268
+ blocks += [nn.AdaptiveAvgPool2d(1)]
269
+ blocks += [spectral_norm(nn.Conv2d(dim_out, num_domains, 1, 1, 0))]
270
+ self.main = nn.Sequential(*blocks)
271
+
272
+ def get_feature(self, x):
273
+ features = []
274
+ for l in self.main:
275
+ x = l(x)
276
+ features.append(x)
277
+ out = features[-1]
278
+ out = out.view(out.size(0), -1) # (batch, num_domains)
279
+ return out, features
280
+
281
+ def forward(self, x):
282
+ out, features = self.get_feature(x)
283
+ out = out.squeeze() # (batch)
284
+ return out, features
285
+
286
+ class ResBlk1d(nn.Module):
287
+ def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
288
+ normalize=False, downsample='none', dropout_p=0.2):
289
+ super().__init__()
290
+ self.actv = actv
291
+ self.normalize = normalize
292
+ self.downsample_type = downsample
293
+ self.learned_sc = dim_in != dim_out
294
+ self._build_weights(dim_in, dim_out)
295
+ self.dropout_p = dropout_p
296
+
297
+ if self.downsample_type == 'none':
298
+ self.pool = nn.Identity()
299
+ else:
300
+ self.pool = weight_norm(nn.Conv1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1))
301
+
302
+ def _build_weights(self, dim_in, dim_out):
303
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_in, 3, 1, 1))
304
+ self.conv2 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
305
+ if self.normalize:
306
+ self.norm1 = nn.InstanceNorm1d(dim_in, affine=True)
307
+ self.norm2 = nn.InstanceNorm1d(dim_in, affine=True)
308
+ if self.learned_sc:
309
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
310
+
311
+ def downsample(self, x):
312
+ if self.downsample_type == 'none':
313
+ return x
314
+ else:
315
+ if x.shape[-1] % 2 != 0:
316
+ x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
317
+ return F.avg_pool1d(x, 2)
318
+
319
+ def _shortcut(self, x):
320
+ if self.learned_sc:
321
+ x = self.conv1x1(x)
322
+ x = self.downsample(x)
323
+ return x
324
+
325
+ def _residual(self, x):
326
+ if self.normalize:
327
+ x = self.norm1(x)
328
+ x = self.actv(x)
329
+ x = F.dropout(x, p=self.dropout_p, training=self.training)
330
+
331
+ x = self.conv1(x)
332
+ x = self.pool(x)
333
+ if self.normalize:
334
+ x = self.norm2(x)
335
+
336
+ x = self.actv(x)
337
+ x = F.dropout(x, p=self.dropout_p, training=self.training)
338
+
339
+ x = self.conv2(x)
340
+ return x
341
+
342
+ def forward(self, x):
343
+ x = self._shortcut(x) + self._residual(x)
344
+ return x / math.sqrt(2) # unit variance
345
+
346
+ class LayerNorm(nn.Module):
347
+ def __init__(self, channels, eps=1e-5):
348
+ super().__init__()
349
+ self.channels = channels
350
+ self.eps = eps
351
+
352
+ self.gamma = nn.Parameter(torch.ones(channels))
353
+ self.beta = nn.Parameter(torch.zeros(channels))
354
+
355
+ def forward(self, x):
356
+ x = x.transpose(1, -1)
357
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
358
+ return x.transpose(1, -1)
359
+
360
+
361
+ class TextEncoder(nn.Module):
362
+ def __init__(self, channels, kernel_size, depth, n_symbols, conv1d_kernel_size, qkv_proj_blocksize, num_heads, actv=nn.LeakyReLU(0.2)):
363
+ super().__init__()
364
+ self.embedding = nn.Embedding(n_symbols, channels)
365
+
366
+ self.prepare_projection=LinearNorm(channels,channels // 2)
367
+ self.post_projection=LinearNorm(channels // 2,channels)
368
+ self.cfg = xLSTMBlockStackConfig(
369
+ mlstm_block=mLSTMBlockConfig(
370
+ mlstm=mLSTMLayerConfig(
371
+ conv1d_kernel_size=conv1d_kernel_size, qkv_proj_blocksize=qkv_proj_blocksize, num_heads=num_heads
372
+ )
373
+ ),
374
+ context_length=channels,
375
+ num_blocks= num_heads * 2,
376
+ embedding_dim=channels // 2,
377
+
378
+ )
379
+
380
+
381
+
382
+ padding = (kernel_size - 1) // 2
383
+ self.cnn = nn.ModuleList()
384
+ for _ in range(depth):
385
+ self.cnn.append(nn.Sequential(
386
+
387
+ weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
388
+ LayerNorm(channels),
389
+ actv,
390
+ nn.Dropout(0.2),
391
+ ))
392
+ # self.cnn = nn.Sequential(*self.cnn)
393
+
394
+
395
+ self.lstm = xLSTMBlockStack(self.cfg)
396
+ def forward(self, x, input_lengths, m):
397
+
398
+ x = self.embedding(x) # [B, T, emb]
399
+
400
+
401
+ x = x.transpose(1, 2) # [B, emb, T]
402
+ m = m.to(input_lengths.device).unsqueeze(1)
403
+ x.masked_fill_(m, 0.0)
404
+
405
+ for c in self.cnn:
406
+ x = c(x)
407
+ x.masked_fill_(m, 0.0)
408
+
409
+ x = x.transpose(1, 2) # [B, T, chn]
410
+
411
+
412
+ input_lengths = input_lengths.cpu().numpy()
413
+
414
+
415
+
416
+ x = self.prepare_projection(x)
417
+
418
+ # x = nn.utils.rnn.pack_padded_sequence(
419
+ # x, input_lengths, batch_first=True, enforce_sorted=False)
420
+
421
+ # self.lstm.flatten_parameters()
422
+ x = self.lstm(x)
423
+
424
+ x = self.post_projection(x)
425
+ # x, _ = nn.utils.rnn.pad_packed_sequence(
426
+ # x, batch_first=True)
427
+
428
+ x = x.transpose(-1, -2)
429
+ # x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
430
+
431
+ # x_pad[:, :, :x.shape[-1]] = x
432
+ # x = x_pad.to(x.device)
433
+
434
+ x.masked_fill_(m, 0.0)
435
+
436
+ return x
437
+
438
+ def inference(self, x):
439
+ x = self.embedding(x)
440
+ x = x.transpose(1, 2)
441
+ x = self.cnn(x)
442
+ x = x.transpose(1, 2)
443
+ # self.lstm.flatten_parameters()
444
+ x = self.lstm(x)
445
+ return x
446
+
447
+ def length_to_mask(self, lengths):
448
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
449
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
450
+ return mask
451
+
452
+
453
+
454
+ class AdaIN1d(nn.Module):
455
+ def __init__(self, style_dim, num_features):
456
+ super().__init__()
457
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
458
+ self.fc = nn.Linear(style_dim, num_features*2)
459
+
460
+ def forward(self, x, s):
461
+ h = self.fc(s)
462
+
463
+ h = h.view(h.size(0), h.size(1), 1)
464
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
465
+ return (1 + gamma) * self.norm(x) + beta
466
+
467
+ class UpSample1d(nn.Module):
468
+ def __init__(self, layer_type):
469
+ super().__init__()
470
+ self.layer_type = layer_type
471
+
472
+ def forward(self, x):
473
+ if self.layer_type == 'none':
474
+ return x
475
+ else:
476
+ return F.interpolate(x, scale_factor=2, mode='nearest')
477
+
478
+ class AdainResBlk1d(nn.Module):
479
+ def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
480
+ upsample='none', dropout_p=0.0):
481
+ super().__init__()
482
+ self.actv = actv
483
+ self.upsample_type = upsample
484
+ self.upsample = UpSample1d(upsample)
485
+ self.learned_sc = dim_in != dim_out
486
+ self._build_weights(dim_in, dim_out, style_dim)
487
+ self.dropout = nn.Dropout(dropout_p)
488
+
489
+ if upsample == 'none':
490
+ self.pool = nn.Identity()
491
+ else:
492
+ self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
493
+
494
+
495
+ def _build_weights(self, dim_in, dim_out, style_dim):
496
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
497
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
498
+ self.norm1 = AdaIN1d(style_dim, dim_in)
499
+ self.norm2 = AdaIN1d(style_dim, dim_out)
500
+ if self.learned_sc:
501
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
502
+
503
+ def _shortcut(self, x):
504
+ x = self.upsample(x)
505
+ if self.learned_sc:
506
+ x = self.conv1x1(x)
507
+ return x
508
+
509
+ def _residual(self, x, s):
510
+ x = self.norm1(x, s)
511
+ x = self.actv(x)
512
+ x = self.pool(x)
513
+ x = self.conv1(self.dropout(x))
514
+ x = self.norm2(x, s)
515
+ x = self.actv(x)
516
+ x = self.conv2(self.dropout(x))
517
+ return x
518
+
519
+ def forward(self, x, s):
520
+ out = self._residual(x, s)
521
+ out = (out + self._shortcut(x)) / math.sqrt(2)
522
+ return out
523
+
524
+ class AdaLayerNorm(nn.Module):
525
+ def __init__(self, style_dim, channels, eps=1e-5):
526
+ super().__init__()
527
+ self.channels = channels
528
+ self.eps = eps
529
+
530
+ self.fc = nn.Linear(style_dim, channels*2)
531
+
532
+ def forward(self, x, s):
533
+ x = x.transpose(-1, -2)
534
+ x = x.transpose(1, -1)
535
+
536
+ h = self.fc(s)
537
+ h = h.view(h.size(0), h.size(1), 1)
538
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
539
+ gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
540
+
541
+
542
+ x = F.layer_norm(x, (self.channels,), eps=self.eps)
543
+ x = (1 + gamma) * x + beta
544
+ return x.transpose(1, -1).transpose(-1, -2)
545
+
546
+ class ProsodyPredictor(nn.Module):
547
+
548
+ def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
549
+ super().__init__()
550
+
551
+ self.cfg = xLSTMBlockStackConfig(
552
+ mlstm_block=mLSTMBlockConfig(
553
+ mlstm=mLSTMLayerConfig(
554
+ conv1d_kernel_size=4, qkv_proj_blocksize=4, num_heads=4
555
+ )
556
+ ),
557
+ context_length=d_hid,
558
+ num_blocks=8,
559
+ embedding_dim=d_hid + style_dim,
560
+
561
+
562
+ )
563
+
564
+ self.cfg_pred = xLSTMBlockStackConfig(
565
+ mlstm_block=mLSTMBlockConfig(
566
+ mlstm=mLSTMLayerConfig(
567
+ conv1d_kernel_size=4, qkv_proj_blocksize=4, num_heads=4
568
+ )
569
+ ),
570
+
571
+ context_length=4096,
572
+ num_blocks=8,
573
+ embedding_dim=d_hid + style_dim,
574
+
575
+ )
576
+
577
+
578
+ # self.shared = Hopfield(input_size=d_hid + style_dim,
579
+ # hidden_size=d_hid // 2,
580
+ # num_heads=32,
581
+ # # scaling=.75,
582
+ # add_zero_association=True,
583
+ # batch_first=True)
584
+
585
+ # if you want to use hopfield, just comment out the block above, then hash the "self.shared below"
586
+
587
+
588
+
589
+
590
+ self.text_encoder = DurationEncoder(sty_dim=style_dim,
591
+ d_model=d_hid,
592
+ nlayers=nlayers,
593
+ dropout=dropout)
594
+
595
+
596
+ self.lstm = xLSTMBlockStack(self.cfg)
597
+
598
+ self.prepare_projection = nn.Linear(d_hid + style_dim, d_hid)
599
+
600
+ self.duration_proj = LinearNorm(d_hid , max_dur)
601
+
602
+ self.shared = xLSTMBlockStack(self.cfg_pred)
603
+
604
+ # self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
605
+
606
+ self.F0 = nn.ModuleList()
607
+ self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
608
+ self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
609
+ self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
610
+
611
+ self.N = nn.ModuleList()
612
+ self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
613
+ self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
614
+ self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
615
+
616
+ self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
617
+ self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
618
+
619
+
620
+ def forward(self, texts, style, text_lengths=None, alignment=None, m=None, f0=False):
621
+
622
+ if f0:
623
+ x, s = texts, style
624
+ # x = self.prepare_projection(x.transpose(-1, -2))
625
+ # x = self.shared(x)
626
+
627
+ x = self.shared(x.transpose(-1, -2))
628
+ x = self.prepare_projection(x)
629
+
630
+ F0 = x.transpose(-1, -2)
631
+ for block in self.F0:
632
+ F0 = block(F0, s)
633
+ F0 = self.F0_proj(F0)
634
+
635
+ N = x.transpose(-1, -2)
636
+ for block in self.N:
637
+ N = block(N, s)
638
+ N = self.N_proj(N)
639
+
640
+ return F0.squeeze(1), N.squeeze(1)
641
+
642
+ else:
643
+ # Problem is here
644
+ d = self.text_encoder(texts, style, text_lengths, m)
645
+
646
+ batch_size = d.shape[0]
647
+ text_size = d.shape[1]
648
+
649
+ # predict duration
650
+
651
+
652
+ input_lengths = text_lengths.cpu().numpy()
653
+
654
+
655
+ # x = nn.utils.rnn.pack_padded_sequence(
656
+ # d, input_lengths, batch_first=True, enforce_sorted=False)
657
+
658
+ x = d # this dude can handle variable seq len so no need for padding
659
+
660
+
661
+ m = m.to(text_lengths.device).unsqueeze(1)
662
+
663
+ # self.lstm.flatten_parameters()
664
+ x = self.lstm(x) # no longer using lstm
665
+ x = self.prepare_projection(x)
666
+
667
+
668
+ # x, _ = nn.utils.rnn.pad_packed_sequence(
669
+ # x, batch_first=True)
670
+
671
+ # x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
672
+
673
+ # x_pad[:, :x.shape[1], :] = x
674
+ # x = x_pad.to(x.device)
675
+
676
+ x = x.transpose(-1,-2)
677
+ x = x.permute(0,2,1)
678
+ duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training))
679
+
680
+
681
+
682
+ en = (d.transpose(-1, -2) @ alignment)
683
+
684
+ return duration.squeeze(-1), en
685
+
686
+
687
+ def F0Ntrain(self, x, s):
688
+
689
+
690
+ # x = self.prepare_projection(x.transpose(-1, -2))
691
+ # x = self.shared(x)
692
+
693
+ ####
694
+ x = self.shared(x.transpose(-1, -2))
695
+ x = self.prepare_projection(x)
696
+
697
+
698
+
699
+ F0 = x.transpose(-1, -2)
700
+
701
+ for block in self.F0:
702
+ F0 = block(F0, s)
703
+ F0 = self.F0_proj(F0)
704
+
705
+ N = x.transpose(-1, -2)
706
+ for block in self.N:
707
+ N = block(N, s)
708
+ N = self.N_proj(N)
709
+
710
+ return F0.squeeze(1), N.squeeze(1)
711
+
712
+ def length_to_mask(self, lengths):
713
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
714
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
715
+ return mask
716
+
717
+ class DurationEncoder(nn.Module):
718
+
719
+ def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
720
+ super().__init__()
721
+ self.lstms = nn.ModuleList()
722
+ for _ in range(nlayers):
723
+ self.lstms.append(nn.LSTM(d_model + sty_dim,
724
+ d_model // 2,
725
+ num_layers=1,
726
+ batch_first=True,
727
+ bidirectional=True,
728
+ dropout=dropout))
729
+ self.lstms.append(AdaLayerNorm(sty_dim, d_model))
730
+
731
+
732
+ self.dropout = dropout
733
+ self.d_model = d_model
734
+ self.sty_dim = sty_dim
735
+
736
+ def forward(self, x, style, text_lengths, m):
737
+ masks = m.to(text_lengths.device)
738
+
739
+ x = x.permute(2, 0, 1)
740
+ s = style.expand(x.shape[0], x.shape[1], -1)
741
+ x = torch.cat([x, s], axis=-1)
742
+ x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
743
+
744
+ x = x.transpose(0, 1)
745
+ input_lengths = text_lengths.cpu().numpy()
746
+ x = x.transpose(-1, -2)
747
+
748
+ for block in self.lstms:
749
+ if isinstance(block, AdaLayerNorm):
750
+ x = block(x.transpose(-1, -2), style).transpose(-1, -2)
751
+ x = torch.cat([x, s.permute(1, -1, 0)], axis=1)
752
+ x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
753
+ else:
754
+ x = x.transpose(-1, -2)
755
+ x = nn.utils.rnn.pack_padded_sequence(
756
+ x, input_lengths, batch_first=True, enforce_sorted=False)
757
+ block.flatten_parameters()
758
+ x, _ = block(x)
759
+ x, _ = nn.utils.rnn.pad_packed_sequence(
760
+ x, batch_first=True)
761
+ x = F.dropout(x, p=self.dropout, training=self.training)
762
+ x = x.transpose(-1, -2)
763
+
764
+ x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
765
+
766
+ x_pad[:, :, :x.shape[-1]] = x
767
+ x = x_pad.to(x.device)
768
+
769
+ return x.transpose(-1, -2)
770
+
771
+ def inference(self, x, style):
772
+ x = self.embedding(x.transpose(-1, -2)) * math.sqrt(self.d_model)
773
+ style = style.expand(x.shape[0], x.shape[1], -1)
774
+ x = torch.cat([x, style], axis=-1)
775
+ src = self.pos_encoder(x)
776
+ output = self.transformer_encoder(src).transpose(0, 1)
777
+ return output
778
+
779
+ def length_to_mask(self, lengths):
780
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
781
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
782
+ return mask
783
+
784
+ def inference(self, x, style):
785
+ x = self.embedding(x.transpose(-1, -2)) * math.sqrt(self.d_model)
786
+ style = style.expand(x.shape[0], x.shape[1], -1)
787
+ x = torch.cat([x, style], axis=-1)
788
+ src = self.pos_encoder(x)
789
+ output = self.transformer_encoder(src).transpose(0, 1)
790
+ return output
791
+
792
+ def length_to_mask(self, lengths):
793
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
794
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
795
+ return mask
796
+
797
+
798
+
799
+ def load_F0_models(path):
800
+ # load F0 model
801
+
802
+ F0_model = JDCNet(num_class=1, seq_len=192)
803
+ params = torch.load(path, map_location='cpu')['net']
804
+ F0_model.load_state_dict(params)
805
+ _ = F0_model.train()
806
+
807
+ return F0_model
808
+
809
+
810
+ def load_KotoDama_Prompter(path, cfg=None, model_ckpt="ku-nlp/deberta-v3-base-japanese"):
811
+
812
+ cfg = AutoConfig.from_pretrained(model_ckpt)
813
+ cfg.update({
814
+ "num_labels": 256
815
+ })
816
+
817
+ kotodama_prompt = KotoDama_Prompt.from_pretrained(path, config=cfg)
818
+
819
+ return kotodama_prompt
820
+
821
+
822
+ def load_KotoDama_TextSampler(path, cfg=None, model_ckpt="line-corporation/line-distilbert-base-japanese"):
823
+
824
+ cfg = AutoConfig.from_pretrained(model_ckpt)
825
+ cfg.update({
826
+ "num_labels": 256
827
+ })
828
+
829
+ kotodama_sampler = KotoDama_Text.from_pretrained(path, config=cfg)
830
+
831
+ return kotodama_sampler
832
+
833
+
834
+
835
+ # def reconstruction_head(path): # didn't make a lot of difference, disabling it for now until i find / train a better net
836
+
837
+ # recon_model = DiffusionUpsampler(
838
+
839
+ # net_t=UNetV0,
840
+ # upsample_factor=2,
841
+ # in_channels=1,
842
+ # channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024],
843
+ # factors=[1, 4, 4, 4, 2, 2, 2, 2, 2],
844
+ # items=[1, 2, 2, 2, 2, 2, 2, 4, 4],
845
+ # diffusion_t=VDiffusion,
846
+ # sampler_t=VSampler,
847
+ # )
848
+
849
+ # checkpoint = torch.load(path, map_location='cpu')
850
+
851
+ # new_state_dict = {}
852
+ # for key, value in checkpoint['model_state_dict'].items():
853
+ # new_key = key.replace('module.', '') # Remove 'module.' prefix
854
+ # new_state_dict[new_key] = value
855
+
856
+ # recon_model.load_state_dict(new_state_dict)
857
+ # recon_model.eval()
858
+
859
+ # recon_model = recon_model.to('cuda')
860
+
861
+ # return recon_model
862
+
863
+
864
+
865
+ def load_ASR_models(ASR_MODEL_PATH, ASR_MODEL_CONFIG):
866
+ # load ASR model
867
+ def _load_config(path):
868
+ with open(path) as f:
869
+ config = yaml.safe_load(f)
870
+ model_config = config['model_params']
871
+ return model_config
872
+
873
+ def _load_model(model_config, model_path):
874
+ model = ASRCNN(**model_config)
875
+ params = torch.load(model_path, map_location='cpu')['model']
876
+ model.load_state_dict(params)
877
+ return model
878
+
879
+ asr_model_config = _load_config(ASR_MODEL_CONFIG)
880
+ asr_model = _load_model(asr_model_config, ASR_MODEL_PATH)
881
+ _ = asr_model.train()
882
+
883
+ return asr_model
884
+
885
+ def build_model(args, text_aligner, pitch_extractor, bert):
886
+ assert args.decoder.type in ['istftnet', 'hifigan'], 'Decoder type unknown'
887
+
888
+ if args.decoder.type == "istftnet":
889
+ from Modules.istftnet import Decoder
890
+ decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
891
+ resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
892
+ upsample_rates = args.decoder.upsample_rates,
893
+ upsample_initial_channel=args.decoder.upsample_initial_channel,
894
+ resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
895
+ upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
896
+ gen_istft_n_fft=args.decoder.gen_istft_n_fft, gen_istft_hop_size=args.decoder.gen_istft_hop_size)
897
+ else:
898
+ from Modules.hifigan import Decoder
899
+ decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
900
+ resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
901
+ upsample_rates = args.decoder.upsample_rates,
902
+ upsample_initial_channel=args.decoder.upsample_initial_channel,
903
+ resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
904
+ upsample_kernel_sizes=args.decoder.upsample_kernel_sizes)
905
+
906
+ text_encoder = TextEncoder(channels=args.hidden_dim,
907
+ kernel_size=5,
908
+ depth=args.n_layer,
909
+ n_symbols=args.n_token,
910
+ conv1d_kernel_size = 4,
911
+ qkv_proj_blocksize = 4,
912
+ num_heads = 8)
913
+
914
+
915
+ predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout)
916
+
917
+
918
+ style_encoder = StyleEncoder(mel_dim=args.n_mels, hidden_dim=args.hidden_dim, style_dim=args.style_dim, num_heads=8, num_layers=args.n_layer_conformer)
919
+ predictor_encoder = StyleEncoder(mel_dim=args.n_mels, hidden_dim=args.hidden_dim, style_dim=args.style_dim, num_heads=8, num_layers=args.n_layer_conformer)
920
+
921
+
922
+ # define diffusion model
923
+ if args.multispeaker:
924
+ transformer = StyleTransformer1d(channels=args.style_dim*2,
925
+ context_embedding_features=bert.config.hidden_size,
926
+ context_features=args.style_dim*2,
927
+ **args.diffusion.transformer)
928
+ else:
929
+ transformer = Transformer1d(channels=args.style_dim*2,
930
+ context_embedding_features=bert.config.hidden_size,
931
+ **args.diffusion.transformer)
932
+
933
+ diffusion = AudioDiffusionConditional(
934
+ in_channels=1,
935
+ embedding_max_length=bert.config.max_position_embeddings,
936
+ embedding_features=bert.config.hidden_size,
937
+ embedding_mask_proba=args.diffusion.embedding_mask_proba, # Conditional dropout of batch elements,
938
+ channels=args.style_dim*2,
939
+ context_features=args.style_dim*2,
940
+ )
941
+
942
+ diffusion.diffusion = KDiffusion(
943
+ net=diffusion.unet,
944
+ sigma_distribution=LogNormalDistribution(mean = args.diffusion.dist.mean, std = args.diffusion.dist.std),
945
+ sigma_data=args.diffusion.dist.sigma_data, # a placeholder, will be changed dynamically when start training diffusion model
946
+ dynamic_threshold=0.0
947
+ )
948
+ diffusion.diffusion.net = transformer
949
+ diffusion.unet = transformer
950
+
951
+
952
+ nets = Munch(
953
+
954
+ bert=bert,
955
+ bert_encoder=nn.Linear(bert.config.hidden_size, args.hidden_dim),
956
+
957
+ predictor=predictor,
958
+ decoder=decoder,
959
+ text_encoder=text_encoder,
960
+
961
+ predictor_encoder=predictor_encoder,
962
+ style_encoder=style_encoder,
963
+ diffusion=diffusion,
964
+
965
+ text_aligner = text_aligner,
966
+ pitch_extractor = pitch_extractor,
967
+
968
+ mpd = MultiPeriodDiscriminator(),
969
+ msd = MultiResSpecDiscriminator(),
970
+
971
+ # slm discriminator head
972
+ wd = WavLMDiscriminator(args.slm.hidden, args.slm.nlayers, args.slm.initial_channel),
973
+
974
+ # KotoDama_Prompt = KotoDama_Prompt,
975
+ # KotoDama_Text = KotoDama_Text,
976
+
977
+ # recon_diff = recon_diff,
978
+
979
+ )
980
+
981
+ return nets
982
+
983
+
984
+ def load_checkpoint(model, optimizer, path, load_only_params=False, ignore_modules=[]):
985
+ state = torch.load(path, map_location='cpu')
986
+ params = state['net']
987
+ print('loading the ckpt using the correct function.')
988
+
989
+ for key in model:
990
+ if key in params and key not in ignore_modules:
991
+ try:
992
+ model[key].load_state_dict(params[key], strict=True)
993
+ except:
994
+ from collections import OrderedDict
995
+ state_dict = params[key]
996
+ new_state_dict = OrderedDict()
997
+ print(f'{key} key length: {len(model[key].state_dict().keys())}, state_dict key length: {len(state_dict.keys())}')
998
+ for (k_m, v_m), (k_c, v_c) in zip(model[key].state_dict().items(), state_dict.items()):
999
+ new_state_dict[k_m] = v_c
1000
+ model[key].load_state_dict(new_state_dict, strict=True)
1001
+ print('%s loaded' % key)
1002
+
1003
+ if not load_only_params:
1004
+ epoch = state["epoch"]
1005
+ iters = state["iters"]
1006
+ optimizer.load_state_dict(state["optimizer"])
1007
+ else:
1008
+ epoch = 0
1009
+ iters = 0
1010
+
1011
+ return model, optimizer, epoch, iters