Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import argparse | |
| import numpy as np | |
| from scipy.io.wavfile import write | |
| import torchaudio | |
| import utils | |
| from speechsr24k.speechsr import SynthesizerTrn as SpeechSR24 | |
| from speechsr48k.speechsr import SynthesizerTrn as SpeechSR48 | |
| seed = 1111 | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| np.random.seed(seed) | |
| def get_param_num(model): | |
| num_param = sum(param.numel() for param in model.parameters()) | |
| return num_param | |
| def SuperResoltuion(a, hierspeech): | |
| speechsr = hierspeech | |
| os.makedirs(a.output_dir, exist_ok=True) | |
| # Prompt load | |
| audio, sample_rate = torchaudio.load(a.input_speech) | |
| # support only single channel | |
| audio = audio[:1,:] | |
| # Resampling | |
| if sample_rate != 16000: | |
| audio = torchaudio.functional.resample(audio, sample_rate, 16000, resampling_method="kaiser_window") | |
| file_name = os.path.splitext(os.path.basename(a.input_speech))[0] | |
| ## SpeechSR (Optional) (16k Audio --> 24k or 48k Audio) | |
| with torch.no_grad(): | |
| converted_audio = speechsr(audio.unsqueeze(1).cuda()) | |
| converted_audio = converted_audio.squeeze() | |
| converted_audio = converted_audio / (torch.abs(converted_audio).max()) * 0.999 * 32767.0 | |
| converted_audio = converted_audio.cpu().numpy().astype('int16') | |
| file_name2 = "{}.wav".format(file_name) | |
| output_file = os.path.join(a.output_dir, file_name2) | |
| if a.output_sr == 48000: | |
| write(output_file, 48000, converted_audio) | |
| else: | |
| write(output_file, 24000, converted_audio) | |
| def model_load(a): | |
| if a.output_sr == 48000: | |
| speechsr = SpeechSR48(h_sr48.data.n_mel_channels, | |
| h_sr48.train.segment_size // h_sr48.data.hop_length, | |
| **h_sr48.model).cuda() | |
| utils.load_checkpoint(a.ckpt_sr48, speechsr, None) | |
| speechsr.eval() | |
| else: | |
| # 24000 Hz | |
| speechsr = SpeechSR24(h_sr.data.n_mel_channels, | |
| h_sr.train.segment_size // h_sr.data.hop_length, | |
| **h_sr.model).cuda() | |
| utils.load_checkpoint(a.ckpt_sr, speechsr, None) | |
| speechsr.eval() | |
| return speechsr | |
| def inference(a): | |
| speechsr = model_load(a) | |
| SuperResoltuion(a, speechsr) | |
| def main(): | |
| print('Initializing Inference Process..') | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--input_speech', default='example/reference_4.wav') | |
| parser.add_argument('--output_dir', default='SR_results') | |
| parser.add_argument('--ckpt_sr', type=str, default='./speechsr24k/G_340000.pth') | |
| parser.add_argument('--ckpt_sr48', type=str, default='./speechsr48k/G_100000.pth') | |
| parser.add_argument('--output_sr', type=float, default=48000) | |
| a = parser.parse_args() | |
| global device, h_sr, h_sr48 | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| h_sr = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt_sr)[0], 'config.json') ) | |
| h_sr48 = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt_sr48)[0], 'config.json') ) | |
| inference(a) | |
| if __name__ == '__main__': | |
| main() |