Dionyssos commited on
Commit
c7362aa
·
1 Parent(s): 780c8d5

oscillate vits duration

Browse files
Files changed (16) hide show
  1. Modules/hifigan.py +29 -82
  2. Modules/vits/models.py +121 -862
  3. README.md +1 -1
  4. Utils/JDC/__init__.py +0 -1
  5. Utils/JDC/bst.pth +0 -3
  6. Utils/JDC/model.py +0 -190
  7. Utils/PLBERT/util.py +1 -1
  8. Utils/text_utils.py +101 -61
  9. api.py +6 -16
  10. audiobook.py +14 -27
  11. demo.py +15 -43
  12. live_demo.py +5 -7
  13. models.py +167 -251
  14. msinference.py +67 -139
  15. requirements.txt +1 -1
  16. tts.py +2 -2
Modules/hifigan.py CHANGED
@@ -2,12 +2,12 @@ import torch
2
  import torch.nn.functional as F
3
  import torch.nn as nn
4
  from torch.nn import Conv1d, ConvTranspose1d
5
- from torch.nn.utils import weight_norm, remove_weight_norm
6
  import math
7
  import numpy as np
8
 
9
 
10
- LRELU_SLOPE = 0.1
11
 
12
 
13
  def get_padding(kernel_size, dilation=1):
@@ -93,80 +93,38 @@ class AdaINResBlock1(torch.nn.Module):
93
  x = xt + x
94
  return x
95
 
96
- def remove_weight_norm(self):
97
- for l in self.convs1:
98
- remove_weight_norm(l)
99
- for l in self.convs2:
100
- remove_weight_norm(l)
101
 
 
102
 
103
- class SineGen(torch.nn.Module):
104
 
105
- def __init__(self,
106
- samp_rate=24000,
107
- upsample_scale=300,
108
- harmonic_num=8, # HARDCODED due to nn.Linear() of SourceModuleHnNSF
109
- voiced_threshold=10):
110
-
111
- super(SineGen, self).__init__()
112
- self.harmonic_num = harmonic_num
113
- self.sampling_rate = samp_rate
114
- self.voiced_threshold = voiced_threshold
115
- self.upsample_scale = upsample_scale
116
-
117
- def _f02sine(self, f0_values):
118
- # --
119
- # 134 HIFI
120
- # torch.Size([1, 145200, 9])
121
- # torch.Size([1, 145200, 9]) torch.Size([1, 145200, 9]) HIFi
122
 
 
 
 
 
 
123
  # modulo of negative f0_values => -21 % 10 = 9 as -3*10 + 9 = 21 NOTICE THAT f0_values IS SIGNED
124
- rad_values = (f0_values / self.sampling_rate) % 1
125
-
126
- rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
 
127
  scale_factor=1/self.upsample_scale,
128
- mode="linear").transpose(1, 2)
129
 
130
  # 1.89 sounds also nice has woofer at punctuation
131
  phase = torch.cumsum(rad_values, dim=1) * 1.84 * np.pi
132
- phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
133
- scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
134
- sines = torch.sin(phase)
135
- return sines
136
-
137
- def forward(self, f0):
138
- # print('____________________________________\nF0 F0\n', f0.abs().mean(), f0.mean(), f0.max(), f0.min()) # male voices sound less muffed via higher scaler in sine_waves
139
- # f0 is already full length - [1, 142600, 1]
140
-
141
- amplif = .0104 if f0.abs().mean() < 100 else .009 # vary amplif based on f0.abs().mean() - voice sensitive
142
-
143
- fn = torch.multiply(f0, torch.FloatTensor(
144
- [[range(1, self.harmonic_num + 2)]]).to(f0.device)) # [1, 145200, 9]
145
-
146
- # .007 # very important effect DEFAULT=0.1 very sensitive to speaker - heuristically
147
- sine_waves = self._f02sine(fn) * amplif # .009
148
-
149
- uv = (f0 > self.voiced_threshold).type(torch.float32)
150
-
151
- return sine_waves * uv
152
-
153
-
154
- class SourceModuleHnNSF(torch.nn.Module):
155
-
156
- def __init__(self, harmonic_num=8):
157
-
158
- super(SourceModuleHnNSF, self).__init__()
159
- self.l_sin_gen = SineGen()
160
- # harmonic=8 is hard fixed due to this nn.Linear()
161
- self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
162
- self.l_tanh = torch.nn.Tanh()
163
-
164
- def forward(self, x):
165
- # print(' HNnSF', x.shape) # why this is [1, 300, 1, 535800]
166
- sine_wavs = self.l_sin_gen(x)
167
- # This linear sums all 9 harmonics
168
- sine_merge = self.l_tanh(self.l_linear(sine_wavs))
169
- return sine_merge
170
 
171
 
172
  class Generator(torch.nn.Module):
@@ -239,7 +197,7 @@ class Generator(torch.nn.Module):
239
  x_source = self.noise_res[i](x_source, s)
240
 
241
  x = self.ups[i](x)
242
- # print(x.min(), x.max(), x_source.min(), x_source.max())
243
  x = x + x_source
244
 
245
  xs = None
@@ -250,22 +208,12 @@ class Generator(torch.nn.Module):
250
  else:
251
  xs += self.resblocks[i*self.num_kernels+j](x, s)
252
  x = xs / self.num_kernels
253
- x = x + (1 / self.alphas[i+1]) * (torch.sin(self.alphas[i+1] * x) ** 2)
254
  x = self.conv_post(x)
255
  x = torch.tanh(x)
256
 
257
  return x
258
 
259
- def remove_weight_norm(self):
260
- print('Removing weight norm...')
261
- for l in self.ups:
262
- remove_weight_norm(l)
263
- for l in self.resblocks:
264
- l.remove_weight_norm()
265
- remove_weight_norm(self.conv_pre)
266
- remove_weight_norm(self.conv_post)
267
-
268
-
269
  class AdainResBlk1d(nn.Module):
270
 
271
  # also used in ProsodyPredictor()
@@ -324,7 +272,7 @@ class UpSample1d(nn.Module):
324
  if self.layer_type == 'none':
325
  return x
326
  else:
327
- return F.interpolate(x, scale_factor=2, mode='nearest')
328
 
329
 
330
  class Decoder(nn.Module):
@@ -361,11 +309,10 @@ class Decoder(nn.Module):
361
 
362
  def forward(self, asr=None, F0_curve=None, N=None, s=None):
363
 
364
- # print('p', asr.shape, F0_curve.shape, N.shape)
365
  F0 = self.F0_conv(F0_curve)
366
  N = self.N_conv(N)
367
 
368
- # print(asr.shape, F0.shape, N.shape, 'TF')
369
 
370
  x = torch.cat([asr, F0, N], axis=1)
371
 
 
2
  import torch.nn.functional as F
3
  import torch.nn as nn
4
  from torch.nn import Conv1d, ConvTranspose1d
5
+ from torch.nn.utils.parametrizations import weight_norm
6
  import math
7
  import numpy as np
8
 
9
 
10
+
11
 
12
 
13
  def get_padding(kernel_size, dilation=1):
 
93
  x = xt + x
94
  return x
95
 
 
 
 
 
 
96
 
97
+ class SourceModuleHnNSF(torch.nn.Module):
98
 
99
+ def __init__(self):
100
 
101
+ super().__init__()
102
+ self.harmonic_num = 8
103
+ self.l_linear = torch.nn.Linear(self.harmonic_num + 1, 1)
104
+ self.upsample_scale = 300
105
+
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
+ def forward(self, x):
108
+ # --
109
+ x = torch.multiply(x, torch.FloatTensor(
110
+ [[range(1, self.harmonic_num + 2)]]).to(x.device)) # [1, 145200, 9]
111
+
112
  # modulo of negative f0_values => -21 % 10 = 9 as -3*10 + 9 = 21 NOTICE THAT f0_values IS SIGNED
113
+ rad_values = x / 25647 #).clamp(0, 1)
114
+ # rad_values = torch.where(torch.logical_or(rad_values < 0, rad_values > 1), 0.5, rad_values)
115
+ rad_values = rad_values % 1 # % of neg values
116
+ rad_values = F.interpolate(rad_values.transpose(1, 2),
117
  scale_factor=1/self.upsample_scale,
118
+ mode='linear').transpose(1, 2)
119
 
120
  # 1.89 sounds also nice has woofer at punctuation
121
  phase = torch.cumsum(rad_values, dim=1) * 1.84 * np.pi
122
+ phase = F.interpolate(phase.transpose(1, 2) * self.upsample_scale,
123
+ scale_factor=self.upsample_scale, mode='linear').transpose(1, 2)
124
+ x = .009 * phase.sin()
125
+ # --
126
+ x = self.l_linear(x).tanh()
127
+ return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
 
130
  class Generator(torch.nn.Module):
 
197
  x_source = self.noise_res[i](x_source, s)
198
 
199
  x = self.ups[i](x)
200
+
201
  x = x + x_source
202
 
203
  xs = None
 
208
  else:
209
  xs += self.resblocks[i*self.num_kernels+j](x, s)
210
  x = xs / self.num_kernels
211
+ # x = x + (1 / self.alphas[i+1]) * (torch.sin(self.alphas[i+1] * x) ** 2) # noisy
212
  x = self.conv_post(x)
213
  x = torch.tanh(x)
214
 
215
  return x
216
 
 
 
 
 
 
 
 
 
 
 
217
  class AdainResBlk1d(nn.Module):
218
 
219
  # also used in ProsodyPredictor()
 
272
  if self.layer_type == 'none':
273
  return x
274
  else:
275
+ return F.interpolate(x, scale_factor=2, mode='nearest-exact')
276
 
277
 
278
  class Decoder(nn.Module):
 
309
 
310
  def forward(self, asr=None, F0_curve=None, N=None, s=None):
311
 
312
+
313
  F0 = self.F0_conv(F0_curve)
314
  N = self.N_conv(N)
315
 
 
316
 
317
  x = torch.cat([asr, F0, N], axis=1)
318
 
Modules/vits/models.py CHANGED
@@ -1,15 +1,30 @@
1
  import math
2
  from dataclasses import dataclass
3
- from typing import Any, Optional, Tuple, Union
4
  import numpy as np
5
  import torch
6
- import torch.utils.checkpoint
7
  from torch import nn
8
-
9
- from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
10
  from transformers.modeling_outputs import BaseModelOutput, ModelOutput
11
  from transformers.modeling_utils import PreTrainedModel
12
  from transformers.configuration_utils import PretrainedConfig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  class VitsConfig(PretrainedConfig):
15
 
@@ -74,11 +89,9 @@ class VitsConfig(PretrainedConfig):
74
  self.ffn_kernel_size = ffn_kernel_size
75
  self.flow_size = flow_size
76
  self.spectrogram_bins = spectrogram_bins
77
-
78
-
79
  self.initializer_range = initializer_range
80
  self.layer_norm_eps = layer_norm_eps
81
- self.use_stochastic_duration_prediction = use_stochastic_duration_prediction
82
  self.num_speakers = num_speakers
83
  self.speaker_embedding_size = speaker_embedding_size
84
  self.upsample_initial_channel = upsample_initial_channel
@@ -92,7 +105,6 @@ class VitsConfig(PretrainedConfig):
92
  self.duration_predictor_flow_bins = duration_predictor_flow_bins
93
  self.duration_predictor_tail_bound = duration_predictor_tail_bound
94
  self.duration_predictor_kernel_size = duration_predictor_kernel_size
95
-
96
  self.duration_predictor_num_flows = duration_predictor_num_flows
97
  self.duration_predictor_filter_channels = duration_predictor_filter_channels
98
  self.prior_encoder_num_flows = prior_encoder_num_flows
@@ -100,8 +112,6 @@ class VitsConfig(PretrainedConfig):
100
  self.posterior_encoder_num_wavenet_layers = posterior_encoder_num_wavenet_layers
101
  self.wavenet_kernel_size = wavenet_kernel_size
102
  self.wavenet_dilation_rate = wavenet_dilation_rate
103
-
104
-
105
  self.noise_scale = noise_scale
106
  self.noise_scale_duration = noise_scale_duration
107
  self.sampling_rate = sampling_rate
@@ -121,183 +131,9 @@ class VitsTextEncoderOutput(ModelOutput):
121
  last_hidden_state: torch.FloatTensor = None
122
  prior_means: torch.FloatTensor = None
123
  prior_log_variances: torch.FloatTensor = None
124
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
125
- attentions: Optional[Tuple[torch.FloatTensor]] = None
126
-
127
- def _unconstrained_rational_quadratic_spline(
128
- inputs,
129
- unnormalized_widths,
130
- unnormalized_heights,
131
- unnormalized_derivatives,
132
- reverse=False,
133
- tail_bound=5.0,
134
- min_bin_width=1e-3,
135
- min_bin_height=1e-3,
136
- min_derivative=1e-3,
137
- ):
138
- inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
139
- outside_interval_mask = ~inside_interval_mask
140
-
141
- outputs = torch.zeros_like(inputs)
142
-
143
- constant = np.log(np.exp(1 - min_derivative) - 1)
144
-
145
- unnormalized_derivatives = nn.functional.pad(unnormalized_derivatives, pad=(1, 1))
146
- unnormalized_derivatives[..., 0] = constant
147
- unnormalized_derivatives[..., -1] = constant
148
-
149
- outputs[outside_interval_mask] = inputs[outside_interval_mask]
150
-
151
-
152
- outputs[inside_interval_mask] = _rational_quadratic_spline(
153
- inputs=inputs[inside_interval_mask],
154
- unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
155
- unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
156
- unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
157
- reverse=reverse,
158
- tail_bound=tail_bound,
159
- min_bin_width=min_bin_width,
160
- min_bin_height=min_bin_height,
161
- min_derivative=min_derivative,
162
- )
163
- return outputs
164
-
165
-
166
- def _rational_quadratic_spline(
167
- inputs,
168
- unnormalized_widths,
169
- unnormalized_heights,
170
- unnormalized_derivatives,
171
- reverse,
172
- tail_bound,
173
- min_bin_width,
174
- min_bin_height,
175
- min_derivative,
176
- ):
177
- """
178
- This transformation represents a monotonically increasing piecewise rational quadratic function. Unlike the
179
- function `_unconstrained_rational_quadratic_spline`, the function behaves the same across the `tail_bound`.
180
-
181
- Args:
182
- inputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
183
- Second half of the hidden-states input to the Vits convolutional flow module.
184
- unnormalized_widths (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
185
- First `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
186
- layer in the convolutional flow module
187
- unnormalized_heights (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
188
- Second `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
189
- layer in the convolutional flow module
190
- unnormalized_derivatives (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
191
- Third `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
192
- layer in the convolutional flow module
193
- reverse (`bool`):
194
- Whether the model is being run in reverse mode.
195
- tail_bound (`float`):
196
- Upper and lower limit bound for the rational quadratic function. Outside of this `tail_bound`, the
197
- transform behaves as an identity function.
198
- min_bin_width (`float`):
199
- Minimum bin value across the width dimension for the piecewise rational quadratic function.
200
- min_bin_height (`float`):
201
- Minimum bin value across the height dimension for the piecewise rational quadratic function.
202
- min_derivative (`float`):
203
- Minimum bin value across the derivatives for the piecewise rational quadratic function.
204
- Returns:
205
- outputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
206
- Hidden-states as transformed by the piecewise rational quadratic function.
207
- log_abs_det (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
208
- Logarithm of the absolute value of the determinants corresponding to the `outputs`.
209
- """
210
- upper_bound = tail_bound
211
- lower_bound = -tail_bound
212
-
213
- if torch.min(inputs) < lower_bound or torch.max(inputs) > upper_bound:
214
- raise ValueError("Input to a transform is not within its domain")
215
-
216
- num_bins = unnormalized_widths.shape[-1]
217
-
218
- if min_bin_width * num_bins > 1.0:
219
- raise ValueError(f"Minimal bin width {min_bin_width} too large for the number of bins {num_bins}")
220
- if min_bin_height * num_bins > 1.0:
221
- raise ValueError(f"Minimal bin height {min_bin_height} too large for the number of bins {num_bins}")
222
-
223
- widths = nn.functional.softmax(unnormalized_widths, dim=-1)
224
- widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
225
- cumwidths = torch.cumsum(widths, dim=-1)
226
- cumwidths = nn.functional.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
227
- cumwidths = (upper_bound - lower_bound) * cumwidths + lower_bound
228
- cumwidths[..., 0] = lower_bound
229
- cumwidths[..., -1] = upper_bound
230
- widths = cumwidths[..., 1:] - cumwidths[..., :-1]
231
-
232
- derivatives = min_derivative + nn.functional.softplus(unnormalized_derivatives)
233
-
234
- heights = nn.functional.softmax(unnormalized_heights, dim=-1)
235
- heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
236
- cumheights = torch.cumsum(heights, dim=-1)
237
- cumheights = nn.functional.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
238
- cumheights = (upper_bound - lower_bound) * cumheights + lower_bound
239
- cumheights[..., 0] = lower_bound
240
- cumheights[..., -1] = upper_bound
241
- heights = cumheights[..., 1:] - cumheights[..., :-1]
242
-
243
- bin_locations = cumheights if reverse else cumwidths
244
- bin_locations[..., -1] += 1e-6
245
- bin_idx = torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
246
- bin_idx = bin_idx[..., None]
247
-
248
- input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
249
- input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
250
-
251
- input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
252
- delta = heights / widths
253
- input_delta = delta.gather(-1, bin_idx)[..., 0]
254
-
255
- input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
256
- input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
257
-
258
- input_heights = heights.gather(-1, bin_idx)[..., 0]
259
-
260
- intermediate1 = input_derivatives + input_derivatives_plus_one - 2 * input_delta
261
- if not reverse:
262
- raise ValueError
263
- # theta = (inputs - input_cumwidths) / input_bin_widths
264
- # theta_one_minus_theta = theta * (1 - theta)
265
-
266
- # numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta)
267
- # denominator = input_delta + intermediate1 * theta_one_minus_theta
268
- # outputs = input_cumheights + numerator / denominator
269
-
270
- # derivative_numerator = input_delta.pow(2) * (
271
- # input_derivatives_plus_one * theta.pow(2)
272
- # + 2 * input_delta * theta_one_minus_theta
273
- # + input_derivatives * (1 - theta).pow(2)
274
- # )
275
- # log_abs_det = torch.log(derivative_numerator) - 2 * torch.log(denominator)
276
- # return outputs, log_abs_det
277
- else:
278
- # find the roots of a quadratic equation
279
- intermediate2 = inputs - input_cumheights
280
- intermediate3 = intermediate2 * intermediate1
281
- a = input_heights * (input_delta - input_derivatives) + intermediate3
282
- b = input_heights * input_derivatives - intermediate3
283
- c = -input_delta * intermediate2
284
-
285
- discriminant = b.pow(2) - 4 * a * c
286
- if not (discriminant >= 0).all():
287
- raise RuntimeError(f"invalid discriminant {discriminant}")
288
-
289
- root = (2 * c) / (-b - torch.sqrt(discriminant))
290
- outputs = root * input_bin_widths + input_cumwidths
291
-
292
- # theta_one_minus_theta = root * (1 - root)
293
- # denominator = input_delta + intermediate1 * theta_one_minus_theta
294
- # derivative_numerator = input_delta.pow(2) * (
295
- # input_derivatives_plus_one * root.pow(2)
296
- # + 2 * input_delta * theta_one_minus_theta
297
- # + input_derivatives * (1 - root).pow(2)
298
- # )
299
- # log_abs_det = torch.log(derivative_numerator) - 2 * torch.log(denominator)
300
- return outputs #, -log_abs_det
301
 
302
 
303
  class VitsWaveNet(torch.nn.Module):
@@ -305,20 +141,14 @@ class VitsWaveNet(torch.nn.Module):
305
  super().__init__()
306
  self.hidden_size = config.hidden_size
307
  self.num_layers = num_layers
308
-
309
  self.in_layers = torch.nn.ModuleList()
310
  self.res_skip_layers = torch.nn.ModuleList()
311
-
312
-
313
- if hasattr(nn.utils.parametrizations, "weight_norm"):
314
- weight_norm = nn.utils.parametrizations.weight_norm
315
- else:
316
- weight_norm = nn.utils.weight_norm
317
-
318
- if config.speaker_embedding_size != 0:
319
- cond_layer = torch.nn.Conv1d(config.speaker_embedding_size, 2 * config.hidden_size * num_layers, 1)
320
- self.cond_layer = weight_norm(cond_layer, name="weight")
321
-
322
  for i in range(num_layers):
