Spaces:
Runtime error
Runtime error
| from vocoder.models.fatchord_version import WaveRNN | |
| from vocoder import hparams as hp | |
| from scipy.fft import rfft, rfftfreq | |
| from scipy import signal | |
| from denoiser.pretrained import master64 | |
| import librosa | |
| import numpy as np | |
| import torch | |
| import torchaudio | |
| import noisereduce as nr | |
| _model = None # type: WaveRNN | |
| def load_model(weights_fpath, verbose=True): | |
| global _model, _device | |
| if verbose: | |
| print("Building Wave-RNN") | |
| _model = WaveRNN( | |
| rnn_dims=hp.voc_rnn_dims, | |
| fc_dims=hp.voc_fc_dims, | |
| bits=hp.bits, | |
| pad=hp.voc_pad, | |
| upsample_factors=hp.voc_upsample_factors, | |
| feat_dims=hp.num_mels, | |
| compute_dims=hp.voc_compute_dims, | |
| res_out_dims=hp.voc_res_out_dims, | |
| res_blocks=hp.voc_res_blocks, | |
| hop_length=hp.hop_length, | |
| sample_rate=hp.sample_rate, | |
| mode=hp.voc_mode | |
| ) | |
| if torch.cuda.is_available(): | |
| _model = _model.cuda() | |
| _device = torch.device('cuda') | |
| else: | |
| _device = torch.device('cpu') | |
| if verbose: | |
| print("Loading model weights at %s" % weights_fpath) | |
| checkpoint = torch.load(weights_fpath, _device) | |
| _model.load_state_dict(checkpoint['model_state']) | |
| _model.eval() | |
| def is_loaded(): | |
| return _model is not None | |
| def infer_waveform(mel, normalize=True, batched=True, target=8000, overlap=800, | |
| progress_callback=None, crossfade=True): | |
| """ | |
| Infers the waveform of a mel spectrogram output by the synthesizer (the format must match | |
| that of the synthesizer!) | |
| :param normalize: | |
| :param batched: | |
| :param target: | |
| :param overlap: | |
| :return: | |
| """ | |
| if _model is None: | |
| raise Exception("Please load Wave-RNN in memory before using it") | |
| if normalize: | |
| mel = mel / hp.mel_max_abs_value | |
| mel = torch.from_numpy(mel[None, ...]) | |
| wav = _model.generate(mel, batched, target, overlap, hp.mu_law, progress_callback, crossfade=crossfade) | |
| wav = waveform_denoising(wav) | |
| return wav | |
| def waveform_denoising(wav): | |
| prop_decrease = hp.prop_decrease_low_freq if hp.sex else hp.prop_decrease_high_freq | |
| if torch.cuda.is_available(): | |
| _device = torch.device('cuda') | |
| else: | |
| _device = torch.device('cpu') | |
| model = master64().to(_device) | |
| noisy=torch.from_numpy(np.array([wav])).to(_device).float() | |
| estimate = model(noisy) | |
| estimate = estimate * (1-hp.dry) + noisy * hp.dry | |
| estimate = estimate[0].cpu().detach().numpy() | |
| return nr.reduce_noise(np.squeeze(estimate), hp.sample_rate, prop_decrease=prop_decrease) | |
| def get_dominant_freq(wav, name="fft"): | |
| import matplotlib.pyplot as plt | |
| N = len(wav) | |
| fft_wav = rfft(wav) | |
| fft_freq = np.real(rfftfreq(N, 1 / hp.sample_rate)) | |
| fft_least_index = np.where(fft_freq >= 60)[0][0] | |
| fft_max = max(fft_wav[fft_least_index: ]) | |
| fft_max_index = np.where(fft_wav == fft_max)[0][0] | |
| fft_max_freq = fft_freq[fft_max_index] | |
| # plt.clf() | |
| # plt.plot(fft_freq, fft_wav) | |
| # plt.savefig(f"{name}.png", dpi=300) | |
| return fft_max_freq |