Spaces:
Running
Running
import logging | |
import torch | |
from trainer.io import load_fsspec | |
from TTS.encoder.models.resnet import ResNetSpeakerEncoder | |
from TTS.vocoder.models.hifigan_generator import HifiganGenerator | |
logger = logging.getLogger(__name__) | |
class HifiDecoder(torch.nn.Module): | |
def __init__( | |
self, | |
input_sample_rate=22050, | |
output_sample_rate=24000, | |
output_hop_length=256, | |
ar_mel_length_compression=1024, | |
decoder_input_dim=1024, | |
resblock_type_decoder="1", | |
resblock_dilation_sizes_decoder=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], | |
resblock_kernel_sizes_decoder=[3, 7, 11], | |
upsample_rates_decoder=[8, 8, 2, 2], | |
upsample_initial_channel_decoder=512, | |
upsample_kernel_sizes_decoder=[16, 16, 4, 4], | |
d_vector_dim=512, | |
cond_d_vector_in_each_upsampling_layer=True, | |
speaker_encoder_audio_config={ | |
"fft_size": 512, | |
"win_length": 400, | |
"hop_length": 160, | |
"sample_rate": 16000, | |
"preemphasis": 0.97, | |
"num_mels": 64, | |
}, | |
): | |
super().__init__() | |
self.input_sample_rate = input_sample_rate | |
self.output_sample_rate = output_sample_rate | |
self.output_hop_length = output_hop_length | |
self.ar_mel_length_compression = ar_mel_length_compression | |
self.speaker_encoder_audio_config = speaker_encoder_audio_config | |
self.waveform_decoder = HifiganGenerator( | |
decoder_input_dim, | |
1, | |
resblock_type_decoder, | |
resblock_dilation_sizes_decoder, | |
resblock_kernel_sizes_decoder, | |
upsample_kernel_sizes_decoder, | |
upsample_initial_channel_decoder, | |
upsample_rates_decoder, | |
inference_padding=0, | |
cond_channels=d_vector_dim, | |
conv_pre_weight_norm=False, | |
conv_post_weight_norm=False, | |
conv_post_bias=False, | |
cond_in_each_up_layer=cond_d_vector_in_each_upsampling_layer, | |
) | |
self.speaker_encoder = ResNetSpeakerEncoder( | |
input_dim=64, | |
proj_dim=512, | |
log_input=True, | |
use_torch_spec=True, | |
audio_config=speaker_encoder_audio_config, | |
) | |
def device(self): | |
return next(self.parameters()).device | |
def forward(self, latents, g=None): | |
""" | |
Args: | |
x (Tensor): feature input tensor (GPT latent). | |
g (Tensor): global conditioning input tensor. | |
Returns: | |
Tensor: output waveform. | |
Shapes: | |
x: [B, C, T] | |
Tensor: [B, 1, T] | |
""" | |
z = torch.nn.functional.interpolate( | |
latents.transpose(1, 2), | |
scale_factor=[self.ar_mel_length_compression / self.output_hop_length], | |
mode="linear", | |
).squeeze(1) | |
# upsample to the right sr | |
if self.output_sample_rate != self.input_sample_rate: | |
z = torch.nn.functional.interpolate( | |
z, | |
scale_factor=[self.output_sample_rate / self.input_sample_rate], | |
mode="linear", | |
).squeeze(0) | |
o = self.waveform_decoder(z, g=g) | |
return o | |
def inference(self, c, g): | |
""" | |
Args: | |
x (Tensor): feature input tensor (GPT latent). | |
g (Tensor): global conditioning input tensor. | |
Returns: | |
Tensor: output waveform. | |
Shapes: | |
x: [B, C, T] | |
Tensor: [B, 1, T] | |
""" | |
return self.forward(c, g=g) | |
def load_checkpoint(self, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin | |
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) | |
# remove unused keys | |
state = state["model"] | |
states_keys = list(state.keys()) | |
for key in states_keys: | |
if "waveform_decoder." not in key and "speaker_encoder." not in key: | |
del state[key] | |
self.load_state_dict(state) | |
if eval: | |
self.eval() | |
assert not self.training | |
self.waveform_decoder.remove_weight_norm() | |