323
  dilation = config.wavenet_dilation_rate**i
324
  padding = (config.wavenet_kernel_size * dilation - dilation) // 2
@@ -337,53 +167,36 @@ class VitsWaveNet(torch.nn.Module):
337
  res_skip_channels = 2 * config.hidden_size
338
  else:
339
  res_skip_channels = config.hidden_size
340
-
341
  res_skip_layer = torch.nn.Conv1d(config.hidden_size, res_skip_channels, 1)
342
  res_skip_layer = weight_norm(res_skip_layer, name="weight")
343
  self.res_skip_layers.append(res_skip_layer)
344
 
345
- def forward(self, inputs, padding_mask, global_conditioning=None):
 
346
  outputs = torch.zeros_like(inputs)
347
  num_channels = torch.IntTensor([self.hidden_size])[0]
348
-
349
-
350
  for i in range(self.num_layers):
351
  in_act = self.in_layers[i](inputs)
352
-
353
-
354
  # global_states = torch.zeros_like(hidden_states) # style ?
355
-
356
  # acts = fused_add_tanh_sigmoid_multiply(hidden_states, global_states, num_channels_tensor[0])
357
-
358
  # --
359
  # def fused_add_tanh_sigmoid_multiply(input_a, input_b, num_channels):
360
  # in_act = input_a # + input_b
361
  t_act = torch.tanh(in_act[:, :num_channels, :])
362
  s_act = torch.sigmoid(in_act[:, num_channels:, :])
363
  acts = t_act * s_act
364
-
365
-
366
- #
367
-
368
-
369
-
370
  res_skip_acts = self.res_skip_layers[i](acts)
371
  if i < self.num_layers - 1:
372
  res_acts = res_skip_acts[:, : self.hidden_size, :]
373
- inputs = (inputs + res_acts) * padding_mask
374
  outputs = outputs + res_skip_acts[:, self.hidden_size :, :]
375
  else:
376
  outputs = outputs + res_skip_acts
 
 
 
377
 
378
- return outputs * padding_mask
379
 
380
- def remove_weight_norm(self):
381
- if self.speaker_embedding_size != 0:
382
- torch.nn.utils.remove_weight_norm(self.cond_layer)
383
- for layer in self.in_layers:
384
- torch.nn.utils.remove_weight_norm(layer)
385
- for layer in self.res_skip_layers:
386
- torch.nn.utils.remove_weight_norm(layer)
387
 
388
 
389
 
@@ -425,22 +238,6 @@ class HifiGanResidualBlock(nn.Module):
425
  def get_padding(self, kernel_size, dilation=1):
426
  return (kernel_size * dilation - dilation) // 2
427
 
428
- def apply_weight_norm(self):
429
- weight_norm = nn.utils.weight_norm
430
- if hasattr(nn.utils.parametrizations, "weight_norm"):
431
- weight_norm = nn.utils.parametrizations.weight_norm
432
-
433
- for layer in self.convs1:
434
- weight_norm(layer)
435
- for layer in self.convs2:
436
- weight_norm(layer)
437
-
438
- def remove_weight_norm(self):
439
- for layer in self.convs1:
440
- nn.utils.remove_weight_norm(layer)
441
- for layer in self.convs2:
442
- nn.utils.remove_weight_norm(layer)
443
-
444
  def forward(self, hidden_states):
445
  for conv1, conv2 in zip(self.convs1, self.convs2):
446
  residual = hidden_states
@@ -483,44 +280,18 @@ class VitsHifiGan(nn.Module):
483
  channels = config.upsample_initial_channel // (2 ** (i + 1))
484
  for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes):
485
  self.resblocks.append(HifiGanResidualBlock(channels, kernel_size, dilation, config.leaky_relu_slope))
486
-
487
  self.conv_post = nn.Conv1d(channels, 1, kernel_size=7, stride=1, padding=3, bias=False)
488
 
489
- if config.speaker_embedding_size != 0:
490
- self.cond = nn.Conv1d(config.speaker_embedding_size, config.upsample_initial_channel, 1)
491
-
492
- def apply_weight_norm(self):
493
- weight_norm = nn.utils.weight_norm
494
- if hasattr(nn.utils.parametrizations, "weight_norm"):
495
- weight_norm = nn.utils.parametrizations.weight_norm
496
-
497
- for layer in self.upsampler:
498
- weight_norm(layer)
499
- for layer in self.resblocks:
500
- layer.apply_weight_norm()
501
-
502
- def remove_weight_norm(self):
503
- for layer in self.upsampler:
504
- nn.utils.remove_weight_norm(layer)
505
- for layer in self.resblocks:
506
- layer.remove_weight_norm()
507
-
508
- def forward(
509
- self,
510
- spectrogram,
511
- global_conditioning=None):
512
-
513
  hidden_states = self.conv_pre(spectrogram)
514
-
515
  for i in range(self.num_upsamples):
516
  hidden_states = nn.functional.leaky_relu(hidden_states, self.config.leaky_relu_slope)
517
  hidden_states = self.upsampler[i](hidden_states)
518
-
519
  res_state = self.resblocks[i * self.num_kernels](hidden_states)
520
  for j in range(1, self.num_kernels):
521
  res_state += self.resblocks[i * self.num_kernels + j](hidden_states)
522
  hidden_states = res_state / self.num_kernels
523
-
524
  hidden_states = nn.functional.leaky_relu(hidden_states)
525
  hidden_states = self.conv_post(hidden_states)
526
  waveform = torch.tanh(hidden_states)
@@ -531,27 +302,20 @@ class VitsResidualCouplingLayer(nn.Module):
531
  def __init__(self, config):
532
  super().__init__()
533
  self.half_channels = config.flow_size // 2
534
-
535
  self.conv_pre = nn.Conv1d(self.half_channels, config.hidden_size, 1)
536
  self.wavenet = VitsWaveNet(config, num_layers=config.prior_encoder_num_wavenet_layers)
537
  self.conv_post = nn.Conv1d(config.hidden_size, self.half_channels, 1)
538
 
539
- def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False):
540
- first_half, second_half = torch.split(inputs, [self.half_channels] * 2, dim=1)
541
- hidden_states = self.conv_pre(first_half) * padding_mask
542
- hidden_states = self.wavenet(hidden_states, padding_mask, global_conditioning)
543
- mean = self.conv_post(hidden_states) * padding_mask
544
- log_stddev = torch.zeros_like(mean)
545
-
546
- if not reverse:
547
- second_half = mean + second_half * torch.exp(log_stddev) * padding_mask
548
- outputs = torch.cat([first_half, second_half], dim=1)
549
- log_determinant = torch.sum(log_stddev, [1, 2])
550
- return outputs, log_determinant
551
- else:
552
- second_half = (second_half - mean) * torch.exp(-log_stddev) * padding_mask
553
- outputs = torch.cat([first_half, second_half], dim=1)
554
- return outputs, None
555
 
556
 
557
  class VitsResidualCouplingBlock(nn.Module):
@@ -561,226 +325,20 @@ class VitsResidualCouplingBlock(nn.Module):
561
  for _ in range(config.prior_encoder_num_flows):
562
  self.flows.append(VitsResidualCouplingLayer(config))
563
 
564
- def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False):
565
- if not reverse:
566
- for flow in self.flows:
567
- inputs, _ = flow(inputs, padding_mask, global_conditioning)
568
- inputs = torch.flip(inputs, [1])
569
- else:
570
- for flow in reversed(self.flows):
571
- inputs = torch.flip(inputs, [1])
572
- inputs, _ = flow(inputs, padding_mask, global_conditioning, reverse=True)
573
- return inputs
574
-
575
-
576
- class VitsDilatedDepthSeparableConv(nn.Module):
577
- def __init__(self, config, dropout_rate=0.0):
578
- super().__init__()
579
- kernel_size = config.duration_predictor_kernel_size
580
- channels = config.hidden_size
581
- self.num_layers = config.depth_separable_num_layers
582
-
583
- self.convs_dilated = nn.ModuleList()
584
- self.convs_pointwise = nn.ModuleList()
585
- self.norms_1 = nn.ModuleList()
586
- self.norms_2 = nn.ModuleList()
587
- for i in range(self.num_layers):
588
- dilation = kernel_size**i
589
- padding = (kernel_size * dilation - dilation) // 2
590
- self.convs_dilated.append(
591
- nn.Conv1d(
592
- in_channels=channels,
593
- out_channels=channels,
594
- kernel_size=kernel_size,
595
- groups=channels,
596
- dilation=dilation,
597
- padding=padding,
598
- )
599
- )
600
- self.convs_pointwise.append(nn.Conv1d(channels, channels, 1))
601
- self.norms_1.append(nn.LayerNorm(channels))
602
- self.norms_2.append(nn.LayerNorm(channels))
603
-
604
- def forward(self, inputs, padding_mask, global_conditioning=None):
605
- if global_conditioning is not None:
606
- inputs = inputs + global_conditioning
607
-
608
- for i in range(self.num_layers):
609
- hidden_states = self.convs_dilated[i](inputs * padding_mask)
610
- hidden_states = self.norms_1[i](hidden_states.transpose(1, -1)).transpose(1, -1)
611
- hidden_states = nn.functional.gelu(hidden_states)
612
- hidden_states = self.convs_pointwise[i](hidden_states)
613
- hidden_states = self.norms_2[i](hidden_states.transpose(1, -1)).transpose(1, -1)
614
- hidden_states = nn.functional.gelu(hidden_states)
615
-
616
- inputs = inputs + hidden_states
617
-
618
- return inputs * padding_mask
619
-
620
-
621
- class VitsConvFlow(nn.Module):
622
- def __init__(self, config):
623
- super().__init__()
624
- self.filter_channels = config.hidden_size
625
- self.half_channels = config.depth_separable_channels // 2
626
- self.num_bins = config.duration_predictor_flow_bins
627
- self.tail_bound = config.duration_predictor_tail_bound
628
-
629
- self.conv_pre = nn.Conv1d(self.half_channels, self.filter_channels, 1)
630
- self.conv_dds = VitsDilatedDepthSeparableConv(config)
631
- self.conv_proj = nn.Conv1d(self.filter_channels, self.half_channels * (self.num_bins * 3 - 1), 1)
632
-
633
- def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False):
634
- first_half, second_half = torch.split(inputs, [self.half_channels] * 2, dim=1)
635
-
636
- hidden_states = self.conv_pre(first_half)
637
- hidden_states = self.conv_dds(hidden_states, padding_mask, global_conditioning)
638
- hidden_states = self.conv_proj(hidden_states) * padding_mask
639
-
640
- batch_size, channels, length = first_half.shape
641
- hidden_states = hidden_states.reshape(batch_size, channels, -1, length).permute(0, 1, 3, 2)
642
-
643
- unnormalized_widths = hidden_states[..., : self.num_bins] / math.sqrt(self.filter_channels)
644
- unnormalized_heights = hidden_states[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels)
645
- unnormalized_derivatives = hidden_states[..., 2 * self.num_bins :]
646
-
647
- second_half = _unconstrained_rational_quadratic_spline(
648
- second_half,
649
- unnormalized_widths,
650
- unnormalized_heights,
651
- unnormalized_derivatives,
652
- reverse=reverse,
653
- tail_bound=self.tail_bound,
654
- )
655
-
656
- outputs = torch.cat([first_half, second_half], dim=1) * padding_mask
657
-
658
- return outputs, None
659
-
660
-
661
- class VitsElementwiseAffine(nn.Module):
662
- def __init__(self, config):
663
- super().__init__()
664
- self.channels = config.depth_separable_channels
665
- self.translate = nn.Parameter(torch.zeros(self.channels, 1))
666
- self.log_scale = nn.Parameter(torch.zeros(self.channels, 1))
667
 
668
- def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False):
669
- if not reverse:
670
- raise ValueError
671
- # outputs = self.translate + torch.exp(self.log_scale) * inputs
672
- # outputs = outputs * padding_mask
673
- # log_determinant = torch.sum(self.log_scale * padding_mask, [1, 2])
674
- # return outputs, log_determinant
675
- else:
676
- outputs = (inputs - self.translate) * torch.exp(-self.log_scale) * padding_mask
677
- return outputs, None
678
-
679
-
680
- class VitsStochasticDurationPredictor(nn.Module):
681
- def __init__(self, config):
682
- super().__init__()
683
- embed_dim = config.speaker_embedding_size
684
- filter_channels = config.hidden_size
685
-
686
- self.conv_pre = nn.Conv1d(filter_channels, filter_channels, 1)
687
- self.conv_proj = nn.Conv1d(filter_channels, filter_channels, 1)
688
- self.conv_dds = VitsDilatedDepthSeparableConv(config)
689
-
690
- if embed_dim != 0:
691
- self.cond = nn.Conv1d(embed_dim, filter_channels, 1)
692
-
693
- self.flows = nn.ModuleList()
694
- self.flows.append(VitsElementwiseAffine(config))
695
- for _ in range(config.duration_predictor_num_flows):
696
- self.flows.append(VitsConvFlow(config))
697
-
698
- # self.post_conv_pre = nn.Conv1d(1, filter_channels, 1)
699
- # self.post_conv_proj = nn.Conv1d(filter_channels, filter_channels, 1)
700
- # self.post_conv_dds = VitsDilatedDepthSeparableConv(
701
- # config,
702
- # dropout_rate=config.duration_predictor_dropout,
703
- # )
704
-
705
- # self.post_flows = nn.ModuleList()
706
- # self.post_flows.append(VitsElementwiseAffine(config))
707
- # for _ in range(config.duration_predictor_num_flows):
708
- # self.post_flows.append(VitsConvFlow(config))
709
-
710
- def forward(self, inputs, padding_mask, global_conditioning=None, durations=None, reverse=False, noise_scale=1.0):
711
- inputs = torch.detach(inputs)
712
- inputs = self.conv_pre(inputs)
713
-
714
- if global_conditioning is not None:
715
- raise ValueError
716
- # global_conditioning = torch.detach(global_conditioning)
717
- # inputs = inputs + self.cond(global_conditioning)
718
-
719
- inputs = self.conv_dds(inputs, padding_mask)
720
- inputs = self.conv_proj(inputs) * padding_mask
721
-
722
- if not reverse:
723
- raise ValueError
724
- # hidden_states = self.post_conv_pre(durations)
725
- # hidden_states = self.post_conv_dds(hidden_states, padding_mask)
726
- # hidden_states = self.post_conv_proj(hidden_states) * padding_mask
727
-
728
- # random_posterior = (
729
- # torch.randn(durations.size(0), 2, durations.size(2)).to(device=inputs.device, dtype=inputs.dtype)
730
- # * padding_mask
731
- # )
732
- # log_determinant_posterior_sum = 0
733
- # latents_posterior = random_posterior
734
- # for flow in self.post_flows:
735
- # latents_posterior, log_determinant = flow(
736
- # latents_posterior, padding_mask, global_conditioning=inputs + hidden_states
737
- # )
738
- # latents_posterior = torch.flip(latents_posterior, [1])
739
- # log_determinant_posterior_sum += log_determinant
740
-
741
- # first_half, second_half = torch.split(latents_posterior, [1, 1], dim=1)
742
-
743
- # log_determinant_posterior_sum += torch.sum(
744
- # (nn.functional.logsigmoid(first_half) + nn.functional.logsigmoid(-first_half)) * padding_mask, [1, 2]
745
- # )
746
- # logq = (
747
- # torch.sum(-0.5 * (math.log(2 * math.pi) + (random_posterior**2)) * padding_mask, [1, 2])
748
- # - log_determinant_posterior_sum
749
- # )
750
-
751
- # first_half = (durations - torch.sigmoid(first_half)) * padding_mask
752
- # first_half = torch.log(torch.clamp_min(first_half, 1e-5)) * padding_mask
753
- # log_determinant_sum = torch.sum(-first_half, [1, 2])
754
-
755
- # latents = torch.cat([first_half, second_half], dim=1)
756
- # for flow in self.flows:
757
- # latents, log_determinant = flow(latents, padding_mask, global_conditioning=inputs)
758
- # latents = torch.flip(latents, [1])
759
- # log_determinant_sum += log_determinant
760
-
761
- # nll = torch.sum(0.5 * (math.log(2 * math.pi) + (latents**2)) * padding_mask, [1, 2]) - log_determinant_sum
762
- # return nll + logq
763
- else:
764
- flows = list(reversed(self.flows))
765
- flows = flows[:-2] + [flows[-1]] # remove a useless vflow
766
-
767
- latents = (
768
- torch.randn(inputs.size(0), 2, inputs.size(2)).to(device=inputs.device, dtype=inputs.dtype)
769
- * noise_scale
770
- )
771
- for flow in flows:
772
- latents = torch.flip(latents, [1])
773
- latents, _ = flow(latents, padding_mask, global_conditioning=inputs, reverse=True)
774
-
775
- log_duration, _ = torch.split(latents, [1, 1], dim=1)
776
- return log_duration
777
 
778
 
779
 
780
 
781
 
782
  class VitsAttention(nn.Module):
783
- """Multi-headed attention with relative positional representation."""
784
 
785
  def __init__(self, config):
786
  super().__init__()
@@ -793,36 +351,22 @@ class VitsAttention(nn.Module):
793
  self.scaling = self.head_dim**-0.5
794
 
795
  if (self.head_dim * self.num_heads) != self.embed_dim:
796
- raise ValueError(
797
- f"hidden_size must be divisible by num_attention_heads (got `hidden_size`: {self.embed_dim}"
798
- f" and `num_attention_heads`: {self.num_heads})."
799
- )
800
-
801
  self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
802
  self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
803
  self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
804
  self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
805
 
806
- if self.window_size:
807
- # Those provide relative pos embs for k/v interpolated from 2*4+1 to 1027 time frames - duration of txt
808
- self.emb_rel_k = nn.Parameter(torch.randn(1, self.window_size * 2 + 1, self.head_dim) * self.scaling)
809
- self.emb_rel_v = nn.Parameter(torch.randn(1, self.window_size * 2 + 1, self.head_dim) * self.scaling)
810
-
811
  def _shape(self, tensor, seq_len, bsz):
812
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
813
 
814
  def forward(
815
  self,
816
  hidden_states,
817
- key_value_states: Optional[torch.Tensor] = None,
818
- attention_mask: Optional[torch.Tensor] = None,
819
- layer_head_mask: Optional[torch.Tensor] = None,
820
- output_attentions: bool = False,
821
  ):
822
- """Input shape: Batch x Time x Channel"""
823
-
824
- # if key_value_states are provided this layer is used as a cross-attention layer
825
- # for the decoder
826
 
827
  bsz, tgt_len, _ = hidden_states.size()
828
 
@@ -840,36 +384,9 @@ class VitsAttention(nn.Module):
840
 
841
  src_len = key_states.size(1)
842
  attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
843
-
844
-
845
-
846
- if self.window_size is not None:
847
- # 4
848
- # key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, src_len)
849
- key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, src_len) # try fix k.shape[2] to have consistent voice deu
850
- # print(f'{self.emb_rel_k.shape=} {key_relative_embeddings.shape=}\n\nL855')
851
- relative_logits = torch.matmul(query_states, key_relative_embeddings.transpose(-2, -1))
852
- # -- only here (key)
853
- rel_pos_bias = self._relative_position_to_absolute_position(relative_logits)
854
- attn_weights += rel_pos_bias
855
-
856
- if attention_mask is not None:
857
-
858
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
859
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
860
  attn_weights = nn.functional.softmax(attn_weights, dim=-1)
861
  attn_output = torch.bmm(attn_weights,
862
  value_states)
863
-
864
-
865
-
866
- if self.window_size is not None:
867
- # Entering here with self.window_size = 4
868
- value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, src_len)
869
- relative_weights = self._absolute_position_to_relative_position(attn_weights)
870
- rel_pos_bias = torch.matmul(relative_weights, value_relative_embeddings)
871
- attn_output += rel_pos_bias
872
-
873
  attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
874
  attn_output = attn_output.transpose(1, 2)
875
 
@@ -881,42 +398,6 @@ class VitsAttention(nn.Module):
881
 
