oscillate vits duration
Browse files- Modules/hifigan.py +29 -82
- Modules/vits/models.py +121 -862
- README.md +1 -1
- Utils/JDC/__init__.py +0 -1
- Utils/JDC/bst.pth +0 -3
- Utils/JDC/model.py +0 -190
- Utils/PLBERT/util.py +1 -1
- Utils/text_utils.py +101 -61
- api.py +6 -16
- audiobook.py +14 -27
- demo.py +15 -43
- live_demo.py +5 -7
- models.py +167 -251
- msinference.py +67 -139
- requirements.txt +1 -1
- tts.py +2 -2
Modules/hifigan.py
CHANGED
@@ -2,12 +2,12 @@ import torch
|
|
2 |
import torch.nn.functional as F
|
3 |
import torch.nn as nn
|
4 |
from torch.nn import Conv1d, ConvTranspose1d
|
5 |
-
from torch.nn.utils import weight_norm
|
6 |
import math
|
7 |
import numpy as np
|
8 |
|
9 |
|
10 |
-
|
11 |
|
12 |
|
13 |
def get_padding(kernel_size, dilation=1):
|
@@ -93,80 +93,38 @@ class AdaINResBlock1(torch.nn.Module):
|
|
93 |
x = xt + x
|
94 |
return x
|
95 |
|
96 |
-
def remove_weight_norm(self):
|
97 |
-
for l in self.convs1:
|
98 |
-
remove_weight_norm(l)
|
99 |
-
for l in self.convs2:
|
100 |
-
remove_weight_norm(l)
|
101 |
|
|
|
102 |
|
103 |
-
|
104 |
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
super(SineGen, self).__init__()
|
112 |
-
self.harmonic_num = harmonic_num
|
113 |
-
self.sampling_rate = samp_rate
|
114 |
-
self.voiced_threshold = voiced_threshold
|
115 |
-
self.upsample_scale = upsample_scale
|
116 |
-
|
117 |
-
def _f02sine(self, f0_values):
|
118 |
-
# --
|
119 |
-
# 134 HIFI
|
120 |
-
# torch.Size([1, 145200, 9])
|
121 |
-
# torch.Size([1, 145200, 9]) torch.Size([1, 145200, 9]) HIFi
|
122 |
|
|
|
|
|
|
|
|
|
|
|
123 |
# modulo of negative f0_values => -21 % 10 = 9 as -3*10 + 9 = 21 NOTICE THAT f0_values IS SIGNED
|
124 |
-
rad_values =
|
125 |
-
|
126 |
-
rad_values =
|
|
|
127 |
scale_factor=1/self.upsample_scale,
|
128 |
-
mode=
|
129 |
|
130 |
# 1.89 sounds also nice has woofer at punctuation
|
131 |
phase = torch.cumsum(rad_values, dim=1) * 1.84 * np.pi
|
132 |
-
phase =
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
# print('____________________________________\nF0 F0\n', f0.abs().mean(), f0.mean(), f0.max(), f0.min()) # male voices sound less muffed via higher scaler in sine_waves
|
139 |
-
# f0 is already full length - [1, 142600, 1]
|
140 |
-
|
141 |
-
amplif = .0104 if f0.abs().mean() < 100 else .009 # vary amplif based on f0.abs().mean() - voice sensitive
|
142 |
-
|
143 |
-
fn = torch.multiply(f0, torch.FloatTensor(
|
144 |
-
[[range(1, self.harmonic_num + 2)]]).to(f0.device)) # [1, 145200, 9]
|
145 |
-
|
146 |
-
# .007 # very important effect DEFAULT=0.1 very sensitive to speaker - heuristically
|
147 |
-
sine_waves = self._f02sine(fn) * amplif # .009
|
148 |
-
|
149 |
-
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
150 |
-
|
151 |
-
return sine_waves * uv
|
152 |
-
|
153 |
-
|
154 |
-
class SourceModuleHnNSF(torch.nn.Module):
|
155 |
-
|
156 |
-
def __init__(self, harmonic_num=8):
|
157 |
-
|
158 |
-
super(SourceModuleHnNSF, self).__init__()
|
159 |
-
self.l_sin_gen = SineGen()
|
160 |
-
# harmonic=8 is hard fixed due to this nn.Linear()
|
161 |
-
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
162 |
-
self.l_tanh = torch.nn.Tanh()
|
163 |
-
|
164 |
-
def forward(self, x):
|
165 |
-
# print(' HNnSF', x.shape) # why this is [1, 300, 1, 535800]
|
166 |
-
sine_wavs = self.l_sin_gen(x)
|
167 |
-
# This linear sums all 9 harmonics
|
168 |
-
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
169 |
-
return sine_merge
|
170 |
|
171 |
|
172 |
class Generator(torch.nn.Module):
|
@@ -239,7 +197,7 @@ class Generator(torch.nn.Module):
|
|
239 |
x_source = self.noise_res[i](x_source, s)
|
240 |
|
241 |
x = self.ups[i](x)
|
242 |
-
|
243 |
x = x + x_source
|
244 |
|
245 |
xs = None
|
@@ -250,22 +208,12 @@ class Generator(torch.nn.Module):
|
|
250 |
else:
|
251 |
xs += self.resblocks[i*self.num_kernels+j](x, s)
|
252 |
x = xs / self.num_kernels
|
253 |
-
x = x + (1 / self.alphas[i+1]) * (torch.sin(self.alphas[i+1] * x) ** 2)
|
254 |
x = self.conv_post(x)
|
255 |
x = torch.tanh(x)
|
256 |
|
257 |
return x
|
258 |
|
259 |
-
def remove_weight_norm(self):
|
260 |
-
print('Removing weight norm...')
|
261 |
-
for l in self.ups:
|
262 |
-
remove_weight_norm(l)
|
263 |
-
for l in self.resblocks:
|
264 |
-
l.remove_weight_norm()
|
265 |
-
remove_weight_norm(self.conv_pre)
|
266 |
-
remove_weight_norm(self.conv_post)
|
267 |
-
|
268 |
-
|
269 |
class AdainResBlk1d(nn.Module):
|
270 |
|
271 |
# also used in ProsodyPredictor()
|
@@ -324,7 +272,7 @@ class UpSample1d(nn.Module):
|
|
324 |
if self.layer_type == 'none':
|
325 |
return x
|
326 |
else:
|
327 |
-
return F.interpolate(x, scale_factor=2, mode='nearest')
|
328 |
|
329 |
|
330 |
class Decoder(nn.Module):
|
@@ -361,11 +309,10 @@ class Decoder(nn.Module):
|
|
361 |
|
362 |
def forward(self, asr=None, F0_curve=None, N=None, s=None):
|
363 |
|
364 |
-
|
365 |
F0 = self.F0_conv(F0_curve)
|
366 |
N = self.N_conv(N)
|
367 |
|
368 |
-
# print(asr.shape, F0.shape, N.shape, 'TF')
|
369 |
|
370 |
x = torch.cat([asr, F0, N], axis=1)
|
371 |
|
|
|
2 |
import torch.nn.functional as F
|
3 |
import torch.nn as nn
|
4 |
from torch.nn import Conv1d, ConvTranspose1d
|
5 |
+
from torch.nn.utils.parametrizations import weight_norm
|
6 |
import math
|
7 |
import numpy as np
|
8 |
|
9 |
|
10 |
+
|
11 |
|
12 |
|
13 |
def get_padding(kernel_size, dilation=1):
|
|
|
93 |
x = xt + x
|
94 |
return x
|
95 |
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
+
class SourceModuleHnNSF(torch.nn.Module):
|
98 |
|
99 |
+
def __init__(self):
|
100 |
|
101 |
+
super().__init__()
|
102 |
+
self.harmonic_num = 8
|
103 |
+
self.l_linear = torch.nn.Linear(self.harmonic_num + 1, 1)
|
104 |
+
self.upsample_scale = 300
|
105 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
+
def forward(self, x):
|
108 |
+
# --
|
109 |
+
x = torch.multiply(x, torch.FloatTensor(
|
110 |
+
[[range(1, self.harmonic_num + 2)]]).to(x.device)) # [1, 145200, 9]
|
111 |
+
|
112 |
# modulo of negative f0_values => -21 % 10 = 9 as -3*10 + 9 = 21 NOTICE THAT f0_values IS SIGNED
|
113 |
+
rad_values = x / 25647 #).clamp(0, 1)
|
114 |
+
# rad_values = torch.where(torch.logical_or(rad_values < 0, rad_values > 1), 0.5, rad_values)
|
115 |
+
rad_values = rad_values % 1 # % of neg values
|
116 |
+
rad_values = F.interpolate(rad_values.transpose(1, 2),
|
117 |
scale_factor=1/self.upsample_scale,
|
118 |
+
mode='linear').transpose(1, 2)
|
119 |
|
120 |
# 1.89 sounds also nice has woofer at punctuation
|
121 |
phase = torch.cumsum(rad_values, dim=1) * 1.84 * np.pi
|
122 |
+
phase = F.interpolate(phase.transpose(1, 2) * self.upsample_scale,
|
123 |
+
scale_factor=self.upsample_scale, mode='linear').transpose(1, 2)
|
124 |
+
x = .009 * phase.sin()
|
125 |
+
# --
|
126 |
+
x = self.l_linear(x).tanh()
|
127 |
+
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
|
130 |
class Generator(torch.nn.Module):
|
|
|
197 |
x_source = self.noise_res[i](x_source, s)
|
198 |
|
199 |
x = self.ups[i](x)
|
200 |
+
|
201 |
x = x + x_source
|
202 |
|
203 |
xs = None
|
|
|
208 |
else:
|
209 |
xs += self.resblocks[i*self.num_kernels+j](x, s)
|
210 |
x = xs / self.num_kernels
|
211 |
+
# x = x + (1 / self.alphas[i+1]) * (torch.sin(self.alphas[i+1] * x) ** 2) # noisy
|
212 |
x = self.conv_post(x)
|
213 |
x = torch.tanh(x)
|
214 |
|
215 |
return x
|
216 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
class AdainResBlk1d(nn.Module):
|
218 |
|
219 |
# also used in ProsodyPredictor()
|
|
|
272 |
if self.layer_type == 'none':
|
273 |
return x
|
274 |
else:
|
275 |
+
return F.interpolate(x, scale_factor=2, mode='nearest-exact')
|
276 |
|
277 |
|
278 |
class Decoder(nn.Module):
|
|
|
309 |
|
310 |
def forward(self, asr=None, F0_curve=None, N=None, s=None):
|
311 |
|
312 |
+
|
313 |
F0 = self.F0_conv(F0_curve)
|
314 |
N = self.N_conv(N)
|
315 |
|
|
|
316 |
|
317 |
x = torch.cat([asr, F0, N], axis=1)
|
318 |
|
Modules/vits/models.py
CHANGED
@@ -1,15 +1,30 @@
|
|
1 |
import math
|
2 |
from dataclasses import dataclass
|
3 |
-
from typing import Any, Optional, Tuple, Union
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
-
import torch.utils.checkpoint
|
7 |
from torch import nn
|
8 |
-
|
9 |
-
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
10 |
from transformers.modeling_outputs import BaseModelOutput, ModelOutput
|
11 |
from transformers.modeling_utils import PreTrainedModel
|
12 |
from transformers.configuration_utils import PretrainedConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
class VitsConfig(PretrainedConfig):
|
15 |
|
@@ -74,11 +89,9 @@ class VitsConfig(PretrainedConfig):
|
|
74 |
self.ffn_kernel_size = ffn_kernel_size
|
75 |
self.flow_size = flow_size
|
76 |
self.spectrogram_bins = spectrogram_bins
|
77 |
-
|
78 |
-
|
79 |
self.initializer_range = initializer_range
|
80 |
self.layer_norm_eps = layer_norm_eps
|
81 |
-
self.use_stochastic_duration_prediction = use_stochastic_duration_prediction
|
82 |
self.num_speakers = num_speakers
|
83 |
self.speaker_embedding_size = speaker_embedding_size
|
84 |
self.upsample_initial_channel = upsample_initial_channel
|
@@ -92,7 +105,6 @@ class VitsConfig(PretrainedConfig):
|
|
92 |
self.duration_predictor_flow_bins = duration_predictor_flow_bins
|
93 |
self.duration_predictor_tail_bound = duration_predictor_tail_bound
|
94 |
self.duration_predictor_kernel_size = duration_predictor_kernel_size
|
95 |
-
|
96 |
self.duration_predictor_num_flows = duration_predictor_num_flows
|
97 |
self.duration_predictor_filter_channels = duration_predictor_filter_channels
|
98 |
self.prior_encoder_num_flows = prior_encoder_num_flows
|
@@ -100,8 +112,6 @@ class VitsConfig(PretrainedConfig):
|
|
100 |
self.posterior_encoder_num_wavenet_layers = posterior_encoder_num_wavenet_layers
|
101 |
self.wavenet_kernel_size = wavenet_kernel_size
|
102 |
self.wavenet_dilation_rate = wavenet_dilation_rate
|
103 |
-
|
104 |
-
|
105 |
self.noise_scale = noise_scale
|
106 |
self.noise_scale_duration = noise_scale_duration
|
107 |
self.sampling_rate = sampling_rate
|
@@ -121,183 +131,9 @@ class VitsTextEncoderOutput(ModelOutput):
|
|
121 |
last_hidden_state: torch.FloatTensor = None
|
122 |
prior_means: torch.FloatTensor = None
|
123 |
prior_log_variances: torch.FloatTensor = None
|
124 |
-
hidden_states:
|
125 |
-
attentions:
|
126 |
-
|
127 |
-
def _unconstrained_rational_quadratic_spline(
|
128 |
-
inputs,
|
129 |
-
unnormalized_widths,
|
130 |
-
unnormalized_heights,
|
131 |
-
unnormalized_derivatives,
|
132 |
-
reverse=False,
|
133 |
-
tail_bound=5.0,
|
134 |
-
min_bin_width=1e-3,
|
135 |
-
min_bin_height=1e-3,
|
136 |
-
min_derivative=1e-3,
|
137 |
-
):
|
138 |
-
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
|
139 |
-
outside_interval_mask = ~inside_interval_mask
|
140 |
-
|
141 |
-
outputs = torch.zeros_like(inputs)
|
142 |
-
|
143 |
-
constant = np.log(np.exp(1 - min_derivative) - 1)
|
144 |
-
|
145 |
-
unnormalized_derivatives = nn.functional.pad(unnormalized_derivatives, pad=(1, 1))
|
146 |
-
unnormalized_derivatives[..., 0] = constant
|
147 |
-
unnormalized_derivatives[..., -1] = constant
|
148 |
-
|
149 |
-
outputs[outside_interval_mask] = inputs[outside_interval_mask]
|
150 |
-
|
151 |
-
|
152 |
-
outputs[inside_interval_mask] = _rational_quadratic_spline(
|
153 |
-
inputs=inputs[inside_interval_mask],
|
154 |
-
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
|
155 |
-
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
|
156 |
-
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
|
157 |
-
reverse=reverse,
|
158 |
-
tail_bound=tail_bound,
|
159 |
-
min_bin_width=min_bin_width,
|
160 |
-
min_bin_height=min_bin_height,
|
161 |
-
min_derivative=min_derivative,
|
162 |
-
)
|
163 |
-
return outputs
|
164 |
-
|
165 |
-
|
166 |
-
def _rational_quadratic_spline(
|
167 |
-
inputs,
|
168 |
-
unnormalized_widths,
|
169 |
-
unnormalized_heights,
|
170 |
-
unnormalized_derivatives,
|
171 |
-
reverse,
|
172 |
-
tail_bound,
|
173 |
-
min_bin_width,
|
174 |
-
min_bin_height,
|
175 |
-
min_derivative,
|
176 |
-
):
|
177 |
-
"""
|
178 |
-
This transformation represents a monotonically increasing piecewise rational quadratic function. Unlike the
|
179 |
-
function `_unconstrained_rational_quadratic_spline`, the function behaves the same across the `tail_bound`.
|
180 |
-
|
181 |
-
Args:
|
182 |
-
inputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
|
183 |
-
Second half of the hidden-states input to the Vits convolutional flow module.
|
184 |
-
unnormalized_widths (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
|
185 |
-
First `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
|
186 |
-
layer in the convolutional flow module
|
187 |
-
unnormalized_heights (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
|
188 |
-
Second `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
|
189 |
-
layer in the convolutional flow module
|
190 |
-
unnormalized_derivatives (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
|
191 |
-
Third `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
|
192 |
-
layer in the convolutional flow module
|
193 |
-
reverse (`bool`):
|
194 |
-
Whether the model is being run in reverse mode.
|
195 |
-
tail_bound (`float`):
|
196 |
-
Upper and lower limit bound for the rational quadratic function. Outside of this `tail_bound`, the
|
197 |
-
transform behaves as an identity function.
|
198 |
-
min_bin_width (`float`):
|
199 |
-
Minimum bin value across the width dimension for the piecewise rational quadratic function.
|
200 |
-
min_bin_height (`float`):
|
201 |
-
Minimum bin value across the height dimension for the piecewise rational quadratic function.
|
202 |
-
min_derivative (`float`):
|
203 |
-
Minimum bin value across the derivatives for the piecewise rational quadratic function.
|
204 |
-
Returns:
|
205 |
-
outputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
|
206 |
-
Hidden-states as transformed by the piecewise rational quadratic function.
|
207 |
-
log_abs_det (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
|
208 |
-
Logarithm of the absolute value of the determinants corresponding to the `outputs`.
|
209 |
-
"""
|
210 |
-
upper_bound = tail_bound
|
211 |
-
lower_bound = -tail_bound
|
212 |
-
|
213 |
-
if torch.min(inputs) < lower_bound or torch.max(inputs) > upper_bound:
|
214 |
-
raise ValueError("Input to a transform is not within its domain")
|
215 |
-
|
216 |
-
num_bins = unnormalized_widths.shape[-1]
|
217 |
-
|
218 |
-
if min_bin_width * num_bins > 1.0:
|
219 |
-
raise ValueError(f"Minimal bin width {min_bin_width} too large for the number of bins {num_bins}")
|
220 |
-
if min_bin_height * num_bins > 1.0:
|
221 |
-
raise ValueError(f"Minimal bin height {min_bin_height} too large for the number of bins {num_bins}")
|
222 |
-
|
223 |
-
widths = nn.functional.softmax(unnormalized_widths, dim=-1)
|
224 |
-
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
|
225 |
-
cumwidths = torch.cumsum(widths, dim=-1)
|
226 |
-
cumwidths = nn.functional.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
|
227 |
-
cumwidths = (upper_bound - lower_bound) * cumwidths + lower_bound
|
228 |
-
cumwidths[..., 0] = lower_bound
|
229 |
-
cumwidths[..., -1] = upper_bound
|
230 |
-
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
|
231 |
-
|
232 |
-
derivatives = min_derivative + nn.functional.softplus(unnormalized_derivatives)
|
233 |
-
|
234 |
-
heights = nn.functional.softmax(unnormalized_heights, dim=-1)
|
235 |
-
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
|
236 |
-
cumheights = torch.cumsum(heights, dim=-1)
|
237 |
-
cumheights = nn.functional.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
|
238 |
-
cumheights = (upper_bound - lower_bound) * cumheights + lower_bound
|
239 |
-
cumheights[..., 0] = lower_bound
|
240 |
-
cumheights[..., -1] = upper_bound
|
241 |
-
heights = cumheights[..., 1:] - cumheights[..., :-1]
|
242 |
-
|
243 |
-
bin_locations = cumheights if reverse else cumwidths
|
244 |
-
bin_locations[..., -1] += 1e-6
|
245 |
-
bin_idx = torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
|
246 |
-
bin_idx = bin_idx[..., None]
|
247 |
-
|
248 |
-
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
|
249 |
-
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
|
250 |
-
|
251 |
-
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
|
252 |
-
delta = heights / widths
|
253 |
-
input_delta = delta.gather(-1, bin_idx)[..., 0]
|
254 |
-
|
255 |
-
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
|
256 |
-
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
|
257 |
-
|
258 |
-
input_heights = heights.gather(-1, bin_idx)[..., 0]
|
259 |
-
|
260 |
-
intermediate1 = input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
261 |
-
if not reverse:
|
262 |
-
raise ValueError
|
263 |
-
# theta = (inputs - input_cumwidths) / input_bin_widths
|
264 |
-
# theta_one_minus_theta = theta * (1 - theta)
|
265 |
-
|
266 |
-
# numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta)
|
267 |
-
# denominator = input_delta + intermediate1 * theta_one_minus_theta
|
268 |
-
# outputs = input_cumheights + numerator / denominator
|
269 |
-
|
270 |
-
# derivative_numerator = input_delta.pow(2) * (
|
271 |
-
# input_derivatives_plus_one * theta.pow(2)
|
272 |
-
# + 2 * input_delta * theta_one_minus_theta
|
273 |
-
# + input_derivatives * (1 - theta).pow(2)
|
274 |
-
# )
|
275 |
-
# log_abs_det = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
276 |
-
# return outputs, log_abs_det
|
277 |
-
else:
|
278 |
-
# find the roots of a quadratic equation
|
279 |
-
intermediate2 = inputs - input_cumheights
|
280 |
-
intermediate3 = intermediate2 * intermediate1
|
281 |
-
a = input_heights * (input_delta - input_derivatives) + intermediate3
|
282 |
-
b = input_heights * input_derivatives - intermediate3
|
283 |
-
c = -input_delta * intermediate2
|
284 |
-
|
285 |
-
discriminant = b.pow(2) - 4 * a * c
|
286 |
-
if not (discriminant >= 0).all():
|
287 |
-
raise RuntimeError(f"invalid discriminant {discriminant}")
|
288 |
-
|
289 |
-
root = (2 * c) / (-b - torch.sqrt(discriminant))
|
290 |
-
outputs = root * input_bin_widths + input_cumwidths
|
291 |
-
|
292 |
-
# theta_one_minus_theta = root * (1 - root)
|
293 |
-
# denominator = input_delta + intermediate1 * theta_one_minus_theta
|
294 |
-
# derivative_numerator = input_delta.pow(2) * (
|
295 |
-
# input_derivatives_plus_one * root.pow(2)
|
296 |
-
# + 2 * input_delta * theta_one_minus_theta
|
297 |
-
# + input_derivatives * (1 - root).pow(2)
|
298 |
-
# )
|
299 |
-
# log_abs_det = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
300 |
-
return outputs #, -log_abs_det
|
301 |
|
302 |
|
303 |
class VitsWaveNet(torch.nn.Module):
|
@@ -305,20 +141,14 @@ class VitsWaveNet(torch.nn.Module):
|
|
305 |
super().__init__()
|
306 |
self.hidden_size = config.hidden_size
|
307 |
self.num_layers = num_layers
|
308 |
-
|
309 |
self.in_layers = torch.nn.ModuleList()
|
310 |
self.res_skip_layers = torch.nn.ModuleList()
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
if config.speaker_embedding_size != 0:
|
319 |
-
cond_layer = torch.nn.Conv1d(config.speaker_embedding_size, 2 * config.hidden_size * num_layers, 1)
|
320 |
-
self.cond_layer = weight_norm(cond_layer, name="weight")
|
321 |
-
|
322 |
for i in range(num_layers):
|
323 |
dilation = config.wavenet_dilation_rate**i
|
324 |
padding = (config.wavenet_kernel_size * dilation - dilation) // 2
|
@@ -337,53 +167,36 @@ class VitsWaveNet(torch.nn.Module):
|
|
337 |
res_skip_channels = 2 * config.hidden_size
|
338 |
else:
|
339 |
res_skip_channels = config.hidden_size
|
340 |
-
|
341 |
res_skip_layer = torch.nn.Conv1d(config.hidden_size, res_skip_channels, 1)
|
342 |
res_skip_layer = weight_norm(res_skip_layer, name="weight")
|
343 |
self.res_skip_layers.append(res_skip_layer)
|
344 |
|
345 |
-
def forward(self,
|
|
|
346 |
outputs = torch.zeros_like(inputs)
|
347 |
num_channels = torch.IntTensor([self.hidden_size])[0]
|
348 |
-
|
349 |
-
|
350 |
for i in range(self.num_layers):
|
351 |
in_act = self.in_layers[i](inputs)
|
352 |
-
|
353 |
-
|
354 |
# global_states = torch.zeros_like(hidden_states) # style ?
|
355 |
-
|
356 |
# acts = fused_add_tanh_sigmoid_multiply(hidden_states, global_states, num_channels_tensor[0])
|
357 |
-
|
358 |
# --
|
359 |
# def fused_add_tanh_sigmoid_multiply(input_a, input_b, num_channels):
|
360 |
# in_act = input_a # + input_b
|
361 |
t_act = torch.tanh(in_act[:, :num_channels, :])
|
362 |
s_act = torch.sigmoid(in_act[:, num_channels:, :])
|
363 |
acts = t_act * s_act
|
364 |
-
|
365 |
-
|
366 |
-
#
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
res_skip_acts = self.res_skip_layers[i](acts)
|
371 |
if i < self.num_layers - 1:
|
372 |
res_acts = res_skip_acts[:, : self.hidden_size, :]
|
373 |
-
inputs =
|
374 |
outputs = outputs + res_skip_acts[:, self.hidden_size :, :]
|
375 |
else:
|
376 |
outputs = outputs + res_skip_acts
|
|
|
|
|
|
|
377 |
|
378 |
-
return outputs * padding_mask
|
379 |
|
380 |
-
def remove_weight_norm(self):
|
381 |
-
if self.speaker_embedding_size != 0:
|
382 |
-
torch.nn.utils.remove_weight_norm(self.cond_layer)
|
383 |
-
for layer in self.in_layers:
|
384 |
-
torch.nn.utils.remove_weight_norm(layer)
|
385 |
-
for layer in self.res_skip_layers:
|
386 |
-
torch.nn.utils.remove_weight_norm(layer)
|
387 |
|
388 |
|
389 |
|
@@ -425,22 +238,6 @@ class HifiGanResidualBlock(nn.Module):
|
|
425 |
def get_padding(self, kernel_size, dilation=1):
|
426 |
return (kernel_size * dilation - dilation) // 2
|
427 |
|
428 |
-
def apply_weight_norm(self):
|
429 |
-
weight_norm = nn.utils.weight_norm
|
430 |
-
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
431 |
-
weight_norm = nn.utils.parametrizations.weight_norm
|
432 |
-
|
433 |
-
for layer in self.convs1:
|
434 |
-
weight_norm(layer)
|
435 |
-
for layer in self.convs2:
|
436 |
-
weight_norm(layer)
|
437 |
-
|
438 |
-
def remove_weight_norm(self):
|
439 |
-
for layer in self.convs1:
|
440 |
-
nn.utils.remove_weight_norm(layer)
|
441 |
-
for layer in self.convs2:
|
442 |
-
nn.utils.remove_weight_norm(layer)
|
443 |
-
|
444 |
def forward(self, hidden_states):
|
445 |
for conv1, conv2 in zip(self.convs1, self.convs2):
|
446 |
residual = hidden_states
|
@@ -483,44 +280,18 @@ class VitsHifiGan(nn.Module):
|
|
483 |
channels = config.upsample_initial_channel // (2 ** (i + 1))
|
484 |
for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes):
|
485 |
self.resblocks.append(HifiGanResidualBlock(channels, kernel_size, dilation, config.leaky_relu_slope))
|
486 |
-
|
487 |
self.conv_post = nn.Conv1d(channels, 1, kernel_size=7, stride=1, padding=3, bias=False)
|
488 |
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
def apply_weight_norm(self):
|
493 |
-
weight_norm = nn.utils.weight_norm
|
494 |
-
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
495 |
-
weight_norm = nn.utils.parametrizations.weight_norm
|
496 |
-
|
497 |
-
for layer in self.upsampler:
|
498 |
-
weight_norm(layer)
|
499 |
-
for layer in self.resblocks:
|
500 |
-
layer.apply_weight_norm()
|
501 |
-
|
502 |
-
def remove_weight_norm(self):
|
503 |
-
for layer in self.upsampler:
|
504 |
-
nn.utils.remove_weight_norm(layer)
|
505 |
-
for layer in self.resblocks:
|
506 |
-
layer.remove_weight_norm()
|
507 |
-
|
508 |
-
def forward(
|
509 |
-
self,
|
510 |
-
spectrogram,
|
511 |
-
global_conditioning=None):
|
512 |
-
|
513 |
hidden_states = self.conv_pre(spectrogram)
|
514 |
-
|
515 |
for i in range(self.num_upsamples):
|
516 |
hidden_states = nn.functional.leaky_relu(hidden_states, self.config.leaky_relu_slope)
|
517 |
hidden_states = self.upsampler[i](hidden_states)
|
518 |
-
|
519 |
res_state = self.resblocks[i * self.num_kernels](hidden_states)
|
520 |
for j in range(1, self.num_kernels):
|
521 |
res_state += self.resblocks[i * self.num_kernels + j](hidden_states)
|
522 |
hidden_states = res_state / self.num_kernels
|
523 |
-
|
524 |
hidden_states = nn.functional.leaky_relu(hidden_states)
|
525 |
hidden_states = self.conv_post(hidden_states)
|
526 |
waveform = torch.tanh(hidden_states)
|
@@ -531,27 +302,20 @@ class VitsResidualCouplingLayer(nn.Module):
|
|
531 |
def __init__(self, config):
|
532 |
super().__init__()
|
533 |
self.half_channels = config.flow_size // 2
|
534 |
-
|
535 |
self.conv_pre = nn.Conv1d(self.half_channels, config.hidden_size, 1)
|
536 |
self.wavenet = VitsWaveNet(config, num_layers=config.prior_encoder_num_wavenet_layers)
|
537 |
self.conv_post = nn.Conv1d(config.hidden_size, self.half_channels, 1)
|
538 |
|
539 |
-
def forward(self,
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
log_determinant = torch.sum(log_stddev, [1, 2])
|
550 |
-
return outputs, log_determinant
|
551 |
-
else:
|
552 |
-
second_half = (second_half - mean) * torch.exp(-log_stddev) * padding_mask
|
553 |
-
outputs = torch.cat([first_half, second_half], dim=1)
|
554 |
-
return outputs, None
|
555 |
|
556 |
|
557 |
class VitsResidualCouplingBlock(nn.Module):
|
@@ -561,226 +325,20 @@ class VitsResidualCouplingBlock(nn.Module):
|
|
561 |
for _ in range(config.prior_encoder_num_flows):
|
562 |
self.flows.append(VitsResidualCouplingLayer(config))
|
563 |
|
564 |
-
def forward(self,
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
for flow in reversed(self.flows):
|
571 |
-
inputs = torch.flip(inputs, [1])
|
572 |
-
inputs, _ = flow(inputs, padding_mask, global_conditioning, reverse=True)
|
573 |
-
return inputs
|
574 |
-
|
575 |
-
|
576 |
-
class VitsDilatedDepthSeparableConv(nn.Module):
|
577 |
-
def __init__(self, config, dropout_rate=0.0):
|
578 |
-
super().__init__()
|
579 |
-
kernel_size = config.duration_predictor_kernel_size
|
580 |
-
channels = config.hidden_size
|
581 |
-
self.num_layers = config.depth_separable_num_layers
|
582 |
-
|
583 |
-
self.convs_dilated = nn.ModuleList()
|
584 |
-
self.convs_pointwise = nn.ModuleList()
|
585 |
-
self.norms_1 = nn.ModuleList()
|
586 |
-
self.norms_2 = nn.ModuleList()
|
587 |
-
for i in range(self.num_layers):
|
588 |
-
dilation = kernel_size**i
|
589 |
-
padding = (kernel_size * dilation - dilation) // 2
|
590 |
-
self.convs_dilated.append(
|
591 |
-
nn.Conv1d(
|
592 |
-
in_channels=channels,
|
593 |
-
out_channels=channels,
|
594 |
-
kernel_size=kernel_size,
|
595 |
-
groups=channels,
|
596 |
-
dilation=dilation,
|
597 |
-
padding=padding,
|
598 |
-
)
|
599 |
-
)
|
600 |
-
self.convs_pointwise.append(nn.Conv1d(channels, channels, 1))
|
601 |
-
self.norms_1.append(nn.LayerNorm(channels))
|
602 |
-
self.norms_2.append(nn.LayerNorm(channels))
|
603 |
-
|
604 |
-
def forward(self, inputs, padding_mask, global_conditioning=None):
|
605 |
-
if global_conditioning is not None:
|
606 |
-
inputs = inputs + global_conditioning
|
607 |
-
|
608 |
-
for i in range(self.num_layers):
|
609 |
-
hidden_states = self.convs_dilated[i](inputs * padding_mask)
|
610 |
-
hidden_states = self.norms_1[i](hidden_states.transpose(1, -1)).transpose(1, -1)
|
611 |
-
hidden_states = nn.functional.gelu(hidden_states)
|
612 |
-
hidden_states = self.convs_pointwise[i](hidden_states)
|
613 |
-
hidden_states = self.norms_2[i](hidden_states.transpose(1, -1)).transpose(1, -1)
|
614 |
-
hidden_states = nn.functional.gelu(hidden_states)
|
615 |
-
|
616 |
-
inputs = inputs + hidden_states
|
617 |
-
|
618 |
-
return inputs * padding_mask
|
619 |
-
|
620 |
-
|
621 |
-
class VitsConvFlow(nn.Module):
|
622 |
-
def __init__(self, config):
|
623 |
-
super().__init__()
|
624 |
-
self.filter_channels = config.hidden_size
|
625 |
-
self.half_channels = config.depth_separable_channels // 2
|
626 |
-
self.num_bins = config.duration_predictor_flow_bins
|
627 |
-
self.tail_bound = config.duration_predictor_tail_bound
|
628 |
-
|
629 |
-
self.conv_pre = nn.Conv1d(self.half_channels, self.filter_channels, 1)
|
630 |
-
self.conv_dds = VitsDilatedDepthSeparableConv(config)
|
631 |
-
self.conv_proj = nn.Conv1d(self.filter_channels, self.half_channels * (self.num_bins * 3 - 1), 1)
|
632 |
-
|
633 |
-
def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False):
|
634 |
-
first_half, second_half = torch.split(inputs, [self.half_channels] * 2, dim=1)
|
635 |
-
|
636 |
-
hidden_states = self.conv_pre(first_half)
|
637 |
-
hidden_states = self.conv_dds(hidden_states, padding_mask, global_conditioning)
|
638 |
-
hidden_states = self.conv_proj(hidden_states) * padding_mask
|
639 |
-
|
640 |
-
batch_size, channels, length = first_half.shape
|
641 |
-
hidden_states = hidden_states.reshape(batch_size, channels, -1, length).permute(0, 1, 3, 2)
|
642 |
-
|
643 |
-
unnormalized_widths = hidden_states[..., : self.num_bins] / math.sqrt(self.filter_channels)
|
644 |
-
unnormalized_heights = hidden_states[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels)
|
645 |
-
unnormalized_derivatives = hidden_states[..., 2 * self.num_bins :]
|
646 |
-
|
647 |
-
second_half = _unconstrained_rational_quadratic_spline(
|
648 |
-
second_half,
|
649 |
-
unnormalized_widths,
|
650 |
-
unnormalized_heights,
|
651 |
-
unnormalized_derivatives,
|
652 |
-
reverse=reverse,
|
653 |
-
tail_bound=self.tail_bound,
|
654 |
-
)
|
655 |
-
|
656 |
-
outputs = torch.cat([first_half, second_half], dim=1) * padding_mask
|
657 |
-
|
658 |
-
return outputs, None
|
659 |
-
|
660 |
-
|
661 |
-
class VitsElementwiseAffine(nn.Module):
|
662 |
-
def __init__(self, config):
|
663 |
-
super().__init__()
|
664 |
-
self.channels = config.depth_separable_channels
|
665 |
-
self.translate = nn.Parameter(torch.zeros(self.channels, 1))
|
666 |
-
self.log_scale = nn.Parameter(torch.zeros(self.channels, 1))
|
667 |
|
668 |
-
def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False):
|
669 |
-
if not reverse:
|
670 |
-
raise ValueError
|
671 |
-
# outputs = self.translate + torch.exp(self.log_scale) * inputs
|
672 |
-
# outputs = outputs * padding_mask
|
673 |
-
# log_determinant = torch.sum(self.log_scale * padding_mask, [1, 2])
|
674 |
-
# return outputs, log_determinant
|
675 |
-
else:
|
676 |
-
outputs = (inputs - self.translate) * torch.exp(-self.log_scale) * padding_mask
|
677 |
-
return outputs, None
|
678 |
-
|
679 |
-
|
680 |
-
class VitsStochasticDurationPredictor(nn.Module):
|
681 |
-
def __init__(self, config):
|
682 |
-
super().__init__()
|
683 |
-
embed_dim = config.speaker_embedding_size
|
684 |
-
filter_channels = config.hidden_size
|
685 |
-
|
686 |
-
self.conv_pre = nn.Conv1d(filter_channels, filter_channels, 1)
|
687 |
-
self.conv_proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
688 |
-
self.conv_dds = VitsDilatedDepthSeparableConv(config)
|
689 |
-
|
690 |
-
if embed_dim != 0:
|
691 |
-
self.cond = nn.Conv1d(embed_dim, filter_channels, 1)
|
692 |
-
|
693 |
-
self.flows = nn.ModuleList()
|
694 |
-
self.flows.append(VitsElementwiseAffine(config))
|
695 |
-
for _ in range(config.duration_predictor_num_flows):
|
696 |
-
self.flows.append(VitsConvFlow(config))
|
697 |
-
|
698 |
-
# self.post_conv_pre = nn.Conv1d(1, filter_channels, 1)
|
699 |
-
# self.post_conv_proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
700 |
-
# self.post_conv_dds = VitsDilatedDepthSeparableConv(
|
701 |
-
# config,
|
702 |
-
# dropout_rate=config.duration_predictor_dropout,
|
703 |
-
# )
|
704 |
-
|
705 |
-
# self.post_flows = nn.ModuleList()
|
706 |
-
# self.post_flows.append(VitsElementwiseAffine(config))
|
707 |
-
# for _ in range(config.duration_predictor_num_flows):
|
708 |
-
# self.post_flows.append(VitsConvFlow(config))
|
709 |
-
|
710 |
-
def forward(self, inputs, padding_mask, global_conditioning=None, durations=None, reverse=False, noise_scale=1.0):
|
711 |
-
inputs = torch.detach(inputs)
|
712 |
-
inputs = self.conv_pre(inputs)
|
713 |
-
|
714 |
-
if global_conditioning is not None:
|
715 |
-
raise ValueError
|
716 |
-
# global_conditioning = torch.detach(global_conditioning)
|
717 |
-
# inputs = inputs + self.cond(global_conditioning)
|
718 |
-
|
719 |
-
inputs = self.conv_dds(inputs, padding_mask)
|
720 |
-
inputs = self.conv_proj(inputs) * padding_mask
|
721 |
-
|
722 |
-
if not reverse:
|
723 |
-
raise ValueError
|
724 |
-
# hidden_states = self.post_conv_pre(durations)
|
725 |
-
# hidden_states = self.post_conv_dds(hidden_states, padding_mask)
|
726 |
-
# hidden_states = self.post_conv_proj(hidden_states) * padding_mask
|
727 |
-
|
728 |
-
# random_posterior = (
|
729 |
-
# torch.randn(durations.size(0), 2, durations.size(2)).to(device=inputs.device, dtype=inputs.dtype)
|
730 |
-
# * padding_mask
|
731 |
-
# )
|
732 |
-
# log_determinant_posterior_sum = 0
|
733 |
-
# latents_posterior = random_posterior
|
734 |
-
# for flow in self.post_flows:
|
735 |
-
# latents_posterior, log_determinant = flow(
|
736 |
-
# latents_posterior, padding_mask, global_conditioning=inputs + hidden_states
|
737 |
-
# )
|
738 |
-
# latents_posterior = torch.flip(latents_posterior, [1])
|
739 |
-
# log_determinant_posterior_sum += log_determinant
|
740 |
-
|
741 |
-
# first_half, second_half = torch.split(latents_posterior, [1, 1], dim=1)
|
742 |
-
|
743 |
-
# log_determinant_posterior_sum += torch.sum(
|
744 |
-
# (nn.functional.logsigmoid(first_half) + nn.functional.logsigmoid(-first_half)) * padding_mask, [1, 2]
|
745 |
-
# )
|
746 |
-
# logq = (
|
747 |
-
# torch.sum(-0.5 * (math.log(2 * math.pi) + (random_posterior**2)) * padding_mask, [1, 2])
|
748 |
-
# - log_determinant_posterior_sum
|
749 |
-
# )
|
750 |
-
|
751 |
-
# first_half = (durations - torch.sigmoid(first_half)) * padding_mask
|
752 |
-
# first_half = torch.log(torch.clamp_min(first_half, 1e-5)) * padding_mask
|
753 |
-
# log_determinant_sum = torch.sum(-first_half, [1, 2])
|
754 |
-
|
755 |
-
# latents = torch.cat([first_half, second_half], dim=1)
|
756 |
-
# for flow in self.flows:
|
757 |
-
# latents, log_determinant = flow(latents, padding_mask, global_conditioning=inputs)
|
758 |
-
# latents = torch.flip(latents, [1])
|
759 |
-
# log_determinant_sum += log_determinant
|
760 |
-
|
761 |
-
# nll = torch.sum(0.5 * (math.log(2 * math.pi) + (latents**2)) * padding_mask, [1, 2]) - log_determinant_sum
|
762 |
-
# return nll + logq
|
763 |
-
else:
|
764 |
-
flows = list(reversed(self.flows))
|
765 |
-
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
|
766 |
-
|
767 |
-
latents = (
|
768 |
-
torch.randn(inputs.size(0), 2, inputs.size(2)).to(device=inputs.device, dtype=inputs.dtype)
|
769 |
-
* noise_scale
|
770 |
-
)
|
771 |
-
for flow in flows:
|
772 |
-
latents = torch.flip(latents, [1])
|
773 |
-
latents, _ = flow(latents, padding_mask, global_conditioning=inputs, reverse=True)
|
774 |
-
|
775 |
-
log_duration, _ = torch.split(latents, [1, 1], dim=1)
|
776 |
-
return log_duration
|
777 |
|
778 |
|
779 |
|
780 |
|
781 |
|
782 |
class VitsAttention(nn.Module):
|
783 |
-
"""
|
784 |
|
785 |
def __init__(self, config):
|
786 |
super().__init__()
|
@@ -793,36 +351,22 @@ class VitsAttention(nn.Module):
|
|
793 |
self.scaling = self.head_dim**-0.5
|
794 |
|
795 |
if (self.head_dim * self.num_heads) != self.embed_dim:
|
796 |
-
raise ValueError
|
797 |
-
f"hidden_size must be divisible by num_attention_heads (got `hidden_size`: {self.embed_dim}"
|
798 |
-
f" and `num_attention_heads`: {self.num_heads})."
|
799 |
-
)
|
800 |
-
|
801 |
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
|
802 |
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
|
803 |
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
|
804 |
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
|
805 |
|
806 |
-
if self.window_size:
|
807 |
-
# Those provide relative pos embs for k/v interpolated from 2*4+1 to 1027 time frames - duration of txt
|
808 |
-
self.emb_rel_k = nn.Parameter(torch.randn(1, self.window_size * 2 + 1, self.head_dim) * self.scaling)
|
809 |
-
self.emb_rel_v = nn.Parameter(torch.randn(1, self.window_size * 2 + 1, self.head_dim) * self.scaling)
|
810 |
-
|
811 |
def _shape(self, tensor, seq_len, bsz):
|
812 |
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
813 |
|
814 |
def forward(
|
815 |
self,
|
816 |
hidden_states,
|
817 |
-
|
818 |
-
|
819 |
-
layer_head_mask: Optional[torch.Tensor] = None,
|
820 |
-
output_attentions: bool = False,
|
821 |
):
|
822 |
-
|
823 |
-
|
824 |
-
# if key_value_states are provided this layer is used as a cross-attention layer
|
825 |
-
# for the decoder
|
826 |
|
827 |
bsz, tgt_len, _ = hidden_states.size()
|
828 |
|
@@ -840,36 +384,9 @@ class VitsAttention(nn.Module):
|
|
840 |
|
841 |
src_len = key_states.size(1)
|
842 |
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
843 |
-
|
844 |
-
|
845 |
-
|
846 |
-
if self.window_size is not None:
|
847 |
-
# 4
|
848 |
-
# key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, src_len)
|
849 |
-
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, src_len) # try fix k.shape[2] to have consistent voice deu
|
850 |
-
# print(f'{self.emb_rel_k.shape=} {key_relative_embeddings.shape=}\n\nL855')
|
851 |
-
relative_logits = torch.matmul(query_states, key_relative_embeddings.transpose(-2, -1))
|
852 |
-
# -- only here (key)
|
853 |
-
rel_pos_bias = self._relative_position_to_absolute_position(relative_logits)
|
854 |
-
attn_weights += rel_pos_bias
|
855 |
-
|
856 |
-
if attention_mask is not None:
|
857 |
-
|
858 |
-
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
859 |
-
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
860 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
861 |
attn_output = torch.bmm(attn_weights,
|
862 |
value_states)
|
863 |
-
|
864 |
-
|
865 |
-
|
866 |
-
if self.window_size is not None:
|
867 |
-
# Entering here with self.window_size = 4
|
868 |
-
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, src_len)
|
869 |
-
relative_weights = self._absolute_position_to_relative_position(attn_weights)
|
870 |
-
rel_pos_bias = torch.matmul(relative_weights, value_relative_embeddings)
|
871 |
-
attn_output += rel_pos_bias
|
872 |
-
|
873 |
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
874 |
attn_output = attn_output.transpose(1, 2)
|
875 |
|
@@ -881,42 +398,6 @@ class VitsAttention(nn.Module):
|
|
881 |
|
882 |
return attn_output, None #attn_weights_reshaped
|
883 |
|
884 |
-
def _get_relative_embeddings(self, relative_embeddings, length):
|
885 |
-
pad_length = max(length - (self.window_size + 1), 0)
|
886 |
-
if pad_length > 0:
|
887 |
-
relative_embeddings = nn.functional.pad(relative_embeddings, [0, 0, pad_length, pad_length, 0, 0])
|
888 |
-
|
889 |
-
slice_start_position = max((self.window_size + 1) - length, 0)
|
890 |
-
slice_end_position = slice_start_position + 2 * length - 1
|
891 |
-
return relative_embeddings[:, slice_start_position:slice_end_position]
|
892 |
-
|
893 |
-
def _relative_position_to_absolute_position(self, x):
|
894 |
-
batch_heads, length, _ = x.size()
|
895 |
-
|
896 |
-
# Concat columns of pad to shift from relative to absolute indexing.
|
897 |
-
x = nn.functional.pad(x, [0, 1, 0, 0, 0, 0])
|
898 |
-
|
899 |
-
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
900 |
-
x_flat = x.view([batch_heads, length * 2 * length])
|
901 |
-
x_flat = nn.functional.pad(x_flat, [0, length - 1, 0, 0])
|
902 |
-
|
903 |
-
# Reshape and slice out the padded elements.
|
904 |
-
x_final = x_flat.view([batch_heads, length + 1, 2 * length - 1])
|
905 |
-
x_final = x_final[:, :length, length - 1 :]
|
906 |
-
return x_final
|
907 |
-
|
908 |
-
def _absolute_position_to_relative_position(self, x):
|
909 |
-
batch_heads, length, _ = x.size()
|
910 |
-
|
911 |
-
# Pad along column
|
912 |
-
x = nn.functional.pad(x, [0, length - 1, 0, 0, 0, 0])
|
913 |
-
x_flat = x.view([batch_heads, length * (2 * length - 1)])
|
914 |
-
|
915 |
-
# Add 0's in the beginning that will skew the elements after reshape
|
916 |
-
x_flat = nn.functional.pad(x_flat, [length, 0, 0, 0])
|
917 |
-
x_final = x_flat.view([batch_heads, length, 2 * length])[:, :, 1:]
|
918 |
-
return x_final
|
919 |
-
|
920 |
|
921 |
class VitsFeedForward(nn.Module):
|
922 |
def __init__(self, config):
|
@@ -933,25 +414,15 @@ class VitsFeedForward(nn.Module):
|
|
933 |
else:
|
934 |
self.padding = None
|
935 |
|
936 |
-
def forward(self, hidden_states
|
937 |
hidden_states = hidden_states.permute(0, 2, 1)
|
938 |
-
padding_mask = padding_mask.permute(0, 2, 1)
|
939 |
-
|
940 |
-
hidden_states = hidden_states * padding_mask
|
941 |
if self.padding is not None:
|
942 |
hidden_states = nn.functional.pad(hidden_states, self.padding)
|
943 |
-
|
944 |
hidden_states = self.conv_1(hidden_states)
|
945 |
hidden_states = self.act_fn(hidden_states)
|
946 |
-
|
947 |
-
|
948 |
-
hidden_states = hidden_states * padding_mask
|
949 |
if self.padding is not None:
|
950 |
hidden_states = nn.functional.pad(hidden_states, self.padding)
|
951 |
-
|
952 |
hidden_states = self.conv_2(hidden_states)
|
953 |
-
hidden_states = hidden_states * padding_mask
|
954 |
-
|
955 |
hidden_states = hidden_states.permute(0, 2, 1)
|
956 |
return hidden_states
|
957 |
|
@@ -960,22 +431,19 @@ class VitsEncoderLayer(nn.Module):
|
|
960 |
def __init__(self, config):
|
961 |
super().__init__()
|
962 |
self.attention = VitsAttention(config)
|
963 |
-
|
964 |
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
965 |
self.feed_forward = VitsFeedForward(config)
|
966 |
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
967 |
|
968 |
def forward(
|
969 |
self,
|
970 |
-
hidden_states
|
971 |
-
|
972 |
-
attention_mask: Optional[torch.Tensor] = None,
|
973 |
-
output_attentions: bool = False,
|
974 |
):
|
975 |
residual = hidden_states
|
976 |
hidden_states, attn_weights = self.attention(
|
977 |
hidden_states=hidden_states,
|
978 |
-
attention_mask=attention_mask,
|
979 |
output_attentions=output_attentions,
|
980 |
)
|
981 |
|
@@ -983,15 +451,12 @@ class VitsEncoderLayer(nn.Module):
|
|
983 |
hidden_states = self.layer_norm(residual + hidden_states)
|
984 |
|
985 |
residual = hidden_states
|
986 |
-
hidden_states = self.feed_forward(hidden_states
|
987 |
-
|
988 |
hidden_states = self.final_layer_norm(residual + hidden_states)
|
989 |
|
990 |
outputs = (hidden_states,)
|
991 |
|
992 |
-
if output_attentions:
|
993 |
-
outputs += (attn_weights,)
|
994 |
-
|
995 |
return outputs
|
996 |
|
997 |
|
@@ -1005,52 +470,24 @@ class VitsEncoder(nn.Module):
|
|
1005 |
|
1006 |
def forward(
|
1007 |
self,
|
1008 |
-
hidden_states
|
1009 |
-
|
1010 |
-
|
1011 |
-
|
1012 |
-
output_hidden_states: Optional[bool] = None,
|
1013 |
-
return_dict: Optional[bool] = None,
|
1014 |
):
|
1015 |
-
|
1016 |
-
|
1017 |
-
|
1018 |
-
# expand attention_mask
|
1019 |
-
if attention_mask is not None:
|
1020 |
-
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
1021 |
-
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
|
1022 |
-
|
1023 |
-
hidden_states = hidden_states * padding_mask
|
1024 |
-
|
1025 |
-
|
1026 |
-
|
1027 |
-
for encoder_layer in self.layers:
|
1028 |
-
if output_hidden_states:
|
1029 |
-
all_hidden_states = all_hidden_states + (hidden_states,)
|
1030 |
-
|
1031 |
-
layer_outputs = encoder_layer(
|
1032 |
-
hidden_states,
|
1033 |
-
attention_mask=attention_mask,
|
1034 |
-
padding_mask=padding_mask,
|
1035 |
-
output_attentions=output_attentions,
|
1036 |
-
)
|
1037 |
hidden_states = layer_outputs[0]
|
1038 |
-
|
1039 |
-
if output_attentions:
|
1040 |
-
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
1041 |
-
|
1042 |
-
hidden_states = hidden_states * padding_mask
|
1043 |
-
|
1044 |
return BaseModelOutput(
|
1045 |
last_hidden_state=hidden_states,
|
1046 |
-
hidden_states=all_hidden_states,
|
1047 |
-
attentions=all_self_attentions,
|
1048 |
)
|
1049 |
|
1050 |
|
1051 |
class VitsTextEncoder(nn.Module):
|
1052 |
"""
|
1053 |
-
|
1054 |
"""
|
1055 |
|
1056 |
def __init__(self, config):
|
@@ -1060,75 +497,30 @@ class VitsTextEncoder(nn.Module):
|
|
1060 |
self.encoder = VitsEncoder(config) # 6 Layers of VitsAttention
|
1061 |
self.project = nn.Conv1d(config.hidden_size, config.flow_size * 2, kernel_size=1)
|
1062 |
|
1063 |
-
|
1064 |
-
|
1065 |
-
|
1066 |
-
|
1067 |
-
|
1068 |
-
|
1069 |
-
def forward(
|
1070 |
-
self,
|
1071 |
-
input_ids: torch.Tensor,
|
1072 |
-
padding_mask: torch.FloatTensor,
|
1073 |
-
attention_mask: Optional[torch.Tensor] = None,
|
1074 |
-
output_attentions: Optional[bool] = None,
|
1075 |
-
output_hidden_states: Optional[bool] = None,
|
1076 |
-
return_dict: Optional[bool] = True,
|
1077 |
-
):
|
1078 |
-
hidden_states = self.embed_tokens(input_ids) * math.sqrt(self.config.hidden_size)
|
1079 |
-
|
1080 |
-
encoder_outputs = self.encoder(
|
1081 |
-
hidden_states=hidden_states,
|
1082 |
-
padding_mask=padding_mask,
|
1083 |
-
attention_mask=attention_mask,
|
1084 |
-
output_attentions=output_attentions,
|
1085 |
-
output_hidden_states=output_hidden_states,
|
1086 |
-
return_dict=return_dict,
|
1087 |
-
)
|
1088 |
|
1089 |
-
|
1090 |
-
|
1091 |
-
stats = self.project(last_hidden_state.transpose(1, 2)).transpose(1, 2) * padding_mask
|
1092 |
prior_means, prior_log_variances = torch.split(stats, self.config.flow_size, dim=2)
|
1093 |
|
1094 |
return VitsTextEncoderOutput(
|
1095 |
last_hidden_state=last_hidden_state,
|
1096 |
prior_means=prior_means,
|
1097 |
-
prior_log_variances=prior_log_variances,
|
1098 |
-
hidden_states=encoder_outputs.hidden_states,
|
1099 |
-
attentions=encoder_outputs.attentions,
|
1100 |
)
|
1101 |
|
1102 |
|
1103 |
class VitsPreTrainedModel(PreTrainedModel):
|
1104 |
-
"""
|
1105 |
-
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
1106 |
-
models.
|
1107 |
-
"""
|
1108 |
-
|
1109 |
config_class = VitsConfig
|
1110 |
base_model_prefix = "vits"
|
1111 |
main_input_name = "input_ids"
|
1112 |
supports_gradient_checkpointing = True
|
1113 |
|
1114 |
-
def _init_weights(self, module):
|
1115 |
-
"""Initialize the weights"""
|
1116 |
-
if isinstance(module, nn.Linear):
|
1117 |
-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
1118 |
-
if module.bias is not None:
|
1119 |
-
module.bias.data.zero_()
|
1120 |
-
elif isinstance(module, nn.LayerNorm):
|
1121 |
-
module.bias.data.zero_()
|
1122 |
-
module.weight.data.fill_(1.0)
|
1123 |
-
elif isinstance(module, nn.Conv1d):
|
1124 |
-
nn.init.kaiming_normal_(module.weight)
|
1125 |
-
if module.bias is not None:
|
1126 |
-
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
|
1127 |
-
nn.init.uniform_(module.bias, a=-k, b=k)
|
1128 |
-
elif isinstance(module, nn.Embedding):
|
1129 |
-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
1130 |
-
if module.padding_idx is not None:
|
1131 |
-
module.weight.data[module.padding_idx].zero_()
|
1132 |
|
1133 |
|
1134 |
class VitsModel(VitsPreTrainedModel):
|
@@ -1138,27 +530,9 @@ class VitsModel(VitsPreTrainedModel):
|
|
1138 |
self.text_encoder = VitsTextEncoder(config) # has VitsEncoder that includes 6L of VitsAttention
|
1139 |
self.flow = VitsResidualCouplingBlock(config)
|
1140 |
self.decoder = VitsHifiGan(config)
|
1141 |
-
|
1142 |
-
if config.use_stochastic_duration_prediction:
|
1143 |
-
self.duration_predictor = VitsStochasticDurationPredictor(config)
|
1144 |
-
else:
|
1145 |
-
raise ValueError
|
1146 |
-
# self.duration_predictor = VitsDurationPredictor(config)
|
1147 |
-
|
1148 |
-
if config.num_speakers > 1:
|
1149 |
-
self.embed_speaker = nn.Embedding(config.num_speakers, config.speaker_embedding_size)
|
1150 |
-
|
1151 |
-
|
1152 |
-
|
1153 |
-
self.noise_scale = config.noise_scale
|
1154 |
-
self.noise_scale_duration = config.noise_scale_duration
|
1155 |
-
|
1156 |
# Initialize weights and apply final processing
|
1157 |
self.post_init()
|
1158 |
|
1159 |
-
def get_encoder(self):
|
1160 |
-
return self.text_encoder
|
1161 |
-
|
1162 |
def forward(
|
1163 |
self,
|
1164 |
input_ids = None,
|
@@ -1168,69 +542,37 @@ class VitsModel(VitsPreTrainedModel):
|
|
1168 |
output_hidden_states = None,
|
1169 |
return_dict = None,
|
1170 |
labels = None,
|
1171 |
-
speed=None,
|
|
|
1172 |
):
|
1173 |
-
|
1174 |
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1175 |
-
output_hidden_states = (
|
1176 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1177 |
-
)
|
1178 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1179 |
-
|
1180 |
-
if labels is not None:
|
1181 |
-
raise NotImplementedError("Training of VITS is not supported yet.")
|
1182 |
-
|
1183 |
mask_dtype = self.text_encoder.embed_tokens.weight.dtype
|
1184 |
if attention_mask is not None:
|
1185 |
input_padding_mask = attention_mask.unsqueeze(-1).to(mask_dtype)
|
1186 |
else:
|
1187 |
input_padding_mask = torch.ones_like(input_ids).unsqueeze(-1).to(mask_dtype)
|
1188 |
-
|
1189 |
-
|
1190 |
-
if not 0 <= speaker_id < self.config.num_speakers:
|
1191 |
-
raise ValueError(f"Set `speaker_id` in the range 0-{self.config.num_speakers - 1}.")
|
1192 |
-
if isinstance(speaker_id, int):
|
1193 |
-
speaker_id = torch.full(size=(1,), fill_value=speaker_id, device=self.device)
|
1194 |
-
speaker_embeddings = self.embed_speaker(speaker_id).unsqueeze(-1)
|
1195 |
-
else:
|
1196 |
-
speaker_embeddings = None
|
1197 |
-
|
1198 |
-
text_encoder_output = self.text_encoder(
|
1199 |
-
input_ids=input_ids,
|
1200 |
-
padding_mask=input_padding_mask,
|
1201 |
-
attention_mask=attention_mask,
|
1202 |
-
output_attentions=output_attentions,
|
1203 |
-
output_hidden_states=output_hidden_states,
|
1204 |
-
return_dict=return_dict,
|
1205 |
-
)
|
1206 |
-
hidden_states = text_encoder_output[0] if not return_dict else text_encoder_output.last_hidden_state
|
1207 |
-
hidden_states = hidden_states.transpose(1, 2)
|
1208 |
input_padding_mask = input_padding_mask.transpose(1, 2)
|
1209 |
-
prior_means =
|
1210 |
-
|
1211 |
-
|
1212 |
-
if
|
1213 |
-
|
1214 |
-
|
1215 |
-
|
1216 |
-
|
1217 |
-
|
1218 |
-
|
1219 |
-
)
|
1220 |
else:
|
1221 |
-
|
1222 |
-
|
1223 |
-
|
1224 |
-
|
1225 |
-
|
1226 |
predicted_lengths = torch.clamp_min(torch.sum(duration, [1, 2]), 1).long()
|
1227 |
-
|
1228 |
-
# Create a padding mask for the output lengths of shape (batch, 1, max_output_length)
|
1229 |
indices = torch.arange(predicted_lengths.max(), dtype=predicted_lengths.dtype, device=predicted_lengths.device)
|
1230 |
output_padding_mask = indices.unsqueeze(0) < predicted_lengths.unsqueeze(1)
|
1231 |
output_padding_mask = output_padding_mask.unsqueeze(1).to(input_padding_mask.dtype)
|
1232 |
-
|
1233 |
-
# Reconstruct an attention tensor of shape (batch, 1, out_length, in_length)
|
1234 |
attn_mask = torch.unsqueeze(input_padding_mask, 2) * torch.unsqueeze(output_padding_mask, -1)
|
1235 |
batch_size, _, output_length, input_length = attn_mask.shape
|
1236 |
cum_duration = torch.cumsum(duration, -1).view(batch_size * input_length, 1)
|
@@ -1239,106 +581,30 @@ class VitsModel(VitsPreTrainedModel):
|
|
1239 |
valid_indices = valid_indices.to(attn_mask.dtype).view(batch_size, input_length, output_length)
|
1240 |
padded_indices = valid_indices - nn.functional.pad(valid_indices, [0, 0, 1, 0, 0, 0])[:, :-1]
|
1241 |
attn = padded_indices.unsqueeze(1).transpose(2, 3) * attn_mask
|
1242 |
-
|
1243 |
-
# Expand prior distribution
|
1244 |
-
prior_means = torch.matmul(attn.squeeze(1), prior_means).transpose(1, 2)
|
1245 |
-
prior_log_variances = torch.matmul(attn.squeeze(1), prior_log_variances).transpose(1, 2)
|
1246 |
-
|
1247 |
-
prior_latents = prior_means + torch.randn_like(prior_means) * torch.exp(prior_log_variances) * self.noise_scale
|
1248 |
-
latents = self.flow(prior_latents, output_padding_mask, speaker_embeddings, reverse=True)
|
1249 |
-
|
1250 |
-
spectrogram = latents * output_padding_mask
|
1251 |
-
waveform = self.decoder(spectrogram, speaker_embeddings)
|
1252 |
-
waveform = waveform.squeeze(1)
|
1253 |
-
sequence_lengths = predicted_lengths * np.prod(self.config.upsample_rates)
|
1254 |
-
|
1255 |
-
if not return_dict:
|
1256 |
-
outputs = (waveform, sequence_lengths, spectrogram) + text_encoder_output[3:]
|
1257 |
-
return outputs
|
1258 |
-
|
1259 |
-
return waveform
|
1260 |
-
|
1261 |
-
|
1262 |
-
|
1263 |
|
1264 |
|
|
|
|
|
|
|
|
|
1265 |
|
|
|
1266 |
|
1267 |
|
1268 |
|
|
|
1269 |
|
|
|
|
|
1270 |
|
1271 |
-
#
|
1272 |
|
1273 |
-
|
1274 |
-
# Copyright 2023 The Kakao Enterprise Authors, the MMS-TTS Authors and the HuggingFace Inc. team. All rights reserved.
|
1275 |
-
#
|
1276 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
1277 |
-
# you may not use this file except in compliance with the License.
|
1278 |
-
# You may obtain a copy of the License at
|
1279 |
-
#
|
1280 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
1281 |
-
#
|
1282 |
-
# Unless required by applicable law or agreed to in writing, software
|
1283 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
1284 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
1285 |
-
# See the License for the specific language governing permissions and
|
1286 |
-
# limitations under the License.
|
1287 |
-
"""Tokenization class for VITS."""
|
1288 |
-
|
1289 |
-
import json
|
1290 |
-
import os
|
1291 |
-
import re
|
1292 |
-
from typing import Any, Dict, List, Optional, Tuple, Union
|
1293 |
-
|
1294 |
-
from transformers.tokenization_utils import PreTrainedTokenizer
|
1295 |
-
from transformers.utils import is_phonemizer_available, is_uroman_available
|
1296 |
-
|
1297 |
-
|
1298 |
-
if is_phonemizer_available():
|
1299 |
-
import phonemizer
|
1300 |
-
|
1301 |
-
if is_uroman_available():
|
1302 |
-
import uroman as ur
|
1303 |
-
|
1304 |
-
|
1305 |
-
|
1306 |
-
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json"}
|
1307 |
-
|
1308 |
-
|
1309 |
-
def has_non_roman_characters(input_string):
|
1310 |
-
# Find any character outside the ASCII range
|
1311 |
-
non_roman_pattern = re.compile(r"[^\x00-\x7F]")
|
1312 |
-
|
1313 |
-
# Search the input string for non-Roman characters
|
1314 |
-
match = non_roman_pattern.search(input_string)
|
1315 |
-
has_non_roman = match is not None
|
1316 |
-
return has_non_roman
|
1317 |
|
1318 |
|
1319 |
class VitsTokenizer(PreTrainedTokenizer):
|
1320 |
-
"""
|
1321 |
-
Construct a VITS tokenizer. Also supports MMS-TTS.
|
1322 |
-
|
1323 |
-
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
1324 |
-
this superclass for more information regarding those methods.
|
1325 |
-
|
1326 |
-
Args:
|
1327 |
-
vocab_file (`str`):
|
1328 |
-
Path to the vocabulary file.
|
1329 |
-
language (`str`, *optional*):
|
1330 |
-
Language identifier.
|
1331 |
-
add_blank (`bool`, *optional*, defaults to `True`):
|
1332 |
-
Whether to insert token id 0 in between the other tokens.
|
1333 |
-
normalize (`bool`, *optional*, defaults to `True`):
|
1334 |
-
Whether to normalize the input text by removing all casing and punctuation.
|
1335 |
-
phonemize (`bool`, *optional*, defaults to `True`):
|
1336 |
-
Whether to convert the input text into phonemes.
|
1337 |
-
is_uroman (`bool`, *optional*, defaults to `False`):
|
1338 |
-
Whether the `uroman` Romanizer needs to be applied to the input text prior to tokenizing.
|
1339 |
-
"""
|
1340 |
-
|
1341 |
-
vocab_files_names = VOCAB_FILES_NAMES
|
1342 |
model_input_names = ["input_ids", "attention_mask"]
|
1343 |
|
1344 |
def __init__(
|
@@ -1412,12 +678,8 @@ class VitsTokenizer(PreTrainedTokenizer):
|
|
1412 |
return text
|
1413 |
|
1414 |
def prepare_for_tokenization(
|
1415 |
-
self, text: str, is_split_into_words: bool = False, normalize
|
1416 |
-
|
1417 |
-
'''
|
1418 |
-
Performs any necessary transformations before tokenization.
|
1419 |
-
|
1420 |
-
'''
|
1421 |
normalize = normalize if normalize is not None else self.normalize
|
1422 |
|
1423 |
if normalize:
|
@@ -1462,21 +724,18 @@ class VitsTokenizer(PreTrainedTokenizer):
|
|
1462 |
tokens = list(text)
|
1463 |
|
1464 |
if self.add_blank:
|
1465 |
-
|
1466 |
-
|
1467 |
-
|
|
|
|
|
1468 |
|
1469 |
return tokens
|
1470 |
|
1471 |
-
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
1472 |
-
if self.add_blank and len(tokens) > 1:
|
1473 |
-
tokens = tokens[1::2]
|
1474 |
-
return "".join(tokens)
|
1475 |
-
|
1476 |
def _convert_token_to_id(self, token):
|
1477 |
"""Converts a token (str) in an id using the vocab."""
|
1478 |
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
1479 |
|
1480 |
def _convert_id_to_token(self, index):
|
1481 |
"""Converts an index (integer) in a token (str) using the vocab."""
|
1482 |
-
return self.decoder.get(index)
|
|
|
1 |
import math
|
2 |
from dataclasses import dataclass
|
|
|
3 |
import numpy as np
|
4 |
import torch
|
|
|
5 |
from torch import nn
|
|
|
|
|
6 |
from transformers.modeling_outputs import BaseModelOutput, ModelOutput
|
7 |
from transformers.modeling_utils import PreTrainedModel
|
8 |
from transformers.configuration_utils import PretrainedConfig
|
9 |
+
import json
|
10 |
+
import os
|
11 |
+
import re
|
12 |
+
from typing import Any, Dict, List, Optional, Tuple
|
13 |
+
from transformers.tokenization_utils import PreTrainedTokenizer
|
14 |
+
import phonemizer
|
15 |
+
import uroman as ur
|
16 |
+
import torch.nn.functional as F
|
17 |
+
|
18 |
+
|
19 |
+
def has_non_roman_characters(input_string):
|
20 |
+
# Find any character outside the ASCII range
|
21 |
+
non_roman_pattern = re.compile(r"[^\x00-\x7F]")
|
22 |
+
|
23 |
+
# Search the input string for non-Roman characters
|
24 |
+
match = non_roman_pattern.search(input_string)
|
25 |
+
has_non_roman = match is not None
|
26 |
+
return has_non_roman
|
27 |
+
|
28 |
|
29 |
class VitsConfig(PretrainedConfig):
|
30 |
|
|
|
89 |
self.ffn_kernel_size = ffn_kernel_size
|
90 |
self.flow_size = flow_size
|
91 |
self.spectrogram_bins = spectrogram_bins
|
|
|
|
|
92 |
self.initializer_range = initializer_range
|
93 |
self.layer_norm_eps = layer_norm_eps
|
94 |
+
# self.use_stochastic_duration_prediction = use_stochastic_duration_prediction
|
95 |
self.num_speakers = num_speakers
|
96 |
self.speaker_embedding_size = speaker_embedding_size
|
97 |
self.upsample_initial_channel = upsample_initial_channel
|
|
|
105 |
self.duration_predictor_flow_bins = duration_predictor_flow_bins
|
106 |
self.duration_predictor_tail_bound = duration_predictor_tail_bound
|
107 |
self.duration_predictor_kernel_size = duration_predictor_kernel_size
|
|
|
108 |
self.duration_predictor_num_flows = duration_predictor_num_flows
|
109 |
self.duration_predictor_filter_channels = duration_predictor_filter_channels
|
110 |
self.prior_encoder_num_flows = prior_encoder_num_flows
|
|
|
112 |
self.posterior_encoder_num_wavenet_layers = posterior_encoder_num_wavenet_layers
|
113 |
self.wavenet_kernel_size = wavenet_kernel_size
|
114 |
self.wavenet_dilation_rate = wavenet_dilation_rate
|
|
|
|
|
115 |
self.noise_scale = noise_scale
|
116 |
self.noise_scale_duration = noise_scale_duration
|
117 |
self.sampling_rate = sampling_rate
|
|
|
131 |
last_hidden_state: torch.FloatTensor = None
|
132 |
prior_means: torch.FloatTensor = None
|
133 |
prior_log_variances: torch.FloatTensor = None
|
134 |
+
hidden_states: torch.FloatTensor = None
|
135 |
+
attentions: torch.FloatTensor = None
|
136 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
|
138 |
|
139 |
class VitsWaveNet(torch.nn.Module):
|
|
|
141 |
super().__init__()
|
142 |
self.hidden_size = config.hidden_size
|
143 |
self.num_layers = num_layers
|
|
|
144 |
self.in_layers = torch.nn.ModuleList()
|
145 |
self.res_skip_layers = torch.nn.ModuleList()
|
146 |
+
# if hasattr(nn.utils.parametrizations, "weight_norm"):
|
147 |
+
# # raise ValueError
|
148 |
+
weight_norm = nn.utils.parametrizations.weight_norm
|
149 |
+
# else:
|
150 |
+
# raise ValueError
|
151 |
+
# # weight_norm = nn.utils.weight_norm
|
|
|
|
|
|
|
|
|
|
|
152 |
for i in range(num_layers):
|
153 |
dilation = config.wavenet_dilation_rate**i
|
154 |
padding = (config.wavenet_kernel_size * dilation - dilation) // 2
|
|
|
167 |
res_skip_channels = 2 * config.hidden_size
|
168 |
else:
|
169 |
res_skip_channels = config.hidden_size
|
|
|
170 |
res_skip_layer = torch.nn.Conv1d(config.hidden_size, res_skip_channels, 1)
|
171 |
res_skip_layer = weight_norm(res_skip_layer, name="weight")
|
172 |
self.res_skip_layers.append(res_skip_layer)
|
173 |
|
174 |
+
def forward(self,
|
175 |
+
inputs):
|
176 |
outputs = torch.zeros_like(inputs)
|
177 |
num_channels = torch.IntTensor([self.hidden_size])[0]
|
|
|
|
|
178 |
for i in range(self.num_layers):
|
179 |
in_act = self.in_layers[i](inputs)
|
|
|
|
|
180 |
# global_states = torch.zeros_like(hidden_states) # style ?
|
|
|
181 |
# acts = fused_add_tanh_sigmoid_multiply(hidden_states, global_states, num_channels_tensor[0])
|
|
|
182 |
# --
|
183 |
# def fused_add_tanh_sigmoid_multiply(input_a, input_b, num_channels):
|
184 |
# in_act = input_a # + input_b
|
185 |
t_act = torch.tanh(in_act[:, :num_channels, :])
|
186 |
s_act = torch.sigmoid(in_act[:, num_channels:, :])
|
187 |
acts = t_act * s_act
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
res_skip_acts = self.res_skip_layers[i](acts)
|
189 |
if i < self.num_layers - 1:
|
190 |
res_acts = res_skip_acts[:, : self.hidden_size, :]
|
191 |
+
inputs = inputs + res_acts
|
192 |
outputs = outputs + res_skip_acts[:, self.hidden_size :, :]
|
193 |
else:
|
194 |
outputs = outputs + res_skip_acts
|
195 |
+
return outputs
|
196 |
+
|
197 |
+
|
198 |
|
|
|
199 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
|
201 |
|
202 |
|
|
|
238 |
def get_padding(self, kernel_size, dilation=1):
|
239 |
return (kernel_size * dilation - dilation) // 2
|
240 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
def forward(self, hidden_states):
|
242 |
for conv1, conv2 in zip(self.convs1, self.convs2):
|
243 |
residual = hidden_states
|
|
|
280 |
channels = config.upsample_initial_channel // (2 ** (i + 1))
|
281 |
for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes):
|
282 |
self.resblocks.append(HifiGanResidualBlock(channels, kernel_size, dilation, config.leaky_relu_slope))
|
|
|
283 |
self.conv_post = nn.Conv1d(channels, 1, kernel_size=7, stride=1, padding=3, bias=False)
|
284 |
|
285 |
+
def forward(self,
|
286 |
+
spectrogram):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
287 |
hidden_states = self.conv_pre(spectrogram)
|
|
|
288 |
for i in range(self.num_upsamples):
|
289 |
hidden_states = nn.functional.leaky_relu(hidden_states, self.config.leaky_relu_slope)
|
290 |
hidden_states = self.upsampler[i](hidden_states)
|
|
|
291 |
res_state = self.resblocks[i * self.num_kernels](hidden_states)
|
292 |
for j in range(1, self.num_kernels):
|
293 |
res_state += self.resblocks[i * self.num_kernels + j](hidden_states)
|
294 |
hidden_states = res_state / self.num_kernels
|
|
|
295 |
hidden_states = nn.functional.leaky_relu(hidden_states)
|
296 |
hidden_states = self.conv_post(hidden_states)
|
297 |
waveform = torch.tanh(hidden_states)
|
|
|
302 |
def __init__(self, config):
|
303 |
super().__init__()
|
304 |
self.half_channels = config.flow_size // 2
|
|
|
305 |
self.conv_pre = nn.Conv1d(self.half_channels, config.hidden_size, 1)
|
306 |
self.wavenet = VitsWaveNet(config, num_layers=config.prior_encoder_num_wavenet_layers)
|
307 |
self.conv_post = nn.Conv1d(config.hidden_size, self.half_channels, 1)
|
308 |
|
309 |
+
def forward(self,
|
310 |
+
x,
|
311 |
+
reverse=False):
|
312 |
+
first_half, second_half = torch.split(x, [self.half_channels] * 2, dim=1)
|
313 |
+
hidden_states = self.conv_pre(first_half)
|
314 |
+
hidden_states = self.wavenet(hidden_states)
|
315 |
+
mean = self.conv_post(hidden_states)
|
316 |
+
second_half = (second_half - mean)
|
317 |
+
outputs = torch.cat([first_half, second_half], dim=1)
|
318 |
+
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
319 |
|
320 |
|
321 |
class VitsResidualCouplingBlock(nn.Module):
|
|
|
325 |
for _ in range(config.prior_encoder_num_flows):
|
326 |
self.flows.append(VitsResidualCouplingLayer(config))
|
327 |
|
328 |
+
def forward(self, x, reverse=False):
|
329 |
+
# x L [1, 192, 481]
|
330 |
+
for flow in reversed(self.flows):
|
331 |
+
x = torch.flip(x, [1]) # flipud CHANNELs
|
332 |
+
x = flow(x, reverse=True)
|
333 |
+
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
335 |
|
336 |
|
337 |
|
338 |
|
339 |
|
340 |
class VitsAttention(nn.Module):
|
341 |
+
"""has no positional info"""
|
342 |
|
343 |
def __init__(self, config):
|
344 |
super().__init__()
|
|
|
351 |
self.scaling = self.head_dim**-0.5
|
352 |
|
353 |
if (self.head_dim * self.num_heads) != self.embed_dim:
|
354 |
+
raise ValueError
|
|
|
|
|
|
|
|
|
355 |
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
|
356 |
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
|
357 |
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
|
358 |
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
|
359 |
|
|
|
|
|
|
|
|
|
|
|
360 |
def _shape(self, tensor, seq_len, bsz):
|
361 |
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
362 |
|
363 |
def forward(
|
364 |
self,
|
365 |
hidden_states,
|
366 |
+
layer_head_mask = None,
|
367 |
+
output_attentions = False,
|
|
|
|
|
368 |
):
|
369 |
+
|
|
|
|
|
|
|
370 |
|
371 |
bsz, tgt_len, _ = hidden_states.size()
|
372 |
|
|
|
384 |
|
385 |
src_len = key_states.size(1)
|
386 |
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
387 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
388 |
attn_output = torch.bmm(attn_weights,
|
389 |
value_states)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
390 |
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
391 |
attn_output = attn_output.transpose(1, 2)
|
392 |
|
|
|
398 |
|
399 |
return attn_output, None #attn_weights_reshaped
|
400 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
401 |
|
402 |
class VitsFeedForward(nn.Module):
|
403 |
def __init__(self, config):
|
|
|
414 |
else:
|
415 |
self.padding = None
|
416 |
|
417 |
+
def forward(self, hidden_states):
|
418 |
hidden_states = hidden_states.permute(0, 2, 1)
|
|
|
|
|
|
|
419 |
if self.padding is not None:
|
420 |
hidden_states = nn.functional.pad(hidden_states, self.padding)
|
|
|
421 |
hidden_states = self.conv_1(hidden_states)
|
422 |
hidden_states = self.act_fn(hidden_states)
|
|
|
|
|
|
|
423 |
if self.padding is not None:
|
424 |
hidden_states = nn.functional.pad(hidden_states, self.padding)
|
|
|
425 |
hidden_states = self.conv_2(hidden_states)
|
|
|
|
|
426 |
hidden_states = hidden_states.permute(0, 2, 1)
|
427 |
return hidden_states
|
428 |
|
|
|
431 |
def __init__(self, config):
|
432 |
super().__init__()
|
433 |
self.attention = VitsAttention(config)
|
|
|
434 |
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
435 |
self.feed_forward = VitsFeedForward(config)
|
436 |
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
437 |
|
438 |
def forward(
|
439 |
self,
|
440 |
+
hidden_states,
|
441 |
+
output_attentions = False,
|
|
|
|
|
442 |
):
|
443 |
residual = hidden_states
|
444 |
hidden_states, attn_weights = self.attention(
|
445 |
hidden_states=hidden_states,
|
446 |
+
# attention_mask=attention_mask,
|
447 |
output_attentions=output_attentions,
|
448 |
)
|
449 |
|
|
|
451 |
hidden_states = self.layer_norm(residual + hidden_states)
|
452 |
|
453 |
residual = hidden_states
|
454 |
+
hidden_states = self.feed_forward(hidden_states)
|
455 |
+
|
456 |
hidden_states = self.final_layer_norm(residual + hidden_states)
|
457 |
|
458 |
outputs = (hidden_states,)
|
459 |
|
|
|
|
|
|
|
460 |
return outputs
|
461 |
|
462 |
|
|
|
470 |
|
471 |
def forward(
|
472 |
self,
|
473 |
+
hidden_states,
|
474 |
+
output_attentions = None,
|
475 |
+
output_hidden_states = None,
|
476 |
+
return_dict = None,
|
|
|
|
|
477 |
):
|
478 |
+
for _layer in self.layers:
|
479 |
+
layer_outputs = _layer(hidden_states)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
480 |
hidden_states = layer_outputs[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
481 |
return BaseModelOutput(
|
482 |
last_hidden_state=hidden_states,
|
483 |
+
# hidden_states=all_hidden_states,
|
484 |
+
# attentions=all_self_attentions,
|
485 |
)
|
486 |
|
487 |
|
488 |
class VitsTextEncoder(nn.Module):
|
489 |
"""
|
490 |
+
Has VitsEncoder
|
491 |
"""
|
492 |
|
493 |
def __init__(self, config):
|
|
|
497 |
self.encoder = VitsEncoder(config) # 6 Layers of VitsAttention
|
498 |
self.project = nn.Conv1d(config.hidden_size, config.flow_size * 2, kernel_size=1)
|
499 |
|
500 |
+
def forward(self,
|
501 |
+
input_ids
|
502 |
+
):
|
503 |
+
hidden_states = self.embed_tokens(input_ids) * math.sqrt(self.config.hidden_size)
|
504 |
+
last_hidden_state = self.encoder(hidden_states=hidden_states).last_hidden_state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
505 |
|
506 |
+
stats = self.project(last_hidden_state.transpose(1, 2)).transpose(1, 2)
|
|
|
|
|
507 |
prior_means, prior_log_variances = torch.split(stats, self.config.flow_size, dim=2)
|
508 |
|
509 |
return VitsTextEncoderOutput(
|
510 |
last_hidden_state=last_hidden_state,
|
511 |
prior_means=prior_means,
|
512 |
+
# prior_log_variances=prior_log_variances,
|
513 |
+
# hidden_states=encoder_outputs.hidden_states,
|
514 |
+
# attentions=encoder_outputs.attentions,
|
515 |
)
|
516 |
|
517 |
|
518 |
class VitsPreTrainedModel(PreTrainedModel):
|
|
|
|
|
|
|
|
|
|
|
519 |
config_class = VitsConfig
|
520 |
base_model_prefix = "vits"
|
521 |
main_input_name = "input_ids"
|
522 |
supports_gradient_checkpointing = True
|
523 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
524 |
|
525 |
|
526 |
class VitsModel(VitsPreTrainedModel):
|
|
|
530 |
self.text_encoder = VitsTextEncoder(config) # has VitsEncoder that includes 6L of VitsAttention
|
531 |
self.flow = VitsResidualCouplingBlock(config)
|
532 |
self.decoder = VitsHifiGan(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
533 |
# Initialize weights and apply final processing
|
534 |
self.post_init()
|
535 |
|
|
|
|
|
|
|
536 |
def forward(
|
537 |
self,
|
538 |
input_ids = None,
|
|
|
542 |
output_hidden_states = None,
|
543 |
return_dict = None,
|
544 |
labels = None,
|
545 |
+
speed = None,
|
546 |
+
lang_code = 'deu', # speed oscillation pattern per voice/lang
|
547 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
548 |
mask_dtype = self.text_encoder.embed_tokens.weight.dtype
|
549 |
if attention_mask is not None:
|
550 |
input_padding_mask = attention_mask.unsqueeze(-1).to(mask_dtype)
|
551 |
else:
|
552 |
input_padding_mask = torch.ones_like(input_ids).unsqueeze(-1).to(mask_dtype)
|
553 |
+
out = self.text_encoder(input_ids=input_ids)
|
554 |
+
hidden_states = out.last_hidden_state.transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
555 |
input_padding_mask = input_padding_mask.transpose(1, 2)
|
556 |
+
prior_means = out.prior_means
|
557 |
+
bs, _, in_len = hidden_states.shape
|
558 |
+
# VITS Duration Oscillation
|
559 |
+
if lang_code == 'deu':
|
560 |
+
pattern = [1, 2, 1] # each voice (lang_code) sounds cooler with different pattern
|
561 |
+
elif lang_code == 'rmc-script_latin':
|
562 |
+
pattern = [2, 2, 1, 2, 2] # [2, 2, 2, 1, 2]
|
563 |
+
elif lang_code == 'hun':
|
564 |
+
# pattern = [1, 2, 2, 1, 1, 1] #sounds cool / has valley-pause
|
565 |
+
pattern = [1, 2, 1, 1, 1]
|
|
|
566 |
else:
|
567 |
+
pattern = [1, 2, 1]
|
568 |
+
duration = torch.tensor(pattern, device=hidden_states.device).repeat(int(in_len / len(pattern)) + 2)[None, None, :in_len] # perhaps define [1, 2, 1] per voice or language
|
569 |
+
duration[:, :, 0] = 4
|
570 |
+
duration[:, :, -1] = 3
|
571 |
+
# ATTN
|
572 |
predicted_lengths = torch.clamp_min(torch.sum(duration, [1, 2]), 1).long()
|
|
|
|
|
573 |
indices = torch.arange(predicted_lengths.max(), dtype=predicted_lengths.dtype, device=predicted_lengths.device)
|
574 |
output_padding_mask = indices.unsqueeze(0) < predicted_lengths.unsqueeze(1)
|
575 |
output_padding_mask = output_padding_mask.unsqueeze(1).to(input_padding_mask.dtype)
|
|
|
|
|
576 |
attn_mask = torch.unsqueeze(input_padding_mask, 2) * torch.unsqueeze(output_padding_mask, -1)
|
577 |
batch_size, _, output_length, input_length = attn_mask.shape
|
578 |
cum_duration = torch.cumsum(duration, -1).view(batch_size * input_length, 1)
|
|
|
581 |
valid_indices = valid_indices.to(attn_mask.dtype).view(batch_size, input_length, output_length)
|
582 |
padded_indices = valid_indices - nn.functional.pad(valid_indices, [0, 0, 1, 0, 0, 0])[:, :-1]
|
583 |
attn = padded_indices.unsqueeze(1).transpose(2, 3) * attn_mask
|
584 |
+
attn = attn[:, 0, :, :]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
585 |
|
586 |
|
587 |
+
attn = attn + 1e-4 * torch.rand_like(attn)
|
588 |
+
attn /= attn.sum(2, keepdims=True)
|
589 |
+
#print(attn)
|
590 |
+
prior_means = torch.matmul(attn, prior_means) # try attn to contain .5/.5 instead of 1/0 so it smoothly interpolates repeated prior_means
|
591 |
|
592 |
+
#prior_means = F.interpolate(prior_means.transpose(1,2), int(1.74 * prior_means.shape[1]), mode='linear').transpose(1,2) # extend for slow speed
|
593 |
|
594 |
|
595 |
|
596 |
+
# prior means have now been replicated x duration of each prior mean
|
597 |
|
598 |
+
latents = self.flow(prior_means.transpose(1, 2), # + torch.randn_like(prior_means) * .94,
|
599 |
+
reverse=True)
|
600 |
|
601 |
+
waveform = self.decoder(latents) # [bs, 1, 16000]
|
602 |
|
603 |
+
return waveform[:, 0, :]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
604 |
|
605 |
|
606 |
class VitsTokenizer(PreTrainedTokenizer):
|
607 |
+
vocab_files_names = {"vocab_file": "vocab.json"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
608 |
model_input_names = ["input_ids", "attention_mask"]
|
609 |
|
610 |
def __init__(
|
|
|
678 |
return text
|
679 |
|
680 |
def prepare_for_tokenization(
|
681 |
+
self, text: str, is_split_into_words: bool = False, normalize = None, **kwargs):
|
682 |
+
|
|
|
|
|
|
|
|
|
683 |
normalize = normalize if normalize is not None else self.normalize
|
684 |
|
685 |
if normalize:
|
|
|
724 |
tokens = list(text)
|
725 |
|
726 |
if self.add_blank:
|
727 |
+
# sounds dyslexi if no space between letters
|
728 |
+
# sounds disconnected if >2 spaces between letters
|
729 |
+
interspersed = [self._convert_id_to_token(0)] * (len(tokens) * 2) # + 1) # +1 rises slice index error if tokens odd
|
730 |
+
interspersed[::2] = tokens
|
731 |
+
tokens = interspersed + [self._convert_id_to_token(0)] # append one last space (it has indexing error ::2 mismatch if tokens is odd)
|
732 |
|
733 |
return tokens
|
734 |
|
|
|
|
|
|
|
|
|
|
|
735 |
def _convert_token_to_id(self, token):
|
736 |
"""Converts a token (str) in an id using the vocab."""
|
737 |
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
738 |
|
739 |
def _convert_id_to_token(self, index):
|
740 |
"""Converts an index (integer) in a token (str) using the vocab."""
|
741 |
+
return self.decoder.get(index)
|
README.md
CHANGED
@@ -131,7 +131,7 @@ python live_demo.py # type text & plays AudioGen sound & TTS
|
|
131 |
|
132 |
# Audiobook
|
133 |
|
134 |
-
Create audiobook from `.docx`. Listen to it - YouTube [male voice](https://www.youtube.com/watch?v=5-cpf7u18JE) / [v2](https://www.youtube.com/watch?v=Pzo-kKaNg6s) / [v2.1](https://www.youtube.com/watch?v=X4qlKBBaegM)/ [no diffusio](https://www.youtube.com/watch?v=vahKXpd6oLg)
|
135 |
|
136 |
```python
|
137 |
# audiobook will be saved in ./tts_audiobooks
|
|
|
131 |
|
132 |
# Audiobook
|
133 |
|
134 |
+
Create audiobook from `.docx`. Listen to it - YouTube [male voice](https://www.youtube.com/watch?v=5-cpf7u18JE) / [v2](https://www.youtube.com/watch?v=Pzo-kKaNg6s) / [v2.1](https://www.youtube.com/watch?v=X4qlKBBaegM)/ [no diffusio](https://www.youtube.com/watch?v=vahKXpd6oLg) [Audionar](https://youtu.be/fUGpfq_o_CU) / [F](https://www.youtube.com/watch?v=tlRdRV5nm40)
|
135 |
|
136 |
```python
|
137 |
# audiobook will be saved in ./tts_audiobooks
|
Utils/JDC/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
|
|
|
|
Utils/JDC/bst.pth
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:54dc94364b97e18ac1dfa6287714ed121248cfaac4cfd39d061c6e0a089ef169
|
3 |
-
size 21029926
|
|
|
|
|
|
|
|
Utils/JDC/model.py
DELETED
@@ -1,190 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Implementation of model from:
|
3 |
-
Kum et al. - "Joint Detection and Classification of Singing Voice Melody Using
|
4 |
-
Convolutional Recurrent Neural Networks" (2019)
|
5 |
-
Link: https://www.semanticscholar.org/paper/Joint-Detection-and-Classification-of-Singing-Voice-Kum-Nam/60a2ad4c7db43bace75805054603747fcd062c0d
|
6 |
-
"""
|
7 |
-
import torch
|
8 |
-
from torch import nn
|
9 |
-
|
10 |
-
class JDCNet(nn.Module):
|
11 |
-
"""
|
12 |
-
Joint Detection and Classification Network model for singing voice melody.
|
13 |
-
"""
|
14 |
-
def __init__(self, num_class=722, seq_len=31, leaky_relu_slope=0.01):
|
15 |
-
super().__init__()
|
16 |
-
self.num_class = num_class
|
17 |
-
|
18 |
-
# input = (b, 1, 31, 513), b = batch size
|
19 |
-
self.conv_block = nn.Sequential(
|
20 |
-
nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding=1, bias=False), # out: (b, 64, 31, 513)
|
21 |
-
nn.BatchNorm2d(num_features=64),
|
22 |
-
nn.LeakyReLU(leaky_relu_slope, inplace=True),
|
23 |
-
nn.Conv2d(64, 64, 3, padding=1, bias=False), # (b, 64, 31, 513)
|
24 |
-
)
|
25 |
-
|
26 |
-
# res blocks
|
27 |
-
self.res_block1 = ResBlock(in_channels=64, out_channels=128) # (b, 128, 31, 128)
|
28 |
-
self.res_block2 = ResBlock(in_channels=128, out_channels=192) # (b, 192, 31, 32)
|
29 |
-
self.res_block3 = ResBlock(in_channels=192, out_channels=256) # (b, 256, 31, 8)
|
30 |
-
|
31 |
-
# pool block
|
32 |
-
self.pool_block = nn.Sequential(
|
33 |
-
nn.BatchNorm2d(num_features=256),
|
34 |
-
nn.LeakyReLU(leaky_relu_slope, inplace=True),
|
35 |
-
nn.MaxPool2d(kernel_size=(1, 4)), # (b, 256, 31, 2)
|
36 |
-
nn.Dropout(p=0.2),
|
37 |
-
)
|
38 |
-
|
39 |
-
# maxpool layers (for auxiliary network inputs)
|
40 |
-
# in = (b, 128, 31, 513) from conv_block, out = (b, 128, 31, 2)
|
41 |
-
self.maxpool1 = nn.MaxPool2d(kernel_size=(1, 40))
|
42 |
-
# in = (b, 128, 31, 128) from res_block1, out = (b, 128, 31, 2)
|
43 |
-
self.maxpool2 = nn.MaxPool2d(kernel_size=(1, 20))
|
44 |
-
# in = (b, 128, 31, 32) from res_block2, out = (b, 128, 31, 2)
|
45 |
-
self.maxpool3 = nn.MaxPool2d(kernel_size=(1, 10))
|
46 |
-
|
47 |
-
# in = (b, 640, 31, 2), out = (b, 256, 31, 2)
|
48 |
-
self.detector_conv = nn.Sequential(
|
49 |
-
nn.Conv2d(640, 256, 1, bias=False),
|
50 |
-
nn.BatchNorm2d(256),
|
51 |
-
nn.LeakyReLU(leaky_relu_slope, inplace=True),
|
52 |
-
nn.Dropout(p=0.2),
|
53 |
-
)
|
54 |
-
|
55 |
-
# input: (b, 31, 512) - resized from (b, 256, 31, 2)
|
56 |
-
self.bilstm_classifier = nn.LSTM(
|
57 |
-
input_size=512, hidden_size=256,
|
58 |
-
batch_first=True, bidirectional=True) # (b, 31, 512)
|
59 |
-
|
60 |
-
# input: (b, 31, 512) - resized from (b, 256, 31, 2)
|
61 |
-
self.bilstm_detector = nn.LSTM(
|
62 |
-
input_size=512, hidden_size=256,
|
63 |
-
batch_first=True, bidirectional=True) # (b, 31, 512)
|
64 |
-
|
65 |
-
# input: (b * 31, 512)
|
66 |
-
self.classifier = nn.Linear(in_features=512, out_features=self.num_class) # (b * 31, num_class)
|
67 |
-
|
68 |
-
# input: (b * 31, 512)
|
69 |
-
self.detector = nn.Linear(in_features=512, out_features=2) # (b * 31, 2) - binary classifier
|
70 |
-
|
71 |
-
# initialize weights
|
72 |
-
self.apply(self.init_weights)
|
73 |
-
|
74 |
-
def get_feature_GAN(self, x):
|
75 |
-
seq_len = x.shape[-2]
|
76 |
-
x = x.float().transpose(-1, -2)
|
77 |
-
|
78 |
-
convblock_out = self.conv_block(x)
|
79 |
-
|
80 |
-
resblock1_out = self.res_block1(convblock_out)
|
81 |
-
resblock2_out = self.res_block2(resblock1_out)
|
82 |
-
resblock3_out = self.res_block3(resblock2_out)
|
83 |
-
poolblock_out = self.pool_block[0](resblock3_out)
|
84 |
-
poolblock_out = self.pool_block[1](poolblock_out)
|
85 |
-
|
86 |
-
return poolblock_out.transpose(-1, -2)
|
87 |
-
|
88 |
-
def get_feature(self, x):
|
89 |
-
seq_len = x.shape[-2]
|
90 |
-
x = x.float().transpose(-1, -2)
|
91 |
-
|
92 |
-
convblock_out = self.conv_block(x)
|
93 |
-
|
94 |
-
resblock1_out = self.res_block1(convblock_out)
|
95 |
-
resblock2_out = self.res_block2(resblock1_out)
|
96 |
-
resblock3_out = self.res_block3(resblock2_out)
|
97 |
-
poolblock_out = self.pool_block[0](resblock3_out)
|
98 |
-
poolblock_out = self.pool_block[1](poolblock_out)
|
99 |
-
|
100 |
-
return self.pool_block[2](poolblock_out)
|
101 |
-
|
102 |
-
def forward(self, x):
|
103 |
-
"""
|
104 |
-
Returns:
|
105 |
-
classification_prediction, detection_prediction
|
106 |
-
sizes: (b, 31, 722), (b, 31, 2)
|
107 |
-
"""
|
108 |
-
###############################
|
109 |
-
# forward pass for classifier #
|
110 |
-
###############################
|
111 |
-
seq_len = x.shape[-1]
|
112 |
-
x = x.float().transpose(-1, -2)
|
113 |
-
|
114 |
-
convblock_out = self.conv_block(x)
|
115 |
-
|
116 |
-
resblock1_out = self.res_block1(convblock_out)
|
117 |
-
resblock2_out = self.res_block2(resblock1_out)
|
118 |
-
resblock3_out = self.res_block3(resblock2_out)
|
119 |
-
|
120 |
-
|
121 |
-
poolblock_out = self.pool_block[0](resblock3_out)
|
122 |
-
poolblock_out = self.pool_block[1](poolblock_out)
|
123 |
-
GAN_feature = poolblock_out.transpose(-1, -2)
|
124 |
-
poolblock_out = self.pool_block[2](poolblock_out)
|
125 |
-
|
126 |
-
# (b, 256, 31, 2) => (b, 31, 256, 2) => (b, 31, 512)
|
127 |
-
classifier_out = poolblock_out.permute(0, 2, 1, 3).contiguous().view((-1, seq_len, 512))
|
128 |
-
classifier_out, _ = self.bilstm_classifier(classifier_out) # ignore the hidden states
|
129 |
-
|
130 |
-
classifier_out = classifier_out.contiguous().view((-1, 512)) # (b * 31, 512)
|
131 |
-
classifier_out = self.classifier(classifier_out)
|
132 |
-
classifier_out = classifier_out.view((-1, seq_len, self.num_class)) # (b, 31, num_class)
|
133 |
-
|
134 |
-
# sizes: (b, 31, 722), (b, 31, 2)
|
135 |
-
# classifier output consists of predicted pitch classes per frame
|
136 |
-
# detector output consists of: (isvoice, notvoice) estimates per frame
|
137 |
-
return torch.abs(classifier_out.squeeze()), GAN_feature, poolblock_out
|
138 |
-
|
139 |
-
@staticmethod
|
140 |
-
def init_weights(m):
|
141 |
-
if isinstance(m, nn.Linear):
|
142 |
-
nn.init.kaiming_uniform_(m.weight)
|
143 |
-
if m.bias is not None:
|
144 |
-
nn.init.constant_(m.bias, 0)
|
145 |
-
elif isinstance(m, nn.Conv2d):
|
146 |
-
nn.init.xavier_normal_(m.weight)
|
147 |
-
elif isinstance(m, nn.LSTM) or isinstance(m, nn.LSTMCell):
|
148 |
-
for p in m.parameters():
|
149 |
-
if p.data is None:
|
150 |
-
continue
|
151 |
-
|
152 |
-
if len(p.shape) >= 2:
|
153 |
-
nn.init.orthogonal_(p.data)
|
154 |
-
else:
|
155 |
-
nn.init.normal_(p.data)
|
156 |
-
|
157 |
-
|
158 |
-
class ResBlock(nn.Module):
|
159 |
-
def __init__(self, in_channels: int, out_channels: int, leaky_relu_slope=0.01):
|
160 |
-
super().__init__()
|
161 |
-
self.downsample = in_channels != out_channels
|
162 |
-
|
163 |
-
# BN / LReLU / MaxPool layer before the conv layer - see Figure 1b in the paper
|
164 |
-
self.pre_conv = nn.Sequential(
|
165 |
-
nn.BatchNorm2d(num_features=in_channels),
|
166 |
-
nn.LeakyReLU(leaky_relu_slope, inplace=True),
|
167 |
-
nn.MaxPool2d(kernel_size=(1, 2)), # apply downsampling on the y axis only
|
168 |
-
)
|
169 |
-
|
170 |
-
# conv layers
|
171 |
-
self.conv = nn.Sequential(
|
172 |
-
nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
|
173 |
-
kernel_size=3, padding=1, bias=False),
|
174 |
-
nn.BatchNorm2d(out_channels),
|
175 |
-
nn.LeakyReLU(leaky_relu_slope, inplace=True),
|
176 |
-
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
|
177 |
-
)
|
178 |
-
|
179 |
-
# 1 x 1 convolution layer to match the feature dimensions
|
180 |
-
self.conv1by1 = None
|
181 |
-
if self.downsample:
|
182 |
-
self.conv1by1 = nn.Conv2d(in_channels, out_channels, 1, bias=False)
|
183 |
-
|
184 |
-
def forward(self, x):
|
185 |
-
x = self.pre_conv(x)
|
186 |
-
if self.downsample:
|
187 |
-
x = self.conv(x) + self.conv1by1(x)
|
188 |
-
else:
|
189 |
-
x = self.conv(x) + x
|
190 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Utils/PLBERT/util.py
CHANGED
@@ -27,7 +27,7 @@ def load_plbert(log_dir):
|
|
27 |
iters = [int(f.split('_')[-1].split('.')[0]) for f in ckpts if os.path.isfile(os.path.join(log_dir, f))]
|
28 |
iters = sorted(iters)[-1]
|
29 |
|
30 |
-
checkpoint = torch.load(log_dir + "/step_" + str(iters) + ".pth", map_location='cpu')
|
31 |
state_dict = checkpoint['net']
|
32 |
from collections import OrderedDict
|
33 |
new_state_dict = OrderedDict()
|
|
|
27 |
iters = [int(f.split('_')[-1].split('.')[0]) for f in ckpts if os.path.isfile(os.path.join(log_dir, f))]
|
28 |
iters = sorted(iters)[-1]
|
29 |
|
30 |
+
checkpoint = torch.load(log_dir + "/step_" + str(iters) + ".pth", map_location='cpu', weights_only=True)
|
31 |
state_dict = checkpoint['net']
|
32 |
from collections import OrderedDict
|
33 |
new_state_dict = OrderedDict()
|
Utils/text_utils.py
CHANGED
@@ -35,72 +35,112 @@ class TextCleaner:
|
|
35 |
|
36 |
# == Sentence Splitter
|
37 |
|
38 |
-
|
39 |
-
prefixes = "(Mr|St|Mrs|Ms|Dr)[.]"
|
40 |
-
suffixes = "(Inc|Ltd|Jr|Sr|Co)"
|
41 |
-
starters = "(Mr|Mrs|Ms|Dr|Prof|Capt|Cpt|Lt|He\s|She\s|It\s|They\s|Their\s|Our\s|We\s|But\s|However\s|That\s|This\s|Wherever)"
|
42 |
-
acronyms = "([A-Z][.][A-Z][.](?:[A-Z][.])?)"
|
43 |
-
websites = "[.](com|net|org|io|gov|edu|me)"
|
44 |
-
digits = "([0-9])"
|
45 |
-
multiple_dots = r'\.{2,}'
|
46 |
-
|
47 |
|
48 |
-
def split_into_sentences(text):
|
49 |
"""
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
to incorrect splitting because they are used as markers for splitting.
|
54 |
-
|
55 |
-
:param text: text to be split into sentences
|
56 |
-
:type text: str
|
57 |
|
58 |
-
:
|
59 |
-
|
|
|
60 |
|
61 |
-
|
|
|
62 |
"""
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
text
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
if
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
|
106 |
def store_ssml(text=None,
|
|
|
35 |
|
36 |
# == Sentence Splitter
|
37 |
|
38 |
+
import re
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
+
def split_into_sentences(text, max_len=200):
|
41 |
"""
|
42 |
+
Splits a string into chunks of max_len characters, ensuring each chunk
|
43 |
+
terminates with a period if it was split mid-sentence. Prioritizes
|
44 |
+
splitting at natural sentence breaks and avoids splitting words.
|
|
|
|
|
|
|
|
|
45 |
|
46 |
+
Args:
|
47 |
+
text (str): The input string.
|
48 |
+
max_len (int): The maximum desired length for each chunk.
|
49 |
|
50 |
+
Returns:
|
51 |
+
list: A list of strings, where each string is a sentence chunk.
|
52 |
"""
|
53 |
+
if not text:
|
54 |
+
return []
|
55 |
+
|
56 |
+
# Regex to split text into potential sentence candidates.
|
57 |
+
# We still use the lookbehind to keep the punctuation with the sentence.
|
58 |
+
sentence_candidates = [s.strip() for s in re.split(r'(?<=[.!?])\s+', text) if s.strip()]
|
59 |
+
|
60 |
+
# Handle the last part if it doesn't end with a punctuation (e.g., a phrase or incomplete sentence)
|
61 |
+
if text and not text.strip().endswith(('.', '!', '?')) and text.strip() not in sentence_candidates:
|
62 |
+
# Check if the last candidate already contains the end of the text.
|
63 |
+
# This is a heuristic, as re.split can sometimes be tricky with trailing non-matches.
|
64 |
+
if not (sentence_candidates and text.strip().endswith(sentence_candidates[-1])):
|
65 |
+
remaining_text = text.strip()
|
66 |
+
if sentence_candidates:
|
67 |
+
# Find the part of the text that wasn't included in sentence_candidates
|
68 |
+
last_candidate_start_index = text.rfind(sentence_candidates[-1])
|
69 |
+
if last_candidate_start_index != -1:
|
70 |
+
remaining_text = text[last_candidate_start_index + len(sentence_candidates[-1]):].strip()
|
71 |
+
|
72 |
+
if remaining_text and not remaining_text.endswith(('.', '!', '?')):
|
73 |
+
sentence_candidates.append(remaining_text)
|
74 |
+
|
75 |
+
|
76 |
+
chunks = []
|
77 |
+
current_chunk_elements = [] # Stores individual sentences that form the current chunk
|
78 |
+
current_chunk_length = 0
|
79 |
+
|
80 |
+
for sentence in sentence_candidates:
|
81 |
+
# Calculate the length this sentence would add to the current chunk.
|
82 |
+
# Add 1 for the space that will separate sentences within a chunk, if needed.
|
83 |
+
potential_addition_length = len(sentence) + (1 if current_chunk_elements else 0)
|
84 |
+
|
85 |
+
# Check if adding this sentence would exceed the maximum length
|
86 |
+
if current_chunk_length + potential_addition_length > max_len:
|
87 |
+
# First, finalize the current chunk
|
88 |
+
if current_chunk_elements:
|
89 |
+
final_chunk = " ".join(current_chunk_elements).strip()
|
90 |
+
chunks.append(final_chunk)
|
91 |
+
|
92 |
+
# Reset for the new chunk and handle the current `sentence`.
|
93 |
+
# This `sentence` itself might be longer than `max_len`.
|
94 |
+
remaining_sentence = sentence
|
95 |
+
while len(remaining_sentence) > max_len:
|
96 |
+
# Prioritize splitting at a period or a space to avoid splitting words.
|
97 |
+
# Search backwards from `max_len - 1` to find the last valid break point.
|
98 |
+
split_point = -1
|
99 |
+
search_area = remaining_sentence[:max_len]
|
100 |
+
|
101 |
+
# Option 1: Find the last period in the search area
|
102 |
+
last_period_idx = search_area.rfind('.')
|
103 |
+
if last_period_idx != -1:
|
104 |
+
split_point = last_period_idx
|
105 |
+
|
106 |
+
# Option 2: If no period, find the last space (to avoid splitting words)
|
107 |
+
if split_point == -1:
|
108 |
+
last_space_idx = search_area.rfind(' ')
|
109 |
+
if last_space_idx != -1:
|
110 |
+
split_point = last_space_idx
|
111 |
+
|
112 |
+
if split_point != -1:
|
113 |
+
# If a period or space is found, split there.
|
114 |
+
# If it's a period, include it. If it's a space, don't include the space
|
115 |
+
# but ensure the chunk ends with a period if it didn't already.
|
116 |
+
chunk_to_add = remaining_sentence[:split_point + (1 if remaining_sentence[split_point] == '.' else 0)].strip()
|
117 |
+
if not chunk_to_add.endswith('.'):
|
118 |
+
chunk_to_add += '.' # Ensure period termination
|
119 |
+
|
120 |
+
chunks.append(chunk_to_add)
|
121 |
+
remaining_sentence = remaining_sentence[split_point + 1:].lstrip() # Update remaining
|
122 |
+
else:
|
123 |
+
# No natural break (period or space) within max_len.
|
124 |
+
# This happens for extremely long words or sequences without spaces.
|
125 |
+
# In this rare case, we force split at max_len and append a period.
|
126 |
+
chunks.append(remaining_sentence[:max_len].strip() + '.')
|
127 |
+
remaining_sentence = remaining_sentence[max_len:].lstrip() # Update remaining
|
128 |
+
|
129 |
+
# The `remaining_sentence` (now guaranteed to be `<= max_len`)
|
130 |
+
# becomes the start of the new `current_chunk`.
|
131 |
+
current_chunk_elements = [remaining_sentence]
|
132 |
+
current_chunk_length = len(remaining_sentence)
|
133 |
+
|
134 |
+
else:
|
135 |
+
# The current sentence fits within the `max_len`, so add it.
|
136 |
+
current_chunk_elements.append(sentence)
|
137 |
+
current_chunk_length += potential_addition_length
|
138 |
+
|
139 |
+
# After iterating through all sentences, add any remaining elements
|
140 |
+
# in `current_chunk_elements` as the final chunk.
|
141 |
+
if current_chunk_elements:
|
142 |
+
chunks.append(" ".join(current_chunk_elements).strip())
|
143 |
+
return chunks
|
144 |
|
145 |
|
146 |
def store_ssml(text=None,
|
api.py
CHANGED
@@ -113,21 +113,11 @@ def _resize(image, width=None, height=None, inter=cv2.INTER_AREA):
|
|
113 |
|
114 |
def overlay(x, soundscape=None):
|
115 |
if soundscape is not None:
|
116 |
-
# AudioGen sound is suffice to be ~10s long
|
117 |
background = sound_generator.generate(soundscape,
|
118 |
-
#
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
# len_soundscape = len(background)
|
123 |
-
|
124 |
-
# fading = .5 + .5 * np.tanh(4*(np.linspace(10, -10, len_soundscape) + 9.4)) # fade heaviside 1,1,1,1,...,0
|
125 |
-
|
126 |
-
# x = np.concatenate([fading * background, x], 0) # blend TTS with AudioGen
|
127 |
-
# background /= np.abs(background).max() + 1e-7 # amplify speech to full [-1,1]
|
128 |
-
# background will be longer by xtra .74s
|
129 |
-
x = .47 * x + .46 * background[:len(x)]
|
130 |
-
return x # TTS / AudioGen @ 16kHz
|
131 |
|
132 |
|
133 |
def tts_multi_sentence(precomputed_style_vector=None,
|
@@ -176,7 +166,7 @@ def tts_multi_sentence(precomputed_style_vector=None,
|
|
176 |
|
177 |
# volume
|
178 |
|
179 |
-
x /= np.abs(x).max() + 1e-7 # amplify speech to full [-1,1]
|
180 |
|
181 |
return overlay(x, soundscape=soundscape)
|
182 |
|
@@ -211,7 +201,7 @@ def serve_wav():
|
|
211 |
_shorten(r.get('native')[0]),
|
212 |
affective=r.get('affective')[0],
|
213 |
voice=r.get('voice')[0],
|
214 |
-
speed=
|
215 |
soundscape=r.get('soundscape')[0] if r.get(
|
216 |
'soundscape') is not None else None,
|
217 |
)
|
|
|
113 |
|
114 |
def overlay(x, soundscape=None):
|
115 |
if soundscape is not None:
|
|
|
116 |
background = sound_generator.generate(soundscape,
|
117 |
+
duration=len(x)/16000 + .74, # duration seconds
|
118 |
+
).detach().cpu().numpy()
|
119 |
+
x = .6 * x + .4 * background[:len(x)]
|
120 |
+
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
|
123 |
def tts_multi_sentence(precomputed_style_vector=None,
|
|
|
166 |
|
167 |
# volume
|
168 |
|
169 |
+
x /= 1.12 * np.abs(x).max() + 1e-7 # amplify speech to full [-1,1] No amplification / normalisation on soundscapes
|
170 |
|
171 |
return overlay(x, soundscape=soundscape)
|
172 |
|
|
|
201 |
_shorten(r.get('native')[0]),
|
202 |
affective=r.get('affective')[0],
|
203 |
voice=r.get('voice')[0],
|
204 |
+
speed=None, # obsolete due to oscillating MMS TTS VITS duration per language
|
205 |
soundscape=r.get('soundscape')[0] if r.get(
|
206 |
'soundscape') is not None else None,
|
207 |
)
|
audiobook.py
CHANGED
@@ -8,20 +8,23 @@ import subprocess
|
|
8 |
import numpy as np
|
9 |
import soundfile
|
10 |
import docx # package = python-docx
|
11 |
-
|
12 |
import urllib
|
13 |
from pathlib import Path
|
14 |
from moviepy.editor import *
|
15 |
|
16 |
-
FS =
|
17 |
ROOT_DIR = './tts_audiobooks/voices/'
|
18 |
Path(ROOT_DIR).mkdir(parents=True,
|
19 |
exist_ok=True)
|
20 |
voices = [
|
21 |
-
# 'en_US/vctk_low#p228',
|
22 |
# 'af_ZA_google-nwu_0184', # https://huggingface.co/dkounadis/artificial-styletts2/discussions/1#6783e3b00e7d90facec060c6
|
23 |
-
'en_US/vctk_low#p326',
|
24 |
-
#
|
|
|
|
|
|
|
25 |
] # select any voice from - https://audeering.github.io/shift/
|
26 |
|
27 |
#urllib.request.urlretrieve("https://github.com/audeering/shift/raw/refs/heads/main/assets/INCLUSION_IN_MUSEUMS_audiobook.docx", "audiobook_TTS.docx")
|
@@ -54,7 +57,7 @@ for vox in voices:
|
|
54 |
|
55 |
total = []
|
56 |
chapter = []
|
57 |
-
|
58 |
final_paragraph_for_saving_last_chapter = d.paragraphs[-1]
|
59 |
final_paragraph_for_saving_last_chapter.text = 'CHAPTER: END OF AUDIOBOOK'
|
60 |
|
@@ -69,12 +72,6 @@ for vox in voices:
|
|
69 |
if t.startswith('CHAPTER:'):
|
70 |
|
71 |
|
72 |
-
|
73 |
-
# silence for end chapter
|
74 |
-
|
75 |
-
chapter.append(np.zeros(int(.24 * FS),
|
76 |
-
dtype=np.float32))
|
77 |
-
|
78 |
# chapter.wav
|
79 |
|
80 |
audio = np.concatenate(chapter)
|
@@ -116,17 +113,14 @@ for vox in voices:
|
|
116 |
[
|
117 |
"python",
|
118 |
"tts.py",
|
119 |
-
"--text",
|
120 |
-
"_tmp.txt",
|
121 |
-
|
122 |
-
#'--image', '_tmp_banner.png',
|
123 |
-
# '--scene', 'calm sounds of castle',
|
124 |
'--voice', vox,
|
125 |
'--out_file', '_tmp' # save on _tmp load audio and concat to total
|
126 |
])
|
127 |
|
128 |
-
audio, _fs = soundfile.read('out/_tmp.wav')
|
129 |
-
audio = audresample.resample(audio.astype(np.float32), 24000, 16000)[0, :]
|
130 |
# print('CHAPTER\n\n\n\n____', audio.shape,'____\n')
|
131 |
chapter.append(audio)
|
132 |
|
@@ -140,9 +134,6 @@ for vox in voices:
|
|
140 |
|
141 |
if not last_paragraph_was_silence: # skip multiple empty pargraphs - silence is added only once
|
142 |
|
143 |
-
chapter.append(np.zeros(int(.1 * FS),
|
144 |
-
dtype=np.float32))
|
145 |
-
|
146 |
last_paragraph_was_silence = True
|
147 |
|
148 |
# save full .wav audiobook - for this voice
|
@@ -157,11 +148,7 @@ for vox in voices:
|
|
157 |
|
158 |
# pic TTS voice
|
159 |
|
160 |
-
voice_pic = np.zeros((
|
161 |
-
|
162 |
-
shift_logo = cv2.imread('assets/shift_banner.png')
|
163 |
-
|
164 |
-
voice_pic[:100, :400, :] = shift_logo[:100, :400, :]
|
165 |
|
166 |
# voice name
|
167 |
# frame_tts = np.zeros((104, 1920, 3), dtype=np.uint8)
|
|
|
8 |
import numpy as np
|
9 |
import soundfile
|
10 |
import docx # package = python-docx
|
11 |
+
|
12 |
import urllib
|
13 |
from pathlib import Path
|
14 |
from moviepy.editor import *
|
15 |
|
16 |
+
FS = 16000
|
17 |
ROOT_DIR = './tts_audiobooks/voices/'
|
18 |
Path(ROOT_DIR).mkdir(parents=True,
|
19 |
exist_ok=True)
|
20 |
voices = [
|
21 |
+
# 'en_US/vctk_low#p228', # https://huggingface.co/dkounadis/artificial-styletts2/discussions/1#67854dcbd3e6beb1a78f7f20
|
22 |
# 'af_ZA_google-nwu_0184', # https://huggingface.co/dkounadis/artificial-styletts2/discussions/1#6783e3b00e7d90facec060c6
|
23 |
+
# 'en_US/vctk_low#p326',
|
24 |
+
#'en_US/vctk_low#p292',
|
25 |
+
# 'jv_ID_google-gmu_06207',
|
26 |
+
# 'fr_FR_m-ailabs_bernard'
|
27 |
+
'en_US_m-ailabs_mary_ann'
|
28 |
] # select any voice from - https://audeering.github.io/shift/
|
29 |
|
30 |
#urllib.request.urlretrieve("https://github.com/audeering/shift/raw/refs/heads/main/assets/INCLUSION_IN_MUSEUMS_audiobook.docx", "audiobook_TTS.docx")
|
|
|
57 |
|
58 |
total = []
|
59 |
chapter = []
|
60 |
+
|
61 |
final_paragraph_for_saving_last_chapter = d.paragraphs[-1]
|
62 |
final_paragraph_for_saving_last_chapter.text = 'CHAPTER: END OF AUDIOBOOK'
|
63 |
|
|
|
72 |
if t.startswith('CHAPTER:'):
|
73 |
|
74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
# chapter.wav
|
76 |
|
77 |
audio = np.concatenate(chapter)
|
|
|
113 |
[
|
114 |
"python",
|
115 |
"tts.py",
|
116 |
+
"--text",
|
117 |
+
"_tmp.txt",
|
118 |
+
'--soundscape', 'birds formig' if chapter_counter < 2 else '',
|
|
|
|
|
119 |
'--voice', vox,
|
120 |
'--out_file', '_tmp' # save on _tmp load audio and concat to total
|
121 |
])
|
122 |
|
123 |
+
audio, _fs = soundfile.read('out/_tmp.wav') # already 16 kHz
|
|
|
124 |
# print('CHAPTER\n\n\n\n____', audio.shape,'____\n')
|
125 |
chapter.append(audio)
|
126 |
|
|
|
134 |
|
135 |
if not last_paragraph_was_silence: # skip multiple empty pargraphs - silence is added only once
|
136 |
|
|
|
|
|
|
|
137 |
last_paragraph_was_silence = True
|
138 |
|
139 |
# save full .wav audiobook - for this voice
|
|
|
148 |
|
149 |
# pic TTS voice
|
150 |
|
151 |
+
voice_pic = np.zeros((1920, 1080, 3), dtype=np.uint8)
|
|
|
|
|
|
|
|
|
152 |
|
153 |
# voice name
|
154 |
# frame_tts = np.zeros((104, 1920, 3), dtype=np.uint8)
|
demo.py
CHANGED
@@ -1,68 +1,40 @@
|
|
1 |
import numpy as np
|
2 |
import soundfile
|
3 |
-
import msinference
|
4 |
from audiocraft.builders import AudioGen
|
5 |
|
6 |
-
def tts_entry(text='A quick brown fox jumps over the lazy dog. Sweet dreams are made of this, I traveled the world and the seven seas.',
|
7 |
-
voice='en_US/vctk_low#p326', #'en_US/vctk_low#p276', # 'deu', 'af_ZA_google-nwu_1919', 'serbian', 'isl',
|
8 |
-
speed=1.14,
|
9 |
-
affect = True, # False = higher clarity voice
|
10 |
-
soundscape = 'dogs barg in dungeons n dragons'
|
11 |
-
):
|
12 |
-
'''16 KHz
|
13 |
-
|
14 |
-
voice : 'en_US/vctk_low#p276' # Native English voices -> https://audeering.github.io/shift/
|
15 |
-
|
16 |
-
or
|
17 |
-
|
18 |
-
voice : 'af_ZA_google-nwu_1919' # Non-Native English voices -> https://huggingface.co/dkounadis/artificial-styletts2/discussions/1#6783e3b00e7d90facec060c6
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
23 |
'''
|
24 |
|
25 |
-
# StyleTTS2 - find voice from folder
|
26 |
-
|
27 |
if ('en_US/' in voice) or ('en_UK/' in voice):
|
28 |
-
|
29 |
-
style_vector = msinference.compute_style('assets/wavs/style_vector' + a + '/' + voice.replace(
|
30 |
'/', '_').replace('#', '_').replace(
|
31 |
'cmu-arctic', 'cmu_arctic').replace(
|
32 |
'_low', '') + '.wav')
|
33 |
|
34 |
-
x = msinference.inference(text,
|
35 |
-
style_vector)
|
36 |
-
|
37 |
-
# find voice from mimic-3 folder with styles
|
38 |
-
|
39 |
elif '_' in voice:
|
40 |
style_vector = msinference.compute_style('assets/wavs/mimic3_foreign_4x/' + voice.replace(
|
41 |
'/', '_').replace('#', '_').replace(
|
42 |
'cmu-arctic', 'cmu_arctic').replace(
|
43 |
'_low', '') + '.wav')
|
44 |
|
45 |
-
x = msinference.inference(text,
|
46 |
-
style_vector)
|
47 |
-
|
48 |
-
|
49 |
-
# Fallback - MMS TTS - Non-English voice / langs
|
50 |
-
|
51 |
else:
|
52 |
-
x = msinference.foreign(text=text,
|
53 |
-
|
54 |
-
speed=speed) # volume normalis.
|
55 |
-
|
56 |
-
# volume
|
57 |
-
|
58 |
-
x /= np.abs(x).max() + 1e-7 # amplify speech to full [-1,1]
|
59 |
-
|
60 |
if soundscape is not None:
|
61 |
sound_gen = AudioGen().to('cuda:0').eval()
|
62 |
-
background = sound_gen.generate(soundscape,
|
63 |
-
duration=len(x)/16000 + .74, # sound duration in seconds
|
64 |
).detach().cpu().numpy()
|
65 |
-
x = .
|
66 |
return x
|
67 |
|
|
|
68 |
soundfile.write(f'demo.wav', tts_entry(), 16000)
|
|
|
1 |
import numpy as np
|
2 |
import soundfile
|
3 |
+
import msinference # api.py has also split into sentences for OOM
|
4 |
from audiocraft.builders import AudioGen
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
+
def tts_entry(text='A quick brown fox jumps over the lazy dog. Sweet dreams are made of this, I traveled the world and the seven seas.',
|
8 |
+
voice='en_US/m-ailabs_low#mary_ann', #fr_FR_m-ailabs_bernard', #'deu', #'serbian', #'romanian', #'deu', #'en_US/vctk_low#p326', #'en_US/vctk_low#p276', # 'deu', 'af_ZA_google-nwu_1919', 'serbian', 'isl',
|
9 |
+
soundscape = 'birds river'):
|
10 |
+
'''voice = 'en_US/vctk_low#p276' # Native English Voices > https://audeering.github.io/shift/
|
11 |
+
= 'af_ZA_google-nwu_1919' # Non Native English Voices > https://huggingface.co/dkounadis/artificial-styletts2/discussions/1#6783e3b00e7d90facec060c6
|
12 |
+
= 'deu' # Other languages > https://huggingface.co/dkounadis/artificial-styletts2/blob/main/Utils/all_langs.csv
|
13 |
'''
|
14 |
|
|
|
|
|
15 |
if ('en_US/' in voice) or ('en_UK/' in voice):
|
16 |
+
style_vector = msinference.compute_style('assets/wavs/style_vector/' + voice.replace(
|
|
|
17 |
'/', '_').replace('#', '_').replace(
|
18 |
'cmu-arctic', 'cmu_arctic').replace(
|
19 |
'_low', '') + '.wav')
|
20 |
|
21 |
+
x = msinference.inference(text, style_vector)
|
|
|
|
|
|
|
|
|
22 |
elif '_' in voice:
|
23 |
style_vector = msinference.compute_style('assets/wavs/mimic3_foreign_4x/' + voice.replace(
|
24 |
'/', '_').replace('#', '_').replace(
|
25 |
'cmu-arctic', 'cmu_arctic').replace(
|
26 |
'_low', '') + '.wav')
|
27 |
|
28 |
+
x = msinference.inference(text, style_vector)
|
|
|
|
|
|
|
|
|
|
|
29 |
else:
|
30 |
+
x = msinference.foreign(text=text, lang=voice)
|
31 |
+
x /= 1.02 * np.abs(x).max() + 1e-7 # volume amplify full [-1,1]
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
if soundscape is not None:
|
33 |
sound_gen = AudioGen().to('cuda:0').eval()
|
34 |
+
background = sound_gen.generate(soundscape, duration=len(x)/16000 + .74, # sound duration seconds
|
|
|
35 |
).detach().cpu().numpy()
|
36 |
+
x = .6 * x + .4 * background[:len(x)]
|
37 |
return x
|
38 |
|
39 |
+
|
40 |
soundfile.write(f'demo.wav', tts_entry(), 16000)
|
live_demo.py
CHANGED
@@ -16,7 +16,6 @@ def send_to_server(args):
|
|
16 |
'affective': True,
|
17 |
'image': None,
|
18 |
'video': None,
|
19 |
-
'speed': 1.14,
|
20 |
'native': None,
|
21 |
}
|
22 |
|
@@ -24,16 +23,15 @@ def send_to_server(args):
|
|
24 |
|
25 |
|
26 |
args = SimpleNamespace()
|
27 |
-
args.voice = '
|
28 |
-
args.speed = 1.14
|
29 |
os.system('cls' if os.name == 'nt' else 'clear')
|
30 |
while True:
|
31 |
_str = input("\n\n\n\nDescribe Any Sound: \n\n\n\n")
|
32 |
-
|
33 |
-
|
34 |
-
_str += 'Lorem ipsum dolor sit amet, consetetur elixir sed diam nonumy eirmod tempor invidunt labore et dolore magna aliquyam erat, sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. Soutet clita kasd gubergren.'
|
35 |
-
|
36 |
args.soundscape = _str
|
|
|
|
|
|
|
37 |
args.text = '_tmp.txt' # input -> .txt (implementation thought for audiobooks in API)
|
38 |
|
39 |
with open(args.text, 'w') as f:
|
|
|
16 |
'affective': True,
|
17 |
'image': None,
|
18 |
'video': None,
|
|
|
19 |
'native': None,
|
20 |
}
|
21 |
|
|
|
23 |
|
24 |
|
25 |
args = SimpleNamespace()
|
26 |
+
args.voice = 'en_US/m-ailabs_low#judy_bieber'
|
|
|
27 |
os.system('cls' if os.name == 'nt' else 'clear')
|
28 |
while True:
|
29 |
_str = input("\n\n\n\nDescribe Any Sound: \n\n\n\n")
|
30 |
+
|
|
|
|
|
|
|
31 |
args.soundscape = _str
|
32 |
+
|
33 |
+
_str += 'A quick brown fox jumps over the lazy dog. Sweet dreams are made of this, I traveled the world and the seven seas.'
|
34 |
+
|
35 |
args.text = '_tmp.txt' # input -> .txt (implementation thought for audiobooks in API)
|
36 |
|
37 |
with open(args.text, 'w') as f:
|
models.py
CHANGED
@@ -1,93 +1,98 @@
|
|
1 |
-
#coding:utf-8
|
2 |
|
3 |
import os
|
4 |
-
import math
|
5 |
import torch
|
6 |
import torch.nn as nn
|
7 |
import torch.nn.functional as F
|
8 |
-
from torch.nn.utils import
|
|
|
9 |
# from Utils.ASR.models import ASRCNN
|
10 |
-
from Utils.JDC.model import JDCNet
|
11 |
from Modules.hifigan import _tile, AdainResBlk1d
|
12 |
-
import
|
13 |
|
|
|
14 |
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
super().__init__()
|
18 |
-
self.
|
19 |
-
|
20 |
-
if self.
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
def forward(self, x):
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
|
35 |
-
class
|
36 |
-
def __init__(self,
|
37 |
super().__init__()
|
38 |
-
self.
|
39 |
-
|
|
|
40 |
def forward(self, x):
|
41 |
-
|
42 |
-
return x
|
43 |
-
elif self.layer_type == 'timepreserve':
|
44 |
-
return F.avg_pool2d(x, (2, 1))
|
45 |
-
elif self.layer_type == 'half':
|
46 |
-
if x.shape[-1] % 2 != 0:
|
47 |
-
x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
|
48 |
-
return F.avg_pool2d(x, 2)
|
49 |
-
else:
|
50 |
-
raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
|
55 |
|
56 |
class ResBlk(nn.Module):
|
57 |
-
def __init__(self,
|
58 |
-
|
59 |
super().__init__()
|
60 |
-
self.actv =
|
61 |
-
self.
|
62 |
-
self.downsample = DownSample(downsample)
|
63 |
-
self.downsample_res = LearnedDownSample(downsample, dim_in)
|
64 |
self.learned_sc = dim_in != dim_out
|
65 |
-
self._build_weights(dim_in, dim_out)
|
66 |
-
|
67 |
-
def _build_weights(self, dim_in, dim_out):
|
68 |
self.conv1 = spectral_norm(nn.Conv2d(dim_in, dim_in, 3, 1, 1))
|
69 |
self.conv2 = spectral_norm(nn.Conv2d(dim_in, dim_out, 3, 1, 1))
|
70 |
-
if self.normalize:
|
71 |
-
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
72 |
-
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
|
73 |
if self.learned_sc:
|
74 |
-
self.conv1x1 = spectral_norm(
|
|
|
75 |
|
76 |
def _shortcut(self, x):
|
77 |
if self.learned_sc:
|
78 |
x = self.conv1x1(x)
|
79 |
-
if
|
80 |
-
x =
|
81 |
-
return x
|
82 |
|
83 |
def _residual(self, x):
|
84 |
-
if self.normalize:
|
85 |
-
x = self.norm1(x)
|
86 |
x = self.actv(x)
|
87 |
x = self.conv1(x)
|
88 |
x = self.downsample_res(x)
|
89 |
-
if self.normalize:
|
90 |
-
x = self.norm2(x)
|
91 |
x = self.actv(x)
|
92 |
x = self.conv2(x)
|
93 |
return x
|
@@ -101,113 +106,41 @@ class StyleEncoder(nn.Module):
|
|
101 |
|
102 |
# for both acoustic & prosodic ref_s/p
|
103 |
|
104 |
-
def __init__(self,
|
|
|
|
|
|
|
105 |
super().__init__()
|
106 |
-
blocks = []
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
dim_out = min(dim_in*2, max_conv_dim)
|
112 |
-
blocks += [ResBlk(dim_in, dim_out, downsample='half')]
|
113 |
dim_in = dim_out
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
# blocks += [nn.AdaptiveAvgPool2d(1)] # THIS AVERAGES THE TIME-FRAMES OF SPEAKER STYLE
|
119 |
-
|
120 |
-
blocks += [nn.LeakyReLU(0.2)]
|
121 |
self.shared = nn.Sequential(*blocks)
|
122 |
-
|
123 |
self.unshared = nn.Linear(dim_out, style_dim)
|
124 |
|
125 |
def forward(self, x):
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
h = h.transpose(1, 3)
|
131 |
-
s = self.unshared(h)
|
132 |
-
|
133 |
-
|
134 |
return s
|
135 |
|
136 |
|
137 |
class LinearNorm(torch.nn.Module):
|
138 |
-
def __init__(self, in_dim, out_dim, bias=True
|
139 |
-
super(
|
140 |
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
141 |
|
142 |
-
torch.nn.init.xavier_uniform_(
|
143 |
-
self.linear_layer.weight,
|
144 |
-
gain=torch.nn.init.calculate_gain(w_init_gain))
|
145 |
-
|
146 |
def forward(self, x):
|
147 |
return self.linear_layer(x)
|
148 |
|
149 |
|
150 |
-
class ResBlk1d(nn.Module):
|
151 |
-
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
|
152 |
-
normalize=False, downsample='none', dropout_p=0.2):
|
153 |
-
super().__init__()
|
154 |
-
self.actv = actv
|
155 |
-
self.normalize = normalize
|
156 |
-
self.downsample_type = downsample
|
157 |
-
self.learned_sc = dim_in != dim_out
|
158 |
-
self._build_weights(dim_in, dim_out)
|
159 |
-
self.dropout_p = dropout_p
|
160 |
-
|
161 |
-
if self.downsample_type == 'none':
|
162 |
-
self.pool = nn.Identity()
|
163 |
-
else:
|
164 |
-
self.pool = weight_norm(nn.Conv1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1))
|
165 |
-
|
166 |
-
def _build_weights(self, dim_in, dim_out):
|
167 |
-
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_in, 3, 1, 1))
|
168 |
-
self.conv2 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
|
169 |
-
if self.normalize:
|
170 |
-
self.norm1 = nn.InstanceNorm1d(dim_in, affine=True)
|
171 |
-
self.norm2 = nn.InstanceNorm1d(dim_in, affine=True)
|
172 |
-
if self.learned_sc:
|
173 |
-
self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
|
174 |
-
|
175 |
-
def downsample(self, x):
|
176 |
-
if self.downsample_type == 'none':
|
177 |
-
return x
|
178 |
-
else:
|
179 |
-
if x.shape[-1] % 2 != 0:
|
180 |
-
x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
|
181 |
-
return F.avg_pool1d(x, 2)
|
182 |
-
|
183 |
-
def _shortcut(self, x):
|
184 |
-
if self.learned_sc:
|
185 |
-
x = self.conv1x1(x)
|
186 |
-
x = self.downsample(x)
|
187 |
-
return x
|
188 |
-
|
189 |
-
def _residual(self, x):
|
190 |
-
if self.normalize:
|
191 |
-
x = self.norm1(x)
|
192 |
-
x = self.actv(x)
|
193 |
-
x = F.dropout(x, p=self.dropout_p, training=self.training)
|
194 |
-
|
195 |
-
x = self.conv1(x)
|
196 |
-
x = self.pool(x)
|
197 |
-
if self.normalize:
|
198 |
-
x = self.norm2(x)
|
199 |
-
|
200 |
-
x = self.actv(x)
|
201 |
-
x = F.dropout(x, p=self.dropout_p, training=self.training)
|
202 |
-
|
203 |
-
x = self.conv2(x)
|
204 |
-
return x
|
205 |
-
|
206 |
-
def forward(self, x):
|
207 |
-
x = self._shortcut(x) + self._residual(x)
|
208 |
-
return x / math.sqrt(2) # unit variance
|
209 |
-
|
210 |
-
|
211 |
class LayerNorm(nn.Module):
|
212 |
def __init__(self, channels, eps=1e-5):
|
213 |
super().__init__()
|
@@ -222,168 +155,151 @@ class LayerNorm(nn.Module):
|
|
222 |
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
223 |
return x.transpose(1, -1)
|
224 |
|
|
|
225 |
class TextEncoder(nn.Module):
|
226 |
-
def __init__(self, channels, kernel_size, depth, n_symbols
|
227 |
super().__init__()
|
228 |
self.embedding = nn.Embedding(n_symbols, channels)
|
229 |
-
|
230 |
padding = (kernel_size - 1) // 2
|
231 |
self.cnn = nn.ModuleList()
|
232 |
for _ in range(depth):
|
233 |
self.cnn.append(nn.Sequential(
|
234 |
weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
|
235 |
LayerNorm(channels),
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True)
|
242 |
|
243 |
-
def forward(self, x
|
244 |
x = self.embedding(x) # [B, T, emb]
|
245 |
-
x = x.transpose(1, 2)
|
246 |
for c in self.cnn:
|
247 |
-
x = c(x)
|
248 |
-
x = x.transpose(1, 2)
|
249 |
-
input_lengths = input_lengths.cpu().numpy()
|
250 |
-
x = nn.utils.rnn.pack_padded_sequence(
|
251 |
-
x, input_lengths,
|
252 |
-
batch_first=True,
|
253 |
-
enforce_sorted=False)
|
254 |
-
self.lstm.flatten_parameters()
|
255 |
x, _ = self.lstm(x)
|
256 |
-
x, _ = nn.utils.rnn.pad_packed_sequence(
|
257 |
-
x, batch_first=True)
|
258 |
-
x = x.transpose(-1, -2)
|
259 |
return x
|
260 |
-
|
|
|
261 |
class AdaLayerNorm(nn.Module):
|
262 |
-
|
263 |
-
# only instantianted in DurationPredictor()
|
264 |
-
|
265 |
def __init__(self, style_dim, channels=None, eps=1e-5):
|
266 |
super().__init__()
|
267 |
self.eps = eps
|
268 |
self.fc = nn.Linear(style_dim, 1024)
|
269 |
|
270 |
def forward(self, x, s):
|
271 |
-
h = self.fc(s
|
272 |
gamma = h[:, :, :512]
|
273 |
beta = h[:, :, 512:1024]
|
274 |
-
|
275 |
-
x = F.layer_norm(x.transpose(1, 2), (512, ), eps=self.eps)
|
276 |
x = (1 + gamma) * x + beta
|
277 |
return x # [1, 75, 512]
|
278 |
|
|
|
279 |
class ProsodyPredictor(nn.Module):
|
280 |
|
281 |
-
def __init__(self, style_dim, d_hid, nlayers, max_dur=50
|
282 |
-
super().__init__()
|
283 |
-
|
284 |
-
self.text_encoder = DurationEncoder(sty_dim=style_dim,
|
285 |
-
d_model=d_hid,
|
286 |
-
nlayers=nlayers,
|
287 |
-
dropout=dropout)
|
288 |
|
289 |
-
self.
|
|
|
|
|
|
|
|
|
290 |
self.duration_proj = LinearNorm(d_hid, max_dur)
|
291 |
-
|
292 |
-
|
293 |
-
self.F0 = nn.ModuleList(
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
self.N = nn.ModuleList(
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
|
304 |
self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
|
305 |
-
|
306 |
def F0Ntrain(self, x, s):
|
307 |
|
308 |
-
x, _ = self.shared(x
|
309 |
|
310 |
x = x.transpose(1, 2) # [bs, ch, time]
|
311 |
-
|
312 |
-
|
313 |
F0 = x
|
314 |
-
|
315 |
for block in self.F0:
|
316 |
# print(f'LOOP {F0.shape=} {s.shape=}\n')
|
317 |
# )N F0.shape=torch.Size([1, 512, 147]) s.shape=torch.Size([1, 128])
|
318 |
-
|
|
|
319 |
F0 = self.F0_proj(F0)
|
320 |
-
|
321 |
N = x
|
322 |
-
|
323 |
for block in self.N:
|
324 |
N = block(N, s)
|
325 |
N = self.N_proj(N)
|
326 |
-
|
327 |
return F0, N
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
328 |
|
329 |
class DurationEncoder(nn.Module):
|
330 |
|
331 |
-
def __init__(self, sty_dim, d_model, nlayers
|
332 |
super().__init__()
|
333 |
self.lstms = nn.ModuleList()
|
334 |
for _ in range(nlayers):
|
335 |
-
self.lstms.append(nn.LSTM(d_model + sty_dim,
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
self.lstms.append(AdaLayerNorm(sty_dim, d_model))
|
342 |
-
|
343 |
-
|
344 |
-
self.dropout = dropout
|
345 |
-
self.d_model = d_model
|
346 |
-
self.sty_dim = sty_dim
|
347 |
|
348 |
-
def forward(self, x, style, text_lengths):
|
349 |
|
350 |
-
|
351 |
-
|
352 |
-
style = _tile(style, length=x.shape[2]) # replicate style vector to duration of txt - F.interpolate or cyclic/tile
|
353 |
|
354 |
-
|
|
|
|
|
|
|
355 |
|
356 |
-
input_lengths = text_lengths.cpu().numpy()
|
357 |
-
|
358 |
for block in self.lstms:
|
359 |
if isinstance(block, AdaLayerNorm):
|
360 |
-
|
361 |
-
x = block(x, style)
|
362 |
-
x = torch.cat([x.transpose(1, 2), style], axis=1) # [bs, 512, 75]
|
363 |
|
364 |
else:
|
365 |
-
|
366 |
-
|
367 |
-
x = nn.utils.rnn.pack_padded_sequence(
|
368 |
-
x, input_lengths, batch_first=True, enforce_sorted=False)
|
369 |
-
block.flatten_parameters()
|
370 |
-
x, _ = block(x)
|
371 |
-
x, _ = nn.utils.rnn.pad_packed_sequence(
|
372 |
-
x, batch_first=True)
|
373 |
-
x = F.dropout(x, p=self.dropout, training=self.training)
|
374 |
-
x = x.transpose(-1, -2)
|
375 |
-
return x.transpose(-1, -2)
|
376 |
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
def load_F0_models(path):
|
381 |
-
# load F0 model
|
382 |
-
|
383 |
-
F0_model = JDCNet(num_class=1, seq_len=192)
|
384 |
-
path = path.replace('.t7', '.pth')
|
385 |
-
params = torch.load(path, map_location='cpu')['net']
|
386 |
-
F0_model.load_state_dict(params)
|
387 |
-
_ = F0_model.train()
|
388 |
-
|
389 |
-
return F0_model
|
|
|
1 |
+
# coding:utf-8
|
2 |
|
3 |
import os
|
|
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
import torch.nn.functional as F
|
7 |
+
from torch.nn.utils import spectral_norm
|
8 |
+
from torch.nn.utils.parametrizations import weight_norm
|
9 |
# from Utils.ASR.models import ASRCNN
|
10 |
+
# from Utils.JDC.model import JDCNet
|
11 |
from Modules.hifigan import _tile, AdainResBlk1d
|
12 |
+
import math
|
13 |
|
14 |
+
class MelSpec(torch.nn.Module):
|
15 |
|
16 |
+
def __init__(self,
|
17 |
+
sample_rate=17402, # https://github.com/fakerybakery/styletts2-cli/blob/main/msinference.py = Default 16000. However 17400 vocalises better also "en_US/vctk_p274"
|
18 |
+
n_fft=2048,
|
19 |
+
win_length=1200,
|
20 |
+
hop_length=300,
|
21 |
+
n_mels=80
|
22 |
+
):
|
23 |
+
'''avoids dependency on torchaudio'''
|
24 |
super().__init__()
|
25 |
+
self.n_fft = n_fft
|
26 |
+
self.win_length = win_length if win_length is not None else n_fft
|
27 |
+
self.hop_length = hop_length if hop_length is not None else self.win_length // 2
|
28 |
+
# --
|
29 |
+
f_min = 0.0
|
30 |
+
f_max = float(sample_rate // 2)
|
31 |
+
all_freqs = torch.linspace(0, sample_rate // 2, n_fft//2+1)
|
32 |
+
m_min = 2595.0 * math.log10(1.0 + (f_min / 700.0))
|
33 |
+
m_max = 2595.0 * math.log10(1.0 + (f_max / 700.0))
|
34 |
+
m_pts = torch.linspace(m_min, m_max, n_mels + 2)
|
35 |
+
f_pts = 700.0 * (10 ** (m_pts / 2595.0) - 1.0)
|
36 |
+
f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1)
|
37 |
+
slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1)
|
38 |
+
zero = torch.zeros(1)
|
39 |
+
down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_mels)
|
40 |
+
up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_mels)
|
41 |
+
fb = torch.max(zero, torch.min(down_slopes, up_slopes))
|
42 |
+
# --
|
43 |
+
self.register_buffer('fb', fb)
|
44 |
+
window = torch.hann_window(self.win_length)
|
45 |
+
self.register_buffer('window', window)
|
46 |
|
47 |
def forward(self, x):
|
48 |
+
spec_f = torch.stft(x,
|
49 |
+
self.n_fft,
|
50 |
+
self.hop_length,
|
51 |
+
self.win_length,
|
52 |
+
self.window,
|
53 |
+
center=True,
|
54 |
+
pad_mode="reflect",
|
55 |
+
normalized=False,
|
56 |
+
onesided=True,
|
57 |
+
return_complex=True) # [bs, 1025, 56]
|
58 |
+
mel_specgram = torch.matmul(spec_f.abs().pow(2).transpose(1, 2), self.fb).transpose(1, 2)
|
59 |
+
return mel_specgram[:, None, :, :] # [bs, 1, 80, time]
|
60 |
|
61 |
|
62 |
+
class LearnedDownSample(nn.Module):
|
63 |
+
def __init__(self, dim_in):
|
64 |
super().__init__()
|
65 |
+
self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(
|
66 |
+
3, 3), stride=(2, 2), groups=dim_in, padding=1))
|
67 |
+
|
68 |
def forward(self, x):
|
69 |
+
return self.conv(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
|
72 |
class ResBlk(nn.Module):
|
73 |
+
def __init__(self,
|
74 |
+
dim_in, dim_out):
|
75 |
super().__init__()
|
76 |
+
self.actv = nn.LeakyReLU(0.2) # .07 also nice
|
77 |
+
self.downsample_res = LearnedDownSample(dim_in)
|
|
|
|
|
78 |
self.learned_sc = dim_in != dim_out
|
|
|
|
|
|
|
79 |
self.conv1 = spectral_norm(nn.Conv2d(dim_in, dim_in, 3, 1, 1))
|
80 |
self.conv2 = spectral_norm(nn.Conv2d(dim_in, dim_out, 3, 1, 1))
|
|
|
|
|
|
|
81 |
if self.learned_sc:
|
82 |
+
self.conv1x1 = spectral_norm(
|
83 |
+
nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False))
|
84 |
|
85 |
def _shortcut(self, x):
|
86 |
if self.learned_sc:
|
87 |
x = self.conv1x1(x)
|
88 |
+
if x.shape[3] % 2 != 0: # [bs, 128, Freq, Time]
|
89 |
+
x = torch.cat([x, x[:, :, :, -1:]], dim=3)
|
90 |
+
return F.interpolate(x, scale_factor=.5, mode='nearest-exact') # F.avg_pool2d(x, 2)
|
91 |
|
92 |
def _residual(self, x):
|
|
|
|
|
93 |
x = self.actv(x)
|
94 |
x = self.conv1(x)
|
95 |
x = self.downsample_res(x)
|
|
|
|
|
96 |
x = self.actv(x)
|
97 |
x = self.conv2(x)
|
98 |
return x
|
|
|
106 |
|
107 |
# for both acoustic & prosodic ref_s/p
|
108 |
|
109 |
+
def __init__(self,
|
110 |
+
dim_in=64,
|
111 |
+
style_dim=128,
|
112 |
+
max_conv_dim=512):
|
113 |
super().__init__()
|
114 |
+
blocks = [spectral_norm(nn.Conv2d(1, dim_in, 3, stride=1, padding=1))]
|
115 |
+
for _ in range(4):
|
116 |
+
dim_out = min(dim_in * 2,
|
117 |
+
max_conv_dim)
|
118 |
+
blocks += [ResBlk(dim_in, dim_out)]
|
|
|
|
|
119 |
dim_in = dim_out
|
120 |
+
blocks += [nn.LeakyReLU(0.24), # w/o this activation - produces no speech
|
121 |
+
spectral_norm(nn.Conv2d(dim_out, dim_out, 5, stride=1, padding=0)),
|
122 |
+
nn.LeakyReLU(0.2) # 0.3 sounds nice
|
123 |
+
]
|
|
|
|
|
|
|
124 |
self.shared = nn.Sequential(*blocks)
|
|
|
125 |
self.unshared = nn.Linear(dim_out, style_dim)
|
126 |
|
127 |
def forward(self, x):
|
128 |
+
x = self.shared(x)
|
129 |
+
x = x.mean(3, keepdims=True) # comment this line for time varying style vector
|
130 |
+
x = x.transpose(1, 3)
|
131 |
+
s = self.unshared(x)
|
|
|
|
|
|
|
|
|
132 |
return s
|
133 |
|
134 |
|
135 |
class LinearNorm(torch.nn.Module):
|
136 |
+
def __init__(self, in_dim, out_dim, bias=True):
|
137 |
+
super().__init__()
|
138 |
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
139 |
|
|
|
|
|
|
|
|
|
140 |
def forward(self, x):
|
141 |
return self.linear_layer(x)
|
142 |
|
143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
class LayerNorm(nn.Module):
|
145 |
def __init__(self, channels, eps=1e-5):
|
146 |
super().__init__()
|
|
|
155 |
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
156 |
return x.transpose(1, -1)
|
157 |
|
158 |
+
|
159 |
class TextEncoder(nn.Module):
|
160 |
+
def __init__(self, channels, kernel_size, depth, n_symbols):
|
161 |
super().__init__()
|
162 |
self.embedding = nn.Embedding(n_symbols, channels)
|
|
|
163 |
padding = (kernel_size - 1) // 2
|
164 |
self.cnn = nn.ModuleList()
|
165 |
for _ in range(depth):
|
166 |
self.cnn.append(nn.Sequential(
|
167 |
weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
|
168 |
LayerNorm(channels),
|
169 |
+
nn.LeakyReLU(0.24))
|
170 |
+
)
|
171 |
+
self.lstm = nn.LSTM(channels, channels//2, 1,
|
172 |
+
batch_first=True, bidirectional=True)
|
|
|
|
|
173 |
|
174 |
+
def forward(self, x):
|
175 |
x = self.embedding(x) # [B, T, emb]
|
176 |
+
x = x.transpose(1, 2)
|
177 |
for c in self.cnn:
|
178 |
+
x = c(x)
|
179 |
+
x = x.transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
x, _ = self.lstm(x)
|
|
|
|
|
|
|
181 |
return x
|
182 |
+
|
183 |
+
|
184 |
class AdaLayerNorm(nn.Module):
|
185 |
+
|
|
|
|
|
186 |
def __init__(self, style_dim, channels=None, eps=1e-5):
|
187 |
super().__init__()
|
188 |
self.eps = eps
|
189 |
self.fc = nn.Linear(style_dim, 1024)
|
190 |
|
191 |
def forward(self, x, s):
|
192 |
+
h = self.fc(s)
|
193 |
gamma = h[:, :, :512]
|
194 |
beta = h[:, :, 512:1024]
|
195 |
+
x = F.layer_norm(x, (512, ), eps=self.eps)
|
|
|
196 |
x = (1 + gamma) * x + beta
|
197 |
return x # [1, 75, 512]
|
198 |
|
199 |
+
|
200 |
class ProsodyPredictor(nn.Module):
|
201 |
|
202 |
+
def __init__(self, style_dim, d_hid, nlayers, max_dur=50):
|
203 |
+
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
+
self.text_encoder = DurationEncoder(sty_dim=style_dim,
|
206 |
+
d_model=d_hid,
|
207 |
+
nlayers=nlayers) # called outside forward
|
208 |
+
self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2,
|
209 |
+
1, batch_first=True, bidirectional=True)
|
210 |
self.duration_proj = LinearNorm(d_hid, max_dur)
|
211 |
+
self.shared = nn.LSTM(d_hid + style_dim, d_hid //
|
212 |
+
2, 1, batch_first=True, bidirectional=True)
|
213 |
+
self.F0 = nn.ModuleList([
|
214 |
+
AdainResBlk1d(d_hid, d_hid, style_dim),
|
215 |
+
AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True),
|
216 |
+
AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim),
|
217 |
+
])
|
218 |
+
self.N = nn.ModuleList([
|
219 |
+
AdainResBlk1d(d_hid, d_hid, style_dim),
|
220 |
+
AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True),
|
221 |
+
AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim)
|
222 |
+
])
|
223 |
self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
|
224 |
self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
|
225 |
+
|
226 |
def F0Ntrain(self, x, s):
|
227 |
|
228 |
+
x, _ = self.shared(x) # [bs, time, ch] LSTM
|
229 |
|
230 |
x = x.transpose(1, 2) # [bs, ch, time]
|
231 |
+
|
|
|
232 |
F0 = x
|
233 |
+
|
234 |
for block in self.F0:
|
235 |
# print(f'LOOP {F0.shape=} {s.shape=}\n')
|
236 |
# )N F0.shape=torch.Size([1, 512, 147]) s.shape=torch.Size([1, 128])
|
237 |
+
# This is an AdainResBlk1d expects conv1d dimensions
|
238 |
+
F0 = block(F0, s)
|
239 |
F0 = self.F0_proj(F0)
|
240 |
+
|
241 |
N = x
|
242 |
+
|
243 |
for block in self.N:
|
244 |
N = block(N, s)
|
245 |
N = self.N_proj(N)
|
246 |
+
|
247 |
return F0, N
|
248 |
+
|
249 |
+
def forward(self, d_en=None, s=None):
|
250 |
+
blend = self.text_encoder(d_en, s)
|
251 |
+
x, _ = self.lstm(blend)
|
252 |
+
dur = self.duration_proj(x) # [bs, 150, 50]
|
253 |
+
|
254 |
+
_, input_length, classifier_50 = dur.shape
|
255 |
+
|
256 |
+
dur = dur[0, :, :]
|
257 |
+
dur = torch.sigmoid(dur).sum(1)
|
258 |
+
dur = dur.round().clamp(min=1).to(torch.int64)
|
259 |
+
aln_trg = torch.zeros(1,
|
260 |
+
dur.sum(),
|
261 |
+
input_length,
|
262 |
+
device=s.device)
|
263 |
+
c_frame = 0
|
264 |
+
for i in range(input_length):
|
265 |
+
aln_trg[:, c_frame:c_frame + dur[i], i] = 1
|
266 |
+
c_frame += dur[i]
|
267 |
+
en = torch.bmm(aln_trg, blend)
|
268 |
+
F0_pred, N_pred = self.F0Ntrain(en, s)
|
269 |
+
return aln_trg, F0_pred, N_pred
|
270 |
+
|
271 |
|
272 |
class DurationEncoder(nn.Module):
|
273 |
|
274 |
+
def __init__(self, sty_dim=128, d_model=512, nlayers=3):
|
275 |
super().__init__()
|
276 |
self.lstms = nn.ModuleList()
|
277 |
for _ in range(nlayers):
|
278 |
+
self.lstms.append(nn.LSTM(d_model + sty_dim,
|
279 |
+
d_model // 2,
|
280 |
+
num_layers=1,
|
281 |
+
batch_first=True,
|
282 |
+
bidirectional=True
|
283 |
+
))
|
284 |
self.lstms.append(AdaLayerNorm(sty_dim, d_model))
|
|
|
|
|
|
|
|
|
|
|
285 |
|
|
|
286 |
|
287 |
+
def forward(self, x, style):
|
|
|
|
|
288 |
|
289 |
+
_, _, input_lengths = x.shape # [bs, 512, time]
|
290 |
+
|
291 |
+
style = _tile(style, length=x.shape[2]).transpose(1, 2)
|
292 |
+
x = x.transpose(1, 2)
|
293 |
|
|
|
|
|
294 |
for block in self.lstms:
|
295 |
if isinstance(block, AdaLayerNorm):
|
296 |
+
|
297 |
+
x = block(x, style) # LSTM has transposed x
|
|
|
298 |
|
299 |
else:
|
300 |
+
x = torch.cat([x, style], axis=2)
|
301 |
+
# LSTM
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
302 |
|
303 |
+
x,_ = block(x) # expects [bs, time, chan] OUTPUTS [bs, time, 2*chan] 2x FROM BIDIRECTIONAL
|
304 |
+
|
305 |
+
return torch.cat([x, style], axis=2) # predictor.lstm()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
msinference.py
CHANGED
@@ -3,25 +3,28 @@ import sys
|
|
3 |
import tempfile
|
4 |
import re
|
5 |
import os
|
6 |
-
from num2words import num2words
|
7 |
from collections import OrderedDict
|
8 |
from Modules.hifigan import Decoder
|
9 |
from Utils.PLBERT.util import load_plbert
|
10 |
import phonemizer
|
11 |
import torch
|
12 |
from cached_path import cached_path
|
13 |
-
|
14 |
import audresample
|
15 |
-
|
|
|
|
|
16 |
import numpy as np
|
17 |
import yaml
|
18 |
-
import torchaudio
|
19 |
import librosa
|
20 |
-
from models import ProsodyPredictor, TextEncoder, StyleEncoder,
|
21 |
from nltk.tokenize import word_tokenize
|
22 |
from Utils.text_utils import transliterate_number
|
23 |
import textwrap
|
24 |
-
|
|
|
|
|
|
|
25 |
|
26 |
_pad = "$"
|
27 |
_punctuation = ';:,.!?¡¿—…"«»“” '
|
@@ -56,45 +59,33 @@ class TextCleaner:
|
|
56 |
|
57 |
textclenaer = TextCleaner()
|
58 |
|
59 |
-
|
60 |
-
to_mel = torchaudio.transforms.MelSpectrogram(
|
61 |
-
n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
|
62 |
-
mean, std = -4, 4
|
63 |
-
|
64 |
-
|
65 |
def alpha_num(f):
|
66 |
f = re.sub(' +', ' ', f) # delete spaces
|
67 |
f = re.sub(r'[^A-Z a-z0-9 ]+', '', f) # del non alpha num
|
68 |
return f
|
69 |
|
70 |
-
|
71 |
-
def preprocess(wave):
|
72 |
-
wave_tensor = torch.from_numpy(wave).float()
|
73 |
-
mel_tensor = to_mel(wave_tensor)
|
74 |
-
mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
|
75 |
-
return mel_tensor
|
76 |
-
|
77 |
|
78 |
def compute_style(path):
|
79 |
-
|
80 |
-
|
81 |
if sr != 24000:
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
with torch.no_grad():
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
88 |
|
89 |
-
|
90 |
-
|
91 |
-
s = s[:, :, 0, :].transpose(1, 2) # [1, 128, 11]
|
92 |
-
return s # [1, 128, 11]
|
93 |
|
|
|
94 |
|
95 |
-
|
96 |
-
|
97 |
-
device = 'cuda'
|
98 |
|
99 |
global_phonemizer = phonemizer.backend.EspeakBackend(
|
100 |
language='en-us', preserve_punctuation=True, with_stress=True)
|
@@ -104,10 +95,6 @@ global_phonemizer = phonemizer.backend.EspeakBackend(
|
|
104 |
args = yaml.safe_load(open(str('Utils/config.yml')))
|
105 |
ASR_config = args['ASR_config']
|
106 |
|
107 |
-
F0_path = args['F0_path']
|
108 |
-
pitch_extractor = load_F0_models(F0_path).eval().to(device)
|
109 |
-
|
110 |
-
|
111 |
bert = load_plbert(args['PLBERT_dir']).eval().to(device)
|
112 |
|
113 |
decoder = Decoder(dim_in=512,
|
@@ -128,8 +115,7 @@ text_encoder = TextEncoder(channels=512,
|
|
128 |
predictor = ProsodyPredictor(style_dim=128,
|
129 |
d_hid=512,
|
130 |
nlayers=3, # OFFICIAL config.nlayers=5;
|
131 |
-
max_dur=50
|
132 |
-
dropout=.2).eval().to(device)
|
133 |
|
134 |
style_encoder = StyleEncoder(dim_in=64,
|
135 |
style_dim=128,
|
@@ -141,9 +127,10 @@ bert_encoder = torch.nn.Linear(bert.config.hidden_size, 512).eval().to(device)
|
|
141 |
|
142 |
# params_whole = torch.load('freevc2/yl4579_styletts2.pth' map_location='cpu')
|
143 |
params_whole = torch.load(str(cached_path(
|
144 |
-
"hf://yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/epochs_2nd_00020.pth")), map_location='cpu')
|
145 |
params = params_whole['net']
|
146 |
-
|
|
|
147 |
|
148 |
def _del_prefix(d):
|
149 |
# del ".module"
|
@@ -163,95 +150,41 @@ predictor_encoder.load_state_dict(_del_prefix(
|
|
163 |
params['predictor_encoder']), strict=True)
|
164 |
style_encoder.load_state_dict(_del_prefix(
|
165 |
params['style_encoder']), strict=True)
|
166 |
-
pitch_extractor.load_state_dict(_del_prefix(
|
167 |
-
params['pitch_extractor']), strict=True)
|
168 |
-
|
169 |
-
# def _shift(x):
|
170 |
-
# # [bs, samples] shift circular each batch elem of sound
|
171 |
-
# n = x.shape[1]
|
172 |
-
# for i, batch_elem in enumerate(x):
|
173 |
-
# offset = np.random.randint(.24 * n, max(1, .74 * n)) # high should be above >= 0 TBD
|
174 |
-
# x[i, ...] = torch.roll(batch_elem, offset, dims=1) # batch_elem = [400000, ]
|
175 |
-
# return x
|
176 |
-
|
177 |
|
178 |
def inference(text,
|
179 |
-
ref_s
|
180 |
-
|
181 |
-
|
182 |
-
text = transliterate_number(text, lang='en').strip()
|
183 |
-
|
184 |
ps = global_phonemizer.phonemize([text])
|
185 |
-
# print(f'PHONEMIZER: {ps=}\n\n') #PHONEMIZER: ps=['ɐbˈɛbæbləm ']
|
186 |
ps = word_tokenize(ps[0])
|
187 |
-
# # print(f'TOKENIZER: {ps=}\n\n') #OKENIZER: ps=['ɐbˈɛbæbləm']
|
188 |
ps = ' '.join(ps)
|
189 |
tokens = textclenaer(ps)
|
190 |
-
# print(f'TEXTCLEAN: {ps=}\n\n') #TEXTCLEAN: ps='ɐbˈɛbæbləm'
|
191 |
tokens.insert(0, 0)
|
192 |
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
|
193 |
-
# print(f'TOKENSFINAL: {ps=}\n\n')
|
194 |
-
|
195 |
with torch.no_grad():
|
196 |
-
|
197 |
-
|
198 |
-
hidden_states = text_encoder(tokens, input_lengths)
|
199 |
-
|
200 |
-
bert_dur = bert(tokens, attention_mask=None)
|
201 |
d_en = bert_encoder(bert_dur).transpose(-1, -2)
|
202 |
-
ref = ref_s[:, :128, :] # [bs, 128, 11]
|
203 |
-
s = ref_s[:, 128:, :]
|
204 |
-
d = predictor.text_encoder(d_en, s, input_lengths)
|
205 |
-
d = d.transpose(1, 2)
|
206 |
-
# -------------------------------- pred_aln_trg = clones bert frames as duration
|
207 |
-
|
208 |
-
d = predictor.text_encoder(d_en,
|
209 |
-
s,
|
210 |
-
input_lengths)
|
211 |
-
|
212 |
-
x, _ = predictor.lstm(d)
|
213 |
-
|
214 |
-
duration = predictor.duration_proj(x)
|
215 |
-
|
216 |
-
duration = torch.sigmoid(duration).sum(axis=-1)
|
217 |
-
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
|
218 |
|
219 |
-
|
220 |
-
c_frame = 0
|
221 |
-
for i in range(pred_aln_trg.size(0)):
|
222 |
-
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
|
223 |
-
c_frame += int(pred_dur[i].data)
|
224 |
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
|
|
|
|
230 |
|
231 |
-
|
|
|
232 |
|
233 |
-
|
234 |
-
|
235 |
-
asr_new = torch.zeros_like(asr)
|
236 |
-
asr_new[:, :, 0] = asr[:, :, 0]
|
237 |
-
asr_new[:, :, 1:] = asr[:, :, 0:-1]
|
238 |
-
asr = asr_new
|
239 |
-
# -
|
240 |
-
|
241 |
-
x = decoder(asr=asr,
|
242 |
-
F0_curve=F0_pred,
|
243 |
-
N=N_pred,
|
244 |
-
s=ref)
|
245 |
-
|
246 |
-
x = x.cpu().numpy()[0, 0, :-400] # weird pulse at the end of sentences
|
247 |
-
|
248 |
-
# StyleTTS2 is 24kHz -> Resample to 16kHz ofAudioGen / MMS
|
249 |
|
250 |
if x.shape[0] > 10:
|
251 |
-
|
252 |
x = audresample.resample(signal=x.astype(np.float32),
|
253 |
original_rate=24000,
|
254 |
-
target_rate=16000)[0, :] # reshapes (64,) -> (1,64)
|
255 |
|
256 |
else:
|
257 |
print('\n\n\n\n\nEMPTY TTS\n\n\n\n\n\nn', x.shape)
|
@@ -346,17 +279,14 @@ def foreign(text=None, # split sentences here so we can prepend a txt for germ
|
|
346 |
elif 'rom' in lang:
|
347 |
|
348 |
lang_code = 'ron'
|
349 |
-
speed = 1.24 if speed is None else speed
|
350 |
|
351 |
-
elif 'ger' in lang:
|
352 |
|
353 |
lang_code = 'deu'
|
354 |
-
speed = 1.14 if speed is None else speed
|
355 |
|
356 |
elif 'alban' in lang:
|
357 |
|
358 |
lang_code = 'sqi'
|
359 |
-
speed = 1.04 if speed is None else speed
|
360 |
|
361 |
else:
|
362 |
|
@@ -364,38 +294,38 @@ def foreign(text=None, # split sentences here so we can prepend a txt for germ
|
|
364 |
|
365 |
# load VITS
|
366 |
|
367 |
-
net_g = VitsModel.from_pretrained(f'facebook/mms-tts-{lang_code}').eval().to(device)
|
368 |
-
tokenizer = VitsTokenizer.from_pretrained(f'facebook/mms-tts-{lang_code}')
|
|
|
369 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
370 |
|
371 |
|
372 |
total_audio = []
|
373 |
|
374 |
# Split long sentences if deu to control voice switch - for other languages let text no-split
|
375 |
if not isinstance(text, list):
|
376 |
-
|
377 |
-
|
378 |
-
# However prosody is nicer on non-split for MMS TTS
|
379 |
-
# prepend txt snippet
|
380 |
-
text = [
|
381 |
-
sub_sent+' ' for sub_sent in textwrap.wrap(text, 200, break_long_words=0)]
|
382 |
-
# assert that it chooses unique voice
|
383 |
-
else:
|
384 |
-
# allow longer non split text
|
385 |
-
text = [
|
386 |
-
sub_sent+' ' for sub_sent in textwrap.wrap(text, 640, break_long_words=0)]
|
387 |
-
# for non deu MMS TTS lang.
|
388 |
|
389 |
for _t in text:
|
390 |
|
391 |
_t = _t.lower()
|
392 |
|
393 |
-
#
|
394 |
-
print('\n\n\n\nBEF TRansliteration', _t,'\n\n\n\n\n')
|
395 |
-
_t = transliterate_number(_t, lang=lang_code)
|
396 |
-
print('AFT nums', _t,'\n____________________________________________')
|
397 |
|
398 |
-
|
|
|
|
|
|
|
|
|
|
|
399 |
|
400 |
if lang_code == 'rmc-script_latin':
|
401 |
|
@@ -417,7 +347,7 @@ def foreign(text=None, # split sentences here so we can prepend a txt for germ
|
|
417 |
|
418 |
x = net_g(input_ids=inputs.input_ids.to(device),
|
419 |
attention_mask=inputs.attention_mask.to(device),
|
420 |
-
|
421 |
)[0, :]
|
422 |
|
423 |
# crop the 1st audio - is PREFIX text 156000 samples to chose deu voice / VitsAttention()
|
@@ -428,8 +358,6 @@ def foreign(text=None, # split sentences here so we can prepend a txt for germ
|
|
428 |
|
429 |
x = torch.cat(total_audio).cpu().numpy()
|
430 |
|
431 |
-
x /= np.abs(x).max() + 1e-7
|
432 |
-
|
433 |
-
# print(x.shape, x.min(), x.max(), hps.data.sampling_rate)
|
434 |
|
435 |
return x # 16kHz - only resample StyleTTS2 from 24Hkz -> 16kHz
|
|
|
3 |
import tempfile
|
4 |
import re
|
5 |
import os
|
|
|
6 |
from collections import OrderedDict
|
7 |
from Modules.hifigan import Decoder
|
8 |
from Utils.PLBERT.util import load_plbert
|
9 |
import phonemizer
|
10 |
import torch
|
11 |
from cached_path import cached_path
|
12 |
+
import nltk
|
13 |
import audresample
|
14 |
+
nltk.download('punkt', download_dir='./') # comment if downloaded once
|
15 |
+
nltk.download('punkt_tab', download_dir='./')
|
16 |
+
nltk.data.path.append('.')
|
17 |
import numpy as np
|
18 |
import yaml
|
|
|
19 |
import librosa
|
20 |
+
from models import ProsodyPredictor, TextEncoder, StyleEncoder, MelSpec
|
21 |
from nltk.tokenize import word_tokenize
|
22 |
from Utils.text_utils import transliterate_number
|
23 |
import textwrap
|
24 |
+
|
25 |
+
device = 'cpu'
|
26 |
+
if torch.cuda.is_available():
|
27 |
+
device = 'cuda'
|
28 |
|
29 |
_pad = "$"
|
30 |
_punctuation = ';:,.!?¡¿—…"«»“” '
|
|
|
59 |
|
60 |
textclenaer = TextCleaner()
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
def alpha_num(f):
|
63 |
f = re.sub(' +', ' ', f) # delete spaces
|
64 |
f = re.sub(r'[^A-Z a-z0-9 ]+', '', f) # del non alpha num
|
65 |
return f
|
66 |
|
67 |
+
mel_spec = MelSpec().to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
def compute_style(path):
|
70 |
+
x, sr = librosa.load(path, sr=24000)
|
71 |
+
x, _ = librosa.effects.trim(x, top_db=30)
|
72 |
if sr != 24000:
|
73 |
+
x = librosa.resample(x, sr, 24000)
|
74 |
+
|
|
|
75 |
with torch.no_grad():
|
76 |
+
x = torch.from_numpy(x[None, :]).to(device=device, dtype=torch.float)
|
77 |
+
|
78 |
+
mel_tensor = (torch.log(1e-5 + mel_spec(x)) + 4) / 4
|
79 |
+
|
80 |
+
#mel_tensor = preprocess(audio).to(device)
|
81 |
|
82 |
+
ref_s = style_encoder(mel_tensor)
|
83 |
+
ref_p = predictor_encoder(mel_tensor) # [bs, 11, 1, 128]
|
|
|
|
|
84 |
|
85 |
+
s = torch.cat([ref_s, ref_p], dim=3) # [bs, 11, 1, 256]
|
86 |
|
87 |
+
s = s[:, :, 0, :].transpose(1, 2) # [1, 128, 11]
|
88 |
+
return s # [1, 128, 11]
|
|
|
89 |
|
90 |
global_phonemizer = phonemizer.backend.EspeakBackend(
|
91 |
language='en-us', preserve_punctuation=True, with_stress=True)
|
|
|
95 |
args = yaml.safe_load(open(str('Utils/config.yml')))
|
96 |
ASR_config = args['ASR_config']
|
97 |
|
|
|
|
|
|
|
|
|
98 |
bert = load_plbert(args['PLBERT_dir']).eval().to(device)
|
99 |
|
100 |
decoder = Decoder(dim_in=512,
|
|
|
115 |
predictor = ProsodyPredictor(style_dim=128,
|
116 |
d_hid=512,
|
117 |
nlayers=3, # OFFICIAL config.nlayers=5;
|
118 |
+
max_dur=50).eval().to(device)
|
|
|
119 |
|
120 |
style_encoder = StyleEncoder(dim_in=64,
|
121 |
style_dim=128,
|
|
|
127 |
|
128 |
# params_whole = torch.load('freevc2/yl4579_styletts2.pth' map_location='cpu')
|
129 |
params_whole = torch.load(str(cached_path(
|
130 |
+
"hf://yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/epochs_2nd_00020.pth")), map_location='cpu', weights_only=True)
|
131 |
params = params_whole['net']
|
132 |
+
#params['decoder'].pop('module.generator.m_source.l_linear.weight')
|
133 |
+
#params['decoder'].pop('module.generator.m_source.l_linear.bias') # SourceHNSf
|
134 |
|
135 |
def _del_prefix(d):
|
136 |
# del ".module"
|
|
|
150 |
params['predictor_encoder']), strict=True)
|
151 |
style_encoder.load_state_dict(_del_prefix(
|
152 |
params['style_encoder']), strict=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
def inference(text,
|
155 |
+
ref_s):
|
156 |
+
# text = transliterate_number(text, lang='en').strip() # Transliteration only used for foreign() # perhaps add xtra . after ? ;
|
|
|
|
|
|
|
157 |
ps = global_phonemizer.phonemize([text])
|
|
|
158 |
ps = word_tokenize(ps[0])
|
|
|
159 |
ps = ' '.join(ps)
|
160 |
tokens = textclenaer(ps)
|
|
|
161 |
tokens.insert(0, 0)
|
162 |
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
|
|
|
|
|
163 |
with torch.no_grad():
|
164 |
+
hidden_states = text_encoder(tokens)
|
165 |
+
bert_dur = bert(tokens, attention_mask=torch.ones_like(tokens))
|
|
|
|
|
|
|
166 |
d_en = bert_encoder(bert_dur).transpose(-1, -2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
|
168 |
+
aln_trg, F0_pred, N_pred = predictor(d_en=d_en, s=ref_s[:, 128:, :])
|
|
|
|
|
|
|
|
|
169 |
|
170 |
+
asr = torch.bmm(aln_trg, hidden_states)
|
171 |
+
asr = asr.transpose(1, 2)
|
172 |
+
asr = torch.cat([asr[:, :, 0:1], asr[:, :, 0:-1]], 2)
|
173 |
+
x = decoder(asr=asr, # [1, 512, 201]
|
174 |
+
F0_curve=F0_pred, # [1, 1, 402] 2x time
|
175 |
+
N=N_pred, # [1, 1, 402] 2x time
|
176 |
+
s=ref_s[:, :128, :]) # [1, 256, 1]
|
177 |
|
178 |
+
x = x.cpu().numpy()[0, 0, :]
|
179 |
+
x[-400:] = 0 # noisy pulse produced for unterminated sentences, in absence of punctuation, (not sure if same behaviour for all voices)
|
180 |
|
181 |
+
# StyleTTS2 is 24kHz -> Resample to 16kHz as is AudioGen / MMS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
|
183 |
if x.shape[0] > 10:
|
184 |
+
|
185 |
x = audresample.resample(signal=x.astype(np.float32),
|
186 |
original_rate=24000,
|
187 |
+
target_rate=16000)[0, :] # audresample reshapes (64,) -> (1,64) | Volume Normalisation applies in api.py:tts_multi_sentence()
|
188 |
|
189 |
else:
|
190 |
print('\n\n\n\n\nEMPTY TTS\n\n\n\n\n\nn', x.shape)
|
|
|
279 |
elif 'rom' in lang:
|
280 |
|
281 |
lang_code = 'ron'
|
|
|
282 |
|
283 |
+
elif 'ger' in lang or 'deu' in lang or 'allem' in lang:
|
284 |
|
285 |
lang_code = 'deu'
|
|
|
286 |
|
287 |
elif 'alban' in lang:
|
288 |
|
289 |
lang_code = 'sqi'
|
|
|
290 |
|
291 |
else:
|
292 |
|
|
|
294 |
|
295 |
# load VITS
|
296 |
|
297 |
+
# net_g = VitsModel.from_pretrained(f'facebook/mms-tts-{lang_code}').eval().to(device)
|
298 |
+
# tokenizer = VitsTokenizer.from_pretrained(f'facebook/mms-tts-{lang_code}')
|
299 |
+
global cached_lang_code, cached_net_g, cached_tokenizer
|
300 |
|
301 |
+
if 'cached_lang_code' not in globals() or cached_lang_code != lang_code:
|
302 |
+
cached_lang_code = lang_code
|
303 |
+
cached_net_g = VitsModel.from_pretrained(f'facebook/mms-tts-{lang_code}').eval().to(device)
|
304 |
+
cached_tokenizer = VitsTokenizer.from_pretrained(f'facebook/mms-tts-{lang_code}')
|
305 |
+
|
306 |
+
net_g = cached_net_g
|
307 |
+
tokenizer = cached_tokenizer
|
308 |
|
309 |
|
310 |
total_audio = []
|
311 |
|
312 |
# Split long sentences if deu to control voice switch - for other languages let text no-split
|
313 |
if not isinstance(text, list):
|
314 |
+
# Split Very long sentences
|
315 |
+
text = [sub_sent+' ' for sub_sent in textwrap.wrap(text, 440, break_long_words=0)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
316 |
|
317 |
for _t in text:
|
318 |
|
319 |
_t = _t.lower()
|
320 |
|
321 |
+
# NUMBERS
|
|
|
|
|
|
|
322 |
|
323 |
+
try:
|
324 |
+
_t = transliterate_number(_t, lang=lang_code)
|
325 |
+
except NotImplementedError:
|
326 |
+
print('Transliterate Numbers - NotImplemented for {lang_code=}', _t,'\n____________________________________________')
|
327 |
+
|
328 |
+
# PRONOUNC.
|
329 |
|
330 |
if lang_code == 'rmc-script_latin':
|
331 |
|
|
|
347 |
|
348 |
x = net_g(input_ids=inputs.input_ids.to(device),
|
349 |
attention_mask=inputs.attention_mask.to(device),
|
350 |
+
lang_code=lang_code,
|
351 |
)[0, :]
|
352 |
|
353 |
# crop the 1st audio - is PREFIX text 156000 samples to chose deu voice / VitsAttention()
|
|
|
358 |
|
359 |
x = torch.cat(total_audio).cpu().numpy()
|
360 |
|
361 |
+
# x /= np.abs(x).max() + 1e-7 ~ Volume normalisation @api.py:tts_multi_sentence() OR demo.py
|
|
|
|
|
362 |
|
363 |
return x # 16kHz - only resample StyleTTS2 from 24Hkz -> 16kHz
|
requirements.txt
CHANGED
@@ -18,4 +18,4 @@ srt
|
|
18 |
nltk
|
19 |
phonemizer
|
20 |
docx
|
21 |
-
|
|
|
18 |
nltk
|
19 |
phonemizer
|
20 |
docx
|
21 |
+
uroman
|
tts.py
CHANGED
@@ -91,13 +91,13 @@ def command_line_args():
|
|
91 |
|
92 |
def send_to_server(args):
|
93 |
url = "http://192.168.88.209:5000"
|
94 |
-
|
95 |
# Args
|
96 |
|
97 |
payload = {
|
98 |
'affective': args.affective,
|
99 |
'voice': args.voice,
|
100 |
-
'soundscape': args.soundscape,
|
101 |
'native': args.native,
|
102 |
'text': args.text,
|
103 |
'image': args.image,
|
|
|
91 |
|
92 |
def send_to_server(args):
|
93 |
url = "http://192.168.88.209:5000"
|
94 |
+
|
95 |
# Args
|
96 |
|
97 |
payload = {
|
98 |
'affective': args.affective,
|
99 |
'voice': args.voice,
|
100 |
+
'soundscape': args.soundscape if args.soundscape != '' else None,
|
101 |
'native': args.native,
|
102 |
'text': args.text,
|
103 |
'image': args.image,
|