882
  return attn_output, None #attn_weights_reshaped
883
 
884
- def _get_relative_embeddings(self, relative_embeddings, length):
885
- pad_length = max(length - (self.window_size + 1), 0)
886
- if pad_length > 0:
887
- relative_embeddings = nn.functional.pad(relative_embeddings, [0, 0, pad_length, pad_length, 0, 0])
888
-
889
- slice_start_position = max((self.window_size + 1) - length, 0)
890
- slice_end_position = slice_start_position + 2 * length - 1
891
- return relative_embeddings[:, slice_start_position:slice_end_position]
892
-
893
- def _relative_position_to_absolute_position(self, x):
894
- batch_heads, length, _ = x.size()
895
-
896
- # Concat columns of pad to shift from relative to absolute indexing.
897
- x = nn.functional.pad(x, [0, 1, 0, 0, 0, 0])
898
-
899
- # Concat extra elements so to add up to shape (len+1, 2*len-1).
900
- x_flat = x.view([batch_heads, length * 2 * length])
901
- x_flat = nn.functional.pad(x_flat, [0, length - 1, 0, 0])
902
-
903
- # Reshape and slice out the padded elements.
904
- x_final = x_flat.view([batch_heads, length + 1, 2 * length - 1])
905
- x_final = x_final[:, :length, length - 1 :]
906
- return x_final
907
-
908
- def _absolute_position_to_relative_position(self, x):
909
- batch_heads, length, _ = x.size()
910
-
911
- # Pad along column
912
- x = nn.functional.pad(x, [0, length - 1, 0, 0, 0, 0])
913
- x_flat = x.view([batch_heads, length * (2 * length - 1)])
914
-
915
- # Add 0's in the beginning that will skew the elements after reshape
916
- x_flat = nn.functional.pad(x_flat, [length, 0, 0, 0])
917
- x_final = x_flat.view([batch_heads, length, 2 * length])[:, :, 1:]
918
- return x_final
919
-
920
 
921
  class VitsFeedForward(nn.Module):
922
  def __init__(self, config):
@@ -933,25 +414,15 @@ class VitsFeedForward(nn.Module):
933
  else:
934
  self.padding = None
935
 
936
- def forward(self, hidden_states, padding_mask):
937
  hidden_states = hidden_states.permute(0, 2, 1)
938
- padding_mask = padding_mask.permute(0, 2, 1)
939
-
940
- hidden_states = hidden_states * padding_mask
941
  if self.padding is not None:
942
  hidden_states = nn.functional.pad(hidden_states, self.padding)
943
-
944
  hidden_states = self.conv_1(hidden_states)
945
  hidden_states = self.act_fn(hidden_states)
946
-
947
-
948
- hidden_states = hidden_states * padding_mask
949
  if self.padding is not None:
950
  hidden_states = nn.functional.pad(hidden_states, self.padding)
951
-
952
  hidden_states = self.conv_2(hidden_states)
953
- hidden_states = hidden_states * padding_mask
954
-
955
  hidden_states = hidden_states.permute(0, 2, 1)
956
  return hidden_states
957
 
@@ -960,22 +431,19 @@ class VitsEncoderLayer(nn.Module):
960
  def __init__(self, config):
961
  super().__init__()
962
  self.attention = VitsAttention(config)
963
-
964
  self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
965
  self.feed_forward = VitsFeedForward(config)
966
  self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
967
 
968
  def forward(
969
  self,
970
- hidden_states: torch.Tensor,
971
- padding_mask: torch.FloatTensor,
972
- attention_mask: Optional[torch.Tensor] = None,
973
- output_attentions: bool = False,
974
  ):
975
  residual = hidden_states
976
  hidden_states, attn_weights = self.attention(
977
  hidden_states=hidden_states,
978
- attention_mask=attention_mask,
979
  output_attentions=output_attentions,
980
  )
981
 
@@ -983,15 +451,12 @@ class VitsEncoderLayer(nn.Module):
983
  hidden_states = self.layer_norm(residual + hidden_states)
984
 
985
  residual = hidden_states
986
- hidden_states = self.feed_forward(hidden_states, padding_mask)
987
-
988
  hidden_states = self.final_layer_norm(residual + hidden_states)
989
 
990
  outputs = (hidden_states,)
991
 
992
- if output_attentions:
993
- outputs += (attn_weights,)
994
-
995
  return outputs
996
 
997
 
@@ -1005,52 +470,24 @@ class VitsEncoder(nn.Module):
1005
 
1006
  def forward(
1007
  self,
1008
- hidden_states: torch.FloatTensor,
1009
- padding_mask: torch.FloatTensor,
1010
- attention_mask: Optional[torch.Tensor] = None,
1011
- output_attentions: Optional[bool] = None,
1012
- output_hidden_states: Optional[bool] = None,
1013
- return_dict: Optional[bool] = None,
1014
  ):
1015
- all_hidden_states = () if output_hidden_states else None
1016
- all_self_attentions = () if output_attentions else None
1017
-
1018
- # expand attention_mask
1019
- if attention_mask is not None:
1020
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1021
- attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
1022
-
1023
- hidden_states = hidden_states * padding_mask
1024
-
1025
-
1026
-
1027
- for encoder_layer in self.layers:
1028
- if output_hidden_states:
1029
- all_hidden_states = all_hidden_states + (hidden_states,)
1030
-
1031
- layer_outputs = encoder_layer(
1032
- hidden_states,
1033
- attention_mask=attention_mask,
1034
- padding_mask=padding_mask,
1035
- output_attentions=output_attentions,
1036
- )
1037
  hidden_states = layer_outputs[0]
1038
-
1039
- if output_attentions:
1040
- all_self_attentions = all_self_attentions + (layer_outputs[1],)
1041
-
1042
- hidden_states = hidden_states * padding_mask
1043
-
1044
  return BaseModelOutput(
1045
  last_hidden_state=hidden_states,
1046
- hidden_states=all_hidden_states,
1047
- attentions=all_self_attentions,
1048
  )
1049
 
1050
 
1051
  class VitsTextEncoder(nn.Module):
1052
  """
1053
- Transformer encoder that uses relative positional representation instead of absolute positional encoding.
1054
  """
1055
 
1056
  def __init__(self, config):
@@ -1060,75 +497,30 @@ class VitsTextEncoder(nn.Module):
1060
  self.encoder = VitsEncoder(config) # 6 Layers of VitsAttention
1061
  self.project = nn.Conv1d(config.hidden_size, config.flow_size * 2, kernel_size=1)
1062
 
1063
- # def get_input_embeddings(self):
1064
- # return self.embed_tokens
1065
-
1066
- # def set_input_embeddings(self, value):
1067
- # self.embed_tokens = value
1068
-
1069
- def forward(
1070
- self,
1071
- input_ids: torch.Tensor,
1072
- padding_mask: torch.FloatTensor,
1073
- attention_mask: Optional[torch.Tensor] = None,
1074
- output_attentions: Optional[bool] = None,
1075
- output_hidden_states: Optional[bool] = None,
1076
- return_dict: Optional[bool] = True,
1077
- ):
1078
- hidden_states = self.embed_tokens(input_ids) * math.sqrt(self.config.hidden_size)
1079
-
1080
- encoder_outputs = self.encoder(
1081
- hidden_states=hidden_states,
1082
- padding_mask=padding_mask,
1083
- attention_mask=attention_mask,
1084
- output_attentions=output_attentions,
1085
- output_hidden_states=output_hidden_states,
1086
- return_dict=return_dict,
1087
- )
1088
 
1089
- last_hidden_state = encoder_outputs[0] if not return_dict else encoder_outputs.last_hidden_state
1090
-
1091
- stats = self.project(last_hidden_state.transpose(1, 2)).transpose(1, 2) * padding_mask
1092
  prior_means, prior_log_variances = torch.split(stats, self.config.flow_size, dim=2)
1093
 
1094
  return VitsTextEncoderOutput(
1095
  last_hidden_state=last_hidden_state,
1096
  prior_means=prior_means,
1097
- prior_log_variances=prior_log_variances,
1098
- hidden_states=encoder_outputs.hidden_states,
1099
- attentions=encoder_outputs.attentions,
1100
  )
1101
 
1102
 
1103
  class VitsPreTrainedModel(PreTrainedModel):
1104
- """
1105
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
1106
- models.
1107
- """
1108
-
1109
  config_class = VitsConfig
1110
  base_model_prefix = "vits"
1111
  main_input_name = "input_ids"
1112
  supports_gradient_checkpointing = True
1113
 
1114
- def _init_weights(self, module):
1115
- """Initialize the weights"""
1116
- if isinstance(module, nn.Linear):
1117
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1118
- if module.bias is not None:
1119
- module.bias.data.zero_()
1120
- elif isinstance(module, nn.LayerNorm):
1121
- module.bias.data.zero_()
1122
- module.weight.data.fill_(1.0)
1123
- elif isinstance(module, nn.Conv1d):
1124
- nn.init.kaiming_normal_(module.weight)
1125
- if module.bias is not None:
1126
- k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
1127
- nn.init.uniform_(module.bias, a=-k, b=k)
1128
- elif isinstance(module, nn.Embedding):
1129
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1130
- if module.padding_idx is not None:
1131
- module.weight.data[module.padding_idx].zero_()
1132
 
1133
 
1134
  class VitsModel(VitsPreTrainedModel):
@@ -1138,27 +530,9 @@ class VitsModel(VitsPreTrainedModel):
1138
  self.text_encoder = VitsTextEncoder(config) # has VitsEncoder that includes 6L of VitsAttention
1139
  self.flow = VitsResidualCouplingBlock(config)
1140
  self.decoder = VitsHifiGan(config)
1141
-
1142
- if config.use_stochastic_duration_prediction:
1143
- self.duration_predictor = VitsStochasticDurationPredictor(config)
1144
- else:
1145
- raise ValueError
1146
- # self.duration_predictor = VitsDurationPredictor(config)
1147
-
1148
- if config.num_speakers > 1:
1149
- self.embed_speaker = nn.Embedding(config.num_speakers, config.speaker_embedding_size)
1150
-
1151
-
1152
-
1153
- self.noise_scale = config.noise_scale
1154
- self.noise_scale_duration = config.noise_scale_duration
1155
-
1156
  # Initialize weights and apply final processing
1157
  self.post_init()
1158
 
1159
- def get_encoder(self):
1160
- return self.text_encoder
1161
-
1162
  def forward(
1163
  self,
1164
  input_ids = None,
@@ -1168,69 +542,37 @@ class VitsModel(VitsPreTrainedModel):
1168
  output_hidden_states = None,
1169
  return_dict = None,
1170
  labels = None,
1171
- speed=None,
 
1172
  ):
1173
-
1174
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1175
- output_hidden_states = (
1176
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1177
- )
1178
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1179
-
1180
- if labels is not None:
1181
- raise NotImplementedError("Training of VITS is not supported yet.")
1182
-
1183
  mask_dtype = self.text_encoder.embed_tokens.weight.dtype
1184
  if attention_mask is not None:
1185
  input_padding_mask = attention_mask.unsqueeze(-1).to(mask_dtype)
1186
  else:
1187
  input_padding_mask = torch.ones_like(input_ids).unsqueeze(-1).to(mask_dtype)
1188
-
1189
- if self.config.num_speakers > 1 and speaker_id is not None:
1190
- if not 0 <= speaker_id < self.config.num_speakers:
1191
- raise ValueError(f"Set `speaker_id` in the range 0-{self.config.num_speakers - 1}.")
1192
- if isinstance(speaker_id, int):
1193
- speaker_id = torch.full(size=(1,), fill_value=speaker_id, device=self.device)
1194
- speaker_embeddings = self.embed_speaker(speaker_id).unsqueeze(-1)
1195
- else:
1196
- speaker_embeddings = None
1197
-
1198
- text_encoder_output = self.text_encoder(
1199
- input_ids=input_ids,
1200
- padding_mask=input_padding_mask,
1201
- attention_mask=attention_mask,
1202
- output_attentions=output_attentions,
1203
- output_hidden_states=output_hidden_states,
1204
- return_dict=return_dict,
1205
- )
1206
- hidden_states = text_encoder_output[0] if not return_dict else text_encoder_output.last_hidden_state
1207
- hidden_states = hidden_states.transpose(1, 2)
1208
  input_padding_mask = input_padding_mask.transpose(1, 2)
1209
- prior_means = text_encoder_output[1] if not return_dict else text_encoder_output.prior_means
1210
- prior_log_variances = text_encoder_output[2] if not return_dict else text_encoder_output.prior_log_variances
1211
-
1212
- if self.config.use_stochastic_duration_prediction:
1213
- log_duration = self.duration_predictor(
1214
- hidden_states,
1215
- input_padding_mask,
1216
- speaker_embeddings,
1217
- reverse=True,
1218
- noise_scale=self.noise_scale_duration,
1219
- )
1220
  else:
1221
- raise ValueError
1222
- # log_duration = self.duration_predictor(hidden_states, input_padding_mask, speaker_embeddings)
1223
-
1224
- length_scale = 1.0 / speed
1225
- duration = torch.ceil(torch.exp(log_duration) * input_padding_mask * length_scale)
1226
  predicted_lengths = torch.clamp_min(torch.sum(duration, [1, 2]), 1).long()
1227
-
1228
- # Create a padding mask for the output lengths of shape (batch, 1, max_output_length)
1229
  indices = torch.arange(predicted_lengths.max(), dtype=predicted_lengths.dtype, device=predicted_lengths.device)
1230
  output_padding_mask = indices.unsqueeze(0) < predicted_lengths.unsqueeze(1)
1231
  output_padding_mask = output_padding_mask.unsqueeze(1).to(input_padding_mask.dtype)
1232
-
1233
- # Reconstruct an attention tensor of shape (batch, 1, out_length, in_length)
1234
  attn_mask = torch.unsqueeze(input_padding_mask, 2) * torch.unsqueeze(output_padding_mask, -1)
1235
  batch_size, _, output_length, input_length = attn_mask.shape
1236
  cum_duration = torch.cumsum(duration, -1).view(batch_size * input_length, 1)
@@ -1239,106 +581,30 @@ class VitsModel(VitsPreTrainedModel):
1239
  valid_indices = valid_indices.to(attn_mask.dtype).view(batch_size, input_length, output_length)
1240
  padded_indices = valid_indices - nn.functional.pad(valid_indices, [0, 0, 1, 0, 0, 0])[:, :-1]
1241
  attn = padded_indices.unsqueeze(1).transpose(2, 3) * attn_mask
1242
-
1243
- # Expand prior distribution
1244
- prior_means = torch.matmul(attn.squeeze(1), prior_means).transpose(1, 2)
1245
- prior_log_variances = torch.matmul(attn.squeeze(1), prior_log_variances).transpose(1, 2)
1246
-
1247
- prior_latents = prior_means + torch.randn_like(prior_means) * torch.exp(prior_log_variances) * self.noise_scale
1248
- latents = self.flow(prior_latents, output_padding_mask, speaker_embeddings, reverse=True)
1249
-
1250
- spectrogram = latents * output_padding_mask
1251
- waveform = self.decoder(spectrogram, speaker_embeddings)
1252
- waveform = waveform.squeeze(1)
1253
- sequence_lengths = predicted_lengths * np.prod(self.config.upsample_rates)
1254
-
1255
- if not return_dict:
1256
- outputs = (waveform, sequence_lengths, spectrogram) + text_encoder_output[3:]
1257
- return outputs
1258
-
1259
- return waveform
1260
-
1261
-
1262
-
1263
 
1264
 
 
 
 
 
1265
 
 
1266
 
1267
 
1268
 
 
1269
 
 
 
1270
 
1271
- # ================================================ tokenization
1272
 
1273
- # coding=utf-8
1274
- # Copyright 2023 The Kakao Enterprise Authors, the MMS-TTS Authors and the HuggingFace Inc. team. All rights reserved.
1275
- #
1276
- # Licensed under the Apache License, Version 2.0 (the "License");
1277
- # you may not use this file except in compliance with the License.
1278
- # You may obtain a copy of the License at
1279
- #
1280
- # http://www.apache.org/licenses/LICENSE-2.0
1281
- #
1282
- # Unless required by applicable law or agreed to in writing, software
1283
- # distributed under the License is distributed on an "AS IS" BASIS,
1284
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1285
- # See the License for the specific language governing permissions and
1286
- # limitations under the License.
1287
- """Tokenization class for VITS."""
1288
-
1289
- import json
1290
- import os
1291
- import re
1292
- from typing import Any, Dict, List, Optional, Tuple, Union
1293
-
1294
- from transformers.tokenization_utils import PreTrainedTokenizer
1295
- from transformers.utils import is_phonemizer_available, is_uroman_available
1296
-
1297
-
1298
- if is_phonemizer_available():
1299
- import phonemizer
1300
-
1301
- if is_uroman_available():
1302
- import uroman as ur
1303
-
1304
-
1305
-
1306
- VOCAB_FILES_NAMES = {"vocab_file": "vocab.json"}
1307
-
1308
-
1309
- def has_non_roman_characters(input_string):
1310
- # Find any character outside the ASCII range
1311
- non_roman_pattern = re.compile(r"[^\x00-\x7F]")
1312
-
1313
- # Search the input string for non-Roman characters
1314
- match = non_roman_pattern.search(input_string)
1315
- has_non_roman = match is not None
1316
- return has_non_roman
1317
 
1318
 
1319
  class VitsTokenizer(PreTrainedTokenizer):
1320
- """
1321
- Construct a VITS tokenizer. Also supports MMS-TTS.
1322
-
1323
- This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
1324
- this superclass for more information regarding those methods.
1325
-
1326
- Args:
1327
- vocab_file (`str`):
1328
- Path to the vocabulary file.
1329
- language (`str`, *optional*):
1330
- Language identifier.
1331
- add_blank (`bool`, *optional*, defaults to `True`):
1332
- Whether to insert token id 0 in between the other tokens.
1333
- normalize (`bool`, *optional*, defaults to `True`):
1334
- Whether to normalize the input text by removing all casing and punctuation.
1335
- phonemize (`bool`, *optional*, defaults to `True`):
1336
- Whether to convert the input text into phonemes.
1337
- is_uroman (`bool`, *optional*, defaults to `False`):
1338
- Whether the `uroman` Romanizer needs to be applied to the input text prior to tokenizing.
1339
- """
1340
-
1341
- vocab_files_names = VOCAB_FILES_NAMES
1342
  model_input_names = ["input_ids", "attention_mask"]
1343
 
1344
  def __init__(
@@ -1412,12 +678,8 @@ class VitsTokenizer(PreTrainedTokenizer):
1412
  return text
1413
 
1414
  def prepare_for_tokenization(
1415
- self, text: str, is_split_into_words: bool = False, normalize: Optional[bool] = None, **kwargs
1416
- ) -> Tuple[str, Dict[str, Any]]:
1417
- '''
1418
- Performs any necessary transformations before tokenization.
1419
-
1420
- '''
1421
  normalize = normalize if normalize is not None else self.normalize
1422
 
1423
  if normalize:
@@ -1462,21 +724,18 @@ class VitsTokenizer(PreTrainedTokenizer):
1462
  tokens = list(text)
1463
 
1464
  if self.add_blank:
1465
- interspersed = [self._convert_id_to_token(0)] * (len(tokens) * 2 + 1)
1466
- interspersed[1::2] = tokens
1467
- tokens = interspersed
 
 
1468
 
1469
  return tokens
1470
 
1471
- def convert_tokens_to_string(self, tokens: List[str]) -> str:
1472
- if self.add_blank and len(tokens) > 1:
1473
- tokens = tokens[1::2]
1474
- return "".join(tokens)
1475
-
1476
  def _convert_token_to_id(self, token):
1477
  """Converts a token (str) in an id using the vocab."""
1478
  return self.encoder.get(token, self.encoder.get(self.unk_token))
1479
 
1480
  def _convert_id_to_token(self, index):
1481
  """Converts an index (integer) in a token (str) using the vocab."""
1482
- return self.decoder.get(index)
 
1
  import math
2
  from dataclasses import dataclass
 
3
  import numpy as np
4
  import torch
 
5
  from torch import nn
 
 
6
  from transformers.modeling_outputs import BaseModelOutput, ModelOutput
7
  from transformers.modeling_utils import PreTrainedModel
8
  from transformers.configuration_utils import PretrainedConfig
9
+ import json
10
+ import os
11
+ import re
12
+ from typing import Any, Dict, List, Optional, Tuple
13
+ from transformers.tokenization_utils import PreTrainedTokenizer
14
+ import phonemizer
15
+ import uroman as ur
16
+ import torch.nn.functional as F
17
+
18
+
19
+ def has_non_roman_characters(input_string):
20
+ # Find any character outside the ASCII range
21
+ non_roman_pattern = re.compile(r"[^\x00-\x7F]")
22
+
23
+ # Search the input string for non-Roman characters
24
+ match = non_roman_pattern.search(input_string)
25
+ has_non_roman = match is not None
26
+ return has_non_roman
27
+
28
 
29
  class VitsConfig(PretrainedConfig):
30
 
 
89
  self.ffn_kernel_size = ffn_kernel_size
90
  self.flow_size = flow_size
91
  self.spectrogram_bins = spectrogram_bins
 
 
92
  self.initializer_range = initializer_range
93
  self.layer_norm_eps = layer_norm_eps
94
+ # self.use_stochastic_duration_prediction = use_stochastic_duration_prediction
95
  self.num_speakers = num_speakers
96
  self.speaker_embedding_size = speaker_embedding_size
97
  self.upsample_initial_channel = upsample_initial_channel
 
105
  self.duration_predictor_flow_bins = duration_predictor_flow_bins
106
  self.duration_predictor_tail_bound = duration_predictor_tail_bound
107
  self.duration_predictor_kernel_size = duration_predictor_kernel_size
 
108
  self.duration_predictor_num_flows = duration_predictor_num_flows
109
  self.duration_predictor_filter_channels = duration_predictor_filter_channels
110
  self.prior_encoder_num_flows = prior_encoder_num_flows
 
112
  self.posterior_encoder_num_wavenet_layers = posterior_encoder_num_wavenet_layers
113
  self.wavenet_kernel_size = wavenet_kernel_size
114
  self.wavenet_dilation_rate = wavenet_dilation_rate
 
 
115
  self.noise_scale = noise_scale
116
  self.noise_scale_duration = noise_scale_duration
117
  self.sampling_rate = sampling_rate
 
131
  last_hidden_state: torch.FloatTensor = None
132
  prior_means: torch.FloatTensor = None
133
  prior_log_variances: torch.FloatTensor = None
134
+ hidden_states: torch.FloatTensor = None
135
+ attentions: torch.FloatTensor = None
136
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
 
139
  class VitsWaveNet(torch.nn.Module):
 
141
  super().__init__()
142
  self.hidden_size = config.hidden_size
143
  self.num_layers = num_layers
 
144
  self.in_layers = torch.nn.ModuleList()
145
  self.res_skip_layers = torch.nn.ModuleList()
146
+ # if hasattr(nn.utils.parametrizations, "weight_norm"):
147
+ # # raise ValueError
148
+ weight_norm = nn.utils.parametrizations.weight_norm
149
+ # else:
150
+ # raise ValueError
151
+ # # weight_norm = nn.utils.weight_norm
 
 
 
 
 
152
  for i in range(num_layers):
153
  dilation = config.wavenet_dilation_rate**i
154
  padding = (config.wavenet_kernel_size * dilation - dilation) // 2
 
167
  res_skip_channels = 2 * config.hidden_size
168
  else:
169
  res_skip_channels = config.hidden_size
 
170
  res_skip_layer = torch.nn.Conv1d(config.hidden_size, res_skip_channels, 1)
171
  res_skip_layer = weight_norm(res_skip_layer, name="weight")
172
  self.res_skip_layers.append(res_skip_layer)
173
 
174
+ def forward(self,
175
+ inputs):
176
  outputs = torch.zeros_like(inputs)
177
  num_channels = torch.IntTensor([self.hidden_size])[0]
 
 
178
  for i in range(self.num_layers):
179
  in_act = self.in_layers[i](inputs)
 
 
180
  # global_states = torch.zeros_like(hidden_states) # style ?
 
181
  # acts = fused_add_tanh_sigmoid_multiply(hidden_states, global_states, num_channels_tensor[0])
 
182
  # --
183
  # def fused_add_tanh_sigmoid_multiply(input_a, input_b, num_channels):
184
  # in_act = input_a # + input_b
185
  t_act = torch.tanh(in_act[:, :num_channels, :])
186
  s_act = torch.sigmoid(in_act[:, num_channels:, :])
187
  acts = t_act * s_act
 
 
 
 
 
 
188
  res_skip_acts = self.res_skip_layers[i](acts)
189
  if i < self.num_layers - 1:
190
  res_acts = res_skip_acts[:, : self.hidden_size, :]
191
+ inputs = inputs + res_acts
192
  outputs = outputs + res_skip_acts[:, self.hidden_size :, :]
193
  else:
194
  outputs = outputs + res_skip_acts
195
+ return outputs
196
+
197
+
198
 
 
199
 
 
 
 
 
 
 
 
200
 
201
 
202
 
 
238
  def get_padding(self, kernel_size, dilation=1):
239
  return (kernel_size * dilation - dilation) // 2
240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  def forward(self, hidden_states):
242
  for conv1, conv2 in zip(self.convs1, self.convs2):
243
  residual = hidden_states
 
280
  channels = config.upsample_initial_channel // (2 ** (i + 1))
281
  for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes):
282
  self.resblocks.append(HifiGanResidualBlock(channels, kernel_size, dilation, config.leaky_relu_slope))
 
283
  self.conv_post = nn.Conv1d(channels, 1, kernel_size=7, stride=1, padding=3, bias=False)
284
 
285
+ def forward(self,
286
+ spectrogram):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  hidden_states = self.conv_pre(spectrogram)
 
288
  for i in range(self.num_upsamples):
289
  hidden_states = nn.functional.leaky_relu(hidden_states, self.config.leaky_relu_slope)
290
  hidden_states = self.upsampler[i](hidden_states)
 
291
  res_state = self.resblocks[i * self.num_kernels](hidden_states)
292
  for j in range(1, self.num_kernels):
293
  res_state += self.resblocks[i * self.num_kernels + j](hidden_states)
294
  hidden_states = res_state / self.num_kernels
 
295
  hidden_states = nn.functional.leaky_relu(hidden_states)
296
  hidden_states = self.conv_post(hidden_states)
297
  waveform = torch.tanh(hidden_states)
 
302
  def __init__(self, config):
303
  super().__init__()
304
  self.half_channels = config.flow_size // 2
 
305
  self.conv_pre = nn.Conv1d(self.half_channels, config.hidden_size, 1)
306
  self.wavenet = VitsWaveNet(config, num_layers=config.prior_encoder_num_wavenet_layers)
307
  self.conv_post = nn.Conv1d(config.hidden_size, self.half_channels, 1)
308
 
309
+ def forward(self,
310
+ x,
311
+ reverse=False):
312
+ first_half, second_half = torch.split(x, [self.half_channels] * 2, dim=1)
313
+ hidden_states = self.conv_pre(first_half)
314
+ hidden_states = self.wavenet(hidden_states)
315
+ mean = self.conv_post(hidden_states)
316
+ second_half = (second_half - mean)
317
+ outputs = torch.cat([first_half, second_half], dim=1)
318
+ return outputs
 
 
 
 
 
 
319
 
320
 
321
  class VitsResidualCouplingBlock(nn.Module):
 
325
  for _ in range(config.prior_encoder_num_flows):
326
  self.flows.append(VitsResidualCouplingLayer(config))
327
 
328
+ def forward(self, x, reverse=False):
329
+ # x L [1, 192, 481]
330
+ for flow in reversed(self.flows):
331
+ x = torch.flip(x, [1]) # flipud CHANNELs
332
+ x = flow(x, reverse=True)
333
+ return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
 
336
 
337
 
338
 
339
 
340
  class VitsAttention(nn.Module):
341
+ """has no positional info"""
342
 
343
  def __init__(self, config):
344
  super().__init__()
 
351
  self.scaling = self.head_dim**-0.5
352
 
353
  if (self.head_dim * self.num_heads) != self.embed_dim:
354
+ raise ValueError
 
 
 
 
355
  self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
356
  self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
357
  self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
358
  self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
359
 
 
 
 
 
 
360
  def _shape(self, tensor, seq_len, bsz):
361
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
362
 
363
  def forward(
364
  self,
365
  hidden_states,
366
+ layer_head_mask = None,
367
+ output_attentions = False,
 
 
368
  ):
369
+
 
 
 
370
 
371
  bsz, tgt_len, _ = hidden_states.size()
372
 
 
384
 
385
  src_len = key_states.size(1)
386
  attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
  attn_weights = nn.functional.softmax(attn_weights, dim=-1)
388
  attn_output = torch.bmm(attn_weights,
389
  value_states)
 
 
 
 
 
 
 
 
 
 
390
  attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
391
  attn_output = attn_output.transpose(1, 2)
392
 
 
398
 
399
  return attn_output, None #attn_weights_reshaped
400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
 
402
  class VitsFeedForward(nn.Module):
403
  def __init__(self, config):
 
414
  else:
415
  self.padding = None
416
 
417
+ def forward(self, hidden_states):
418
  hidden_states = hidden_states.permute(0, 2, 1)
 
 
 
419
  if self.padding is not None:
420
  hidden_states = nn.functional.pad(hidden_states, self.padding)
 
421
  hidden_states = self.conv_1(hidden_states)
422
  hidden_states = self.act_fn(hidden_states)
 
 
 
423
  if self.padding is not None:
424
  hidden_states = nn.functional.pad(hidden_states, self.padding)
 
425
  hidden_states = self.conv_2(hidden_states)
 
 
426
  hidden_states = hidden_states.permute(0, 2, 1)
427
  return hidden_states
428
 
 
431
  def __init__(self, config):
432
  super().__init__()
433
  self.attention = VitsAttention(config)
 
434
  self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
435
  self.feed_forward = VitsFeedForward(config)
436
  self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
437
 
438
  def forward(
439
  self,
440
+ hidden_states,
441
+ output_attentions = False,
 
 
442
  ):
443
  residual = hidden_states
444
  hidden_states, attn_weights = self.attention(
445
  hidden_states=hidden_states,
446
+ # attention_mask=attention_mask,
447
  output_attentions=output_attentions,
448
  )
449
 
 
451
  hidden_states = self.layer_norm(residual + hidden_states)
452
 
453
  residual = hidden_states
454
+ hidden_states = self.feed_forward(hidden_states)
455
+
456
  hidden_states = self.final_layer_norm(residual + hidden_states)
457
 
458
  outputs = (hidden_states,)
459
 
 
 
 
460
  return outputs
461
 
462
 
 
470
 
471
  def forward(
472
  self,
473
+ hidden_states,
474
+ output_attentions = None,
475
+ output_hidden_states = None,
476
+ return_dict = None,
 
 
477
  ):
478
+ for _layer in self.layers:
479
+ layer_outputs = _layer(hidden_states)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
480
  hidden_states = layer_outputs[0]
 
 
 
 
 
 
481
  return BaseModelOutput(
482
  last_hidden_state=hidden_states,
483
+ # hidden_states=all_hidden_states,
484
+ # attentions=all_self_attentions,
485
  )
486
 
487
 
488
  class VitsTextEncoder(nn.Module):
489
  """
490
+ Has VitsEncoder
491
  """
492
 
493
  def __init__(self, config):
 
497
  self.encoder = VitsEncoder(config) # 6 Layers of VitsAttention
498
  self.project = nn.Conv1d(config.hidden_size, config.flow_size * 2, kernel_size=1)
499
 
500
+ def forward(self,
501
+ input_ids
502
+ ):
503
+ hidden_states = self.embed_tokens(input_ids) * math.sqrt(self.config.hidden_size)
504
+ last_hidden_state = self.encoder(hidden_states=hidden_states).last_hidden_state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
505
 
506
+ stats = self.project(last_hidden_state.transpose(1, 2)).transpose(1, 2)
 
 
507
  prior_means, prior_log_variances = torch.split(stats, self.config.flow_size, dim=2)
508
 
509
  return VitsTextEncoderOutput(
510
  last_hidden_state=last_hidden_state,
511
  prior_means=prior_means,
512
+ # prior_log_variances=prior_log_variances,
513
+ # hidden_states=encoder_outputs.hidden_states,
514
+ # attentions=encoder_outputs.attentions,
515
  )
516
 
517
 
518
  class VitsPreTrainedModel(PreTrainedModel):
 
 
 
 
 
519
  config_class = VitsConfig
520
  base_model_prefix = "vits"
521
  main_input_name = "input_ids"
522
  supports_gradient_checkpointing = True
523
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524
 
525
 
526
  class VitsModel(VitsPreTrainedModel):
 
530
  self.text_encoder = VitsTextEncoder(config) # has VitsEncoder that includes 6L of VitsAttention
531
  self.flow = VitsResidualCouplingBlock(config)
532
  self.decoder = VitsHifiGan(config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
533
  # Initialize weights and apply final processing
534
  self.post_init()
535
 
 
 
 
536
  def forward(
537
  self,
538
  input_ids = None,
 
542
  output_hidden_states = None,
543
  return_dict = None,
544
  labels = None,
545
+ speed = None,
546
+ lang_code = 'deu', # speed oscillation pattern per voice/lang
547
  ):
 
 
 
 
 
 
 
 
 
 
548
  mask_dtype = self.text_encoder.embed_tokens.weight.dtype
549
  if attention_mask is not None:
550
  input_padding_mask = attention_mask.unsqueeze(-1).to(mask_dtype)
551
  else:
552
  input_padding_mask = torch.ones_like(input_ids).unsqueeze(-1).to(mask_dtype)
553
+ out = self.text_encoder(input_ids=input_ids)
554
+ hidden_states = out.last_hidden_state.transpose(1, 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
555
  input_padding_mask = input_padding_mask.transpose(1, 2)
556
+ prior_means = out.prior_means
557
+ bs, _, in_len = hidden_states.shape
558
+ # VITS Duration Oscillation
559
+ if lang_code == 'deu':
560
+ pattern = [1, 2, 1] # each voice (lang_code) sounds cooler with different pattern
561
+ elif lang_code == 'rmc-script_latin':
562
+ pattern = [2, 2, 1, 2, 2] # [2, 2, 2, 1, 2]
563
+ elif lang_code == 'hun':
564
+ # pattern = [1, 2, 2, 1, 1, 1] #sounds cool / has valley-pause
565
+ pattern = [1, 2, 1, 1, 1]
 
566
  else:
567
+ pattern = [1, 2, 1]
568
+ duration = torch.tensor(pattern, device=hidden_states.device).repeat(int(in_len / len(pattern)) + 2)[None, None, :in_len] # perhaps define [1, 2, 1] per voice or language
569
+ duration[:, :, 0] = 4
570
+ duration[:, :, -1] = 3
571
+ # ATTN
572
  predicted_lengths = torch.clamp_min(torch.sum(duration, [1, 2]), 1).long()
 
 
573
  indices = torch.arange(predicted_lengths.max(), dtype=predicted_lengths.dtype, device=predicted_lengths.device)
574
  output_padding_mask = indices.unsqueeze(0) < predicted_lengths.unsqueeze(1)
575
  output_padding_mask = output_padding_mask.unsqueeze(1).to(input_padding_mask.dtype)
 
 
576
  attn_mask = torch.unsqueeze(input_padding_mask, 2) * torch.unsqueeze(output_padding_mask, -1)
577
  batch_size, _, output_length, input_length = attn_mask.shape
578
  cum_duration = torch.cumsum(duration, -1).view(batch_size * input_length, 1)
 
581
  valid_indices = valid_indices.to(attn_mask.dtype).view(batch_size, input_length, output_length)
582
  padded_indices = valid_indices - nn.functional.pad(valid_indices, [0, 0, 1, 0, 0, 0])[:, :-1]
583
  attn = padded_indices.unsqueeze(1).transpose(2, 3) * attn_mask
584
+ attn = attn[:, 0, :, :]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
585
 
586
 
587
+ attn = attn + 1e-4 * torch.rand_like(attn)
588
+ attn /= attn.sum(2, keepdims=True)
589
+ #print(attn)
590
+ prior_means = torch.matmul(attn, prior_means) # try attn to contain .5/.5 instead of 1/0 so it smoothly interpolates repeated prior_means
591
 
592
+ #prior_means = F.interpolate(prior_means.transpose(1,2), int(1.74 * prior_means.shape[1]), mode='linear').transpose(1,2) # extend for slow speed
593
 
594
 
595
 
596
+ # prior means have now been replicated x duration of each prior mean
597
 
598
+ latents = self.flow(prior_means.transpose(1, 2), # + torch.randn_like(prior_means) * .94,
599
+ reverse=True)
600
 
601
+ waveform = self.decoder(latents) # [bs, 1, 16000]
602
 
603
+ return waveform[:, 0, :]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
604
 
605
 
606
  class VitsTokenizer(PreTrainedTokenizer):
607
+ vocab_files_names = {"vocab_file": "vocab.json"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
608
  model_input_names = ["input_ids", "attention_mask"]
609
 
610
  def __init__(
 
678
  return text
679
 
680
  def prepare_for_tokenization(
681
+ self, text: str, is_split_into_words: bool = False, normalize = None, **kwargs):
682
+
 
 
 
 
683
  normalize = normalize if normalize is not None else self.normalize
684
 
685
  if normalize:
 
724
  tokens = list(text)
725
 
726
  if self.add_blank:
727
+ # sounds dyslexi if no space between letters
728
+ # sounds disconnected if >2 spaces between letters
729
+ interspersed = [self._convert_id_to_token(0)] * (len(tokens) * 2) # + 1) # +1 rises slice index error if tokens odd
730
+ interspersed[::2] = tokens
731
+ tokens = interspersed + [self._convert_id_to_token(0)] # append one last space (it has indexing error ::2 mismatch if tokens is odd)
732
 
733
  return tokens
734
 
 
 
 
 
 
735
  def _convert_token_to_id(self, token):
736
  """Converts a token (str) in an id using the vocab."""
737
  return self.encoder.get(token, self.encoder.get(self.unk_token))
738
 
739
  def _convert_id_to_token(self, index):
740
  """Converts an index (integer) in a token (str) using the vocab."""
741
+ return self.decoder.get(index)
README.md CHANGED
@@ -131,7 +131,7 @@ python live_demo.py # type text & plays AudioGen sound & TTS
131
 
132
  # Audiobook
133
 
134
- Create audiobook from `.docx`. Listen to it - YouTube [male voice](https://www.youtube.com/watch?v=5-cpf7u18JE) / [v2](https://www.youtube.com/watch?v=Pzo-kKaNg6s) / [v2.1](https://www.youtube.com/watch?v=X4qlKBBaegM)/ [no diffusio](https://www.youtube.com/watch?v=vahKXpd6oLg)
135
 
136
  ```python
137
  # audiobook will be saved in ./tts_audiobooks
 
131
 
132
  # Audiobook
133
 
134
+ Create audiobook from `.docx`. Listen to it - YouTube [male voice](https://www.youtube.com/watch?v=5-cpf7u18JE) / [v2](https://www.youtube.com/watch?v=Pzo-kKaNg6s) / [v2.1](https://www.youtube.com/watch?v=X4qlKBBaegM)/ [no diffusio](https://www.youtube.com/watch?v=vahKXpd6oLg) [Audionar](https://youtu.be/fUGpfq_o_CU) / [F](https://www.youtube.com/watch?v=tlRdRV5nm40)
135
 
136
  ```python
137
  # audiobook will be saved in ./tts_audiobooks
Utils/JDC/__init__.py DELETED
@@ -1 +0,0 @@
1
-
 
 
Utils/JDC/bst.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:54dc94364b97e18ac1dfa6287714ed121248cfaac4cfd39d061c6e0a089ef169
3
- size 21029926
 
 
 
 
Utils/JDC/model.py DELETED
@@ -1,190 +0,0 @@
1
- """
2
- Implementation of model from:
3
- Kum et al. - "Joint Detection and Classification of Singing Voice Melody Using
4
- Convolutional Recurrent Neural Networks" (2019)
5
- Link: https://www.semanticscholar.org/paper/Joint-Detection-and-Classification-of-Singing-Voice-Kum-Nam/60a2ad4c7db43bace75805054603747fcd062c0d
6
- """
7
- import torch
8
- from torch import nn
9
-
10
- class JDCNet(nn.Module):
11
- """
12
- Joint Detection and Classification Network model for singing voice melody.
13
- """
14
- def __init__(self, num_class=722, seq_len=31, leaky_relu_slope=0.01):
15
- super().__init__()
16
- self.num_class = num_class
17
-
18
- # input = (b, 1, 31, 513), b = batch size
19
- self.conv_block = nn.Sequential(
20
- nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding=1, bias=False), # out: (b, 64, 31, 513)
21
- nn.BatchNorm2d(num_features=64),
22
- nn.LeakyReLU(leaky_relu_slope, inplace=True),
23
- nn.Conv2d(64, 64, 3, padding=1, bias=False), # (b, 64, 31, 513)
24
- )
25
-
26
- # res blocks
27
- self.res_block1 = ResBlock(in_channels=64, out_channels=128) # (b, 128, 31, 128)
28
- self.res_block2 = ResBlock(in_channels=128, out_channels=192) # (b, 192, 31, 32)
29
- self.res_block3 = ResBlock(in_channels=192, out_channels=256) # (b, 256, 31, 8)
30
-
31
- # pool block
32
- self.pool_block = nn.Sequential(
33
- nn.BatchNorm2d(num_features=256),
34
- nn.LeakyReLU(leaky_relu_slope, inplace=True),
35
- nn.MaxPool2d(kernel_size=(1, 4)), # (b, 256, 31, 2)
36
- nn.Dropout(p=0.2),
37
- )
38
-
39
- # maxpool layers (for auxiliary network inputs)
40
- # in = (b, 128, 31, 513) from conv_block, out = (b, 128, 31, 2)
41
- self.maxpool1 = nn.MaxPool2d(kernel_size=(1, 40))
42
- # in = (b, 128, 31, 128) from res_block1, out = (b, 128, 31, 2)
43
- self.maxpool2 = nn.MaxPool2d(kernel_size=(1, 20))
44
- # in = (b, 128, 31, 32) from res_block2, out = (b, 128, 31, 2)
45
- self.maxpool3 = nn.MaxPool2d(kernel_size=(1, 10))
46
-
47
- # in = (b, 640, 31, 2), out = (b, 256, 31, 2)
48
- self.detector_conv = nn.Sequential(
49
- nn.Conv2d(640, 256, 1, bias=False),
50
- nn.BatchNorm2d(256),
51
- nn.LeakyReLU(leaky_relu_slope, inplace=True),
52
- nn.Dropout(p=0.2),
53
- )
54
-
55
- # input: (b, 31, 512) - resized from (b, 256, 31, 2)
56
- self.bilstm_classifier = nn.LSTM(
57
- input_size=512, hidden_size=256,
58
- batch_first=True, bidirectional=True) # (b, 31, 512)
59
-
60
- # input: (b, 31, 512) - resized from (b, 256, 31, 2)
61
- self.bilstm_detector = nn.LSTM(
62
- input_size=512, hidden_size=256,
63
- batch_first=True, bidirectional=True) # (b, 31, 512)
64
-
65
- # input: (b * 31, 512)
66
- self.classifier = nn.Linear(in_features=512, out_features=self.num_class) # (b * 31, num_class)
67
-
68
- # input: (b * 31, 512)
69
- self.detector = nn.Linear(in_features=512, out_features=2) # (b * 31, 2) - binary classifier
70
-
71
- # initialize weights
72
- self.apply(self.init_weights)
73
-
74
- def get_feature_GAN(self, x):
75
- seq_len = x.shape[-2]
76
- x = x.float().transpose(-1, -2)
77
-
78
- convblock_out = self.conv_block(x)
79
-
80
- resblock1_out = self.res_block1(convblock_out)
81
- resblock2_out = self.res_block2(resblock1_out)
82
- resblock3_out = self.res_block3(resblock2_out)
83
- poolblock_out = self.pool_block[0](resblock3_out)
84
- poolblock_out = self.pool_block[1](poolblock_out)
85
-
86
- return poolblock_out.transpose(-1, -2)
87
-
88
- def get_feature(self, x):
89
- seq_len = x.shape[-2]
90
- x = x.float().transpose(-1, -2)
91
-
92
- convblock_out = self.conv_block(x)
93
-
94
- resblock1_out = self.res_block1(convblock_out)
95
- resblock2_out = self.res_block2(resblock1_out)
96
- resblock3_out = self.res_block3(resblock2_out)
97
- poolblock_out = self.pool_block[0](resblock3_out)
98
- poolblock_out = self.pool_block[1](poolblock_out)
99
-
100
- return self.pool_block[2](poolblock_out)
101
-
102
- def forward(self, x):
103
- """
104
- Returns:
105
- classification_prediction, detection_prediction
106
- sizes: (b, 31, 722), (b, 31, 2)
107
- """
108
- ###############################
109
- # forward pass for classifier #
110
- ###############################
111
- seq_len = x.shape[-1]
112
- x = x.float().transpose(-1, -2)
113
-
114
- convblock_out = self.conv_block(x)
115
-
116
- resblock1_out = self.res_block1(convblock_out)
117
- resblock2_out = self.res_block2(resblock1_out)
118
- resblock3_out = self.res_block3(resblock2_out)
119
-
120
-
121
- poolblock_out = self.pool_block[0](resblock3_out)
122
- poolblock_out = self.pool_block[1](poolblock_out)
123
- GAN_feature = poolblock_out.transpose(-1, -2)
124
- poolblock_out = self.pool_block[2](poolblock_out)
125
-
126
- # (b, 256, 31, 2) => (b, 31, 256, 2) => (b, 31, 512)
127
- classifier_out = poolblock_out.permute(0, 2, 1, 3).contiguous().view((-1, seq_len, 512))
128
- classifier_out, _ = self.bilstm_classifier(classifier_out) # ignore the hidden states
129
-
130
- classifier_out = classifier_out.contiguous().view((-1, 512)) # (b * 31, 512)
131
- classifier_out = self.classifier(classifier_out)
132
- classifier_out = classifier_out.view((-1, seq_len, self.num_class)) # (b, 31, num_class)
133
-
134
- # sizes: (b, 31, 722), (b, 31, 2)
135
- # classifier output consists of predicted pitch classes per frame
136
- # detector output consists of: (isvoice, notvoice) estimates per frame
137
- return torch.abs(classifier_out.squeeze()), GAN_feature, poolblock_out
138
-
139
- @staticmethod
140
- def init_weights(m):
141
- if isinstance(m, nn.Linear):
142
- nn.init.kaiming_uniform_(m.weight)
143
- if m.bias is not None:
144
- nn.init.constant_(m.bias, 0)
145
- elif isinstance(m, nn.Conv2d):
146
- nn.init.xavier_normal_(m.weight)
147
- elif isinstance(m, nn.LSTM) or isinstance(m, nn.LSTMCell):
148
- for p in m.parameters():
149
- if p.data is None:
150
- continue
151
-
152
- if len(p.shape) >= 2:
153
- nn.init.orthogonal_(p.data)
154
- else:
155
- nn.init.normal_(p.data)
156
-
157
-
158
- class ResBlock(nn.Module):
159
- def __init__(self, in_channels: int, out_channels: int, leaky_relu_slope=0.01):
160
- super().__init__()
161
- self.downsample = in_channels != out_channels
162
-
163
- # BN / LReLU / MaxPool layer before the conv layer - see Figure 1b in the paper
164
- self.pre_conv = nn.Sequential(
165
- nn.BatchNorm2d(num_features=in_channels),
166
- nn.LeakyReLU(leaky_relu_slope, inplace=True),
167
- nn.MaxPool2d(kernel_size=(1, 2)), # apply downsampling on the y axis only
168
- )
169
-
170
- # conv layers
171
- self.conv = nn.Sequential(
172
- nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
173
- kernel_size=3, padding=1, bias=False),
174
- nn.BatchNorm2d(out_channels),
175
- nn.LeakyReLU(leaky_relu_slope, inplace=True),
176
- nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
177
- )
178
-
179
- # 1 x 1 convolution layer to match the feature dimensions
180
- self.conv1by1 = None
181
- if self.downsample:
182
- self.conv1by1 = nn.Conv2d(in_channels, out_channels, 1, bias=False)
183
-
184
- def forward(self, x):
185
- x = self.pre_conv(x)
186
- if self.downsample:
187
- x = self.conv(x) + self.conv1by1(x)
188
- else:
189
- x = self.conv(x) + x
190
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Utils/PLBERT/util.py CHANGED
@@ -27,7 +27,7 @@ def load_plbert(log_dir):
27
  iters = [int(f.split('_')[-1].split('.')[0]) for f in ckpts if os.path.isfile(os.path.join(log_dir, f))]
28
  iters = sorted(iters)[-1]
29
 
30
- checkpoint = torch.load(log_dir + "/step_" + str(iters) + ".pth", map_location='cpu')
31
  state_dict = checkpoint['net']
32
  from collections import OrderedDict
33
  new_state_dict = OrderedDict()
 
27
  iters = [int(f.split('_')[-1].split('.')[0]) for f in ckpts if os.path.isfile(os.path.join(log_dir, f))]
28
  iters = sorted(iters)[-1]
29
 
30
+ checkpoint = torch.load(log_dir + "/step_" + str(iters) + ".pth", map_location='cpu', weights_only=True)
31
  state_dict = checkpoint['net']
32
  from collections import OrderedDict
33
  new_state_dict = OrderedDict()
Utils/text_utils.py CHANGED
@@ -35,72 +35,112 @@ class TextCleaner:
35
 
36
  # == Sentence Splitter
37
 
38
- alphabets = "([A-Za-z])"
39
- prefixes = "(Mr|St|Mrs|Ms|Dr)[.]"
40
- suffixes = "(Inc|Ltd|Jr|Sr|Co)"
41
- starters = "(Mr|Mrs|Ms|Dr|Prof|Capt|Cpt|Lt|He\s|She\s|It\s|They\s|Their\s|Our\s|We\s|But\s|However\s|That\s|This\s|Wherever)"
42
- acronyms = "([A-Z][.][A-Z][.](?:[A-Z][.])?)"
43
- websites = "[.](com|net|org|io|gov|edu|me)"
44
- digits = "([0-9])"
45
- multiple_dots = r'\.{2,}'
46
-
47
 
48
- def split_into_sentences(text):
49
  """
50
- Split the text into sentences.
51
-
52
- If the text contains substrings "<prd>" or "<stop>", they would lead
53
- to incorrect splitting because they are used as markers for splitting.
54
-
55
- :param text: text to be split into sentences
56
- :type text: str
57
 
58
- :return: list of sentences
59
- :rtype: list[str]
 
60
 
61
- https://stackoverflow.com/questions/4576077/how-can-i-split-a-text-into-sentences
 
62
  """
63
- text = " " + text + " "
64
- text = text.replace("\n", " ")
65
- text = re.sub(prefixes, "\\1<prd>", text)
66
- text = re.sub(websites, "<prd>\\1", text)
67
- text = re.sub(digits + "[.]" + digits, "\\1<prd>\\2", text)
68
- text = re.sub(multiple_dots, lambda match: "<prd>" *
69
- len(match.group(0)) + "<stop>", text)
70
- if "Ph.D" in text:
71
- text = text.replace("Ph.D.", "Ph<prd>D<prd>")
72
- text = re.sub("\s" + alphabets + "[.] ", " \\1<prd> ", text)
73
- text = re.sub(acronyms+" "+starters, "\\1<stop> \\2", text)
74
- text = re.sub(alphabets + "[.]" + alphabets + "[.]" +
75
- alphabets + "[.]", "\\1<prd>\\2<prd>\\3<prd>", text)
76
- text = re.sub(alphabets + "[.]" + alphabets +
77
- "[.]", "\\1<prd>\\2<prd>", text)
78
- text = re.sub(" "+suffixes+"[.] "+starters, " \\1<stop> \\2", text)
79
- text = re.sub(" "+suffixes+"[.]", " \\1<prd>", text)
80
- text = re.sub(" " + alphabets + "[.]", " \\1<prd>", text)
81
- if "”" in text:
82
- text = text.replace(".”", "”.")
83
- if "\"" in text:
84
- text = text.replace(".\"", "\".")
85
- if "!" in text:
86
- text = text.replace("!\"", "\"!")
87
- if "?" in text:
88
- text = text.replace("?\"", "\"?")
89
- text = text.replace(".", ".<stop>")
90
- text = text.replace("?", "?<stop>")
91
- text = text.replace("!", "!<stop>")
92
- text = text.replace("<prd>", ".")
93
- sentences = text.split("<stop>")
94
- sentences = [s.strip() for s in sentences]
95
-
96
- # Split Very long sentences >500 phoneme - StyleTTS2 crashes
97
- # -- even 400 phonemes sometimes OOM in cuda:4
98
- sentences = [
99
- sub_sent+' ' for s in sentences for sub_sent in textwrap.wrap(s, 200, break_long_words=0)]
100
-
101
- # if sentences and not sentences[-1]:
102
- # sentences = sentences[:-1]
103
- return sentences
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
 
106
  def store_ssml(text=None,
 
35
 
36
  # == Sentence Splitter
37
 
38
+ import re
 
 
 
 
 
 
 
 
39
 
40
+ def split_into_sentences(text, max_len=200):
41
  """
42
+ Splits a string into chunks of max_len characters, ensuring each chunk
43
+ terminates with a period if it was split mid-sentence. Prioritizes
44
+ splitting at natural sentence breaks and avoids splitting words.
 
 
 
 
45
 
46
+ Args:
47
+ text (str): The input string.
48
+ max_len (int): The maximum desired length for each chunk.
49
 
50
+ Returns:
51
+ list: A list of strings, where each string is a sentence chunk.
52
  """
53
+ if not text:
54
+ return []
55
+
56
+ # Regex to split text into potential sentence candidates.
57
+ # We still use the lookbehind to keep the punctuation with the sentence.
58
+ sentence_candidates = [s.strip() for s in re.split(r'(?<=[.!?])\s+', text) if s.strip()]
59
+
60
+ # Handle the last part if it doesn't end with a punctuation (e.g., a phrase or incomplete sentence)
61
+ if text and not text.strip().endswith(('.', '!', '?')) and text.strip() not in sentence_candidates:
62
+ # Check if the last candidate already contains the end of the text.
63
+ # This is a heuristic, as re.split can sometimes be tricky with trailing non-matches.
64
+ if not (sentence_candidates and text.strip().endswith(sentence_candidates[-1])):
65
+ remaining_text = text.strip()
66
+ if sentence_candidates:
67
+ # Find the part of the text that wasn't included in sentence_candidates
68
+ last_candidate_start_index = text.rfind(sentence_candidates[-1])
69
+ if last_candidate_start_index != -1:
70
+ remaining_text = text[last_candidate_start_index + len(sentence_candidates[-1]):].strip()
71
+
72
+ if remaining_text and not remaining_text.endswith(('.', '!', '?')):
73
+ sentence_candidates.append(remaining_text)
74
+
75
+
76
+ chunks = []
77
+ current_chunk_elements = [] # Stores individual sentences that form the current chunk
78
+ current_chunk_length = 0
79
+
80
+ for sentence in sentence_candidates:
81
+ # Calculate the length this sentence would add to the current chunk.
82
+ # Add 1 for the space that will separate sentences within a chunk, if needed.
83
+ potential_addition_length = len(sentence) + (1 if current_chunk_elements else 0)
84
+
85
+ # Check if adding this sentence would exceed the maximum length
86
+ if current_chunk_length + potential_addition_length > max_len:
87
+ # First, finalize the current chunk
88
+ if current_chunk_elements:
89
+ final_chunk = " ".join(current_chunk_elements).strip()
90
+ chunks.append(final_chunk)
91
+
92
+ # Reset for the new chunk and handle the current `sentence`.
93
+ # This `sentence` itself might be longer than `max_len`.
94
+ remaining_sentence = sentence
95
+ while len(remaining_sentence) > max_len:
96
+ # Prioritize splitting at a period or a space to avoid splitting words.
97
+ # Search backwards from `max_len - 1` to find the last valid break point.
98
+ split_point = -1
99
+ search_area = remaining_sentence[:max_len]
100
+
101
+ # Option 1: Find the last period in the search area
102
+ last_period_idx = search_area.rfind('.')
103
+ if last_period_idx != -1:
104
+ split_point = last_period_idx
105
+
106
+ # Option 2: If no period, find the last space (to avoid splitting words)
107
+ if split_point == -1:
108
+ last_space_idx = search_area.rfind(' ')
109
+ if last_space_idx != -1:
110
+ split_point = last_space_idx
111
+
112
+ if split_point != -1:
113
+ # If a period or space is found, split there.
114
+ # If it's a period, include it. If it's a space, don't include the space
115
+ # but ensure the chunk ends with a period if it didn't already.
116
+ chunk_to_add = remaining_sentence[:split_point + (1 if remaining_sentence[split_point] == '.' else 0)].strip()
117
+ if not chunk_to_add.endswith('.'):
118
+ chunk_to_add += '.' # Ensure period termination
119
+
120
+ chunks.append(chunk_to_add)
121
+ remaining_sentence = remaining_sentence[split_point + 1:].lstrip() # Update remaining
122
+ else:
123
+ # No natural break (period or space) within max_len.
124
+ # This happens for extremely long words or sequences without spaces.
125
+ # In this rare case, we force split at max_len and append a period.
126
+ chunks.append(remaining_sentence[:max_len].strip() + '.')
127
+ remaining_sentence = remaining_sentence[max_len:].lstrip() # Update remaining
128
+
129
+ # The `remaining_sentence` (now guaranteed to be `<= max_len`)
130
+ # becomes the start of the new `current_chunk`.
131
+ current_chunk_elements = [remaining_sentence]
132
+ current_chunk_length = len(remaining_sentence)
133
+
134
+ else:
135
+ # The current sentence fits within the `max_len`, so add it.
136
+ current_chunk_elements.append(sentence)
137
+ current_chunk_length += potential_addition_length
138
+
139
+ # After iterating through all sentences, add any remaining elements
140
+ # in `current_chunk_elements` as the final chunk.
141
+ if current_chunk_elements:
142
+ chunks.append(" ".join(current_chunk_elements).strip())
143
+ return chunks
144
 
145
 
146
  def store_ssml(text=None,
api.py CHANGED
@@ -113,21 +113,11 @@ def _resize(image, width=None, height=None, inter=cv2.INTER_AREA):
113
 
114
  def overlay(x, soundscape=None):
115
  if soundscape is not None:
116
- # AudioGen sound is suffice to be ~10s long
117
  background = sound_generator.generate(soundscape,
118
- # sound duration = TTS dur
119
- duration=len(x)/16000 + .74,
120
- ).detach().cpu().numpy() # bs, 11400 @.74s
121
-
122
- # len_soundscape = len(background)
123
-
124
- # fading = .5 + .5 * np.tanh(4*(np.linspace(10, -10, len_soundscape) + 9.4)) # fade heaviside 1,1,1,1,...,0
125
-
126
- # x = np.concatenate([fading * background, x], 0) # blend TTS with AudioGen
127
- # background /= np.abs(background).max() + 1e-7 # amplify speech to full [-1,1]
128
- # background will be longer by xtra .74s
129
- x = .47 * x + .46 * background[:len(x)]
130
- return x # TTS / AudioGen @ 16kHz
131
 
132
 
133
  def tts_multi_sentence(precomputed_style_vector=None,
@@ -176,7 +166,7 @@ def tts_multi_sentence(precomputed_style_vector=None,
176
 
177
  # volume
178
 
179
- x /= np.abs(x).max() + 1e-7 # amplify speech to full [-1,1]
180
 
181
  return overlay(x, soundscape=soundscape)
182
 
@@ -211,7 +201,7 @@ def serve_wav():
211
  _shorten(r.get('native')[0]),
212
  affective=r.get('affective')[0],
213
  voice=r.get('voice')[0],
214
- speed=float(r.get('speed')[0]), # For Non-English MMS TTS
215
  soundscape=r.get('soundscape')[0] if r.get(
216
  'soundscape') is not None else None,
217
  )
 
113
 
114
  def overlay(x, soundscape=None):
115
  if soundscape is not None:
 
116
  background = sound_generator.generate(soundscape,
117
+ duration=len(x)/16000 + .74, # duration seconds
118
+ ).detach().cpu().numpy()
119
+ x = .6 * x + .4 * background[:len(x)]
120
+ return x
 
 
 
 
 
 
 
 
 
121
 
122
 
123
  def tts_multi_sentence(precomputed_style_vector=None,
 
166
 
167
  # volume
168
 
169
+ x /= 1.12 * np.abs(x).max() + 1e-7 # amplify speech to full [-1,1] No amplification / normalisation on soundscapes
170
 
171
  return overlay(x, soundscape=soundscape)
172
 
 
201
  _shorten(r.get('native')[0]),
202
  affective=r.get('affective')[0],
203
  voice=r.get('voice')[0],
204
+ speed=None, # obsolete due to oscillating MMS TTS VITS duration per language
205
  soundscape=r.get('soundscape')[0] if r.get(
206
  'soundscape') is not None else None,
207
  )
audiobook.py CHANGED
@@ -8,20 +8,23 @@ import subprocess
8
  import numpy as np
9
  import soundfile
10
  import docx # package = python-docx
11
- import audresample
12
  import urllib
13
  from pathlib import Path
14
  from moviepy.editor import *
15
 
16
- FS = 24000
17
  ROOT_DIR = './tts_audiobooks/voices/'
18
  Path(ROOT_DIR).mkdir(parents=True,
19
  exist_ok=True)
20
  voices = [
21
- # 'en_US/vctk_low#p228', # https://huggingface.co/dkounadis/artificial-styletts2/discussions/1#67854dcbd3e6beb1a78f7f20
22
  # 'af_ZA_google-nwu_0184', # https://huggingface.co/dkounadis/artificial-styletts2/discussions/1#6783e3b00e7d90facec060c6
23
- 'en_US/vctk_low#p326', # Native voice
24
- # 'jv_ID_google-gmu_06207',
 
 
 
25
  ] # select any voice from - https://audeering.github.io/shift/
26
 
27
  #urllib.request.urlretrieve("https://github.com/audeering/shift/raw/refs/heads/main/assets/INCLUSION_IN_MUSEUMS_audiobook.docx", "audiobook_TTS.docx")
@@ -54,7 +57,7 @@ for vox in voices:
54
 
55
  total = []
56
  chapter = []
57
-
58
  final_paragraph_for_saving_last_chapter = d.paragraphs[-1]
59
  final_paragraph_for_saving_last_chapter.text = 'CHAPTER: END OF AUDIOBOOK'
60
 
@@ -69,12 +72,6 @@ for vox in voices:
69
  if t.startswith('CHAPTER:'):
70
 
71
 
72
-
73
- # silence for end chapter
74
-
75
- chapter.append(np.zeros(int(.24 * FS),
76
- dtype=np.float32))
77
-
78
  # chapter.wav
79
 
80
  audio = np.concatenate(chapter)
@@ -116,17 +113,14 @@ for vox in voices:
116
  [
117
  "python",
118
  "tts.py",
119
- "--text",
120
- "_tmp.txt", #t, # paragraph text tts and append to voice_chapter.wav
121
- # "--affect",
122
- #'--image', '_tmp_banner.png',
123
- # '--scene', 'calm sounds of castle',
124
  '--voice', vox,
125
  '--out_file', '_tmp' # save on _tmp load audio and concat to total
126
  ])
127
 
128
- audio, _fs = soundfile.read('out/_tmp.wav')
129
- audio = audresample.resample(audio.astype(np.float32), 24000, 16000)[0, :]
130
  # print('CHAPTER\n\n\n\n____', audio.shape,'____\n')
131
  chapter.append(audio)
132
 
@@ -140,9 +134,6 @@ for vox in voices:
140
 
141
  if not last_paragraph_was_silence: # skip multiple empty pargraphs - silence is added only once
142
 
143
- chapter.append(np.zeros(int(.1 * FS),
144
- dtype=np.float32))
145
-
146
  last_paragraph_was_silence = True
147
 
148
  # save full .wav audiobook - for this voice
@@ -157,11 +148,7 @@ for vox in voices:
157
 
158
  # pic TTS voice
159
 
160
- voice_pic = np.zeros((574, 1024, 3), dtype=np.uint8)
161
-
162
- shift_logo = cv2.imread('assets/shift_banner.png')
163
-
164
- voice_pic[:100, :400, :] = shift_logo[:100, :400, :]
165
 
166
  # voice name
167
  # frame_tts = np.zeros((104, 1920, 3), dtype=np.uint8)
 
8
  import numpy as np
9
  import soundfile
10
  import docx # package = python-docx
11
+
12
  import urllib
13
  from pathlib import Path
14
  from moviepy.editor import *
15
 
16
+ FS = 16000
17
  ROOT_DIR = './tts_audiobooks/voices/'
18
  Path(ROOT_DIR).mkdir(parents=True,
19
  exist_ok=True)
20
  voices = [
21
+ # 'en_US/vctk_low#p228', # https://huggingface.co/dkounadis/artificial-styletts2/discussions/1#67854dcbd3e6beb1a78f7f20
22
  # 'af_ZA_google-nwu_0184', # https://huggingface.co/dkounadis/artificial-styletts2/discussions/1#6783e3b00e7d90facec060c6
23
+ # 'en_US/vctk_low#p326',
24
+ #'en_US/vctk_low#p292',
25
+ # 'jv_ID_google-gmu_06207',
26
+ # 'fr_FR_m-ailabs_bernard'
27
+ 'en_US_m-ailabs_mary_ann'
28
  ] # select any voice from - https://audeering.github.io/shift/
29
 
30
  #urllib.request.urlretrieve("https://github.com/audeering/shift/raw/refs/heads/main/assets/INCLUSION_IN_MUSEUMS_audiobook.docx", "audiobook_TTS.docx")
 
57
 
58
  total = []
59
  chapter = []
60
+
61
  final_paragraph_for_saving_last_chapter = d.paragraphs[-1]
62
  final_paragraph_for_saving_last_chapter.text = 'CHAPTER: END OF AUDIOBOOK'
63
 
 
72
  if t.startswith('CHAPTER:'):
73
 
74
 
 
 
 
 
 
 
75
  # chapter.wav
76
 
77
  audio = np.concatenate(chapter)
 
113
  [
114
  "python",
115
  "tts.py",
116
+ "--text",
117
+ "_tmp.txt",
118
+ '--soundscape', 'birds formig' if chapter_counter < 2 else '',
 
 
119
  '--voice', vox,
120
  '--out_file', '_tmp' # save on _tmp load audio and concat to total
121
  ])
122
 
123
+ audio, _fs = soundfile.read('out/_tmp.wav') # already 16 kHz
 
124
  # print('CHAPTER\n\n\n\n____', audio.shape,'____\n')
125
  chapter.append(audio)
126
 
 
134
 
135
  if not last_paragraph_was_silence: # skip multiple empty pargraphs - silence is added only once
136
 
 
 
 
137
  last_paragraph_was_silence = True
138
 
139
  # save full .wav audiobook - for this voice
 
148
 
149
  # pic TTS voice
150
 
151
+ voice_pic = np.zeros((1920, 1080, 3), dtype=np.uint8)
 
 
 
 
152
 
153
  # voice name
154
  # frame_tts = np.zeros((104, 1920, 3), dtype=np.uint8)
demo.py CHANGED
@@ -1,68 +1,40 @@
1
  import numpy as np
2
  import soundfile
3
- import msinference
4
  from audiocraft.builders import AudioGen
5
 
6
- def tts_entry(text='A quick brown fox jumps over the lazy dog. Sweet dreams are made of this, I traveled the world and the seven seas.',
7
- voice='en_US/vctk_low#p326', #'en_US/vctk_low#p276', # 'deu', 'af_ZA_google-nwu_1919', 'serbian', 'isl',
8
- speed=1.14,
9
- affect = True, # False = higher clarity voice
10
- soundscape = 'dogs barg in dungeons n dragons'
11
- ):
12
- '''16 KHz
13
-
14
- voice : 'en_US/vctk_low#p276' # Native English voices -> https://audeering.github.io/shift/
15
-
16
- or
17
-
18
- voice : 'af_ZA_google-nwu_1919' # Non-Native English voices -> https://huggingface.co/dkounadis/artificial-styletts2/discussions/1#6783e3b00e7d90facec060c6
19
 
20
- or
21
-
22
- voice : 'deu' # Foreign languages -> https://huggingface.co/dkounadis/artificial-styletts2/blob/main/Utils/all_langs.csv
 
 
 
23
  '''
24
 
25
- # StyleTTS2 - find voice from folder
26
-
27
  if ('en_US/' in voice) or ('en_UK/' in voice):
28
- a = '' if affect else '_v2'
29
- style_vector = msinference.compute_style('assets/wavs/style_vector' + a + '/' + voice.replace(
30
  '/', '_').replace('#', '_').replace(
31
  'cmu-arctic', 'cmu_arctic').replace(
32
  '_low', '') + '.wav')
33
 
34
- x = msinference.inference(text,
35
- style_vector)
36
-
37
- # find voice from mimic-3 folder with styles
38
-
39
  elif '_' in voice:
40
  style_vector = msinference.compute_style('assets/wavs/mimic3_foreign_4x/' + voice.replace(
41
  '/', '_').replace('#', '_').replace(
42
  'cmu-arctic', 'cmu_arctic').replace(
43
  '_low', '') + '.wav')
44
 
45
- x = msinference.inference(text,
46
- style_vector)
47
-
48
-
49
- # Fallback - MMS TTS - Non-English voice / langs
50
-
51
  else:
52
- x = msinference.foreign(text=text,
53
- lang=voice,
54
- speed=speed) # volume normalis.
55
-
56
- # volume
57
-
58
- x /= np.abs(x).max() + 1e-7 # amplify speech to full [-1,1]
59
-
60
  if soundscape is not None:
61
  sound_gen = AudioGen().to('cuda:0').eval()
62
- background = sound_gen.generate(soundscape,
63
- duration=len(x)/16000 + .74, # sound duration in seconds
64
  ).detach().cpu().numpy()
65
- x = .5 * x + .47 * background[:len(x)]
66
  return x
67
 
 
68
  soundfile.write(f'demo.wav', tts_entry(), 16000)
 
1
  import numpy as np
2
  import soundfile
3
+ import msinference # api.py has also split into sentences for OOM
4
  from audiocraft.builders import AudioGen
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ def tts_entry(text='A quick brown fox jumps over the lazy dog. Sweet dreams are made of this, I traveled the world and the seven seas.',
8
+ voice='en_US/m-ailabs_low#mary_ann', #fr_FR_m-ailabs_bernard', #'deu', #'serbian', #'romanian', #'deu', #'en_US/vctk_low#p326', #'en_US/vctk_low#p276', # 'deu', 'af_ZA_google-nwu_1919', 'serbian', 'isl',
9
+ soundscape = 'birds river'):
10
+ '''voice = 'en_US/vctk_low#p276' # Native English Voices > https://audeering.github.io/shift/
11
+ = 'af_ZA_google-nwu_1919' # Non Native English Voices > https://huggingface.co/dkounadis/artificial-styletts2/discussions/1#6783e3b00e7d90facec060c6
12
+ = 'deu' # Other languages > https://huggingface.co/dkounadis/artificial-styletts2/blob/main/Utils/all_langs.csv
13
  '''
14
 
 
 
15
  if ('en_US/' in voice) or ('en_UK/' in voice):
16
+ style_vector = msinference.compute_style('assets/wavs/style_vector/' + voice.replace(
 
17
  '/', '_').replace('#', '_').replace(
18
  'cmu-arctic', 'cmu_arctic').replace(
19
  '_low', '') + '.wav')
20
 
21
+ x = msinference.inference(text, style_vector)
 
 
 
 
22
  elif '_' in voice:
23
  style_vector = msinference.compute_style('assets/wavs/mimic3_foreign_4x/' + voice.replace(
24
  '/', '_').replace('#', '_').replace(
25
  'cmu-arctic', 'cmu_arctic').replace(
26
  '_low', '') + '.wav')
27
 
28
+ x = msinference.inference(text, style_vector)
 
 
 
 
 
29
  else:
30
+ x = msinference.foreign(text=text, lang=voice)
31
+ x /= 1.02 * np.abs(x).max() + 1e-7 # volume amplify full [-1,1]
 
 
 
 
 
 
32
  if soundscape is not None:
33
  sound_gen = AudioGen().to('cuda:0').eval()
34
+ background = sound_gen.generate(soundscape, duration=len(x)/16000 + .74, # sound duration seconds
 
35
  ).detach().cpu().numpy()
36
+ x = .6 * x + .4 * background[:len(x)]
37
  return x
38
 
39
+
40
  soundfile.write(f'demo.wav', tts_entry(), 16000)
live_demo.py CHANGED
@@ -16,7 +16,6 @@ def send_to_server(args):
16
  'affective': True,
17
  'image': None,
18
  'video': None,
19
- 'speed': 1.14,
20
  'native': None,
21
  }
22
 
@@ -24,16 +23,15 @@ def send_to_server(args):
24
 
25
 
26
  args = SimpleNamespace()
27
- args.voice = 'fr_FR_m-ailabs_bernard' # 'en_US/m-ailabs_low#judy_bieber'
28
- args.speed = 1.14
29
  os.system('cls' if os.name == 'nt' else 'clear')
30
  while True:
31
  _str = input("\n\n\n\nDescribe Any Sound: \n\n\n\n")
32
-
33
-
34
- _str += 'Lorem ipsum dolor sit amet, consetetur elixir sed diam nonumy eirmod tempor invidunt labore et dolore magna aliquyam erat, sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. Soutet clita kasd gubergren.'
35
-
36
  args.soundscape = _str
 
 
 
37
  args.text = '_tmp.txt' # input -> .txt (implementation thought for audiobooks in API)
38
 
39
  with open(args.text, 'w') as f:
 
16
  'affective': True,
17
  'image': None,
18
  'video': None,
 
19
  'native': None,
20
  }
21
 
 
23
 
24
 
25
  args = SimpleNamespace()
26
+ args.voice = 'en_US/m-ailabs_low#judy_bieber'
 
27
  os.system('cls' if os.name == 'nt' else 'clear')
28
  while True:
29
  _str = input("\n\n\n\nDescribe Any Sound: \n\n\n\n")
30
+
 
 
 
31
  args.soundscape = _str
32
+
33
+ _str += 'A quick brown fox jumps over the lazy dog. Sweet dreams are made of this, I traveled the world and the seven seas.'
34
+
35
  args.text = '_tmp.txt' # input -> .txt (implementation thought for audiobooks in API)
36
 
37
  with open(args.text, 'w') as f:
models.py CHANGED
@@ -1,93 +1,98 @@
1
- #coding:utf-8
2
 
3
  import os
4
- import math
5
  import torch
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
- from torch.nn.utils import weight_norm, spectral_norm
 
9
  # from Utils.ASR.models import ASRCNN
10
- from Utils.JDC.model import JDCNet
11
  from Modules.hifigan import _tile, AdainResBlk1d
12
- import yaml
13
 
 
14
 
15
- class LearnedDownSample(nn.Module):
16
- def __init__(self, layer_type, dim_in):
 
 
 
 
 
 
17
  super().__init__()
18
- self.layer_type = layer_type
19
-
20
- if self.layer_type == 'none':
21
- raise ValueError
22
- # self.conv = nn.Identity()
23
- elif self.layer_type == 'timepreserve':
24
- raise ValueError
25
- # self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, padding=(1, 0)))
26
- elif self.layer_type == 'half':
27
- self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, padding=1))
28
- else:
29
- raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
 
 
 
 
 
 
 
 
 
30
 
31
  def forward(self, x):
32
- return self.conv(x)
 
 
 
 
 
 
 
 
 
 
 
33
 
34
 
35
- class DownSample(nn.Module):
36
- def __init__(self, layer_type):
37
  super().__init__()
38
- self.layer_type = layer_type
39
-
 
40
  def forward(self, x):
41
- if self.layer_type == 'none':
42
- return x
43
- elif self.layer_type == 'timepreserve':
44
- return F.avg_pool2d(x, (2, 1))
45
- elif self.layer_type == 'half':
46
- if x.shape[-1] % 2 != 0:
47
- x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
48
- return F.avg_pool2d(x, 2)
49
- else:
50
- raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
51
-
52
-
53
-
54
 
55
 
56
  class ResBlk(nn.Module):
57
- def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
58
- normalize=False, downsample='none'):
59
  super().__init__()
60
- self.actv = actv
61
- self.normalize = normalize
62
- self.downsample = DownSample(downsample)
63
- self.downsample_res = LearnedDownSample(downsample, dim_in)
64
  self.learned_sc = dim_in != dim_out
65
- self._build_weights(dim_in, dim_out)
66
-
67
- def _build_weights(self, dim_in, dim_out):
68
  self.conv1 = spectral_norm(nn.Conv2d(dim_in, dim_in, 3, 1, 1))
69
  self.conv2 = spectral_norm(nn.Conv2d(dim_in, dim_out, 3, 1, 1))
70
- if self.normalize:
71
- self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
72
- self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
73
  if self.learned_sc:
74
- self.conv1x1 = spectral_norm(nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False))
 
75
 
76
  def _shortcut(self, x):
77
  if self.learned_sc:
78
  x = self.conv1x1(x)
79
- if self.downsample:
80
- x = self.downsample(x)
81
- return x
82
 
83
  def _residual(self, x):
84
- if self.normalize:
85
- x = self.norm1(x)
86
  x = self.actv(x)
87
  x = self.conv1(x)
88
  x = self.downsample_res(x)
89
- if self.normalize:
90
- x = self.norm2(x)
91
  x = self.actv(x)
92
  x = self.conv2(x)
93
  return x
@@ -101,113 +106,41 @@ class StyleEncoder(nn.Module):
101
 
102
  # for both acoustic & prosodic ref_s/p
103
 
104
- def __init__(self, dim_in=48, style_dim=48, max_conv_dim=384):
 
 
 
105
  super().__init__()
106
- blocks = []
107
- blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
108
-
109
- repeat_num = 4
110
- for _ in range(repeat_num):
111
- dim_out = min(dim_in*2, max_conv_dim)
112
- blocks += [ResBlk(dim_in, dim_out, downsample='half')]
113
  dim_in = dim_out
114
-
115
- blocks += [nn.LeakyReLU(0.2)]
116
- blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
117
-
118
- # blocks += [nn.AdaptiveAvgPool2d(1)] # THIS AVERAGES THE TIME-FRAMES OF SPEAKER STYLE
119
-
120
- blocks += [nn.LeakyReLU(0.2)]
121
  self.shared = nn.Sequential(*blocks)
122
-
123
  self.unshared = nn.Linear(dim_out, style_dim)
124
 
125
  def forward(self, x):
126
- h = self.shared(x) # [bs, 512, 1, 11]
127
-
128
- h = h.mean(3, keepdims=True) # UN COMMENT FOR TIME INVARIANT GLOBAL SPEAKER STYLE
129
- # h = .7 * h + .25 * h.mean(3, keepdims=True)
130
- h = h.transpose(1, 3)
131
- s = self.unshared(h)
132
-
133
-
134
  return s
135
 
136
 
137
  class LinearNorm(torch.nn.Module):
138
- def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
139
- super(LinearNorm, self).__init__()
140
  self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
141
 
142
- torch.nn.init.xavier_uniform_(
143
- self.linear_layer.weight,
144
- gain=torch.nn.init.calculate_gain(w_init_gain))
145
-
146
  def forward(self, x):
147
  return self.linear_layer(x)
148
 
149
 
150
- class ResBlk1d(nn.Module):
151
- def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
152
- normalize=False, downsample='none', dropout_p=0.2):
153
- super().__init__()
154
- self.actv = actv
155
- self.normalize = normalize
156
- self.downsample_type = downsample
157
- self.learned_sc = dim_in != dim_out
158
- self._build_weights(dim_in, dim_out)
159
- self.dropout_p = dropout_p
160
-
161
- if self.downsample_type == 'none':
162
- self.pool = nn.Identity()
163
- else:
164
- self.pool = weight_norm(nn.Conv1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1))
165
-
166
- def _build_weights(self, dim_in, dim_out):
167
- self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_in, 3, 1, 1))
168
- self.conv2 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
169
- if self.normalize:
170
- self.norm1 = nn.InstanceNorm1d(dim_in, affine=True)
171
- self.norm2 = nn.InstanceNorm1d(dim_in, affine=True)
172
- if self.learned_sc:
173
- self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
174
-
175
- def downsample(self, x):
176
- if self.downsample_type == 'none':
177
- return x
178
- else:
179
- if x.shape[-1] % 2 != 0:
180
- x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
181
- return F.avg_pool1d(x, 2)
182
-
183
- def _shortcut(self, x):
184
- if self.learned_sc:
185
- x = self.conv1x1(x)
186
- x = self.downsample(x)
187
- return x
188
-
189
- def _residual(self, x):
190
- if self.normalize:
191
- x = self.norm1(x)
192
- x = self.actv(x)
193
- x = F.dropout(x, p=self.dropout_p, training=self.training)
194
-
195
- x = self.conv1(x)
196
- x = self.pool(x)
197
- if self.normalize:
198
- x = self.norm2(x)
199
-
200
- x = self.actv(x)
201
- x = F.dropout(x, p=self.dropout_p, training=self.training)
202
-
203
- x = self.conv2(x)
204
- return x
205
-
206
- def forward(self, x):
207
- x = self._shortcut(x) + self._residual(x)
208
- return x / math.sqrt(2) # unit variance
209
-
210
-
211
  class LayerNorm(nn.Module):
212
  def __init__(self, channels, eps=1e-5):
213
  super().__init__()
@@ -222,168 +155,151 @@ class LayerNorm(nn.Module):
222
  x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
223
  return x.transpose(1, -1)
224
 
 
225
  class TextEncoder(nn.Module):
226
- def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
227
  super().__init__()
228
  self.embedding = nn.Embedding(n_symbols, channels)
229
-
230
  padding = (kernel_size - 1) // 2
231
  self.cnn = nn.ModuleList()
232
  for _ in range(depth):
233
  self.cnn.append(nn.Sequential(
234
  weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
235
  LayerNorm(channels),
236
- actv,
237
- nn.Dropout(0.2),
238
- ))
239
- # self.cnn = nn.Sequential(*self.cnn)
240
-
241
- self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True)
242
 
243
- def forward(self, x, input_lengths):
244
  x = self.embedding(x) # [B, T, emb]
245
- x = x.transpose(1, 2) # [B, emb, T]
246
  for c in self.cnn:
247
- x = c(x)
248
- x = x.transpose(1, 2) # [B, T, chn]
249
- input_lengths = input_lengths.cpu().numpy()
250
- x = nn.utils.rnn.pack_padded_sequence(
251
- x, input_lengths,
252
- batch_first=True,
253
- enforce_sorted=False)
254
- self.lstm.flatten_parameters()
255
  x, _ = self.lstm(x)
256
- x, _ = nn.utils.rnn.pad_packed_sequence(
257
- x, batch_first=True)
258
- x = x.transpose(-1, -2)
259
  return x
260
-
 
261
  class AdaLayerNorm(nn.Module):
262
-
263
- # only instantianted in DurationPredictor()
264
-
265
  def __init__(self, style_dim, channels=None, eps=1e-5):
266
  super().__init__()
267
  self.eps = eps
268
  self.fc = nn.Linear(style_dim, 1024)
269
 
270
  def forward(self, x, s):
271
- h = self.fc(s.transpose(1, 2)) # has to be transposed due to interpolate needing the last dim to be frames
272
  gamma = h[:, :, :512]
273
  beta = h[:, :, 512:1024]
274
-
275
- x = F.layer_norm(x.transpose(1, 2), (512, ), eps=self.eps)
276
  x = (1 + gamma) * x + beta
277
  return x # [1, 75, 512]
278
 
 
279
  class ProsodyPredictor(nn.Module):
280
 
281
- def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
282
- super().__init__()
283
-
284
- self.text_encoder = DurationEncoder(sty_dim=style_dim,
285
- d_model=d_hid,
286
- nlayers=nlayers,
287
- dropout=dropout)
288
 
289
- self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
 
 
 
 
290
  self.duration_proj = LinearNorm(d_hid, max_dur)
291
-
292
- self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
293
- self.F0 = nn.ModuleList()
294
- self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
295
- self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
296
- self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
297
-
298
- self.N = nn.ModuleList()
299
- self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
300
- self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
301
- self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
302
-
303
  self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
304
  self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
305
-
306
  def F0Ntrain(self, x, s):
307
 
308
- x, _ = self.shared(x.transpose(1, 2)) # [bs, time, ch] LSTM
309
 
310
  x = x.transpose(1, 2) # [bs, ch, time]
311
-
312
-
313
  F0 = x
314
-
315
  for block in self.F0:
316
  # print(f'LOOP {F0.shape=} {s.shape=}\n')
317
  # )N F0.shape=torch.Size([1, 512, 147]) s.shape=torch.Size([1, 128])
318
- F0 = block(F0, s) # This is an AdainResBlk1d expects conv1d dimensions
 
319
  F0 = self.F0_proj(F0)
320
-
321
  N = x
322
-
323
  for block in self.N:
324
  N = block(N, s)
325
  N = self.N_proj(N)
326
-
327
  return F0, N
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
 
329
  class DurationEncoder(nn.Module):
330
 
331
- def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
332
  super().__init__()
333
  self.lstms = nn.ModuleList()
334
  for _ in range(nlayers):
335
- self.lstms.append(nn.LSTM(d_model + sty_dim,
336
- d_model // 2,
337
- num_layers=1,
338
- batch_first=True,
339
- bidirectional=True,
340
- dropout=dropout))
341
  self.lstms.append(AdaLayerNorm(sty_dim, d_model))
342
-
343
-
344
- self.dropout = dropout
345
- self.d_model = d_model
346
- self.sty_dim = sty_dim
347
 
348
- def forward(self, x, style, text_lengths):
349
 
350
- # style = style[:, :, 0, :].transpose(2, 1) # [bs, 128, 11]
351
-
352
- style = _tile(style, length=x.shape[2]) # replicate style vector to duration of txt - F.interpolate or cyclic/tile
353
 
354
- x = torch.cat([x, style], axis=1) # [bs, 640, 75]
 
 
 
355
 
356
- input_lengths = text_lengths.cpu().numpy()
357
-
358
  for block in self.lstms:
359
  if isinstance(block, AdaLayerNorm):
360
- # not LST enters here
361
- x = block(x, style) # [bs, 75, 512]
362
- x = torch.cat([x.transpose(1, 2), style], axis=1) # [bs, 512, 75]
363
 
364
  else:
365
- # print(f'{x.shape=} ENTER LSTM') # [bs, 640, 75] LSTM reduce ch 640 -> 512
366
- x = x.transpose(-1, -2)
367
- x = nn.utils.rnn.pack_padded_sequence(
368
- x, input_lengths, batch_first=True, enforce_sorted=False)
369
- block.flatten_parameters()
370
- x, _ = block(x)
371
- x, _ = nn.utils.rnn.pad_packed_sequence(
372
- x, batch_first=True)
373
- x = F.dropout(x, p=self.dropout, training=self.training)
374
- x = x.transpose(-1, -2)
375
- return x.transpose(-1, -2)
376
 
377
-
378
-
379
-
380
- def load_F0_models(path):
381
- # load F0 model
382
-
383
- F0_model = JDCNet(num_class=1, seq_len=192)
384
- path = path.replace('.t7', '.pth')
385
- params = torch.load(path, map_location='cpu')['net']
386
- F0_model.load_state_dict(params)
387
- _ = F0_model.train()
388
-
389
- return F0_model
 
1
+ # coding:utf-8
2
 
3
  import os
 
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
+ from torch.nn.utils import spectral_norm
8
+ from torch.nn.utils.parametrizations import weight_norm
9
  # from Utils.ASR.models import ASRCNN
10
+ # from Utils.JDC.model import JDCNet
11
  from Modules.hifigan import _tile, AdainResBlk1d
12
+ import math
13
 
14
+ class MelSpec(torch.nn.Module):
15
 
16
+ def __init__(self,
17
+ sample_rate=17402, # https://github.com/fakerybakery/styletts2-cli/blob/main/msinference.py = Default 16000. However 17400 vocalises better also "en_US/vctk_p274"
18
+ n_fft=2048,
19
+ win_length=1200,
20
+ hop_length=300,
21
+ n_mels=80
22
+ ):
23
+ '''avoids dependency on torchaudio'''
24
  super().__init__()
25
+ self.n_fft = n_fft
26
+ self.win_length = win_length if win_length is not None else n_fft
27
+ self.hop_length = hop_length if hop_length is not None else self.win_length // 2
28
+ # --
29
+ f_min = 0.0
30
+ f_max = float(sample_rate // 2)
31
+ all_freqs = torch.linspace(0, sample_rate // 2, n_fft//2+1)
32
+ m_min = 2595.0 * math.log10(1.0 + (f_min / 700.0))
33
+ m_max = 2595.0 * math.log10(1.0 + (f_max / 700.0))
34
+ m_pts = torch.linspace(m_min, m_max, n_mels + 2)
35
+ f_pts = 700.0 * (10 ** (m_pts / 2595.0) - 1.0)
36
+ f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1)
37
+ slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1)
38
+ zero = torch.zeros(1)
39
+ down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_mels)
40
+ up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_mels)
41
+ fb = torch.max(zero, torch.min(down_slopes, up_slopes))
42
+ # --
43
+ self.register_buffer('fb', fb)
44
+ window = torch.hann_window(self.win_length)
45
+ self.register_buffer('window', window)
46
 
47
  def forward(self, x):
48
+ spec_f = torch.stft(x,
49
+ self.n_fft,
50
+ self.hop_length,
51
+ self.win_length,
52
+ self.window,
53
+ center=True,
54
+ pad_mode="reflect",
55
+ normalized=False,
56
+ onesided=True,
57
+ return_complex=True) # [bs, 1025, 56]
58
+ mel_specgram = torch.matmul(spec_f.abs().pow(2).transpose(1, 2), self.fb).transpose(1, 2)
59
+ return mel_specgram[:, None, :, :] # [bs, 1, 80, time]
60
 
61
 
62
+ class LearnedDownSample(nn.Module):
63
+ def __init__(self, dim_in):
64
  super().__init__()
65
+ self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(
66
+ 3, 3), stride=(2, 2), groups=dim_in, padding=1))
67
+
68
  def forward(self, x):
69
+ return self.conv(x)
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
 
72
  class ResBlk(nn.Module):
73
+ def __init__(self,
74
+ dim_in, dim_out):
75
  super().__init__()
76
+ self.actv = nn.LeakyReLU(0.2) # .07 also nice
77
+ self.downsample_res = LearnedDownSample(dim_in)
 
 
78
  self.learned_sc = dim_in != dim_out
 
 
 
79
  self.conv1 = spectral_norm(nn.Conv2d(dim_in, dim_in, 3, 1, 1))
80
  self.conv2 = spectral_norm(nn.Conv2d(dim_in, dim_out, 3, 1, 1))
 
 
 
81
  if self.learned_sc:
82
+ self.conv1x1 = spectral_norm(
83
+ nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False))
84
 
85
  def _shortcut(self, x):
86
  if self.learned_sc:
87
  x = self.conv1x1(x)
88
+ if x.shape[3] % 2 != 0: # [bs, 128, Freq, Time]
89
+ x = torch.cat([x, x[:, :, :, -1:]], dim=3)
90
+ return F.interpolate(x, scale_factor=.5, mode='nearest-exact') # F.avg_pool2d(x, 2)
91
 
92
  def _residual(self, x):
 
 
93
  x = self.actv(x)
94
  x = self.conv1(x)
95
  x = self.downsample_res(x)
 
 
96
  x = self.actv(x)
97
  x = self.conv2(x)
98
  return x
 
106
 
107
  # for both acoustic & prosodic ref_s/p
108
 
109
+ def __init__(self,
110
+ dim_in=64,
111
+ style_dim=128,
112
+ max_conv_dim=512):
113
  super().__init__()
114
+ blocks = [spectral_norm(nn.Conv2d(1, dim_in, 3, stride=1, padding=1))]
115
+ for _ in range(4):
116
+ dim_out = min(dim_in * 2,
117
+ max_conv_dim)
118
+ blocks += [ResBlk(dim_in, dim_out)]
 
 
119
  dim_in = dim_out
120
+ blocks += [nn.LeakyReLU(0.24), # w/o this activation - produces no speech
121
+ spectral_norm(nn.Conv2d(dim_out, dim_out, 5, stride=1, padding=0)),
122
+ nn.LeakyReLU(0.2) # 0.3 sounds nice
123
+ ]
 
 
 
124
  self.shared = nn.Sequential(*blocks)
 
125
  self.unshared = nn.Linear(dim_out, style_dim)
126
 
127
  def forward(self, x):
128
+ x = self.shared(x)
129
+ x = x.mean(3, keepdims=True) # comment this line for time varying style vector
130
+ x = x.transpose(1, 3)
131
+ s = self.unshared(x)
 
 
 
 
132
  return s
133
 
134
 
135
  class LinearNorm(torch.nn.Module):
136
+ def __init__(self, in_dim, out_dim, bias=True):
137
+ super().__init__()
138
  self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
139
 
 
 
 
 
140
  def forward(self, x):
141
  return self.linear_layer(x)
142
 
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  class LayerNorm(nn.Module):
145
  def __init__(self, channels, eps=1e-5):
146
  super().__init__()
 
155
  x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
156
  return x.transpose(1, -1)
157
 
158
+
159
  class TextEncoder(nn.Module):
160
+ def __init__(self, channels, kernel_size, depth, n_symbols):
161
  super().__init__()
162
  self.embedding = nn.Embedding(n_symbols, channels)
 
163
  padding = (kernel_size - 1) // 2
164
  self.cnn = nn.ModuleList()
165
  for _ in range(depth):
166
  self.cnn.append(nn.Sequential(
167
  weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
168
  LayerNorm(channels),
169
+ nn.LeakyReLU(0.24))
170
+ )
171
+ self.lstm = nn.LSTM(channels, channels//2, 1,
172
+ batch_first=True, bidirectional=True)
 
 
173
 
174
+ def forward(self, x):
175
  x = self.embedding(x) # [B, T, emb]
176
+ x = x.transpose(1, 2)
177
  for c in self.cnn:
178
+ x = c(x)
179
+ x = x.transpose(1, 2)
 
 
 
 
 
 
180
  x, _ = self.lstm(x)
 
 
 
181
  return x
182
+
183
+
184
  class AdaLayerNorm(nn.Module):
185
+
 
 
186
  def __init__(self, style_dim, channels=None, eps=1e-5):
187
  super().__init__()
188
  self.eps = eps
189
  self.fc = nn.Linear(style_dim, 1024)
190
 
191
  def forward(self, x, s):
192
+ h = self.fc(s)
193
  gamma = h[:, :, :512]
194
  beta = h[:, :, 512:1024]
195
+ x = F.layer_norm(x, (512, ), eps=self.eps)
 
196
  x = (1 + gamma) * x + beta
197
  return x # [1, 75, 512]
198
 
199
+
200
  class ProsodyPredictor(nn.Module):
201
 
202
+ def __init__(self, style_dim, d_hid, nlayers, max_dur=50):
203
+ super().__init__()
 
 
 
 
 
204
 
205
+ self.text_encoder = DurationEncoder(sty_dim=style_dim,
206
+ d_model=d_hid,
207
+ nlayers=nlayers) # called outside forward
208
+ self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2,
209
+ 1, batch_first=True, bidirectional=True)
210
  self.duration_proj = LinearNorm(d_hid, max_dur)
211
+ self.shared = nn.LSTM(d_hid + style_dim, d_hid //
212
+ 2, 1, batch_first=True, bidirectional=True)
213
+ self.F0 = nn.ModuleList([
214
+ AdainResBlk1d(d_hid, d_hid, style_dim),
215
+ AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True),
216
+ AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim),
217
+ ])
218
+ self.N = nn.ModuleList([
219
+ AdainResBlk1d(d_hid, d_hid, style_dim),
220
+ AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True),
221
+ AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim)
222
+ ])
223
  self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
224
  self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
225
+
226
  def F0Ntrain(self, x, s):
227
 
228
+ x, _ = self.shared(x) # [bs, time, ch] LSTM
229
 
230
  x = x.transpose(1, 2) # [bs, ch, time]
231
+
 
232
  F0 = x
233
+
234
  for block in self.F0:
235
  # print(f'LOOP {F0.shape=} {s.shape=}\n')
236
  # )N F0.shape=torch.Size([1, 512, 147]) s.shape=torch.Size([1, 128])
237
+ # This is an AdainResBlk1d expects conv1d dimensions
238
+ F0 = block(F0, s)
239
  F0 = self.F0_proj(F0)
240
+
241
  N = x
242
+
243
  for block in self.N:
244
  N = block(N, s)
245
  N = self.N_proj(N)
246
+
247
  return F0, N
248
+
249
+ def forward(self, d_en=None, s=None):
250
+ blend = self.text_encoder(d_en, s)
251
+ x, _ = self.lstm(blend)
252
+ dur = self.duration_proj(x) # [bs, 150, 50]
253
+
254
+ _, input_length, classifier_50 = dur.shape
255
+
256
+ dur = dur[0, :, :]
257
+ dur = torch.sigmoid(dur).sum(1)
258
+ dur = dur.round().clamp(min=1).to(torch.int64)
259
+ aln_trg = torch.zeros(1,
260
+ dur.sum(),
261
+ input_length,
262
+ device=s.device)
263
+ c_frame = 0
264
+ for i in range(input_length):
265
+ aln_trg[:, c_frame:c_frame + dur[i], i] = 1
266
+ c_frame += dur[i]
267
+ en = torch.bmm(aln_trg, blend)
268
+ F0_pred, N_pred = self.F0Ntrain(en, s)
269
+ return aln_trg, F0_pred, N_pred
270
+
271
 
272
  class DurationEncoder(nn.Module):
273
 
274
+ def __init__(self, sty_dim=128, d_model=512, nlayers=3):
275
  super().__init__()
276
  self.lstms = nn.ModuleList()
277
  for _ in range(nlayers):
278
+ self.lstms.append(nn.LSTM(d_model + sty_dim,
279
+ d_model // 2,
280
+ num_layers=1,
281
+ batch_first=True,
282
+ bidirectional=True
283
+ ))
284
  self.lstms.append(AdaLayerNorm(sty_dim, d_model))
 
 
 
 
 
285
 
 
286
 
287
+ def forward(self, x, style):
 
 
288
 
289
+ _, _, input_lengths = x.shape # [bs, 512, time]
290
+
291
+ style = _tile(style, length=x.shape[2]).transpose(1, 2)
292
+ x = x.transpose(1, 2)
293
 
 
 
294
  for block in self.lstms:
295
  if isinstance(block, AdaLayerNorm):
296
+
297
+ x = block(x, style) # LSTM has transposed x
 
298
 
299
  else:
300
+ x = torch.cat([x, style], axis=2)
301
+ # LSTM
 
 
 
 
 
 
 
 
 
302
 
303
+ x,_ = block(x) # expects [bs, time, chan] OUTPUTS [bs, time, 2*chan] 2x FROM BIDIRECTIONAL
304
+
305
+ return torch.cat([x, style], axis=2) # predictor.lstm()
 
 
 
 
 
 
 
 
 
 
msinference.py CHANGED
@@ -3,25 +3,28 @@ import sys
3
  import tempfile
4
  import re
5
  import os
6
- from num2words import num2words
7
  from collections import OrderedDict
8
  from Modules.hifigan import Decoder
9
  from Utils.PLBERT.util import load_plbert
10
  import phonemizer
11
  import torch
12
  from cached_path import cached_path
13
- # import nltk
14
  import audresample
15
- # nltk.download('punkt')
 
 
16
  import numpy as np
17
  import yaml
18
- import torchaudio
19
  import librosa
20
- from models import ProsodyPredictor, TextEncoder, StyleEncoder, load_F0_models
21
  from nltk.tokenize import word_tokenize
22
  from Utils.text_utils import transliterate_number
23
  import textwrap
24
- # IPA Phonemizer: https://github.com/bootphon/phonemizer
 
 
 
25
 
26
  _pad = "$"
27
  _punctuation = ';:,.!?¡¿—…"«»“” '
@@ -56,45 +59,33 @@ class TextCleaner:
56
 
57
  textclenaer = TextCleaner()
58
 
59
-
60
- to_mel = torchaudio.transforms.MelSpectrogram(
61
- n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
62
- mean, std = -4, 4
63
-
64
-
65
  def alpha_num(f):
66
  f = re.sub(' +', ' ', f) # delete spaces
67
  f = re.sub(r'[^A-Z a-z0-9 ]+', '', f) # del non alpha num
68
  return f
69
 
70
-
71
- def preprocess(wave):
72
- wave_tensor = torch.from_numpy(wave).float()
73
- mel_tensor = to_mel(wave_tensor)
74
- mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
75
- return mel_tensor
76
-
77
 
78
  def compute_style(path):
79
- wave, sr = librosa.load(path, sr=24000)
80
- audio, index = librosa.effects.trim(wave, top_db=30)
81
  if sr != 24000:
82
- audio = librosa.resample(audio, sr, 24000)
83
- mel_tensor = preprocess(audio).to(device)
84
-
85
  with torch.no_grad():
86
- ref_s = style_encoder(mel_tensor.unsqueeze(1))
87
- ref_p = predictor_encoder(mel_tensor.unsqueeze(1)) # [bs, 11, 1, 128]
 
 
 
88
 
89
- s = torch.cat([ref_s, ref_p], dim=3) # [bs, 11, 1, 256]
90
-
91
- s = s[:, :, 0, :].transpose(1, 2) # [1, 128, 11]
92
- return s # [1, 128, 11]
93
 
 
94
 
95
- device = 'cpu'
96
- if torch.cuda.is_available():
97
- device = 'cuda'
98
 
99
  global_phonemizer = phonemizer.backend.EspeakBackend(
100
  language='en-us', preserve_punctuation=True, with_stress=True)
@@ -104,10 +95,6 @@ global_phonemizer = phonemizer.backend.EspeakBackend(
104
  args = yaml.safe_load(open(str('Utils/config.yml')))
105
  ASR_config = args['ASR_config']
106
 
107
- F0_path = args['F0_path']
108
- pitch_extractor = load_F0_models(F0_path).eval().to(device)
109
-
110
-
111
  bert = load_plbert(args['PLBERT_dir']).eval().to(device)
112
 
113
  decoder = Decoder(dim_in=512,
@@ -128,8 +115,7 @@ text_encoder = TextEncoder(channels=512,
128
  predictor = ProsodyPredictor(style_dim=128,
129
  d_hid=512,
130
  nlayers=3, # OFFICIAL config.nlayers=5;
131
- max_dur=50,
132
- dropout=.2).eval().to(device)
133
 
134
  style_encoder = StyleEncoder(dim_in=64,
135
  style_dim=128,
@@ -141,9 +127,10 @@ bert_encoder = torch.nn.Linear(bert.config.hidden_size, 512).eval().to(device)
141
 
142
  # params_whole = torch.load('freevc2/yl4579_styletts2.pth' map_location='cpu')
143
  params_whole = torch.load(str(cached_path(
144
- "hf://yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/epochs_2nd_00020.pth")), map_location='cpu')
145
  params = params_whole['net']
146
-
 
147
 
148
  def _del_prefix(d):
149
  # del ".module"
@@ -163,95 +150,41 @@ predictor_encoder.load_state_dict(_del_prefix(
163
  params['predictor_encoder']), strict=True)
164
  style_encoder.load_state_dict(_del_prefix(
165
  params['style_encoder']), strict=True)
166
- pitch_extractor.load_state_dict(_del_prefix(
167
- params['pitch_extractor']), strict=True)
168
-
169
- # def _shift(x):
170
- # # [bs, samples] shift circular each batch elem of sound
171
- # n = x.shape[1]
172
- # for i, batch_elem in enumerate(x):
173
- # offset = np.random.randint(.24 * n, max(1, .74 * n)) # high should be above >= 0 TBD
174
- # x[i, ...] = torch.roll(batch_elem, offset, dims=1) # batch_elem = [400000, ]
175
- # return x
176
-
177
 
178
  def inference(text,
179
- ref_s,
180
- use_gruut=False):
181
-
182
- text = transliterate_number(text, lang='en').strip()
183
-
184
  ps = global_phonemizer.phonemize([text])
185
- # print(f'PHONEMIZER: {ps=}\n\n') #PHONEMIZER: ps=['ɐbˈɛbæbləm ']
186
  ps = word_tokenize(ps[0])
187
- # # print(f'TOKENIZER: {ps=}\n\n') #OKENIZER: ps=['ɐbˈɛbæbləm']
188
  ps = ' '.join(ps)
189
  tokens = textclenaer(ps)
190
- # print(f'TEXTCLEAN: {ps=}\n\n') #TEXTCLEAN: ps='ɐbˈɛbæbləm'
191
  tokens.insert(0, 0)
192
  tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
193
- # print(f'TOKENSFINAL: {ps=}\n\n')
194
-
195
  with torch.no_grad():
196
- input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
197
-
198
- hidden_states = text_encoder(tokens, input_lengths)
199
-
200
- bert_dur = bert(tokens, attention_mask=None)
201
  d_en = bert_encoder(bert_dur).transpose(-1, -2)
202
- ref = ref_s[:, :128, :] # [bs, 128, 11]
203
- s = ref_s[:, 128:, :]
204
- d = predictor.text_encoder(d_en, s, input_lengths)
205
- d = d.transpose(1, 2)
206
- # -------------------------------- pred_aln_trg = clones bert frames as duration
207
-
208
- d = predictor.text_encoder(d_en,
209
- s,
210
- input_lengths)
211
-
212
- x, _ = predictor.lstm(d)
213
-
214
- duration = predictor.duration_proj(x)
215
-
216
- duration = torch.sigmoid(duration).sum(axis=-1)
217
- pred_dur = torch.round(duration.squeeze()).clamp(min=1)
218
 
219
- pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
220
- c_frame = 0
221
- for i in range(pred_aln_trg.size(0)):
222
- pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
223
- c_frame += int(pred_dur[i].data)
224
 
225
- en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
226
- asr_new = torch.zeros_like(en)
227
- asr_new[:, :, 0] = en[:, :, 0]
228
- asr_new[:, :, 1:] = en[:, :, 0:-1]
229
- en = asr_new
 
 
230
 
231
- F0_pred, N_pred = predictor.F0Ntrain(en, s)
 
232
 
233
- asr = (hidden_states @ pred_aln_trg.unsqueeze(0).to(device))
234
-
235
- asr_new = torch.zeros_like(asr)
236
- asr_new[:, :, 0] = asr[:, :, 0]
237
- asr_new[:, :, 1:] = asr[:, :, 0:-1]
238
- asr = asr_new
239
- # -
240
-
241
- x = decoder(asr=asr,
242
- F0_curve=F0_pred,
243
- N=N_pred,
244
- s=ref)
245
-
246
- x = x.cpu().numpy()[0, 0, :-400] # weird pulse at the end of sentences
247
-
248
- # StyleTTS2 is 24kHz -> Resample to 16kHz ofAudioGen / MMS
249
 
250
  if x.shape[0] > 10:
251
- x /= np.abs(x).max() + 1e-7
252
  x = audresample.resample(signal=x.astype(np.float32),
253
  original_rate=24000,
254
- target_rate=16000)[0, :] # reshapes (64,) -> (1,64)
255
 
256
  else:
257
  print('\n\n\n\n\nEMPTY TTS\n\n\n\n\n\nn', x.shape)
@@ -346,17 +279,14 @@ def foreign(text=None, # split sentences here so we can prepend a txt for germ
346
  elif 'rom' in lang:
347
 
348
  lang_code = 'ron'
349
- speed = 1.24 if speed is None else speed
350
 
351
- elif 'ger' in lang:
352
 
353
  lang_code = 'deu'
354
- speed = 1.14 if speed is None else speed
355
 
356
  elif 'alban' in lang:
357
 
358
  lang_code = 'sqi'
359
- speed = 1.04 if speed is None else speed
360
 
361
  else:
362
 
@@ -364,38 +294,38 @@ def foreign(text=None, # split sentences here so we can prepend a txt for germ
364
 
365
  # load VITS
366
 
367
- net_g = VitsModel.from_pretrained(f'facebook/mms-tts-{lang_code}').eval().to(device)
368
- tokenizer = VitsTokenizer.from_pretrained(f'facebook/mms-tts-{lang_code}')
 
369
 
 
 
 
 
 
 
 
370
 
371
 
372
  total_audio = []
373
 
374
  # Split long sentences if deu to control voice switch - for other languages let text no-split
375
  if not isinstance(text, list):
376
- if lang_code == 'deu':
377
- # Split Very long sentences >500 phoneme - StyleTTS2 crashes # -- even 400 phonemes sometimes OOM in cuda:4
378
- # However prosody is nicer on non-split for MMS TTS
379
- # prepend txt snippet
380
- text = [
381
- sub_sent+' ' for sub_sent in textwrap.wrap(text, 200, break_long_words=0)]
382
- # assert that it chooses unique voice
383
- else:
384
- # allow longer non split text
385
- text = [
386
- sub_sent+' ' for sub_sent in textwrap.wrap(text, 640, break_long_words=0)]
387
- # for non deu MMS TTS lang.
388
 
389
  for _t in text:
390
 
391
  _t = _t.lower()
392
 
393
- # apply this in api.py -> tts_multi_sentence before switching between Styletts2
394
- print('\n\n\n\nBEF TRansliteration', _t,'\n\n\n\n\n')
395
- _t = transliterate_number(_t, lang=lang_code)
396
- print('AFT nums', _t,'\n____________________________________________')
397
 
398
- # However if we transliterate here also the demo sees the transliteration
 
 
 
 
 
399
 
400
  if lang_code == 'rmc-script_latin':
401
 
@@ -417,7 +347,7 @@ def foreign(text=None, # split sentences here so we can prepend a txt for germ
417
 
418
  x = net_g(input_ids=inputs.input_ids.to(device),
419
  attention_mask=inputs.attention_mask.to(device),
420
- speed=speed + .44 * np.random.rand() # variable speed for different sentence
421
  )[0, :]
422
 
423
  # crop the 1st audio - is PREFIX text 156000 samples to chose deu voice / VitsAttention()
@@ -428,8 +358,6 @@ def foreign(text=None, # split sentences here so we can prepend a txt for germ
428
 
429
  x = torch.cat(total_audio).cpu().numpy()
430
 
431
- x /= np.abs(x).max() + 1e-7
432
-
433
- # print(x.shape, x.min(), x.max(), hps.data.sampling_rate)
434
 
435
  return x # 16kHz - only resample StyleTTS2 from 24Hkz -> 16kHz
 
3
  import tempfile
4
  import re
5
  import os
 
6
  from collections import OrderedDict
7
  from Modules.hifigan import Decoder
8
  from Utils.PLBERT.util import load_plbert
9
  import phonemizer
10
  import torch
11
  from cached_path import cached_path
12
+ import nltk
13
  import audresample
14
+ nltk.download('punkt', download_dir='./') # comment if downloaded once
15
+ nltk.download('punkt_tab', download_dir='./')
16
+ nltk.data.path.append('.')
17
  import numpy as np
18
  import yaml
 
19
  import librosa
20
+ from models import ProsodyPredictor, TextEncoder, StyleEncoder, MelSpec
21
  from nltk.tokenize import word_tokenize
22
  from Utils.text_utils import transliterate_number
23
  import textwrap
24
+
25
+ device = 'cpu'
26
+ if torch.cuda.is_available():
27
+ device = 'cuda'
28
 
29
  _pad = "$"
30
  _punctuation = ';:,.!?¡¿—…"«»“” '
 
59
 
60
  textclenaer = TextCleaner()
61
 
 
 
 
 
 
 
62
  def alpha_num(f):
63
  f = re.sub(' +', ' ', f) # delete spaces
64
  f = re.sub(r'[^A-Z a-z0-9 ]+', '', f) # del non alpha num
65
  return f
66
 
67
+ mel_spec = MelSpec().to(device)
 
 
 
 
 
 
68
 
69
  def compute_style(path):
70
+ x, sr = librosa.load(path, sr=24000)
71
+ x, _ = librosa.effects.trim(x, top_db=30)
72
  if sr != 24000:
73
+ x = librosa.resample(x, sr, 24000)
74
+
 
75
  with torch.no_grad():
76
+ x = torch.from_numpy(x[None, :]).to(device=device, dtype=torch.float)
77
+
78
+ mel_tensor = (torch.log(1e-5 + mel_spec(x)) + 4) / 4
79
+
80
+ #mel_tensor = preprocess(audio).to(device)
81
 
82
+ ref_s = style_encoder(mel_tensor)
83
+ ref_p = predictor_encoder(mel_tensor) # [bs, 11, 1, 128]
 
 
84
 
85
+ s = torch.cat([ref_s, ref_p], dim=3) # [bs, 11, 1, 256]
86
 
87
+ s = s[:, :, 0, :].transpose(1, 2) # [1, 128, 11]
88
+ return s # [1, 128, 11]
 
89
 
90
  global_phonemizer = phonemizer.backend.EspeakBackend(
91
  language='en-us', preserve_punctuation=True, with_stress=True)
 
95
  args = yaml.safe_load(open(str('Utils/config.yml')))
96
  ASR_config = args['ASR_config']
97
 
 
 
 
 
98
  bert = load_plbert(args['PLBERT_dir']).eval().to(device)
99
 
100
  decoder = Decoder(dim_in=512,
 
115
  predictor = ProsodyPredictor(style_dim=128,
116
  d_hid=512,
117
  nlayers=3, # OFFICIAL config.nlayers=5;
118
+ max_dur=50).eval().to(device)
 
119
 
120
  style_encoder = StyleEncoder(dim_in=64,
121
  style_dim=128,
 
127
 
128
  # params_whole = torch.load('freevc2/yl4579_styletts2.pth' map_location='cpu')
129
  params_whole = torch.load(str(cached_path(
130
+ "hf://yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/epochs_2nd_00020.pth")), map_location='cpu', weights_only=True)
131
  params = params_whole['net']
132
+ #params['decoder'].pop('module.generator.m_source.l_linear.weight')
133
+ #params['decoder'].pop('module.generator.m_source.l_linear.bias') # SourceHNSf
134
 
135
  def _del_prefix(d):
136
  # del ".module"
 
150
  params['predictor_encoder']), strict=True)
151
  style_encoder.load_state_dict(_del_prefix(
152
  params['style_encoder']), strict=True)
 
 
 
 
 
 
 
 
 
 
 
153
 
154
  def inference(text,
155
+ ref_s):
156
+ # text = transliterate_number(text, lang='en').strip() # Transliteration only used for foreign() # perhaps add xtra . after ? ;
 
 
 
157
  ps = global_phonemizer.phonemize([text])
 
158
  ps = word_tokenize(ps[0])
 
159
  ps = ' '.join(ps)
160
  tokens = textclenaer(ps)
 
161
  tokens.insert(0, 0)
162
  tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
 
 
163
  with torch.no_grad():
164
+ hidden_states = text_encoder(tokens)
165
+ bert_dur = bert(tokens, attention_mask=torch.ones_like(tokens))
 
 
 
166
  d_en = bert_encoder(bert_dur).transpose(-1, -2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
+ aln_trg, F0_pred, N_pred = predictor(d_en=d_en, s=ref_s[:, 128:, :])
 
 
 
 
169
 
170
+ asr = torch.bmm(aln_trg, hidden_states)
171
+ asr = asr.transpose(1, 2)
172
+ asr = torch.cat([asr[:, :, 0:1], asr[:, :, 0:-1]], 2)
173
+ x = decoder(asr=asr, # [1, 512, 201]
174
+ F0_curve=F0_pred, # [1, 1, 402] 2x time
175
+ N=N_pred, # [1, 1, 402] 2x time
176
+ s=ref_s[:, :128, :]) # [1, 256, 1]
177
 
178
+ x = x.cpu().numpy()[0, 0, :]
179
+ x[-400:] = 0 # noisy pulse produced for unterminated sentences, in absence of punctuation, (not sure if same behaviour for all voices)
180
 
181
+ # StyleTTS2 is 24kHz -> Resample to 16kHz as is AudioGen / MMS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
  if x.shape[0] > 10:
184
+
185
  x = audresample.resample(signal=x.astype(np.float32),
186
  original_rate=24000,
187
+ target_rate=16000)[0, :] # audresample reshapes (64,) -> (1,64) | Volume Normalisation applies in api.py:tts_multi_sentence()
188
 
189
  else:
190
  print('\n\n\n\n\nEMPTY TTS\n\n\n\n\n\nn', x.shape)
 
279
  elif 'rom' in lang:
280
 
281
  lang_code = 'ron'
 
282
 
283
+ elif 'ger' in lang or 'deu' in lang or 'allem' in lang:
284
 
285
  lang_code = 'deu'
 
286
 
287
  elif 'alban' in lang:
288
 
289
  lang_code = 'sqi'
 
290
 
291
  else:
292
 
 
294
 
295
  # load VITS
296
 
297
+ # net_g = VitsModel.from_pretrained(f'facebook/mms-tts-{lang_code}').eval().to(device)
298
+ # tokenizer = VitsTokenizer.from_pretrained(f'facebook/mms-tts-{lang_code}')
299
+ global cached_lang_code, cached_net_g, cached_tokenizer
300
 
301
+ if 'cached_lang_code' not in globals() or cached_lang_code != lang_code:
302
+ cached_lang_code = lang_code
303
+ cached_net_g = VitsModel.from_pretrained(f'facebook/mms-tts-{lang_code}').eval().to(device)
304
+ cached_tokenizer = VitsTokenizer.from_pretrained(f'facebook/mms-tts-{lang_code}')
305
+
306
+ net_g = cached_net_g
307
+ tokenizer = cached_tokenizer
308
 
309
 
310
  total_audio = []
311
 
312
  # Split long sentences if deu to control voice switch - for other languages let text no-split
313
  if not isinstance(text, list):
314
+ # Split Very long sentences
315
+ text = [sub_sent+' ' for sub_sent in textwrap.wrap(text, 440, break_long_words=0)]
 
 
 
 
 
 
 
 
 
 
316
 
317
  for _t in text:
318
 
319
  _t = _t.lower()
320
 
321
+ # NUMBERS
 
 
 
322
 
323
+ try:
324
+ _t = transliterate_number(_t, lang=lang_code)
325
+ except NotImplementedError:
326
+ print('Transliterate Numbers - NotImplemented for {lang_code=}', _t,'\n____________________________________________')
327
+
328
+ # PRONOUNC.
329
 
330
  if lang_code == 'rmc-script_latin':
331
 
 
347
 
348
  x = net_g(input_ids=inputs.input_ids.to(device),
349
  attention_mask=inputs.attention_mask.to(device),
350
+ lang_code=lang_code,
351
  )[0, :]
352
 
353
  # crop the 1st audio - is PREFIX text 156000 samples to chose deu voice / VitsAttention()
 
358
 
359
  x = torch.cat(total_audio).cpu().numpy()
360
 
361
+ # x /= np.abs(x).max() + 1e-7 ~ Volume normalisation @api.py:tts_multi_sentence() OR demo.py
 
 
362
 
363
  return x # 16kHz - only resample StyleTTS2 from 24Hkz -> 16kHz
requirements.txt CHANGED
@@ -18,4 +18,4 @@ srt
18
  nltk
19
  phonemizer
20
  docx
21
- torchaudio
 
18
  nltk
19
  phonemizer
20
  docx
21
+ uroman
tts.py CHANGED
@@ -91,13 +91,13 @@ def command_line_args():
91
 
92
  def send_to_server(args):
93
  url = "http://192.168.88.209:5000"
94
-
95
  # Args
96
 
97
  payload = {
98
  'affective': args.affective,
99
  'voice': args.voice,
100
- 'soundscape': args.soundscape,
101
  'native': args.native,
102
  'text': args.text,
103
  'image': args.image,
 
91
 
92
  def send_to_server(args):
93
  url = "http://192.168.88.209:5000"
94
+
95
  # Args
96
 
97
  payload = {
98
  'affective': args.affective,
99
  'voice': args.voice,
100
+ 'soundscape': args.soundscape if args.soundscape != '' else None,
101
  'native': args.native,
102
  'text': args.text,
103
  'image': args.image,