Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Commit 
							
							·
						
						5ec3488
	
1
								Parent(s):
							
							446d1e4
								
Upload 36 files
Browse files- main.py +37 -0
- packages.txt +2 -0
- requirements.txt +11 -0
- voicefixer/__init__.py +14 -0
- voicefixer/__main__.py +170 -0
- voicefixer/base.py +145 -0
- voicefixer/restorer/__init__.py +44 -0
- voicefixer/restorer/model.py +680 -0
- voicefixer/restorer/model_kqq_bn.py +186 -0
- voicefixer/restorer/modules.py +217 -0
- voicefixer/tools/__init__.py +11 -0
- voicefixer/tools/base.py +244 -0
- voicefixer/tools/io.py +44 -0
- voicefixer/tools/mel_scale.py +238 -0
- voicefixer/tools/modules/__init__.py +11 -0
- voicefixer/tools/modules/fDomainHelper.py +234 -0
- voicefixer/tools/modules/filters/f_2_64.mat +0 -0
- voicefixer/tools/modules/filters/f_4_64.mat +0 -0
- voicefixer/tools/modules/filters/f_8_64.mat +0 -0
- voicefixer/tools/modules/filters/h_2_64.mat +0 -0
- voicefixer/tools/modules/filters/h_4_64.mat +0 -0
- voicefixer/tools/modules/filters/h_8_64.mat +0 -0
- voicefixer/tools/modules/pqmf.py +116 -0
- voicefixer/tools/path.py +13 -0
- voicefixer/tools/pytorch_util.py +180 -0
- voicefixer/tools/random_.py +52 -0
- voicefixer/tools/wav.py +242 -0
- voicefixer/vocoder/__init__.py +30 -0
- voicefixer/vocoder/base.py +86 -0
- voicefixer/vocoder/config.py +316 -0
- voicefixer/vocoder/model/__init__.py +11 -0
- voicefixer/vocoder/model/generator.py +168 -0
- voicefixer/vocoder/model/modules.py +947 -0
- voicefixer/vocoder/model/pqmf.py +61 -0
- voicefixer/vocoder/model/res_msd.py +71 -0
- voicefixer/vocoder/model/util.py +135 -0
    	
        main.py
    ADDED
    
    | @@ -0,0 +1,37 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from voicefixer.base import VoiceFixer
         | 
| 2 | 
            +
            import streamlit as st
         | 
| 3 | 
            +
            from audio_recorder_streamlit import audio_recorder
         | 
| 4 | 
            +
            from io import BytesIO
         | 
| 5 | 
            +
            import soundfile as sf
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            st.set_page_config(page_title="VoiceFixer app", page_icon=":notes:")
         | 
| 8 | 
            +
            st.title("Voice Fixer App :notes:")
         | 
| 9 | 
            +
            st.write(
         | 
| 10 | 
            +
                """
         | 
| 11 | 
            +
                This app is a mix of [VoiceFixer Model](https://github.com/haoheliu/voicefixer), and a custom
         | 
| 12 | 
            +
                Streamlit component that [records audio](https://github.com/Joooohan/audio-recorder-streamlit) Online.
         | 
| 13 | 
            +
                Currently the app shows great results when removing background noises, but 
         | 
| 14 | 
            +
                speech improvements aren't as obvious.
         | 
| 15 | 
            +
                """)
         | 
| 16 | 
            +
            #Config files are on voicefixer/base and voicefixer/vocoder/config import
         | 
| 17 | 
            +
            # They were uploaded on hugging face
         | 
| 18 | 
            +
            voicefixer = VoiceFixer()
         | 
| 19 | 
            +
            audio_bytes = audio_recorder(
         | 
| 20 | 
            +
                pause_threshold= 1.5
         | 
| 21 | 
            +
            )
         | 
| 22 | 
            +
            try:
         | 
| 23 | 
            +
                data, samplerate = sf.read(BytesIO(audio_bytes))
         | 
| 24 | 
            +
                print(samplerate)
         | 
| 25 | 
            +
                sf.write("original.wav",data,samplerate)
         | 
| 26 | 
            +
                st.audio(audio_bytes, format = "audio/wav")
         | 
| 27 | 
            +
                if data.shape[0]>=10000:
         | 
| 28 | 
            +
                    voicefixer.restore(input="original.wav", # low quality .wav/.flac file
         | 
| 29 | 
            +
                                   output="enhanced_output.wav",
         | 
| 30 | 
            +
                                   cuda=False, # GPU acceleration
         | 
| 31 | 
            +
                                   mode=0)
         | 
| 32 | 
            +
                    st.write("The Audio without background noises and a little enhancement :ocean:")
         | 
| 33 | 
            +
                    st.audio("enhanced_output.wav")
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                else: st.warning("Recorded Audio is too short, try again :relieved:")#wink
         | 
| 36 | 
            +
            except:
         | 
| 37 | 
            +
                st.info("Try to record some audio :relieved:")
         | 
    	
        packages.txt
    ADDED
    
    | @@ -0,0 +1,2 @@ | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ffmpeg
         | 
| 2 | 
            +
            libsndfile1
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,11 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            audio_recorder_streamlit>=0.0.7
         | 
| 2 | 
            +
            soundfile>=0.9.0
         | 
| 3 | 
            +
            huggingface-hub>=0.11.1
         | 
| 4 | 
            +
            librosa>=0.8.1,<0.9.0
         | 
| 5 | 
            +
            torch>=1.7.0
         | 
| 6 | 
            +
            matplotlib
         | 
| 7 | 
            +
            progressbar
         | 
| 8 | 
            +
            torchlibrosa==0.0.7
         | 
| 9 | 
            +
            GitPython
         | 
| 10 | 
            +
            streamlit>=1.12.
         | 
| 11 | 
            +
            pyyaml
         | 
    	
        voicefixer/__init__.py
    ADDED
    
    | @@ -0,0 +1,14 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/env python
         | 
| 2 | 
            +
            # -*- encoding: utf-8 -*-
         | 
| 3 | 
            +
            """
         | 
| 4 | 
            +
            @File    :   __init__.py.py    
         | 
| 5 | 
            +
            @Contact :   [email protected]
         | 
| 6 | 
            +
            @License :   (C)Copyright 2020-2100
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            @Modify Time      @Author    @Version    @Desciption
         | 
| 9 | 
            +
            ------------      -------    --------    -----------
         | 
| 10 | 
            +
            9/14/21 12:31 AM   Haohe Liu      1.0         None
         | 
| 11 | 
            +
            """
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from voicefixer.vocoder.base import Vocoder
         | 
| 14 | 
            +
            from voicefixer.base import VoiceFixer
         | 
    	
        voicefixer/__main__.py
    ADDED
    
    | @@ -0,0 +1,170 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/python3
         | 
| 2 | 
            +
            from genericpath import exists
         | 
| 3 | 
            +
            import os.path
         | 
| 4 | 
            +
            import argparse
         | 
| 5 | 
            +
            from voicefixer import VoiceFixer
         | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            import os
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            def writefile(infile, outfile, mode, append_mode, cuda, verbose=False):
         | 
| 11 | 
            +
                if append_mode is True:
         | 
| 12 | 
            +
                    outbasename, outext = os.path.splitext(os.path.basename(outfile))
         | 
| 13 | 
            +
                    outfile = os.path.join(
         | 
| 14 | 
            +
                        os.path.dirname(outfile), "{}-mode{}{}".format(outbasename, mode, outext)
         | 
| 15 | 
            +
                    )
         | 
| 16 | 
            +
                if verbose:
         | 
| 17 | 
            +
                    print("Processing {}, mode={}".format(infile, mode))
         | 
| 18 | 
            +
                voicefixer.restore(input=infile, output=outfile, cuda=cuda, mode=int(mode))
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            def check_arguments(args):
         | 
| 21 | 
            +
                process_file, process_folder = len(args.infile) != 0, len(args.infolder) != 0
         | 
| 22 | 
            +
                # assert len(args.infile) == 0 and len(args.outfile) == 0 or process_file, \
         | 
| 23 | 
            +
                #         "Error: You should give the input and output file path at the same time. The input and output file path we receive is %s and %s" % (args.infile, args.outfile)
         | 
| 24 | 
            +
                # assert len(args.infolder) == 0 and len(args.outfolder) == 0 or process_folder, \
         | 
| 25 | 
            +
                #         "Error: You should give the input and output folder path at the same time. The input and output folder path we receive is %s and %s" % (args.infolder, args.outfolder)
         | 
| 26 | 
            +
                assert (
         | 
| 27 | 
            +
                    process_file or process_folder
         | 
| 28 | 
            +
                ), "Error: You need to specify a input file path (--infile) or a input folder path (--infolder) to proceed. For more information please run: voicefixer -h"
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                # if(args.cuda and not torch.cuda.is_available()):
         | 
| 31 | 
            +
                #     print("Warning: You set --cuda while no cuda device found on your machine. We will use CPU instead.")
         | 
| 32 | 
            +
                if process_file:
         | 
| 33 | 
            +
                    assert os.path.exists(args.infile), (
         | 
| 34 | 
            +
                        "Error: The input file %s is not found." % args.infile
         | 
| 35 | 
            +
                    )
         | 
| 36 | 
            +
                    output_dirname = os.path.dirname(args.outfile)
         | 
| 37 | 
            +
                    if len(output_dirname) > 1:
         | 
| 38 | 
            +
                        os.makedirs(output_dirname, exist_ok=True)
         | 
| 39 | 
            +
                if process_folder:
         | 
| 40 | 
            +
                    assert os.path.exists(args.infolder), (
         | 
| 41 | 
            +
                        "Error: The input folder %s is not found." % args.infile
         | 
| 42 | 
            +
                    )
         | 
| 43 | 
            +
                    output_dirname = args.outfolder
         | 
| 44 | 
            +
                    if len(output_dirname) > 1:
         | 
| 45 | 
            +
                        os.makedirs(args.outfolder, exist_ok=True)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                return process_file, process_folder
         | 
| 48 | 
            +
             | 
| 49 | 
            +
             | 
| 50 | 
            +
            if __name__ == "__main__":
         | 
| 51 | 
            +
                parser = argparse.ArgumentParser(
         | 
| 52 | 
            +
                    description="VoiceFixer - restores degraded speech"
         | 
| 53 | 
            +
                )
         | 
| 54 | 
            +
                parser.add_argument(
         | 
| 55 | 
            +
                    "-i",
         | 
| 56 | 
            +
                    "--infile",
         | 
| 57 | 
            +
                    type=str,
         | 
| 58 | 
            +
                    default="",
         | 
| 59 | 
            +
                    help="An input file to be processed by VoiceFixer.",
         | 
| 60 | 
            +
                )
         | 
| 61 | 
            +
                parser.add_argument(
         | 
| 62 | 
            +
                    "-o",
         | 
| 63 | 
            +
                    "--outfile",
         | 
| 64 | 
            +
                    type=str,
         | 
| 65 | 
            +
                    default="outfile.wav",
         | 
| 66 | 
            +
                    help="An output file to store the result.",
         | 
| 67 | 
            +
                )
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                parser.add_argument(
         | 
| 70 | 
            +
                    "-ifdr",
         | 
| 71 | 
            +
                    "--infolder",
         | 
| 72 | 
            +
                    type=str,
         | 
| 73 | 
            +
                    default="",
         | 
| 74 | 
            +
                    help="Input folder. Place all your wav file that need process in this folder.",
         | 
| 75 | 
            +
                )
         | 
| 76 | 
            +
                parser.add_argument(
         | 
| 77 | 
            +
                    "-ofdr",
         | 
| 78 | 
            +
                    "--outfolder",
         | 
| 79 | 
            +
                    type=str,
         | 
| 80 | 
            +
                    default="outfolder",
         | 
| 81 | 
            +
                    help="Output folder. The processed files will be stored in this folder.",
         | 
| 82 | 
            +
                )
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                parser.add_argument(
         | 
| 85 | 
            +
                    "--mode", help="mode", choices=["0", "1", "2", "all"], default="0"
         | 
| 86 | 
            +
                )
         | 
| 87 | 
            +
                parser.add_argument('--disable-cuda', help='Set this flag if you do not want to use your gpu.', default=False, action="store_true")
         | 
| 88 | 
            +
                parser.add_argument(
         | 
| 89 | 
            +
                    "--silent",
         | 
| 90 | 
            +
                    help="Set this flag if you do not want to see any message.",
         | 
| 91 | 
            +
                    default=False,
         | 
| 92 | 
            +
                    action="store_true",
         | 
| 93 | 
            +
                )
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                args = parser.parse_args()
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                if torch.cuda.is_available() and not args.disable_cuda:
         | 
| 98 | 
            +
                    cuda = True
         | 
| 99 | 
            +
                else:
         | 
| 100 | 
            +
                    cuda = False
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                process_file, process_folder = check_arguments(args)
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                if not args.silent:
         | 
| 105 | 
            +
                    print("Initializing VoiceFixer")
         | 
| 106 | 
            +
                voicefixer = VoiceFixer()
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                if not args.silent:
         | 
| 109 | 
            +
                    print("Start processing the input file %s." % args.infile)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                if process_file:
         | 
| 112 | 
            +
                    audioext = os.path.splitext(os.path.basename(args.infile))[-1]
         | 
| 113 | 
            +
                    if audioext != ".wav":
         | 
| 114 | 
            +
                        raise ValueError(
         | 
| 115 | 
            +
                            "Error: Error processing the input file. We only support the .wav format currently. Please convert your %s format to .wav. Thanks."
         | 
| 116 | 
            +
                            % audioext
         | 
| 117 | 
            +
                        )
         | 
| 118 | 
            +
                    if args.mode == "all":
         | 
| 119 | 
            +
                        for file_mode in range(3):
         | 
| 120 | 
            +
                            writefile(
         | 
| 121 | 
            +
                                args.infile,
         | 
| 122 | 
            +
                                args.outfile,
         | 
| 123 | 
            +
                                file_mode,
         | 
| 124 | 
            +
                                True,
         | 
| 125 | 
            +
                                cuda,
         | 
| 126 | 
            +
                                verbose=not args.silent,
         | 
| 127 | 
            +
                            )
         | 
| 128 | 
            +
                    else:
         | 
| 129 | 
            +
                        writefile(
         | 
| 130 | 
            +
                            args.infile,
         | 
| 131 | 
            +
                            args.outfile,
         | 
| 132 | 
            +
                            args.mode,
         | 
| 133 | 
            +
                            False,
         | 
| 134 | 
            +
                            cuda,
         | 
| 135 | 
            +
                            verbose=not args.silent,
         | 
| 136 | 
            +
                        )
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                if process_folder:
         | 
| 139 | 
            +
                    if not args.silent:
         | 
| 140 | 
            +
                        files = [
         | 
| 141 | 
            +
                            file
         | 
| 142 | 
            +
                            for file in os.listdir(args.infolder)
         | 
| 143 | 
            +
                            if (os.path.splitext(os.path.basename(file))[-1] == ".wav")
         | 
| 144 | 
            +
                        ]
         | 
| 145 | 
            +
                        print(
         | 
| 146 | 
            +
                            "Found %s .wav files in the input folder %s. Start processing."
         | 
| 147 | 
            +
                            % (len(files), args.infolder)
         | 
| 148 | 
            +
                        )
         | 
| 149 | 
            +
                    for file in os.listdir(args.infolder):
         | 
| 150 | 
            +
                        outbasename, outext = os.path.splitext(os.path.basename(file))
         | 
| 151 | 
            +
                        in_file = os.path.join(args.infolder, file)
         | 
| 152 | 
            +
                        out_file = os.path.join(args.outfolder, file)
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                        if args.mode == "all":
         | 
| 155 | 
            +
                            for file_mode in range(3):
         | 
| 156 | 
            +
                                writefile(
         | 
| 157 | 
            +
                                    in_file,
         | 
| 158 | 
            +
                                    out_file,
         | 
| 159 | 
            +
                                    file_mode,
         | 
| 160 | 
            +
                                    True,
         | 
| 161 | 
            +
                                    cuda,
         | 
| 162 | 
            +
                                    verbose=not args.silent,
         | 
| 163 | 
            +
                                )
         | 
| 164 | 
            +
                        else:
         | 
| 165 | 
            +
                            writefile(
         | 
| 166 | 
            +
                                in_file, out_file, args.mode, False, cuda, verbose=not args.silent
         | 
| 167 | 
            +
                            )
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                if not args.silent:
         | 
| 170 | 
            +
                    print("Done")
         | 
    	
        voicefixer/base.py
    ADDED
    
    | @@ -0,0 +1,145 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import librosa.display
         | 
| 2 | 
            +
            from voicefixer.tools.pytorch_util import *
         | 
| 3 | 
            +
            from voicefixer.tools.wav import *
         | 
| 4 | 
            +
            from voicefixer.restorer.model import VoiceFixer as voicefixer_fe
         | 
| 5 | 
            +
            import os
         | 
| 6 | 
            +
            from huggingface_hub import hf_hub_download
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            path_to_ckpt = hf_hub_download(repo_id="jlmarrugom/voice_fixer", filename="vf.ckpt")
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            EPS = 1e-8
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            class VoiceFixer(nn.Module):
         | 
| 15 | 
            +
                def __init__(self):
         | 
| 16 | 
            +
                    super(VoiceFixer, self).__init__()
         | 
| 17 | 
            +
                    self._model = voicefixer_fe(channels=2, sample_rate=44100)
         | 
| 18 | 
            +
                    # print(os.path.join(os.path.expanduser('~'), ".cache/voicefixer/analysis_module/checkpoints/epoch=15_trimed_bn.ckpt"))
         | 
| 19 | 
            +
                    self.analysis_module_ckpt = path_to_ckpt #"models/vf.ckpt"
         | 
| 20 | 
            +
                    if(not os.path.exists(self.analysis_module_ckpt)):
         | 
| 21 | 
            +
                        raise RuntimeError("Error 0: The checkpoint for analysis module (vf.ckpt) is not found in ~/.cache/voicefixer/analysis_module/checkpoints. \
         | 
| 22 | 
            +
                                            By default the checkpoint should be download automatically by this program. Something bad may happened.\
         | 
| 23 | 
            +
                                            But don't worry! Alternatively you can download it directly from Zenodo: https://zenodo.org/record/5600188/files/vf.ckpt?download=1.")
         | 
| 24 | 
            +
                    self._model.load_state_dict(
         | 
| 25 | 
            +
                        torch.load(
         | 
| 26 | 
            +
                            self.analysis_module_ckpt
         | 
| 27 | 
            +
                        )
         | 
| 28 | 
            +
                    )
         | 
| 29 | 
            +
                    self._model.eval()
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                def _load_wav_energy(self, path, sample_rate, threshold=0.95):
         | 
| 32 | 
            +
                    wav_10k, _ = librosa.load(path, sr=sample_rate)
         | 
| 33 | 
            +
                    stft = np.log10(np.abs(librosa.stft(wav_10k)) + 1.0)
         | 
| 34 | 
            +
                    fbins = stft.shape[0]
         | 
| 35 | 
            +
                    e_stft = np.sum(stft, axis=1)
         | 
| 36 | 
            +
                    for i in range(e_stft.shape[0]):
         | 
| 37 | 
            +
                        e_stft[-i - 1] = np.sum(e_stft[: -i - 1])
         | 
| 38 | 
            +
                    total = e_stft[-1]
         | 
| 39 | 
            +
                    for i in range(e_stft.shape[0]):
         | 
| 40 | 
            +
                        if e_stft[i] < total * threshold:
         | 
| 41 | 
            +
                            continue
         | 
| 42 | 
            +
                        else:
         | 
| 43 | 
            +
                            break
         | 
| 44 | 
            +
                    return wav_10k, int((sample_rate // 2) * (i / fbins))
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                def _load_wav(self, path, sample_rate, threshold=0.95):
         | 
| 47 | 
            +
                    wav_10k, _ = librosa.load(path, sr=sample_rate)
         | 
| 48 | 
            +
                    return wav_10k
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                def _amp_to_original_f(self, mel_sp_est, mel_sp_target, cutoff=0.2):
         | 
| 51 | 
            +
                    freq_dim = mel_sp_target.size()[-1]
         | 
| 52 | 
            +
                    mel_sp_est_low, mel_sp_target_low = (
         | 
| 53 | 
            +
                        mel_sp_est[..., 5 : int(freq_dim * cutoff)],
         | 
| 54 | 
            +
                        mel_sp_target[..., 5 : int(freq_dim * cutoff)],
         | 
| 55 | 
            +
                    )
         | 
| 56 | 
            +
                    energy_est, energy_target = torch.mean(mel_sp_est_low, dim=(2, 3)), torch.mean(
         | 
| 57 | 
            +
                        mel_sp_target_low, dim=(2, 3)
         | 
| 58 | 
            +
                    )
         | 
| 59 | 
            +
                    amp_ratio = energy_target / energy_est
         | 
| 60 | 
            +
                    return mel_sp_est * amp_ratio[..., None, None], mel_sp_target
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                def _trim_center(self, est, ref):
         | 
| 63 | 
            +
                    diff = np.abs(est.shape[-1] - ref.shape[-1])
         | 
| 64 | 
            +
                    if est.shape[-1] == ref.shape[-1]:
         | 
| 65 | 
            +
                        return est, ref
         | 
| 66 | 
            +
                    elif est.shape[-1] > ref.shape[-1]:
         | 
| 67 | 
            +
                        min_len = min(est.shape[-1], ref.shape[-1])
         | 
| 68 | 
            +
                        est, ref = est[..., int(diff // 2) : -int(diff // 2)], ref
         | 
| 69 | 
            +
                        est, ref = est[..., :min_len], ref[..., :min_len]
         | 
| 70 | 
            +
                        return est, ref
         | 
| 71 | 
            +
                    else:
         | 
| 72 | 
            +
                        min_len = min(est.shape[-1], ref.shape[-1])
         | 
| 73 | 
            +
                        est, ref = est, ref[..., int(diff // 2) : -int(diff // 2)]
         | 
| 74 | 
            +
                        est, ref = est[..., :min_len], ref[..., :min_len]
         | 
| 75 | 
            +
                        return est, ref
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                def _pre(self, model, input, cuda):
         | 
| 78 | 
            +
                    input = input[None, None, ...]
         | 
| 79 | 
            +
                    input = torch.tensor(input)
         | 
| 80 | 
            +
                    input = try_tensor_cuda(input, cuda=cuda)
         | 
| 81 | 
            +
                    sp, _, _ = model.f_helper.wav_to_spectrogram_phase(input)
         | 
| 82 | 
            +
                    mel_orig = model.mel(sp.permute(0, 1, 3, 2)).permute(0, 1, 3, 2)
         | 
| 83 | 
            +
                    # return models.to_log(sp), models.to_log(mel_orig)
         | 
| 84 | 
            +
                    return sp, mel_orig
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                def remove_higher_frequency(self, wav, ratio=0.95):
         | 
| 87 | 
            +
                    stft = librosa.stft(wav)
         | 
| 88 | 
            +
                    real, img = np.real(stft), np.imag(stft)
         | 
| 89 | 
            +
                    mag = (real**2 + img**2) ** 0.5
         | 
| 90 | 
            +
                    cos, sin = real / (mag + EPS), img / (mag + EPS)
         | 
| 91 | 
            +
                    spec = np.abs(stft)  # [1025,T]
         | 
| 92 | 
            +
                    feature = spec.copy()
         | 
| 93 | 
            +
                    feature = np.log10(feature + EPS)
         | 
| 94 | 
            +
                    feature[feature < 0] = 0
         | 
| 95 | 
            +
                    energy_level = np.sum(feature, axis=1)
         | 
| 96 | 
            +
                    threshold = np.sum(energy_level) * ratio
         | 
| 97 | 
            +
                    curent_level, i = energy_level[0], 0
         | 
| 98 | 
            +
                    while i < energy_level.shape[0] and curent_level < threshold:
         | 
| 99 | 
            +
                        curent_level += energy_level[i + 1, ...]
         | 
| 100 | 
            +
                        i += 1
         | 
| 101 | 
            +
                    spec[i:, ...] = np.zeros_like(spec[i:, ...])
         | 
| 102 | 
            +
                    stft = spec * cos + 1j * spec * sin
         | 
| 103 | 
            +
                    return librosa.istft(stft)
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                @torch.no_grad()
         | 
| 106 | 
            +
                def restore_inmem(self, wav_10k, cuda=False, mode=0, your_vocoder_func=None):
         | 
| 107 | 
            +
                    check_cuda_availability(cuda=cuda)
         | 
| 108 | 
            +
                    self._model = try_tensor_cuda(self._model, cuda=cuda)
         | 
| 109 | 
            +
                    if mode == 0:
         | 
| 110 | 
            +
                        self._model.eval()
         | 
| 111 | 
            +
                    elif mode == 1:
         | 
| 112 | 
            +
                        self._model.eval()
         | 
| 113 | 
            +
                    elif mode == 2:
         | 
| 114 | 
            +
                        self._model.train()  # More effective on seriously demaged speech
         | 
| 115 | 
            +
                    res = []
         | 
| 116 | 
            +
                    seg_length = 44100 * 30
         | 
| 117 | 
            +
                    break_point = seg_length
         | 
| 118 | 
            +
                    while break_point < wav_10k.shape[0] + seg_length:
         | 
| 119 | 
            +
                        segment = wav_10k[break_point - seg_length : break_point]
         | 
| 120 | 
            +
                        if mode == 1:
         | 
| 121 | 
            +
                            segment = self.remove_higher_frequency(segment)
         | 
| 122 | 
            +
                        sp, mel_noisy = self._pre(self._model, segment, cuda)
         | 
| 123 | 
            +
                        out_model = self._model(sp, mel_noisy)
         | 
| 124 | 
            +
                        denoised_mel = from_log(out_model["mel"])
         | 
| 125 | 
            +
                        if your_vocoder_func is None:
         | 
| 126 | 
            +
                            out = self._model.vocoder(denoised_mel, cuda=cuda)
         | 
| 127 | 
            +
                        else:
         | 
| 128 | 
            +
                            out = your_vocoder_func(denoised_mel)
         | 
| 129 | 
            +
                        # unify energy
         | 
| 130 | 
            +
                        if torch.max(torch.abs(out)) > 1.0:
         | 
| 131 | 
            +
                            out = out / torch.max(torch.abs(out))
         | 
| 132 | 
            +
                            print("Warning: Exceed energy limit,", input)
         | 
| 133 | 
            +
                        # frame alignment
         | 
| 134 | 
            +
                        out, _ = self._trim_center(out, segment)
         | 
| 135 | 
            +
                        res.append(out)
         | 
| 136 | 
            +
                        break_point += seg_length
         | 
| 137 | 
            +
                    out = torch.cat(res, -1)
         | 
| 138 | 
            +
                    return tensor2numpy(out.squeeze(0))
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                def restore(self, input, output, cuda=False, mode=0, your_vocoder_func=None):
         | 
| 141 | 
            +
                    wav_10k = self._load_wav(input, sample_rate=44100)
         | 
| 142 | 
            +
                    out_np_wav = self.restore_inmem(
         | 
| 143 | 
            +
                        wav_10k, cuda=cuda, mode=mode, your_vocoder_func=your_vocoder_func
         | 
| 144 | 
            +
                    )
         | 
| 145 | 
            +
                    save_wave(out_np_wav, fname=output, sample_rate=44100)
         | 
    	
        voicefixer/restorer/__init__.py
    ADDED
    
    | @@ -0,0 +1,44 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/env python
         | 
| 2 | 
            +
            # -*- encoding: utf-8 -*-
         | 
| 3 | 
            +
            """
         | 
| 4 | 
            +
            @File    :   __init__.py.py    
         | 
| 5 | 
            +
            @Contact :   [email protected]
         | 
| 6 | 
            +
            @License :   (C)Copyright 2020-2100
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            @Modify Time      @Author    @Version    @Desciption
         | 
| 9 | 
            +
            ------------      -------    --------    -----------
         | 
| 10 | 
            +
            9/14/21 12:31 AM   Haohe Liu      1.0         None
         | 
| 11 | 
            +
            """
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            import os
         | 
| 14 | 
            +
            import torch
         | 
| 15 | 
            +
            import urllib.request
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            meta = {
         | 
| 18 | 
            +
                "voicefixer_fe": {
         | 
| 19 | 
            +
                    "path": os.path.join(
         | 
| 20 | 
            +
                        os.path.expanduser("~"),
         | 
| 21 | 
            +
                        ".cache/voicefixer/analysis_module/checkpoints/vf.ckpt",
         | 
| 22 | 
            +
                    ),
         | 
| 23 | 
            +
                    "url": "https://zenodo.org/record/5600188/files/vf.ckpt?download=1",
         | 
| 24 | 
            +
                },
         | 
| 25 | 
            +
            }
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            if not os.path.exists(meta["voicefixer_fe"]["path"]):
         | 
| 28 | 
            +
                os.makedirs(os.path.dirname(meta["voicefixer_fe"]["path"]), exist_ok=True)
         | 
| 29 | 
            +
                print("Downloading the main structure of voicefixer")
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                urllib.request.urlretrieve(
         | 
| 32 | 
            +
                    meta["voicefixer_fe"]["url"], meta["voicefixer_fe"]["path"]
         | 
| 33 | 
            +
                )
         | 
| 34 | 
            +
                print(
         | 
| 35 | 
            +
                    "Weights downloaded in: {} Size: {}".format(
         | 
| 36 | 
            +
                        meta["voicefixer_fe"]["path"],
         | 
| 37 | 
            +
                        os.path.getsize(meta["voicefixer_fe"]["path"]),
         | 
| 38 | 
            +
                    )
         | 
| 39 | 
            +
                )
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                # cmd = "wget "+ meta["voicefixer_fe"]['url'] + " -O " + meta["voicefixer_fe"]['path']
         | 
| 42 | 
            +
                # os.system(cmd)
         | 
| 43 | 
            +
                # temp = torch.load(meta["voicefixer_fe"]['path'])
         | 
| 44 | 
            +
                # torch.save(temp['state_dict'], os.path.join(os.path.expanduser('~'), ".cache/voicefixer/analysis_module/checkpoints/vf.ckpt"))
         | 
    	
        voicefixer/restorer/model.py
    ADDED
    
    | @@ -0,0 +1,680 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # import pytorch_lightning as pl
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch.utils
         | 
| 4 | 
            +
            from voicefixer.tools.mel_scale import MelScale
         | 
| 5 | 
            +
            import torch.utils.data
         | 
| 6 | 
            +
            import matplotlib.pyplot as plt
         | 
| 7 | 
            +
            import librosa.display
         | 
| 8 | 
            +
            from voicefixer.vocoder.base import Vocoder
         | 
| 9 | 
            +
            from voicefixer.tools.pytorch_util import *
         | 
| 10 | 
            +
            from voicefixer.restorer.model_kqq_bn import UNetResComplex_100Mb
         | 
| 11 | 
            +
            from voicefixer.tools.random_ import *
         | 
| 12 | 
            +
            from voicefixer.tools.wav import *
         | 
| 13 | 
            +
            from voicefixer.tools.modules.fDomainHelper import FDomainHelper
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from voicefixer.tools.io import load_json, write_json
         | 
| 16 | 
            +
            from matplotlib import cm
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
         | 
| 19 | 
            +
            EPS = 1e-8
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            class BN_GRU(torch.nn.Module):
         | 
| 23 | 
            +
                def __init__(
         | 
| 24 | 
            +
                    self,
         | 
| 25 | 
            +
                    input_dim,
         | 
| 26 | 
            +
                    hidden_dim,
         | 
| 27 | 
            +
                    layer=1,
         | 
| 28 | 
            +
                    bidirectional=False,
         | 
| 29 | 
            +
                    batchnorm=True,
         | 
| 30 | 
            +
                    dropout=0.0,
         | 
| 31 | 
            +
                ):
         | 
| 32 | 
            +
                    super(BN_GRU, self).__init__()
         | 
| 33 | 
            +
                    self.batchnorm = batchnorm
         | 
| 34 | 
            +
                    if batchnorm:
         | 
| 35 | 
            +
                        self.bn = nn.BatchNorm2d(1)
         | 
| 36 | 
            +
                    self.gru = torch.nn.GRU(
         | 
| 37 | 
            +
                        input_size=input_dim,
         | 
| 38 | 
            +
                        hidden_size=hidden_dim,
         | 
| 39 | 
            +
                        num_layers=layer,
         | 
| 40 | 
            +
                        bidirectional=bidirectional,
         | 
| 41 | 
            +
                        dropout=dropout,
         | 
| 42 | 
            +
                        batch_first=True,
         | 
| 43 | 
            +
                    )
         | 
| 44 | 
            +
                    self.init_weights()
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                def init_weights(self):
         | 
| 47 | 
            +
                    for m in self.modules():
         | 
| 48 | 
            +
                        if type(m) in [nn.GRU, nn.LSTM, nn.RNN]:
         | 
| 49 | 
            +
                            for name, param in m.named_parameters():
         | 
| 50 | 
            +
                                if "weight_ih" in name:
         | 
| 51 | 
            +
                                    torch.nn.init.xavier_uniform_(param.data)
         | 
| 52 | 
            +
                                elif "weight_hh" in name:
         | 
| 53 | 
            +
                                    torch.nn.init.orthogonal_(param.data)
         | 
| 54 | 
            +
                                elif "bias" in name:
         | 
| 55 | 
            +
                                    param.data.fill_(0)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                def forward(self, inputs):
         | 
| 58 | 
            +
                    # (batch, 1, seq, feature)
         | 
| 59 | 
            +
                    if self.batchnorm:
         | 
| 60 | 
            +
                        inputs = self.bn(inputs)
         | 
| 61 | 
            +
                    out, _ = self.gru(inputs.squeeze(1))
         | 
| 62 | 
            +
                    return out.unsqueeze(1)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
             | 
| 65 | 
            +
            class Generator(nn.Module):
         | 
| 66 | 
            +
                def __init__(self, n_mel, hidden, channels):
         | 
| 67 | 
            +
                    super(Generator, self).__init__()
         | 
| 68 | 
            +
                    # todo the currently running trail don't have dropout
         | 
| 69 | 
            +
                    self.denoiser = nn.Sequential(
         | 
| 70 | 
            +
                        nn.BatchNorm2d(1),
         | 
| 71 | 
            +
                        nn.Linear(n_mel, n_mel * 2),
         | 
| 72 | 
            +
                        nn.ReLU(inplace=True),
         | 
| 73 | 
            +
                        nn.BatchNorm2d(1),
         | 
| 74 | 
            +
                        nn.Linear(n_mel * 2, n_mel * 4),
         | 
| 75 | 
            +
                        nn.Dropout(0.5),
         | 
| 76 | 
            +
                        nn.ReLU(inplace=True),
         | 
| 77 | 
            +
                        BN_GRU(
         | 
| 78 | 
            +
                            input_dim=n_mel * 4,
         | 
| 79 | 
            +
                            hidden_dim=n_mel * 2,
         | 
| 80 | 
            +
                            bidirectional=True,
         | 
| 81 | 
            +
                            layer=2,
         | 
| 82 | 
            +
                            batchnorm=True,
         | 
| 83 | 
            +
                        ),
         | 
| 84 | 
            +
                        BN_GRU(
         | 
| 85 | 
            +
                            input_dim=n_mel * 4,
         | 
| 86 | 
            +
                            hidden_dim=n_mel * 2,
         | 
| 87 | 
            +
                            bidirectional=True,
         | 
| 88 | 
            +
                            layer=2,
         | 
| 89 | 
            +
                            batchnorm=True,
         | 
| 90 | 
            +
                        ),
         | 
| 91 | 
            +
                        nn.BatchNorm2d(1),
         | 
| 92 | 
            +
                        nn.ReLU(inplace=True),
         | 
| 93 | 
            +
                        nn.Linear(n_mel * 4, n_mel * 4),
         | 
| 94 | 
            +
                        nn.Dropout(0.5),
         | 
| 95 | 
            +
                        nn.BatchNorm2d(1),
         | 
| 96 | 
            +
                        nn.ReLU(inplace=True),
         | 
| 97 | 
            +
                        nn.Linear(n_mel * 4, n_mel),
         | 
| 98 | 
            +
                        nn.Sigmoid(),
         | 
| 99 | 
            +
                    )
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    self.unet = UNetResComplex_100Mb(channels=channels)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                def forward(self, sp, mel_orig):
         | 
| 104 | 
            +
                    # Denoising
         | 
| 105 | 
            +
                    noisy = mel_orig.clone()
         | 
| 106 | 
            +
                    clean = self.denoiser(noisy) * noisy
         | 
| 107 | 
            +
                    x = to_log(clean.detach())
         | 
| 108 | 
            +
                    unet_in = torch.cat([to_log(mel_orig), x], dim=1)
         | 
| 109 | 
            +
                    # unet_in = lstm_out
         | 
| 110 | 
            +
                    unet_out = self.unet(unet_in)["mel"]
         | 
| 111 | 
            +
                    # masks
         | 
| 112 | 
            +
                    mel = unet_out + x
         | 
| 113 | 
            +
                    # todo mel and addition here are in log scales
         | 
| 114 | 
            +
                    return {
         | 
| 115 | 
            +
                        "mel": mel,
         | 
| 116 | 
            +
                        "lstm_out": unet_out,
         | 
| 117 | 
            +
                        "unet_out": unet_out,
         | 
| 118 | 
            +
                        "noisy": noisy,
         | 
| 119 | 
            +
                        "clean": clean,
         | 
| 120 | 
            +
                    }
         | 
| 121 | 
            +
             | 
| 122 | 
            +
             | 
| 123 | 
            +
            class VoiceFixer(nn.Module):
         | 
| 124 | 
            +
                def __init__(
         | 
| 125 | 
            +
                    self,
         | 
| 126 | 
            +
                    channels,
         | 
| 127 | 
            +
                    type_target="vocals",
         | 
| 128 | 
            +
                    nsrc=1,
         | 
| 129 | 
            +
                    loss="l1",
         | 
| 130 | 
            +
                    lr=0.002,
         | 
| 131 | 
            +
                    gamma=0.9,
         | 
| 132 | 
            +
                    batchsize=None,
         | 
| 133 | 
            +
                    frame_length=None,
         | 
| 134 | 
            +
                    sample_rate=None,
         | 
| 135 | 
            +
                    warm_up_steps=1000,
         | 
| 136 | 
            +
                    reduce_lr_steps=15000,
         | 
| 137 | 
            +
                    # datas
         | 
| 138 | 
            +
                    check_val_every_n_epoch=5,
         | 
| 139 | 
            +
                ):
         | 
| 140 | 
            +
                    super(VoiceFixer, self).__init__()
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    if sample_rate == 44100:
         | 
| 143 | 
            +
                        window_size = 2048
         | 
| 144 | 
            +
                        hop_size = 441
         | 
| 145 | 
            +
                        n_mel = 128
         | 
| 146 | 
            +
                    elif sample_rate == 24000:
         | 
| 147 | 
            +
                        window_size = 768
         | 
| 148 | 
            +
                        hop_size = 240
         | 
| 149 | 
            +
                        n_mel = 80
         | 
| 150 | 
            +
                    elif sample_rate == 16000:
         | 
| 151 | 
            +
                        window_size = 512
         | 
| 152 | 
            +
                        hop_size = 160
         | 
| 153 | 
            +
                        n_mel = 80
         | 
| 154 | 
            +
                    else:
         | 
| 155 | 
            +
                        raise ValueError(
         | 
| 156 | 
            +
                            "Error: Sample rate " + str(sample_rate) + " not supported"
         | 
| 157 | 
            +
                        )
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                    center = (True,)
         | 
| 160 | 
            +
                    pad_mode = "reflect"
         | 
| 161 | 
            +
                    window = "hann"
         | 
| 162 | 
            +
                    freeze_parameters = True
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    # self.save_hyperparameters()
         | 
| 165 | 
            +
                    self.nsrc = nsrc
         | 
| 166 | 
            +
                    self.type_target = type_target
         | 
| 167 | 
            +
                    self.channels = channels
         | 
| 168 | 
            +
                    self.lr = lr
         | 
| 169 | 
            +
                    self.generated = None
         | 
| 170 | 
            +
                    self.gamma = gamma
         | 
| 171 | 
            +
                    self.sample_rate = sample_rate
         | 
| 172 | 
            +
                    self.sample_rate = sample_rate
         | 
| 173 | 
            +
                    self.batchsize = batchsize
         | 
| 174 | 
            +
                    self.frame_length = frame_length
         | 
| 175 | 
            +
                    # self.hparams['channels'] = 2
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                    # self.am = AudioMetrics()
         | 
| 178 | 
            +
                    # self.im = ImgMetrics()
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                    self.vocoder = Vocoder(sample_rate=44100)
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    self.valid = None
         | 
| 183 | 
            +
                    self.fake = None
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                    self.train_step = 0
         | 
| 186 | 
            +
                    self.val_step = 0
         | 
| 187 | 
            +
                    self.val_result_save_dir = None
         | 
| 188 | 
            +
                    self.val_result_save_dir_step = None
         | 
| 189 | 
            +
                    self.downsample_ratio = 2**6  # This number equals 2^{#encoder_blcoks}
         | 
| 190 | 
            +
                    self.check_val_every_n_epoch = check_val_every_n_epoch
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                    self.f_helper = FDomainHelper(
         | 
| 193 | 
            +
                        window_size=window_size,
         | 
| 194 | 
            +
                        hop_size=hop_size,
         | 
| 195 | 
            +
                        center=center,
         | 
| 196 | 
            +
                        pad_mode=pad_mode,
         | 
| 197 | 
            +
                        window=window,
         | 
| 198 | 
            +
                        freeze_parameters=freeze_parameters,
         | 
| 199 | 
            +
                    )
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                    hidden = window_size // 2 + 1
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                    self.mel = MelScale(n_mels=n_mel, sample_rate=sample_rate, n_stft=hidden)
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                    # masking
         | 
| 206 | 
            +
                    self.generator = Generator(n_mel, hidden, channels)
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                    self.lr_lambda = lambda step: self.get_lr_lambda(
         | 
| 209 | 
            +
                        step,
         | 
| 210 | 
            +
                        gamma=self.gamma,
         | 
| 211 | 
            +
                        warm_up_steps=warm_up_steps,
         | 
| 212 | 
            +
                        reduce_lr_steps=reduce_lr_steps,
         | 
| 213 | 
            +
                    )
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                    self.lr_lambda_2 = lambda step: self.get_lr_lambda(
         | 
| 216 | 
            +
                        step, gamma=self.gamma, warm_up_steps=10, reduce_lr_steps=reduce_lr_steps
         | 
| 217 | 
            +
                    )
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                    self.mel_weight_44k_128 = (
         | 
| 220 | 
            +
                        torch.tensor(
         | 
| 221 | 
            +
                            [
         | 
| 222 | 
            +
                                19.40951426,
         | 
| 223 | 
            +
                                19.94047336,
         | 
| 224 | 
            +
                                20.4859038,
         | 
| 225 | 
            +
                                21.04629067,
         | 
| 226 | 
            +
                                21.62194148,
         | 
| 227 | 
            +
                                22.21335214,
         | 
| 228 | 
            +
                                22.8210215,
         | 
| 229 | 
            +
                                23.44529231,
         | 
| 230 | 
            +
                                24.08660962,
         | 
| 231 | 
            +
                                24.74541882,
         | 
| 232 | 
            +
                                25.42234287,
         | 
| 233 | 
            +
                                26.11770576,
         | 
| 234 | 
            +
                                26.83212784,
         | 
| 235 | 
            +
                                27.56615283,
         | 
| 236 | 
            +
                                28.32007747,
         | 
| 237 | 
            +
                                29.0947679,
         | 
| 238 | 
            +
                                29.89060111,
         | 
| 239 | 
            +
                                30.70832636,
         | 
| 240 | 
            +
                                31.54828121,
         | 
| 241 | 
            +
                                32.41121487,
         | 
| 242 | 
            +
                                33.29780773,
         | 
| 243 | 
            +
                                34.20865341,
         | 
| 244 | 
            +
                                35.14437675,
         | 
| 245 | 
            +
                                36.1056621,
         | 
| 246 | 
            +
                                37.09332763,
         | 
| 247 | 
            +
                                38.10795802,
         | 
| 248 | 
            +
                                39.15039691,
         | 
| 249 | 
            +
                                40.22119881,
         | 
| 250 | 
            +
                                41.32154931,
         | 
| 251 | 
            +
                                42.45172373,
         | 
| 252 | 
            +
                                43.61293329,
         | 
| 253 | 
            +
                                44.80609379,
         | 
| 254 | 
            +
                                46.031602,
         | 
| 255 | 
            +
                                47.29070223,
         | 
| 256 | 
            +
                                48.58427549,
         | 
| 257 | 
            +
                                49.91327905,
         | 
| 258 | 
            +
                                51.27863232,
         | 
| 259 | 
            +
                                52.68119708,
         | 
| 260 | 
            +
                                54.1222372,
         | 
| 261 | 
            +
                                55.60274206,
         | 
| 262 | 
            +
                                57.12364703,
         | 
| 263 | 
            +
                                58.68617876,
         | 
| 264 | 
            +
                                60.29148652,
         | 
| 265 | 
            +
                                61.94081306,
         | 
| 266 | 
            +
                                63.63501986,
         | 
| 267 | 
            +
                                65.37562658,
         | 
| 268 | 
            +
                                67.16408954,
         | 
| 269 | 
            +
                                69.00109084,
         | 
| 270 | 
            +
                                70.88850318,
         | 
| 271 | 
            +
                                72.82736101,
         | 
| 272 | 
            +
                                74.81985537,
         | 
| 273 | 
            +
                                76.86654792,
         | 
| 274 | 
            +
                                78.96885475,
         | 
| 275 | 
            +
                                81.12900906,
         | 
| 276 | 
            +
                                83.34840929,
         | 
| 277 | 
            +
                                85.62810662,
         | 
| 278 | 
            +
                                87.97005418,
         | 
| 279 | 
            +
                                90.37689804,
         | 
| 280 | 
            +
                                92.84887686,
         | 
| 281 | 
            +
                                95.38872881,
         | 
| 282 | 
            +
                                97.99777002,
         | 
| 283 | 
            +
                                100.67862715,
         | 
| 284 | 
            +
                                103.43232942,
         | 
| 285 | 
            +
                                106.26140638,
         | 
| 286 | 
            +
                                109.16827015,
         | 
| 287 | 
            +
                                112.15470471,
         | 
| 288 | 
            +
                                115.22184756,
         | 
| 289 | 
            +
                                118.37439245,
         | 
| 290 | 
            +
                                121.6122689,
         | 
| 291 | 
            +
                                124.93877158,
         | 
| 292 | 
            +
                                128.35661454,
         | 
| 293 | 
            +
                                131.86761321,
         | 
| 294 | 
            +
                                135.47417938,
         | 
| 295 | 
            +
                                139.18059494,
         | 
| 296 | 
            +
                                142.98713744,
         | 
| 297 | 
            +
                                146.89771854,
         | 
| 298 | 
            +
                                150.91684347,
         | 
| 299 | 
            +
                                155.0446638,
         | 
| 300 | 
            +
                                159.28614648,
         | 
| 301 | 
            +
                                163.64270198,
         | 
| 302 | 
            +
                                168.12035831,
         | 
| 303 | 
            +
                                172.71749158,
         | 
| 304 | 
            +
                                177.44220154,
         | 
| 305 | 
            +
                                182.29556933,
         | 
| 306 | 
            +
                                187.28286676,
         | 
| 307 | 
            +
                                192.40502126,
         | 
| 308 | 
            +
                                197.6682721,
         | 
| 309 | 
            +
                                203.07516896,
         | 
| 310 | 
            +
                                208.63088733,
         | 
| 311 | 
            +
                                214.33770931,
         | 
| 312 | 
            +
                                220.19910108,
         | 
| 313 | 
            +
                                226.22363072,
         | 
| 314 | 
            +
                                232.41087124,
         | 
| 315 | 
            +
                                238.76803591,
         | 
| 316 | 
            +
                                245.30079083,
         | 
| 317 | 
            +
                                252.01064464,
         | 
| 318 | 
            +
                                258.90261676,
         | 
| 319 | 
            +
                                265.98474,
         | 
| 320 | 
            +
                                273.26010248,
         | 
| 321 | 
            +
                                280.73496362,
         | 
| 322 | 
            +
                                288.41440094,
         | 
| 323 | 
            +
                                296.30489752,
         | 
| 324 | 
            +
                                304.41180337,
         | 
| 325 | 
            +
                                312.7377183,
         | 
| 326 | 
            +
                                321.28877878,
         | 
| 327 | 
            +
                                330.07870237,
         | 
| 328 | 
            +
                                339.10812951,
         | 
| 329 | 
            +
                                348.38276173,
         | 
| 330 | 
            +
                                357.91393924,
         | 
| 331 | 
            +
                                367.70513992,
         | 
| 332 | 
            +
                                377.76413924,
         | 
| 333 | 
            +
                                388.09467408,
         | 
| 334 | 
            +
                                398.70920178,
         | 
| 335 | 
            +
                                409.61813793,
         | 
| 336 | 
            +
                                420.81980127,
         | 
| 337 | 
            +
                                432.33215467,
         | 
| 338 | 
            +
                                444.16083117,
         | 
| 339 | 
            +
                                456.30919947,
         | 
| 340 | 
            +
                                468.78589276,
         | 
| 341 | 
            +
                                481.61325588,
         | 
| 342 | 
            +
                                494.78824596,
         | 
| 343 | 
            +
                                508.31969844,
         | 
| 344 | 
            +
                                522.2238331,
         | 
| 345 | 
            +
                                536.51163441,
         | 
| 346 | 
            +
                                551.18859414,
         | 
| 347 | 
            +
                                566.26142988,
         | 
| 348 | 
            +
                                581.75006061,
         | 
| 349 | 
            +
                                597.66210737,
         | 
| 350 | 
            +
                            ]
         | 
| 351 | 
            +
                        )
         | 
| 352 | 
            +
                        / 19.40951426
         | 
| 353 | 
            +
                    )
         | 
| 354 | 
            +
                    self.mel_weight_44k_128 = self.mel_weight_44k_128[None, None, None, ...]
         | 
| 355 | 
            +
             | 
| 356 | 
            +
                    self.g_loss_weight = 0.01
         | 
| 357 | 
            +
                    self.d_loss_weight = 1
         | 
| 358 | 
            +
             | 
| 359 | 
            +
                def get_vocoder(self):
         | 
| 360 | 
            +
                    return self.vocoder
         | 
| 361 | 
            +
             | 
| 362 | 
            +
                def get_f_helper(self):
         | 
| 363 | 
            +
                    return self.f_helper
         | 
| 364 | 
            +
             | 
| 365 | 
            +
                def get_lr_lambda(self, step, gamma, warm_up_steps, reduce_lr_steps):
         | 
| 366 | 
            +
                    r"""Get lr_lambda for LambdaLR. E.g.,
         | 
| 367 | 
            +
             | 
| 368 | 
            +
                    .. code-block: python
         | 
| 369 | 
            +
                        lr_lambda = lambda step: get_lr_lambda(step, warm_up_steps=1000, reduce_lr_steps=10000)
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                        from torch.optim.lr_scheduler import LambdaLR
         | 
| 372 | 
            +
                        LambdaLR(optimizer, lr_lambda)
         | 
| 373 | 
            +
                    """
         | 
| 374 | 
            +
                    if step <= warm_up_steps:
         | 
| 375 | 
            +
                        return step / warm_up_steps
         | 
| 376 | 
            +
                    else:
         | 
| 377 | 
            +
                        return gamma ** (step // reduce_lr_steps)
         | 
| 378 | 
            +
             | 
| 379 | 
            +
                def init_weights(self, module: nn.Module):
         | 
| 380 | 
            +
                    for m in module.modules():
         | 
| 381 | 
            +
                        if type(m) in [nn.GRU, nn.LSTM, nn.RNN]:
         | 
| 382 | 
            +
                            for name, param in m.named_parameters():
         | 
| 383 | 
            +
                                if "weight_ih" in name:
         | 
| 384 | 
            +
                                    torch.nn.init.xavier_uniform_(param.data)
         | 
| 385 | 
            +
                                elif "weight_hh" in name:
         | 
| 386 | 
            +
                                    torch.nn.init.orthogonal_(param.data)
         | 
| 387 | 
            +
                                elif "bias" in name:
         | 
| 388 | 
            +
                                    param.data.fill_(0)
         | 
| 389 | 
            +
             | 
| 390 | 
            +
                def pre(self, input):
         | 
| 391 | 
            +
                    sp, _, _ = self.f_helper.wav_to_spectrogram_phase(input)
         | 
| 392 | 
            +
                    mel_orig = self.mel(sp.permute(0, 1, 3, 2)).permute(0, 1, 3, 2)
         | 
| 393 | 
            +
                    return sp, mel_orig
         | 
| 394 | 
            +
             | 
| 395 | 
            +
                def forward(self, sp, mel_orig):
         | 
| 396 | 
            +
                    """
         | 
| 397 | 
            +
                    Args:
         | 
| 398 | 
            +
                      input: (batch_size, channels_num, segment_samples)
         | 
| 399 | 
            +
             | 
| 400 | 
            +
                    Outputs:
         | 
| 401 | 
            +
                      output_dict: {
         | 
| 402 | 
            +
                        'wav': (batch_size, channels_num, segment_samples),
         | 
| 403 | 
            +
                        'sp': (batch_size, channels_num, time_steps, freq_bins)}
         | 
| 404 | 
            +
                    """
         | 
| 405 | 
            +
                    return self.generator(sp, mel_orig)
         | 
| 406 | 
            +
             | 
| 407 | 
            +
                def configure_optimizers(self):
         | 
| 408 | 
            +
                    optimizer_g = torch.optim.Adam(
         | 
| 409 | 
            +
                        [{"params": self.generator.parameters()}],
         | 
| 410 | 
            +
                        lr=self.lr,
         | 
| 411 | 
            +
                        amsgrad=True,
         | 
| 412 | 
            +
                        betas=(0.5, 0.999),
         | 
| 413 | 
            +
                    )
         | 
| 414 | 
            +
                    optimizer_d = torch.optim.Adam(
         | 
| 415 | 
            +
                        [{"params": self.discriminator.parameters()}],
         | 
| 416 | 
            +
                        lr=self.lr,
         | 
| 417 | 
            +
                        amsgrad=True,
         | 
| 418 | 
            +
                        betas=(0.5, 0.999),
         | 
| 419 | 
            +
                    )
         | 
| 420 | 
            +
             | 
| 421 | 
            +
                    scheduler_g = {
         | 
| 422 | 
            +
                        "scheduler": torch.optim.lr_scheduler.LambdaLR(optimizer_g, self.lr_lambda),
         | 
| 423 | 
            +
                        "interval": "step",
         | 
| 424 | 
            +
                        "frequency": 1,
         | 
| 425 | 
            +
                    }
         | 
| 426 | 
            +
                    scheduler_d = {
         | 
| 427 | 
            +
                        "scheduler": torch.optim.lr_scheduler.LambdaLR(optimizer_d, self.lr_lambda),
         | 
| 428 | 
            +
                        "interval": "step",
         | 
| 429 | 
            +
                        "frequency": 1,
         | 
| 430 | 
            +
                    }
         | 
| 431 | 
            +
                    return [optimizer_g, optimizer_d], [scheduler_g, scheduler_d]
         | 
| 432 | 
            +
             | 
| 433 | 
            +
                def preprocess(self, batch, train=False, cutoff=None):
         | 
| 434 | 
            +
                    if train:
         | 
| 435 | 
            +
                        vocal = batch[self.type_target]  # final target
         | 
| 436 | 
            +
                        noise = batch["noise_LR"]  # augmented low resolution audio with noise
         | 
| 437 | 
            +
                        augLR = batch[
         | 
| 438 | 
            +
                            self.type_target + "_aug_LR"
         | 
| 439 | 
            +
                        ]  # # augment low resolution audio
         | 
| 440 | 
            +
                        LR = batch[self.type_target + "_LR"]
         | 
| 441 | 
            +
                        # embed()
         | 
| 442 | 
            +
                        vocal, LR, augLR, noise = (
         | 
| 443 | 
            +
                            vocal.float().permute(0, 2, 1),
         | 
| 444 | 
            +
                            LR.float().permute(0, 2, 1),
         | 
| 445 | 
            +
                            augLR.float().permute(0, 2, 1),
         | 
| 446 | 
            +
                            noise.float().permute(0, 2, 1),
         | 
| 447 | 
            +
                        )
         | 
| 448 | 
            +
                        # LR, noise = self.add_random_noise(LR, noise)
         | 
| 449 | 
            +
                        snr, scale = [], []
         | 
| 450 | 
            +
                        for i in range(vocal.size()[0]):
         | 
| 451 | 
            +
                            (
         | 
| 452 | 
            +
                                vocal[i, ...],
         | 
| 453 | 
            +
                                LR[i, ...],
         | 
| 454 | 
            +
                                augLR[i, ...],
         | 
| 455 | 
            +
                                noise[i, ...],
         | 
| 456 | 
            +
                                _snr,
         | 
| 457 | 
            +
                                _scale,
         | 
| 458 | 
            +
                            ) = add_noise_and_scale_with_HQ_with_Aug(
         | 
| 459 | 
            +
                                vocal[i, ...],
         | 
| 460 | 
            +
                                LR[i, ...],
         | 
| 461 | 
            +
                                augLR[i, ...],
         | 
| 462 | 
            +
                                noise[i, ...],
         | 
| 463 | 
            +
                                snr_l=-5,
         | 
| 464 | 
            +
                                snr_h=45,
         | 
| 465 | 
            +
                                scale_lower=0.6,
         | 
| 466 | 
            +
                                scale_upper=1.0,
         | 
| 467 | 
            +
                            )
         | 
| 468 | 
            +
                            snr.append(_snr), scale.append(_scale)
         | 
| 469 | 
            +
                        # vocal, LR = self.amp_to_original_f(vocal, LR)
         | 
| 470 | 
            +
                        # noise = (noise * 0.0) + 1e-8 # todo
         | 
| 471 | 
            +
                        return vocal, augLR, LR, noise + augLR
         | 
| 472 | 
            +
                    else:
         | 
| 473 | 
            +
                        if cutoff is None:
         | 
| 474 | 
            +
                            LR_noisy = batch["noisy"]
         | 
| 475 | 
            +
                            LR = batch["vocals"]
         | 
| 476 | 
            +
                            vocals = batch["vocals"]
         | 
| 477 | 
            +
                            vocals, LR, LR_noisy = (
         | 
| 478 | 
            +
                                vocals.float().permute(0, 2, 1),
         | 
| 479 | 
            +
                                LR.float().permute(0, 2, 1),
         | 
| 480 | 
            +
                                LR_noisy.float().permute(0, 2, 1),
         | 
| 481 | 
            +
                            )
         | 
| 482 | 
            +
                            return vocals, LR, LR_noisy, batch["fname"][0]
         | 
| 483 | 
            +
                        else:
         | 
| 484 | 
            +
                            LR_noisy = batch["noisy" + "LR" + "_" + str(cutoff)]
         | 
| 485 | 
            +
                            LR = batch["vocals" + "LR" + "_" + str(cutoff)]
         | 
| 486 | 
            +
                            vocals = batch["vocals"]
         | 
| 487 | 
            +
                            vocals, LR, LR_noisy = (
         | 
| 488 | 
            +
                                vocals.float().permute(0, 2, 1),
         | 
| 489 | 
            +
                                LR.float().permute(0, 2, 1),
         | 
| 490 | 
            +
                                LR_noisy.float().permute(0, 2, 1),
         | 
| 491 | 
            +
                            )
         | 
| 492 | 
            +
                            return vocals, LR, LR_noisy, batch["fname"][0]
         | 
| 493 | 
            +
             | 
| 494 | 
            +
                def training_step(self, batch, batch_nb, optimizer_idx):
         | 
| 495 | 
            +
                    # dict_keys(['vocals', 'vocals_aug', 'vocals_augLR', 'noise'])
         | 
| 496 | 
            +
                    config = load_json("temp_path.json")
         | 
| 497 | 
            +
                    if "g_loss_weight" not in config.keys():
         | 
| 498 | 
            +
                        config["g_loss_weight"] = self.g_loss_weight
         | 
| 499 | 
            +
                        config["d_loss_weight"] = self.d_loss_weight
         | 
| 500 | 
            +
                        write_json(config, "temp_path.json")
         | 
| 501 | 
            +
                    elif (
         | 
| 502 | 
            +
                        config["g_loss_weight"] != self.g_loss_weight
         | 
| 503 | 
            +
                        or config["d_loss_weight"] != self.d_loss_weight
         | 
| 504 | 
            +
                    ):
         | 
| 505 | 
            +
                        print(
         | 
| 506 | 
            +
                            "Update d_loss weight, from",
         | 
| 507 | 
            +
                            self.d_loss_weight,
         | 
| 508 | 
            +
                            "to",
         | 
| 509 | 
            +
                            config["d_loss_weight"],
         | 
| 510 | 
            +
                        )
         | 
| 511 | 
            +
                        print(
         | 
| 512 | 
            +
                            "Update g_loss weight, from",
         | 
| 513 | 
            +
                            self.g_loss_weight,
         | 
| 514 | 
            +
                            "to",
         | 
| 515 | 
            +
                            config["g_loss_weight"],
         | 
| 516 | 
            +
                        )
         | 
| 517 | 
            +
                        self.g_loss_weight = config["g_loss_weight"]
         | 
| 518 | 
            +
                        self.d_loss_weight = config["d_loss_weight"]
         | 
| 519 | 
            +
             | 
| 520 | 
            +
                    if optimizer_idx == 0:
         | 
| 521 | 
            +
                        self.vocal, self.augLR, _, self.LR_noisy = self.preprocess(
         | 
| 522 | 
            +
                            batch, train=True
         | 
| 523 | 
            +
                        )
         | 
| 524 | 
            +
             | 
| 525 | 
            +
                        for i in range(self.vocal.size()[0]):
         | 
| 526 | 
            +
                            save_wave(
         | 
| 527 | 
            +
                                tensor2numpy(self.vocal[i, ...]),
         | 
| 528 | 
            +
                                str(i) + "vocal" + ".wav",
         | 
| 529 | 
            +
                                sample_rate=44100,
         | 
| 530 | 
            +
                            )
         | 
| 531 | 
            +
                            save_wave(
         | 
| 532 | 
            +
                                tensor2numpy(self.LR_noisy[i, ...]),
         | 
| 533 | 
            +
                                str(i) + "LR_noisy" + ".wav",
         | 
| 534 | 
            +
                                sample_rate=44100,
         | 
| 535 | 
            +
                            )
         | 
| 536 | 
            +
             | 
| 537 | 
            +
                        # all_mel_e2e in non-log scale
         | 
| 538 | 
            +
                        _, self.mel_target = self.pre(self.vocal)
         | 
| 539 | 
            +
                        self.sp_LR_target, self.mel_LR_target = self.pre(self.augLR)
         | 
| 540 | 
            +
                        self.sp_LR_target_noisy, self.mel_LR_target_noisy = self.pre(self.LR_noisy)
         | 
| 541 | 
            +
             | 
| 542 | 
            +
                        if self.valid is None or self.valid.size()[0] != self.mel_target.size()[0]:
         | 
| 543 | 
            +
                            self.valid = torch.ones(
         | 
| 544 | 
            +
                                self.mel_target.size()[0], 1, self.mel_target.size()[2], 1
         | 
| 545 | 
            +
                            )
         | 
| 546 | 
            +
                            self.valid = self.valid.type_as(self.mel_target)
         | 
| 547 | 
            +
                        if self.fake is None or self.fake.size()[0] != self.mel_target.size()[0]:
         | 
| 548 | 
            +
                            self.fake = torch.zeros(
         | 
| 549 | 
            +
                                self.mel_target.size()[0], 1, self.mel_target.size()[2], 1
         | 
| 550 | 
            +
                            )
         | 
| 551 | 
            +
                            self.fake = self.fake.type_as(self.mel_target)
         | 
| 552 | 
            +
             | 
| 553 | 
            +
                        self.generated = self(self.sp_LR_target_noisy, self.mel_LR_target_noisy)
         | 
| 554 | 
            +
             | 
| 555 | 
            +
                        denoise_loss = self.l1loss(self.generated["clean"], self.mel_LR_target)
         | 
| 556 | 
            +
                        targ_loss = self.l1loss(self.generated["mel"], to_log(self.mel_target))
         | 
| 557 | 
            +
             | 
| 558 | 
            +
                        self.log(
         | 
| 559 | 
            +
                            "targ-l",
         | 
| 560 | 
            +
                            targ_loss,
         | 
| 561 | 
            +
                            on_step=True,
         | 
| 562 | 
            +
                            on_epoch=False,
         | 
| 563 | 
            +
                            logger=True,
         | 
| 564 | 
            +
                            sync_dist=True,
         | 
| 565 | 
            +
                            prog_bar=True,
         | 
| 566 | 
            +
                        )
         | 
| 567 | 
            +
                        self.log(
         | 
| 568 | 
            +
                            "noise-l",
         | 
| 569 | 
            +
                            denoise_loss,
         | 
| 570 | 
            +
                            on_step=True,
         | 
| 571 | 
            +
                            on_epoch=False,
         | 
| 572 | 
            +
                            logger=True,
         | 
| 573 | 
            +
                            sync_dist=True,
         | 
| 574 | 
            +
                            prog_bar=True,
         | 
| 575 | 
            +
                        )
         | 
| 576 | 
            +
             | 
| 577 | 
            +
                        loss = targ_loss + denoise_loss
         | 
| 578 | 
            +
             | 
| 579 | 
            +
                        if self.train_step >= 18000:
         | 
| 580 | 
            +
                            g_loss = self.bce_loss(
         | 
| 581 | 
            +
                                self.discriminator(self.generated["mel"]), self.valid
         | 
| 582 | 
            +
                            )
         | 
| 583 | 
            +
                            self.log(
         | 
| 584 | 
            +
                                "g_l",
         | 
| 585 | 
            +
                                g_loss,
         | 
| 586 | 
            +
                                on_step=True,
         | 
| 587 | 
            +
                                on_epoch=False,
         | 
| 588 | 
            +
                                logger=True,
         | 
| 589 | 
            +
                                sync_dist=True,
         | 
| 590 | 
            +
                                prog_bar=True,
         | 
| 591 | 
            +
                            )
         | 
| 592 | 
            +
                            # print("g_loss", g_loss)
         | 
| 593 | 
            +
                            all_loss = loss + self.g_loss_weight * g_loss
         | 
| 594 | 
            +
                            self.log(
         | 
| 595 | 
            +
                                "all_loss",
         | 
| 596 | 
            +
                                all_loss,
         | 
| 597 | 
            +
                                on_step=True,
         | 
| 598 | 
            +
                                on_epoch=True,
         | 
| 599 | 
            +
                                logger=True,
         | 
| 600 | 
            +
                                sync_dist=True,
         | 
| 601 | 
            +
                            )
         | 
| 602 | 
            +
                        else:
         | 
| 603 | 
            +
                            all_loss = loss
         | 
| 604 | 
            +
                        self.train_step += 0.5
         | 
| 605 | 
            +
                        return {"loss": all_loss}
         | 
| 606 | 
            +
             | 
| 607 | 
            +
                    elif optimizer_idx == 1:
         | 
| 608 | 
            +
                        if self.train_step >= 16000:
         | 
| 609 | 
            +
                            self.generated = self(self.sp_LR_target_noisy, self.mel_LR_target_noisy)
         | 
| 610 | 
            +
                            self.train_step += 0.5
         | 
| 611 | 
            +
                            real_loss = self.bce_loss(
         | 
| 612 | 
            +
                                self.discriminator(to_log(self.mel_target)), self.valid
         | 
| 613 | 
            +
                            )
         | 
| 614 | 
            +
                            self.log(
         | 
| 615 | 
            +
                                "r_l",
         | 
| 616 | 
            +
                                real_loss,
         | 
| 617 | 
            +
                                on_step=True,
         | 
| 618 | 
            +
                                on_epoch=False,
         | 
| 619 | 
            +
                                logger=True,
         | 
| 620 | 
            +
                                sync_dist=True,
         | 
| 621 | 
            +
                                prog_bar=True,
         | 
| 622 | 
            +
                            )
         | 
| 623 | 
            +
                            fake_loss = self.bce_loss(
         | 
| 624 | 
            +
                                self.discriminator(self.generated["mel"].detach()), self.fake
         | 
| 625 | 
            +
                            )
         | 
| 626 | 
            +
                            self.log(
         | 
| 627 | 
            +
                                "d_l",
         | 
| 628 | 
            +
                                fake_loss,
         | 
| 629 | 
            +
                                on_step=True,
         | 
| 630 | 
            +
                                on_epoch=False,
         | 
| 631 | 
            +
                                logger=True,
         | 
| 632 | 
            +
                                sync_dist=True,
         | 
| 633 | 
            +
                                prog_bar=True,
         | 
| 634 | 
            +
                            )
         | 
| 635 | 
            +
                            d_loss = self.d_loss_weight * (real_loss + fake_loss) / 2
         | 
| 636 | 
            +
                            self.log(
         | 
| 637 | 
            +
                                "discriminator_loss",
         | 
| 638 | 
            +
                                d_loss,
         | 
| 639 | 
            +
                                on_step=True,
         | 
| 640 | 
            +
                                on_epoch=True,
         | 
| 641 | 
            +
                                logger=True,
         | 
| 642 | 
            +
                                sync_dist=True,
         | 
| 643 | 
            +
                            )
         | 
| 644 | 
            +
                            return {"loss": d_loss}
         | 
| 645 | 
            +
             | 
| 646 | 
            +
                def draw_and_save(
         | 
| 647 | 
            +
                    self, mel: torch.Tensor, path, clip_max=None, clip_min=None, needlog=True
         | 
| 648 | 
            +
                ):
         | 
| 649 | 
            +
                    plt.figure(figsize=(15, 5))
         | 
| 650 | 
            +
                    if clip_min is None:
         | 
| 651 | 
            +
                        clip_max, clip_min = self.clip(mel)
         | 
| 652 | 
            +
                    mel = np.transpose(tensor2numpy(mel)[0, 0, ...], (1, 0))
         | 
| 653 | 
            +
                    # assert np.sum(mel < 0) == 0, str(np.sum(mel < 0)) + str(np.sum(mel < 0))
         | 
| 654 | 
            +
             | 
| 655 | 
            +
                    if needlog:
         | 
| 656 | 
            +
                        assert np.sum(mel < 0) == 0, str(np.sum(mel < 0)) + "-" + path
         | 
| 657 | 
            +
                        mel_log = np.log10(mel + EPS)
         | 
| 658 | 
            +
                    else:
         | 
| 659 | 
            +
                        mel_log = mel
         | 
| 660 | 
            +
             | 
| 661 | 
            +
                    # plt.imshow(mel)
         | 
| 662 | 
            +
                    librosa.display.specshow(
         | 
| 663 | 
            +
                        mel_log,
         | 
| 664 | 
            +
                        sr=44100,
         | 
| 665 | 
            +
                        x_axis="frames",
         | 
| 666 | 
            +
                        y_axis="mel",
         | 
| 667 | 
            +
                        cmap=cm.jet,
         | 
| 668 | 
            +
                        vmax=clip_max,
         | 
| 669 | 
            +
                        vmin=clip_min,
         | 
| 670 | 
            +
                    )
         | 
| 671 | 
            +
                    plt.colorbar()
         | 
| 672 | 
            +
                    plt.savefig(path)
         | 
| 673 | 
            +
                    plt.close()
         | 
| 674 | 
            +
             | 
| 675 | 
            +
                def clip(self, *args):
         | 
| 676 | 
            +
                    val_max, val_min = [], []
         | 
| 677 | 
            +
                    for each in args:
         | 
| 678 | 
            +
                        val_max.append(torch.max(each))
         | 
| 679 | 
            +
                        val_min.append(torch.min(each))
         | 
| 680 | 
            +
                    return max(val_max), min(val_min)
         | 
    	
        voicefixer/restorer/model_kqq_bn.py
    ADDED
    
    | @@ -0,0 +1,186 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from voicefixer.restorer.modules import *
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from voicefixer.tools.pytorch_util import *
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            class UNetResComplex_100Mb(nn.Module):
         | 
| 7 | 
            +
                def __init__(self, channels, nsrc=1):
         | 
| 8 | 
            +
                    super(UNetResComplex_100Mb, self).__init__()
         | 
| 9 | 
            +
                    activation = "relu"
         | 
| 10 | 
            +
                    momentum = 0.01
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                    self.nsrc = nsrc
         | 
| 13 | 
            +
                    self.channels = channels
         | 
| 14 | 
            +
                    self.downsample_ratio = 2**6  # This number equals 2^{#encoder_blcoks}
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                    self.encoder_block1 = EncoderBlockRes(
         | 
| 17 | 
            +
                        in_channels=channels * nsrc,
         | 
| 18 | 
            +
                        out_channels=32,
         | 
| 19 | 
            +
                        downsample=(2, 2),
         | 
| 20 | 
            +
                        activation=activation,
         | 
| 21 | 
            +
                        momentum=momentum,
         | 
| 22 | 
            +
                    )
         | 
| 23 | 
            +
                    self.encoder_block2 = EncoderBlockRes(
         | 
| 24 | 
            +
                        in_channels=32,
         | 
| 25 | 
            +
                        out_channels=64,
         | 
| 26 | 
            +
                        downsample=(2, 2),
         | 
| 27 | 
            +
                        activation=activation,
         | 
| 28 | 
            +
                        momentum=momentum,
         | 
| 29 | 
            +
                    )
         | 
| 30 | 
            +
                    self.encoder_block3 = EncoderBlockRes(
         | 
| 31 | 
            +
                        in_channels=64,
         | 
| 32 | 
            +
                        out_channels=128,
         | 
| 33 | 
            +
                        downsample=(2, 2),
         | 
| 34 | 
            +
                        activation=activation,
         | 
| 35 | 
            +
                        momentum=momentum,
         | 
| 36 | 
            +
                    )
         | 
| 37 | 
            +
                    self.encoder_block4 = EncoderBlockRes(
         | 
| 38 | 
            +
                        in_channels=128,
         | 
| 39 | 
            +
                        out_channels=256,
         | 
| 40 | 
            +
                        downsample=(2, 2),
         | 
| 41 | 
            +
                        activation=activation,
         | 
| 42 | 
            +
                        momentum=momentum,
         | 
| 43 | 
            +
                    )
         | 
| 44 | 
            +
                    self.encoder_block5 = EncoderBlockRes(
         | 
| 45 | 
            +
                        in_channels=256,
         | 
| 46 | 
            +
                        out_channels=384,
         | 
| 47 | 
            +
                        downsample=(2, 2),
         | 
| 48 | 
            +
                        activation=activation,
         | 
| 49 | 
            +
                        momentum=momentum,
         | 
| 50 | 
            +
                    )
         | 
| 51 | 
            +
                    self.encoder_block6 = EncoderBlockRes(
         | 
| 52 | 
            +
                        in_channels=384,
         | 
| 53 | 
            +
                        out_channels=384,
         | 
| 54 | 
            +
                        downsample=(2, 2),
         | 
| 55 | 
            +
                        activation=activation,
         | 
| 56 | 
            +
                        momentum=momentum,
         | 
| 57 | 
            +
                    )
         | 
| 58 | 
            +
                    self.conv_block7 = ConvBlockRes(
         | 
| 59 | 
            +
                        in_channels=384,
         | 
| 60 | 
            +
                        out_channels=384,
         | 
| 61 | 
            +
                        size=3,
         | 
| 62 | 
            +
                        activation=activation,
         | 
| 63 | 
            +
                        momentum=momentum,
         | 
| 64 | 
            +
                    )
         | 
| 65 | 
            +
                    self.decoder_block1 = DecoderBlockRes(
         | 
| 66 | 
            +
                        in_channels=384,
         | 
| 67 | 
            +
                        out_channels=384,
         | 
| 68 | 
            +
                        stride=(2, 2),
         | 
| 69 | 
            +
                        activation=activation,
         | 
| 70 | 
            +
                        momentum=momentum,
         | 
| 71 | 
            +
                    )
         | 
| 72 | 
            +
                    self.decoder_block2 = DecoderBlockRes(
         | 
| 73 | 
            +
                        in_channels=384,
         | 
| 74 | 
            +
                        out_channels=384,
         | 
| 75 | 
            +
                        stride=(2, 2),
         | 
| 76 | 
            +
                        activation=activation,
         | 
| 77 | 
            +
                        momentum=momentum,
         | 
| 78 | 
            +
                    )
         | 
| 79 | 
            +
                    self.decoder_block3 = DecoderBlockRes(
         | 
| 80 | 
            +
                        in_channels=384,
         | 
| 81 | 
            +
                        out_channels=256,
         | 
| 82 | 
            +
                        stride=(2, 2),
         | 
| 83 | 
            +
                        activation=activation,
         | 
| 84 | 
            +
                        momentum=momentum,
         | 
| 85 | 
            +
                    )
         | 
| 86 | 
            +
                    self.decoder_block4 = DecoderBlockRes(
         | 
| 87 | 
            +
                        in_channels=256,
         | 
| 88 | 
            +
                        out_channels=128,
         | 
| 89 | 
            +
                        stride=(2, 2),
         | 
| 90 | 
            +
                        activation=activation,
         | 
| 91 | 
            +
                        momentum=momentum,
         | 
| 92 | 
            +
                    )
         | 
| 93 | 
            +
                    self.decoder_block5 = DecoderBlockRes(
         | 
| 94 | 
            +
                        in_channels=128,
         | 
| 95 | 
            +
                        out_channels=64,
         | 
| 96 | 
            +
                        stride=(2, 2),
         | 
| 97 | 
            +
                        activation=activation,
         | 
| 98 | 
            +
                        momentum=momentum,
         | 
| 99 | 
            +
                    )
         | 
| 100 | 
            +
                    self.decoder_block6 = DecoderBlockRes(
         | 
| 101 | 
            +
                        in_channels=64,
         | 
| 102 | 
            +
                        out_channels=32,
         | 
| 103 | 
            +
                        stride=(2, 2),
         | 
| 104 | 
            +
                        activation=activation,
         | 
| 105 | 
            +
                        momentum=momentum,
         | 
| 106 | 
            +
                    )
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    self.after_conv_block1 = ConvBlockRes(
         | 
| 109 | 
            +
                        in_channels=32,
         | 
| 110 | 
            +
                        out_channels=32,
         | 
| 111 | 
            +
                        size=3,
         | 
| 112 | 
            +
                        activation=activation,
         | 
| 113 | 
            +
                        momentum=momentum,
         | 
| 114 | 
            +
                    )
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    self.after_conv2 = nn.Conv2d(
         | 
| 117 | 
            +
                        in_channels=32,
         | 
| 118 | 
            +
                        out_channels=1,
         | 
| 119 | 
            +
                        kernel_size=(1, 1),
         | 
| 120 | 
            +
                        stride=(1, 1),
         | 
| 121 | 
            +
                        padding=(0, 0),
         | 
| 122 | 
            +
                        bias=True,
         | 
| 123 | 
            +
                    )
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    self.init_weights()
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                def init_weights(self):
         | 
| 128 | 
            +
                    init_layer(self.after_conv2)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                def forward(self, sp):
         | 
| 131 | 
            +
                    """
         | 
| 132 | 
            +
                    Args:
         | 
| 133 | 
            +
                      input: (batch_size, channels_num, segment_samples)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    Outputs:
         | 
| 136 | 
            +
                      output_dict: {
         | 
| 137 | 
            +
                        'wav': (batch_size, channels_num, segment_samples),
         | 
| 138 | 
            +
                        'sp': (batch_size, channels_num, time_steps, freq_bins)}
         | 
| 139 | 
            +
                    """
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    # Batch normalization
         | 
| 142 | 
            +
                    x = sp
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    # Pad spectrogram to be evenly divided by downsample ratio.
         | 
| 145 | 
            +
                    origin_len = x.shape[2]  # time_steps
         | 
| 146 | 
            +
                    pad_len = (
         | 
| 147 | 
            +
                        int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio
         | 
| 148 | 
            +
                        - origin_len
         | 
| 149 | 
            +
                    )
         | 
| 150 | 
            +
                    x = F.pad(x, pad=(0, 0, 0, pad_len))
         | 
| 151 | 
            +
                    x = x[..., 0 : x.shape[-1] - 1]  # (bs, channels, T, F)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    # UNet
         | 
| 154 | 
            +
                    (x1_pool, x1) = self.encoder_block1(x)  # x1_pool: (bs, 32, T / 2, F / 2)
         | 
| 155 | 
            +
                    (x2_pool, x2) = self.encoder_block2(x1_pool)  # x2_pool: (bs, 64, T / 4, F / 4)
         | 
| 156 | 
            +
                    (x3_pool, x3) = self.encoder_block3(x2_pool)  # x3_pool: (bs, 128, T / 8, F / 8)
         | 
| 157 | 
            +
                    (x4_pool, x4) = self.encoder_block4(
         | 
| 158 | 
            +
                        x3_pool
         | 
| 159 | 
            +
                    )  # x4_pool: (bs, 256, T / 16, F / 16)
         | 
| 160 | 
            +
                    (x5_pool, x5) = self.encoder_block5(
         | 
| 161 | 
            +
                        x4_pool
         | 
| 162 | 
            +
                    )  # x5_pool: (bs, 512, T / 32, F / 32)
         | 
| 163 | 
            +
                    (x6_pool, x6) = self.encoder_block6(
         | 
| 164 | 
            +
                        x5_pool
         | 
| 165 | 
            +
                    )  # x6_pool: (bs, 1024, T / 64, F / 64)
         | 
| 166 | 
            +
                    x_center = self.conv_block7(x6_pool)  # (bs, 2048, T / 64, F / 64)
         | 
| 167 | 
            +
                    x7 = self.decoder_block1(x_center, x6)  # (bs, 1024, T / 32, F / 32)
         | 
| 168 | 
            +
                    x8 = self.decoder_block2(x7, x5)  # (bs, 512, T / 16, F / 16)
         | 
| 169 | 
            +
                    x9 = self.decoder_block3(x8, x4)  # (bs, 256, T / 8, F / 8)
         | 
| 170 | 
            +
                    x10 = self.decoder_block4(x9, x3)  # (bs, 128, T / 4, F / 4)
         | 
| 171 | 
            +
                    x11 = self.decoder_block5(x10, x2)  # (bs, 64, T / 2, F / 2)
         | 
| 172 | 
            +
                    x12 = self.decoder_block6(x11, x1)  # (bs, 32, T, F)
         | 
| 173 | 
            +
                    x = self.after_conv_block1(x12)  # (bs, 32, T, F)
         | 
| 174 | 
            +
                    x = self.after_conv2(x)  # (bs, channels, T, F)
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    # Recover shape
         | 
| 177 | 
            +
                    x = F.pad(x, pad=(0, 1))
         | 
| 178 | 
            +
                    x = x[:, :, 0:origin_len, :]
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                    output_dict = {"mel": x}
         | 
| 181 | 
            +
                    return output_dict
         | 
| 182 | 
            +
             | 
| 183 | 
            +
             | 
| 184 | 
            +
            if __name__ == "__main__":
         | 
| 185 | 
            +
                model = UNetResComplex_100Mb(channels=1)
         | 
| 186 | 
            +
                print(model(torch.randn((1, 1, 101, 128)))["mel"].size())
         | 
    	
        voicefixer/restorer/modules.py
    ADDED
    
    | @@ -0,0 +1,217 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch.nn as nn
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
            import math
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            class ConvBlockRes(nn.Module):
         | 
| 8 | 
            +
                def __init__(self, in_channels, out_channels, size, activation, momentum):
         | 
| 9 | 
            +
                    super(ConvBlockRes, self).__init__()
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                    self.activation = activation
         | 
| 12 | 
            +
                    if type(size) == type((3, 4)):
         | 
| 13 | 
            +
                        pad = size[0] // 2
         | 
| 14 | 
            +
                        size = size[0]
         | 
| 15 | 
            +
                    else:
         | 
| 16 | 
            +
                        pad = size // 2
         | 
| 17 | 
            +
                        size = size
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                    self.conv1 = nn.Conv2d(
         | 
| 20 | 
            +
                        in_channels=in_channels,
         | 
| 21 | 
            +
                        out_channels=out_channels,
         | 
| 22 | 
            +
                        kernel_size=(size, size),
         | 
| 23 | 
            +
                        stride=(1, 1),
         | 
| 24 | 
            +
                        dilation=(1, 1),
         | 
| 25 | 
            +
                        padding=(pad, pad),
         | 
| 26 | 
            +
                        bias=False,
         | 
| 27 | 
            +
                    )
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                    self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
         | 
| 30 | 
            +
                    # self.abn1 = InPlaceABN(num_features=in_channels, momentum=momentum, activation='leaky_relu')
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                    self.conv2 = nn.Conv2d(
         | 
| 33 | 
            +
                        in_channels=out_channels,
         | 
| 34 | 
            +
                        out_channels=out_channels,
         | 
| 35 | 
            +
                        kernel_size=(size, size),
         | 
| 36 | 
            +
                        stride=(1, 1),
         | 
| 37 | 
            +
                        dilation=(1, 1),
         | 
| 38 | 
            +
                        padding=(pad, pad),
         | 
| 39 | 
            +
                        bias=False,
         | 
| 40 | 
            +
                    )
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    # self.abn2 = InPlaceABN(num_features=out_channels, momentum=momentum, activation='leaky_relu')
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    if in_channels != out_channels:
         | 
| 47 | 
            +
                        self.shortcut = nn.Conv2d(
         | 
| 48 | 
            +
                            in_channels=in_channels,
         | 
| 49 | 
            +
                            out_channels=out_channels,
         | 
| 50 | 
            +
                            kernel_size=(1, 1),
         | 
| 51 | 
            +
                            stride=(1, 1),
         | 
| 52 | 
            +
                            padding=(0, 0),
         | 
| 53 | 
            +
                        )
         | 
| 54 | 
            +
                        self.is_shortcut = True
         | 
| 55 | 
            +
                    else:
         | 
| 56 | 
            +
                        self.is_shortcut = False
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    self.init_weights()
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                def init_weights(self):
         | 
| 61 | 
            +
                    init_bn(self.bn1)
         | 
| 62 | 
            +
                    init_layer(self.conv1)
         | 
| 63 | 
            +
                    init_layer(self.conv2)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    if self.is_shortcut:
         | 
| 66 | 
            +
                        init_layer(self.shortcut)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                def forward(self, x):
         | 
| 69 | 
            +
                    origin = x
         | 
| 70 | 
            +
                    x = self.conv1(F.leaky_relu_(self.bn1(x), negative_slope=0.01))
         | 
| 71 | 
            +
                    x = self.conv2(F.leaky_relu_(self.bn2(x), negative_slope=0.01))
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    if self.is_shortcut:
         | 
| 74 | 
            +
                        return self.shortcut(origin) + x
         | 
| 75 | 
            +
                    else:
         | 
| 76 | 
            +
                        return origin + x
         | 
| 77 | 
            +
             | 
| 78 | 
            +
             | 
| 79 | 
            +
            class EncoderBlockRes(nn.Module):
         | 
| 80 | 
            +
                def __init__(self, in_channels, out_channels, downsample, activation, momentum):
         | 
| 81 | 
            +
                    super(EncoderBlockRes, self).__init__()
         | 
| 82 | 
            +
                    size = 3
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    self.conv_block1 = ConvBlockRes(
         | 
| 85 | 
            +
                        in_channels, out_channels, size, activation, momentum
         | 
| 86 | 
            +
                    )
         | 
| 87 | 
            +
                    self.conv_block2 = ConvBlockRes(
         | 
| 88 | 
            +
                        out_channels, out_channels, size, activation, momentum
         | 
| 89 | 
            +
                    )
         | 
| 90 | 
            +
                    self.conv_block3 = ConvBlockRes(
         | 
| 91 | 
            +
                        out_channels, out_channels, size, activation, momentum
         | 
| 92 | 
            +
                    )
         | 
| 93 | 
            +
                    self.conv_block4 = ConvBlockRes(
         | 
| 94 | 
            +
                        out_channels, out_channels, size, activation, momentum
         | 
| 95 | 
            +
                    )
         | 
| 96 | 
            +
                    self.downsample = downsample
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                def forward(self, x):
         | 
| 99 | 
            +
                    encoder = self.conv_block1(x)
         | 
| 100 | 
            +
                    encoder = self.conv_block2(encoder)
         | 
| 101 | 
            +
                    encoder = self.conv_block3(encoder)
         | 
| 102 | 
            +
                    encoder = self.conv_block4(encoder)
         | 
| 103 | 
            +
                    encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
         | 
| 104 | 
            +
                    return encoder_pool, encoder
         | 
| 105 | 
            +
             | 
| 106 | 
            +
             | 
| 107 | 
            +
            class DecoderBlockRes(nn.Module):
         | 
| 108 | 
            +
                def __init__(self, in_channels, out_channels, stride, activation, momentum):
         | 
| 109 | 
            +
                    super(DecoderBlockRes, self).__init__()
         | 
| 110 | 
            +
                    size = 3
         | 
| 111 | 
            +
                    self.activation = activation
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    self.conv1 = torch.nn.ConvTranspose2d(
         | 
| 114 | 
            +
                        in_channels=in_channels,
         | 
| 115 | 
            +
                        out_channels=out_channels,
         | 
| 116 | 
            +
                        kernel_size=(size, size),
         | 
| 117 | 
            +
                        stride=stride,
         | 
| 118 | 
            +
                        padding=(0, 0),
         | 
| 119 | 
            +
                        output_padding=(0, 0),
         | 
| 120 | 
            +
                        bias=False,
         | 
| 121 | 
            +
                        dilation=(1, 1),
         | 
| 122 | 
            +
                    )
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                    self.bn1 = nn.BatchNorm2d(in_channels)
         | 
| 125 | 
            +
                    self.conv_block2 = ConvBlockRes(
         | 
| 126 | 
            +
                        out_channels * 2, out_channels, size, activation, momentum
         | 
| 127 | 
            +
                    )
         | 
| 128 | 
            +
                    self.conv_block3 = ConvBlockRes(
         | 
| 129 | 
            +
                        out_channels, out_channels, size, activation, momentum
         | 
| 130 | 
            +
                    )
         | 
| 131 | 
            +
                    self.conv_block4 = ConvBlockRes(
         | 
| 132 | 
            +
                        out_channels, out_channels, size, activation, momentum
         | 
| 133 | 
            +
                    )
         | 
| 134 | 
            +
                    self.conv_block5 = ConvBlockRes(
         | 
| 135 | 
            +
                        out_channels, out_channels, size, activation, momentum
         | 
| 136 | 
            +
                    )
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                def init_weights(self):
         | 
| 139 | 
            +
                    init_layer(self.conv1)
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                def prune(self, x, both=False):
         | 
| 142 | 
            +
                    """Prune the shape of x after transpose convolution."""
         | 
| 143 | 
            +
                    if both:
         | 
| 144 | 
            +
                        x = x[:, :, 0:-1, 0:-1]
         | 
| 145 | 
            +
                    else:
         | 
| 146 | 
            +
                        x = x[:, :, 0:-1, :]
         | 
| 147 | 
            +
                    return x
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                def forward(self, input_tensor, concat_tensor, both=False):
         | 
| 150 | 
            +
                    x = self.conv1(F.relu_(self.bn1(input_tensor)))
         | 
| 151 | 
            +
                    x = self.prune(x, both=both)
         | 
| 152 | 
            +
                    x = torch.cat((x, concat_tensor), dim=1)
         | 
| 153 | 
            +
                    x = self.conv_block2(x)
         | 
| 154 | 
            +
                    x = self.conv_block3(x)
         | 
| 155 | 
            +
                    x = self.conv_block4(x)
         | 
| 156 | 
            +
                    x = self.conv_block5(x)
         | 
| 157 | 
            +
                    return x
         | 
| 158 | 
            +
             | 
| 159 | 
            +
             | 
| 160 | 
            +
            def init_layer(layer):
         | 
| 161 | 
            +
                """Initialize a Linear or Convolutional layer."""
         | 
| 162 | 
            +
                nn.init.xavier_uniform_(layer.weight)
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                if hasattr(layer, "bias"):
         | 
| 165 | 
            +
                    if layer.bias is not None:
         | 
| 166 | 
            +
                        layer.bias.data.fill_(0.0)
         | 
| 167 | 
            +
             | 
| 168 | 
            +
             | 
| 169 | 
            +
            def init_bn(bn):
         | 
| 170 | 
            +
                """Initialize a Batchnorm layer."""
         | 
| 171 | 
            +
                bn.bias.data.fill_(0.0)
         | 
| 172 | 
            +
                bn.weight.data.fill_(1.0)
         | 
| 173 | 
            +
             | 
| 174 | 
            +
             | 
| 175 | 
            +
            def init_gru(rnn):
         | 
| 176 | 
            +
                """Initialize a GRU layer."""
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                def _concat_init(tensor, init_funcs):
         | 
| 179 | 
            +
                    (length, fan_out) = tensor.shape
         | 
| 180 | 
            +
                    fan_in = length // len(init_funcs)
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    for (i, init_func) in enumerate(init_funcs):
         | 
| 183 | 
            +
                        init_func(tensor[i * fan_in : (i + 1) * fan_in, :])
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                def _inner_uniform(tensor):
         | 
| 186 | 
            +
                    fan_in = nn.init._calculate_correct_fan(tensor, "fan_in")
         | 
| 187 | 
            +
                    nn.init.uniform_(tensor, -math.sqrt(3 / fan_in), math.sqrt(3 / fan_in))
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                for i in range(rnn.num_layers):
         | 
| 190 | 
            +
                    _concat_init(
         | 
| 191 | 
            +
                        getattr(rnn, "weight_ih_l{}".format(i)),
         | 
| 192 | 
            +
                        [_inner_uniform, _inner_uniform, _inner_uniform],
         | 
| 193 | 
            +
                    )
         | 
| 194 | 
            +
                    torch.nn.init.constant_(getattr(rnn, "bias_ih_l{}".format(i)), 0)
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                    _concat_init(
         | 
| 197 | 
            +
                        getattr(rnn, "weight_hh_l{}".format(i)),
         | 
| 198 | 
            +
                        [_inner_uniform, _inner_uniform, nn.init.orthogonal_],
         | 
| 199 | 
            +
                    )
         | 
| 200 | 
            +
                    torch.nn.init.constant_(getattr(rnn, "bias_hh_l{}".format(i)), 0)
         | 
| 201 | 
            +
             | 
| 202 | 
            +
             | 
| 203 | 
            +
            from torch.cuda import init
         | 
| 204 | 
            +
             | 
| 205 | 
            +
             | 
| 206 | 
            +
            def act(x, activation):
         | 
| 207 | 
            +
                if activation == "relu":
         | 
| 208 | 
            +
                    return F.relu_(x)
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                elif activation == "leaky_relu":
         | 
| 211 | 
            +
                    return F.leaky_relu_(x, negative_slope=0.2)
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                elif activation == "swish":
         | 
| 214 | 
            +
                    return x * torch.sigmoid(x)
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                else:
         | 
| 217 | 
            +
                    raise Exception("Incorrect activation!")
         | 
    	
        voicefixer/tools/__init__.py
    ADDED
    
    | @@ -0,0 +1,11 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/env python
         | 
| 2 | 
            +
            # -*- encoding: utf-8 -*-
         | 
| 3 | 
            +
            """
         | 
| 4 | 
            +
            @File    :   __init__.py.py    
         | 
| 5 | 
            +
            @Contact :   [email protected]
         | 
| 6 | 
            +
            @License :   (C)Copyright 2020-2100
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            @Modify Time      @Author    @Version    @Desciption
         | 
| 9 | 
            +
            ------------      -------    --------    -----------
         | 
| 10 | 
            +
            9/14/21 12:28 AM   Haohe Liu      1.0         None
         | 
| 11 | 
            +
            """
         | 
    	
        voicefixer/tools/base.py
    ADDED
    
    | @@ -0,0 +1,244 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import os
         | 
| 6 | 
            +
            import torch.fft
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            def get_window(window_size, window_type, square_root_window=True):
         | 
| 12 | 
            +
                """Return the window"""
         | 
| 13 | 
            +
                window = {
         | 
| 14 | 
            +
                    "hamming": torch.hamming_window(window_size),
         | 
| 15 | 
            +
                    "hanning": torch.hann_window(window_size),
         | 
| 16 | 
            +
                }[window_type]
         | 
| 17 | 
            +
                if square_root_window:
         | 
| 18 | 
            +
                    window = torch.sqrt(window)
         | 
| 19 | 
            +
                return window
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            def fft_point(dim):
         | 
| 23 | 
            +
                assert dim > 0
         | 
| 24 | 
            +
                num = math.log(dim, 2)
         | 
| 25 | 
            +
                num_point = 2 ** (math.ceil(num))
         | 
| 26 | 
            +
                return num_point
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            def pre_emphasis(signal, coefficient=0.97):
         | 
| 30 | 
            +
                """Pre-emphasis original signal
         | 
| 31 | 
            +
                y(n) = x(n) - a*x(n-1)
         | 
| 32 | 
            +
                """
         | 
| 33 | 
            +
                return np.append(signal[0], signal[1:] - coefficient * signal[:-1])
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            def de_emphasis(signal, coefficient=0.97):
         | 
| 37 | 
            +
                """De-emphasis original signal
         | 
| 38 | 
            +
                y(n) = x(n) + a*x(n-1)
         | 
| 39 | 
            +
                """
         | 
| 40 | 
            +
                length = signal.shape[0]
         | 
| 41 | 
            +
                for i in range(1, length):
         | 
| 42 | 
            +
                    signal[i] = signal[i] + coefficient * signal[i - 1]
         | 
| 43 | 
            +
                return signal
         | 
| 44 | 
            +
             | 
| 45 | 
            +
             | 
| 46 | 
            +
            def seperate_magnitude(magnitude, phase):
         | 
| 47 | 
            +
                real = torch.cos(phase) * magnitude
         | 
| 48 | 
            +
                imagine = torch.sin(phase) * magnitude
         | 
| 49 | 
            +
                expand_dim = len(list(real.size()))
         | 
| 50 | 
            +
                return torch.stack((real, imagine), expand_dim)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            def stft_single(
         | 
| 54 | 
            +
                signal,
         | 
| 55 | 
            +
                sample_rate=44100,
         | 
| 56 | 
            +
                frame_length=46,
         | 
| 57 | 
            +
                frame_shift=10,
         | 
| 58 | 
            +
                window_type="hanning",
         | 
| 59 | 
            +
                device=torch.device("cuda"),
         | 
| 60 | 
            +
                square_root_window=True,
         | 
| 61 | 
            +
            ):
         | 
| 62 | 
            +
                """Compute the Short Time Fourier Transform.
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                Args:
         | 
| 65 | 
            +
                    signal: input speech signal,
         | 
| 66 | 
            +
                    sample_rate: waveform datas sample frequency (Hz)
         | 
| 67 | 
            +
                    frame_length: frame length in milliseconds
         | 
| 68 | 
            +
                    frame_shift: frame shift in milliseconds
         | 
| 69 | 
            +
                    window_type: type of window
         | 
| 70 | 
            +
                    square_root_window: square root window
         | 
| 71 | 
            +
                Return:
         | 
| 72 | 
            +
                    fft: (n/2)+1 dim complex STFT restults
         | 
| 73 | 
            +
                """
         | 
| 74 | 
            +
                hop_length = int(
         | 
| 75 | 
            +
                    sample_rate * frame_shift / 1000
         | 
| 76 | 
            +
                )  # The greater sample_rate, the greater hop_length
         | 
| 77 | 
            +
                win_length = int(sample_rate * frame_length / 1000)
         | 
| 78 | 
            +
                # num_point = fft_point(win_length)
         | 
| 79 | 
            +
                num_point = win_length
         | 
| 80 | 
            +
                window = get_window(num_point, window_type, square_root_window)
         | 
| 81 | 
            +
                if "cuda" in str(device):
         | 
| 82 | 
            +
                    window = window.cuda(device)
         | 
| 83 | 
            +
                feat = torch.stft(
         | 
| 84 | 
            +
                    signal,
         | 
| 85 | 
            +
                    n_fft=num_point,
         | 
| 86 | 
            +
                    hop_length=hop_length,
         | 
| 87 | 
            +
                    win_length=window.shape[0],
         | 
| 88 | 
            +
                    window=window,
         | 
| 89 | 
            +
                )
         | 
| 90 | 
            +
                real, imag = feat[..., 0], feat[..., 1]
         | 
| 91 | 
            +
                return real.permute(0, 2, 1).unsqueeze(1), imag.permute(0, 2, 1).unsqueeze(1)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
             | 
| 94 | 
            +
            def istft(
         | 
| 95 | 
            +
                real,
         | 
| 96 | 
            +
                imag,
         | 
| 97 | 
            +
                length,
         | 
| 98 | 
            +
                sample_rate=44100,
         | 
| 99 | 
            +
                frame_length=46,
         | 
| 100 | 
            +
                frame_shift=10,
         | 
| 101 | 
            +
                window_type="hanning",
         | 
| 102 | 
            +
                preemphasis=0.0,
         | 
| 103 | 
            +
                device=torch.device("cuda"),
         | 
| 104 | 
            +
                square_root_window=True,
         | 
| 105 | 
            +
            ):
         | 
| 106 | 
            +
                """Convert frames to signal using overlap-and-add systhesis.
         | 
| 107 | 
            +
                Args:
         | 
| 108 | 
            +
                    spectrum: magnitude spectrum [batchsize,x,y,2]
         | 
| 109 | 
            +
                    signal: wave signal to supply phase information
         | 
| 110 | 
            +
                Return:
         | 
| 111 | 
            +
                    wav: synthesied output waveform
         | 
| 112 | 
            +
                """
         | 
| 113 | 
            +
                real = real.permute(0, 3, 2, 1)
         | 
| 114 | 
            +
                imag = imag.permute(0, 3, 2, 1)
         | 
| 115 | 
            +
                spectrum = torch.cat([real, imag], dim=-1)
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                hop_length = int(sample_rate * frame_shift / 1000)
         | 
| 118 | 
            +
                win_length = int(sample_rate * frame_length / 1000)
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                # num_point = fft_point(win_length)
         | 
| 121 | 
            +
                num_point = win_length
         | 
| 122 | 
            +
                if "cuda" in str(device):
         | 
| 123 | 
            +
                    window = get_window(num_point, window_type, square_root_window).cuda(device)
         | 
| 124 | 
            +
                else:
         | 
| 125 | 
            +
                    window = get_window(num_point, window_type, square_root_window)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                wav = torch_istft(
         | 
| 128 | 
            +
                    spectrum,
         | 
| 129 | 
            +
                    num_point,
         | 
| 130 | 
            +
                    hop_length=hop_length,
         | 
| 131 | 
            +
                    win_length=window.shape[0],
         | 
| 132 | 
            +
                    window=window,
         | 
| 133 | 
            +
                )
         | 
| 134 | 
            +
                return wav[..., :length]
         | 
| 135 | 
            +
             | 
| 136 | 
            +
             | 
| 137 | 
            +
            def torch_istft(
         | 
| 138 | 
            +
                stft_matrix,  # type: Tensor
         | 
| 139 | 
            +
                n_fft,  # type: int
         | 
| 140 | 
            +
                hop_length=None,  # type: Optional[int]
         | 
| 141 | 
            +
                win_length=None,  # type: Optional[int]
         | 
| 142 | 
            +
                window=None,  # type: Optional[Tensor]
         | 
| 143 | 
            +
                center=True,  # type: bool
         | 
| 144 | 
            +
                pad_mode="reflect",  # type: str
         | 
| 145 | 
            +
                normalized=False,  # type: bool
         | 
| 146 | 
            +
                onesided=True,  # type: bool
         | 
| 147 | 
            +
                length=None,  # type: Optional[int]
         | 
| 148 | 
            +
            ):
         | 
| 149 | 
            +
                # type: (...) -> Tensor
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                stft_matrix_dim = stft_matrix.dim()
         | 
| 152 | 
            +
                assert 3 <= stft_matrix_dim <= 4, "Incorrect stft dimension: %d" % (stft_matrix_dim)
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                if stft_matrix_dim == 3:
         | 
| 155 | 
            +
                    # add a channel dimension
         | 
| 156 | 
            +
                    stft_matrix = stft_matrix.unsqueeze(0)
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                dtype = stft_matrix.dtype
         | 
| 159 | 
            +
                device = stft_matrix.device
         | 
| 160 | 
            +
                fft_size = stft_matrix.size(1)
         | 
| 161 | 
            +
                assert (onesided and n_fft // 2 + 1 == fft_size) or (
         | 
| 162 | 
            +
                    not onesided and n_fft == fft_size
         | 
| 163 | 
            +
                ), (
         | 
| 164 | 
            +
                    "one_sided implies that n_fft // 2 + 1 == fft_size and not one_sided implies n_fft == fft_size. "
         | 
| 165 | 
            +
                    + "Given values were onesided: %s, n_fft: %d, fft_size: %d"
         | 
| 166 | 
            +
                    % ("True" if onesided else False, n_fft, fft_size)
         | 
| 167 | 
            +
                )
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                # use stft defaults for Optionals
         | 
| 170 | 
            +
                if win_length is None:
         | 
| 171 | 
            +
                    win_length = n_fft
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                if hop_length is None:
         | 
| 174 | 
            +
                    hop_length = int(win_length // 4)
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                # There must be overlap
         | 
| 177 | 
            +
                assert 0 < hop_length <= win_length
         | 
| 178 | 
            +
                assert 0 < win_length <= n_fft
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                if window is None:
         | 
| 181 | 
            +
                    window = torch.ones(win_length, requires_grad=False, device=device, dtype=dtype)
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                assert window.dim() == 1 and window.size(0) == win_length
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                if win_length != n_fft:
         | 
| 186 | 
            +
                    # center window with pad left and right zeros
         | 
| 187 | 
            +
                    left = (n_fft - win_length) // 2
         | 
| 188 | 
            +
                    window = torch.nn.functional.pad(window, (left, n_fft - win_length - left))
         | 
| 189 | 
            +
                    assert window.size(0) == n_fft
         | 
| 190 | 
            +
                # win_length and n_fft are synonymous from here on
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                stft_matrix = stft_matrix.transpose(1, 2)  # size (channel, n_frames, fft_size, 2)
         | 
| 193 | 
            +
                stft_matrix = torch.irfft(
         | 
| 194 | 
            +
                    stft_matrix, 1, normalized, onesided, signal_sizes=(n_fft,)
         | 
| 195 | 
            +
                )  # size (channel, n_frames, n_fft)
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                assert stft_matrix.size(2) == n_fft
         | 
| 198 | 
            +
                n_frames = stft_matrix.size(1)
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                ytmp = stft_matrix * window.view(1, 1, n_fft)  # size (channel, n_frames, n_fft)
         | 
| 201 | 
            +
                # each column of a channel is a frame which needs to be overlap added at the right place
         | 
| 202 | 
            +
                ytmp = ytmp.transpose(1, 2)  # size (channel, n_fft, n_frames)
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                eye = torch.eye(n_fft, requires_grad=False, device=device, dtype=dtype).unsqueeze(
         | 
| 205 | 
            +
                    1
         | 
| 206 | 
            +
                )  # size (n_fft, 1, n_fft)
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                # this does overlap add where the frames of ytmp are added such that the i'th frame of
         | 
| 209 | 
            +
                # ytmp is added starting at i*hop_length in the output
         | 
| 210 | 
            +
                y = torch.nn.functional.conv_transpose1d(
         | 
| 211 | 
            +
                    ytmp, eye, stride=hop_length, padding=0
         | 
| 212 | 
            +
                )  # size (channel, 1, expected_signal_len)
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                # do the same for the window function
         | 
| 215 | 
            +
                window_sq = (
         | 
| 216 | 
            +
                    window.pow(2).view(n_fft, 1).repeat((1, n_frames)).unsqueeze(0)
         | 
| 217 | 
            +
                )  # size (1, n_fft, n_frames)
         | 
| 218 | 
            +
                window_envelop = torch.nn.functional.conv_transpose1d(
         | 
| 219 | 
            +
                    window_sq, eye, stride=hop_length, padding=0
         | 
| 220 | 
            +
                )  # size (1, 1, expected_signal_len)
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                expected_signal_len = n_fft + hop_length * (n_frames - 1)
         | 
| 223 | 
            +
                assert y.size(2) == expected_signal_len
         | 
| 224 | 
            +
                assert window_envelop.size(2) == expected_signal_len
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                half_n_fft = n_fft // 2
         | 
| 227 | 
            +
                # we need to trim the front padding away if center
         | 
| 228 | 
            +
                start = half_n_fft if center else 0
         | 
| 229 | 
            +
                end = -half_n_fft if length is None else start + length
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                y = y[:, :, start:end]
         | 
| 232 | 
            +
                window_envelop = window_envelop[:, :, start:end]
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                # check NOLA non-zero overlap condition
         | 
| 235 | 
            +
                window_envelop_lowest = window_envelop.abs().min()
         | 
| 236 | 
            +
                assert window_envelop_lowest > 1e-11, "window overlap add min: %f" % (
         | 
| 237 | 
            +
                    window_envelop_lowest
         | 
| 238 | 
            +
                )
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                y = (y / window_envelop).squeeze(1)  # size (channel, expected_signal_len)
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                if stft_matrix_dim == 3:  # remove the channel dimension
         | 
| 243 | 
            +
                    y = y.squeeze(0)
         | 
| 244 | 
            +
                return y
         | 
    	
        voicefixer/tools/io.py
    ADDED
    
    | @@ -0,0 +1,44 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import json
         | 
| 2 | 
            +
            import pickle
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            def read_list(fname):
         | 
| 6 | 
            +
                result = []
         | 
| 7 | 
            +
                with open(fname, "r") as f:
         | 
| 8 | 
            +
                    for each in f.readlines():
         | 
| 9 | 
            +
                        each = each.strip("\n")
         | 
| 10 | 
            +
                        result.append(each)
         | 
| 11 | 
            +
                return result
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            def write_list(list, fname):
         | 
| 15 | 
            +
                with open(fname, "w") as f:
         | 
| 16 | 
            +
                    for word in list:
         | 
| 17 | 
            +
                        f.write(word)
         | 
| 18 | 
            +
                        f.write("\n")
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            def write_json(my_dict, fname):
         | 
| 22 | 
            +
                # print("Save json file at "+fname)
         | 
| 23 | 
            +
                json_str = json.dumps(my_dict)
         | 
| 24 | 
            +
                with open(fname, "w") as json_file:
         | 
| 25 | 
            +
                    json_file.write(json_str)
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            def load_json(fname):
         | 
| 29 | 
            +
                with open(fname, "r") as f:
         | 
| 30 | 
            +
                    data = json.load(f)
         | 
| 31 | 
            +
                    return data
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            def save_pickle(obj, fname):
         | 
| 35 | 
            +
                # print("Save pickle at "+fname)
         | 
| 36 | 
            +
                with open(fname, "wb") as f:
         | 
| 37 | 
            +
                    pickle.dump(obj, f)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
             | 
| 40 | 
            +
            def load_pickle(fname):
         | 
| 41 | 
            +
                # print("Load pickle at "+fname)
         | 
| 42 | 
            +
                with open(fname, "rb") as f:
         | 
| 43 | 
            +
                    res = pickle.load(f)
         | 
| 44 | 
            +
                return res
         | 
    	
        voicefixer/tools/mel_scale.py
    ADDED
    
    | @@ -0,0 +1,238 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            from torch import Tensor
         | 
| 3 | 
            +
            from typing import Optional
         | 
| 4 | 
            +
            import math
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import warnings
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            class MelScale(torch.nn.Module):
         | 
| 10 | 
            +
                r"""Turn a normal STFT into a mel frequency STFT, using a conversion
         | 
| 11 | 
            +
                matrix.  This uses triangular filter banks.
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)).
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                Args:
         | 
| 16 | 
            +
                    n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
         | 
| 17 | 
            +
                    sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
         | 
| 18 | 
            +
                    f_min (float, optional): Minimum frequency. (Default: ``0.``)
         | 
| 19 | 
            +
                    f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
         | 
| 20 | 
            +
                    n_stft (int, optional): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`. (Default: ``201``)
         | 
| 21 | 
            +
                    norm (str or None, optional): If 'slaney', divide the triangular mel weights by the width of the mel band
         | 
| 22 | 
            +
                        (area normalization). (Default: ``None``)
         | 
| 23 | 
            +
                    mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                See also:
         | 
| 26 | 
            +
                    :py:func:`torchaudio.functional.melscale_fbanks` - The function used to
         | 
| 27 | 
            +
                    generate the filter banks.
         | 
| 28 | 
            +
                """
         | 
| 29 | 
            +
                __constants__ = ["n_mels", "sample_rate", "f_min", "f_max"]
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                def __init__(
         | 
| 32 | 
            +
                    self,
         | 
| 33 | 
            +
                    n_mels: int = 128,
         | 
| 34 | 
            +
                    sample_rate: int = 16000,
         | 
| 35 | 
            +
                    f_min: float = 0.0,
         | 
| 36 | 
            +
                    f_max: Optional[float] = None,
         | 
| 37 | 
            +
                    n_stft: int = 201,
         | 
| 38 | 
            +
                    norm: Optional[str] = None,
         | 
| 39 | 
            +
                    mel_scale: str = "htk",
         | 
| 40 | 
            +
                ) -> None:
         | 
| 41 | 
            +
                    super(MelScale, self).__init__()
         | 
| 42 | 
            +
                    self.n_mels = n_mels
         | 
| 43 | 
            +
                    self.sample_rate = sample_rate
         | 
| 44 | 
            +
                    self.f_max = f_max if f_max is not None else float(sample_rate // 2)
         | 
| 45 | 
            +
                    self.f_min = f_min
         | 
| 46 | 
            +
                    self.norm = norm
         | 
| 47 | 
            +
                    self.mel_scale = mel_scale
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    assert f_min <= self.f_max, "Require f_min: {} < f_max: {}".format(
         | 
| 50 | 
            +
                        f_min, self.f_max
         | 
| 51 | 
            +
                    )
         | 
| 52 | 
            +
                    fb = melscale_fbanks(
         | 
| 53 | 
            +
                        n_stft,
         | 
| 54 | 
            +
                        self.f_min,
         | 
| 55 | 
            +
                        self.f_max,
         | 
| 56 | 
            +
                        self.n_mels,
         | 
| 57 | 
            +
                        self.sample_rate,
         | 
| 58 | 
            +
                        self.norm,
         | 
| 59 | 
            +
                        self.mel_scale,
         | 
| 60 | 
            +
                    )
         | 
| 61 | 
            +
                    self.register_buffer("fb", fb)
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                def forward(self, specgram: Tensor) -> Tensor:
         | 
| 64 | 
            +
                    r"""
         | 
| 65 | 
            +
                    Args:
         | 
| 66 | 
            +
                        specgram (Tensor): A spectrogram STFT of dimension (..., freq, time).
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    Returns:
         | 
| 69 | 
            +
                        Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
         | 
| 70 | 
            +
                    """
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    # (..., time, freq) dot (freq, n_mels) -> (..., n_mels, time)
         | 
| 73 | 
            +
                    mel_specgram = torch.matmul(specgram.transpose(-1, -2), self.fb).transpose(
         | 
| 74 | 
            +
                        -1, -2
         | 
| 75 | 
            +
                    )
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    return mel_specgram
         | 
| 78 | 
            +
             | 
| 79 | 
            +
             | 
| 80 | 
            +
            def _hz_to_mel(freq: float, mel_scale: str = "htk") -> float:
         | 
| 81 | 
            +
                r"""Convert Hz to Mels.
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                Args:
         | 
| 84 | 
            +
                    freqs (float): Frequencies in Hz
         | 
| 85 | 
            +
                    mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                Returns:
         | 
| 88 | 
            +
                    mels (float): Frequency in Mels
         | 
| 89 | 
            +
                """
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                if mel_scale not in ["slaney", "htk"]:
         | 
| 92 | 
            +
                    raise ValueError('mel_scale should be one of "htk" or "slaney".')
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                if mel_scale == "htk":
         | 
| 95 | 
            +
                    return 2595.0 * math.log10(1.0 + (freq / 700.0))
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                # Fill in the linear part
         | 
| 98 | 
            +
                f_min = 0.0
         | 
| 99 | 
            +
                f_sp = 200.0 / 3
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                mels = (freq - f_min) / f_sp
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                # Fill in the log-scale part
         | 
| 104 | 
            +
                min_log_hz = 1000.0
         | 
| 105 | 
            +
                min_log_mel = (min_log_hz - f_min) / f_sp
         | 
| 106 | 
            +
                logstep = math.log(6.4) / 27.0
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                if freq >= min_log_hz:
         | 
| 109 | 
            +
                    mels = min_log_mel + math.log(freq / min_log_hz) / logstep
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                return mels
         | 
| 112 | 
            +
             | 
| 113 | 
            +
             | 
| 114 | 
            +
            def _mel_to_hz(mels: Tensor, mel_scale: str = "htk") -> Tensor:
         | 
| 115 | 
            +
                """Convert mel bin numbers to frequencies.
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                Args:
         | 
| 118 | 
            +
                    mels (Tensor): Mel frequencies
         | 
| 119 | 
            +
                    mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                Returns:
         | 
| 122 | 
            +
                    freqs (Tensor): Mels converted in Hz
         | 
| 123 | 
            +
                """
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                if mel_scale not in ["slaney", "htk"]:
         | 
| 126 | 
            +
                    raise ValueError('mel_scale should be one of "htk" or "slaney".')
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                if mel_scale == "htk":
         | 
| 129 | 
            +
                    return 700.0 * (10.0 ** (mels / 2595.0) - 1.0)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                # Fill in the linear scale
         | 
| 132 | 
            +
                f_min = 0.0
         | 
| 133 | 
            +
                f_sp = 200.0 / 3
         | 
| 134 | 
            +
                freqs = f_min + f_sp * mels
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                # And now the nonlinear scale
         | 
| 137 | 
            +
                min_log_hz = 1000.0
         | 
| 138 | 
            +
                min_log_mel = (min_log_hz - f_min) / f_sp
         | 
| 139 | 
            +
                logstep = math.log(6.4) / 27.0
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                log_t = mels >= min_log_mel
         | 
| 142 | 
            +
                freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                return freqs
         | 
| 145 | 
            +
             | 
| 146 | 
            +
             | 
| 147 | 
            +
            def _create_triangular_filterbank(
         | 
| 148 | 
            +
                all_freqs: Tensor,
         | 
| 149 | 
            +
                f_pts: Tensor,
         | 
| 150 | 
            +
            ) -> Tensor:
         | 
| 151 | 
            +
                """Create a triangular filter bank.
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                Args:
         | 
| 154 | 
            +
                    all_freqs (Tensor): STFT freq points of size (`n_freqs`).
         | 
| 155 | 
            +
                    f_pts (Tensor): Filter mid points of size (`n_filter`).
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                Returns:
         | 
| 158 | 
            +
                    fb (Tensor): The filter bank of size (`n_freqs`, `n_filter`).
         | 
| 159 | 
            +
                """
         | 
| 160 | 
            +
                # Adopted from Librosa
         | 
| 161 | 
            +
                # calculate the difference between each filter mid point and each stft freq point in hertz
         | 
| 162 | 
            +
                f_diff = f_pts[1:] - f_pts[:-1]  # (n_filter + 1)
         | 
| 163 | 
            +
                slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1)  # (n_freqs, n_filter + 2)
         | 
| 164 | 
            +
                # create overlapping triangles
         | 
| 165 | 
            +
                zero = torch.zeros(1)
         | 
| 166 | 
            +
                down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1]  # (n_freqs, n_filter)
         | 
| 167 | 
            +
                up_slopes = slopes[:, 2:] / f_diff[1:]  # (n_freqs, n_filter)
         | 
| 168 | 
            +
                fb = torch.max(zero, torch.min(down_slopes, up_slopes))
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                return fb
         | 
| 171 | 
            +
             | 
| 172 | 
            +
             | 
| 173 | 
            +
            def melscale_fbanks(
         | 
| 174 | 
            +
                n_freqs: int,
         | 
| 175 | 
            +
                f_min: float,
         | 
| 176 | 
            +
                f_max: float,
         | 
| 177 | 
            +
                n_mels: int,
         | 
| 178 | 
            +
                sample_rate: int,
         | 
| 179 | 
            +
                norm: Optional[str] = None,
         | 
| 180 | 
            +
                mel_scale: str = "htk",
         | 
| 181 | 
            +
            ) -> Tensor:
         | 
| 182 | 
            +
                r"""Create a frequency bin conversion matrix.
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                Note:
         | 
| 185 | 
            +
                    For the sake of the numerical compatibility with librosa, not all the coefficients
         | 
| 186 | 
            +
                    in the resulting filter bank has magnitude of 1.
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                    .. image:: https://download.pytorch.org/torchaudio/doc-assets/mel_fbanks.png
         | 
| 189 | 
            +
                       :alt: Visualization of generated filter bank
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                Args:
         | 
| 192 | 
            +
                    n_freqs (int): Number of frequencies to highlight/apply
         | 
| 193 | 
            +
                    f_min (float): Minimum frequency (Hz)
         | 
| 194 | 
            +
                    f_max (float): Maximum frequency (Hz)
         | 
| 195 | 
            +
                    n_mels (int): Number of mel filterbanks
         | 
| 196 | 
            +
                    sample_rate (int): Sample rate of the audio waveform
         | 
| 197 | 
            +
                    norm (str or None, optional): If 'slaney', divide the triangular mel weights by the width of the mel band
         | 
| 198 | 
            +
                        (area normalization). (Default: ``None``)
         | 
| 199 | 
            +
                    mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                Returns:
         | 
| 202 | 
            +
                    Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``)
         | 
| 203 | 
            +
                    meaning number of frequencies to highlight/apply to x the number of filterbanks.
         | 
| 204 | 
            +
                    Each column is a filterbank so that assuming there is a matrix A of
         | 
| 205 | 
            +
                    size (..., ``n_freqs``), the applied result would be
         | 
| 206 | 
            +
                    ``A * melscale_fbanks(A.size(-1), ...)``.
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                """
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                if norm is not None and norm != "slaney":
         | 
| 211 | 
            +
                    raise ValueError("norm must be one of None or 'slaney'")
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                # freq bins
         | 
| 214 | 
            +
                all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                # calculate mel freq bins
         | 
| 217 | 
            +
                m_min = _hz_to_mel(f_min, mel_scale=mel_scale)
         | 
| 218 | 
            +
                m_max = _hz_to_mel(f_max, mel_scale=mel_scale)
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                m_pts = torch.linspace(m_min, m_max, n_mels + 2)
         | 
| 221 | 
            +
                f_pts = _mel_to_hz(m_pts, mel_scale=mel_scale)
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                # create filterbank
         | 
| 224 | 
            +
                fb = _create_triangular_filterbank(all_freqs, f_pts)
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                if norm is not None and norm == "slaney":
         | 
| 227 | 
            +
                    # Slaney-style mel is scaled to be approx constant energy per channel
         | 
| 228 | 
            +
                    enorm = 2.0 / (f_pts[2 : n_mels + 2] - f_pts[:n_mels])
         | 
| 229 | 
            +
                    fb *= enorm.unsqueeze(0)
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                if (fb.max(dim=0).values == 0.0).any():
         | 
| 232 | 
            +
                    warnings.warn(
         | 
| 233 | 
            +
                        "At least one mel filterbank has all zero values. "
         | 
| 234 | 
            +
                        f"The value for `n_mels` ({n_mels}) may be set too high. "
         | 
| 235 | 
            +
                        f"Or, the value for `n_freqs` ({n_freqs}) may be set too low."
         | 
| 236 | 
            +
                    )
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                return fb
         | 
    	
        voicefixer/tools/modules/__init__.py
    ADDED
    
    | @@ -0,0 +1,11 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/env python
         | 
| 2 | 
            +
            # -*- encoding: utf-8 -*-
         | 
| 3 | 
            +
            """
         | 
| 4 | 
            +
            @File    :   __init__.py.py    
         | 
| 5 | 
            +
            @Contact :   [email protected]
         | 
| 6 | 
            +
            @License :   (C)Copyright 2020-2100
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            @Modify Time      @Author    @Version    @Desciption
         | 
| 9 | 
            +
            ------------      -------    --------    -----------
         | 
| 10 | 
            +
            9/14/21 12:29 AM   Haohe Liu      1.0         None
         | 
| 11 | 
            +
            """
         | 
    	
        voicefixer/tools/modules/fDomainHelper.py
    ADDED
    
    | @@ -0,0 +1,234 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from torchlibrosa.stft import STFT, ISTFT, magphase
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import torch.nn as nn
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            from voicefixer.tools.modules.pqmf import PQMF
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            class FDomainHelper(nn.Module):
         | 
| 8 | 
            +
                def __init__(
         | 
| 9 | 
            +
                    self,
         | 
| 10 | 
            +
                    window_size=2048,
         | 
| 11 | 
            +
                    hop_size=441,
         | 
| 12 | 
            +
                    center=True,
         | 
| 13 | 
            +
                    pad_mode="reflect",
         | 
| 14 | 
            +
                    window="hann",
         | 
| 15 | 
            +
                    freeze_parameters=True,
         | 
| 16 | 
            +
                    subband=None,
         | 
| 17 | 
            +
                    root="/Users/admin/Documents/projects/",
         | 
| 18 | 
            +
                ):
         | 
| 19 | 
            +
                    super(FDomainHelper, self).__init__()
         | 
| 20 | 
            +
                    self.subband = subband
         | 
| 21 | 
            +
                    # assert torchlibrosa.__version__ == "0.0.7", "Error: Found torchlibrosa version %s. Please install 0.0.7 version of torchlibrosa by: pip install torchlibrosa==0.0.7." % torchlibrosa.__version__
         | 
| 22 | 
            +
                    if self.subband is None:
         | 
| 23 | 
            +
                        self.stft = STFT(
         | 
| 24 | 
            +
                            n_fft=window_size,
         | 
| 25 | 
            +
                            hop_length=hop_size,
         | 
| 26 | 
            +
                            win_length=window_size,
         | 
| 27 | 
            +
                            window=window,
         | 
| 28 | 
            +
                            center=center,
         | 
| 29 | 
            +
                            pad_mode=pad_mode,
         | 
| 30 | 
            +
                            freeze_parameters=freeze_parameters,
         | 
| 31 | 
            +
                        )
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                        self.istft = ISTFT(
         | 
| 34 | 
            +
                            n_fft=window_size,
         | 
| 35 | 
            +
                            hop_length=hop_size,
         | 
| 36 | 
            +
                            win_length=window_size,
         | 
| 37 | 
            +
                            window=window,
         | 
| 38 | 
            +
                            center=center,
         | 
| 39 | 
            +
                            pad_mode=pad_mode,
         | 
| 40 | 
            +
                            freeze_parameters=freeze_parameters,
         | 
| 41 | 
            +
                        )
         | 
| 42 | 
            +
                    else:
         | 
| 43 | 
            +
                        self.stft = STFT(
         | 
| 44 | 
            +
                            n_fft=window_size // self.subband,
         | 
| 45 | 
            +
                            hop_length=hop_size // self.subband,
         | 
| 46 | 
            +
                            win_length=window_size // self.subband,
         | 
| 47 | 
            +
                            window=window,
         | 
| 48 | 
            +
                            center=center,
         | 
| 49 | 
            +
                            pad_mode=pad_mode,
         | 
| 50 | 
            +
                            freeze_parameters=freeze_parameters,
         | 
| 51 | 
            +
                        )
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                        self.istft = ISTFT(
         | 
| 54 | 
            +
                            n_fft=window_size // self.subband,
         | 
| 55 | 
            +
                            hop_length=hop_size // self.subband,
         | 
| 56 | 
            +
                            win_length=window_size // self.subband,
         | 
| 57 | 
            +
                            window=window,
         | 
| 58 | 
            +
                            center=center,
         | 
| 59 | 
            +
                            pad_mode=pad_mode,
         | 
| 60 | 
            +
                            freeze_parameters=freeze_parameters,
         | 
| 61 | 
            +
                        )
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    if subband is not None and root is not None:
         | 
| 64 | 
            +
                        self.qmf = PQMF(subband, 64, root)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                def complex_spectrogram(self, input, eps=0.0):
         | 
| 67 | 
            +
                    # [batchsize, samples]
         | 
| 68 | 
            +
                    # return [batchsize, 2, t-steps, f-bins]
         | 
| 69 | 
            +
                    real, imag = self.stft(input)
         | 
| 70 | 
            +
                    return torch.cat([real, imag], dim=1)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                def reverse_complex_spectrogram(self, input, eps=0.0, length=None):
         | 
| 73 | 
            +
                    # [batchsize, 2[real,imag], t-steps, f-bins]
         | 
| 74 | 
            +
                    wav = self.istft(input[:, 0:1, ...], input[:, 1:2, ...], length=length)
         | 
| 75 | 
            +
                    return wav
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                def spectrogram(self, input, eps=0.0):
         | 
| 78 | 
            +
                    (real, imag) = self.stft(input.float())
         | 
| 79 | 
            +
                    return torch.clamp(real**2 + imag**2, eps, np.inf) ** 0.5
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                def spectrogram_phase(self, input, eps=0.0):
         | 
| 82 | 
            +
                    (real, imag) = self.stft(input.float())
         | 
| 83 | 
            +
                    mag = torch.clamp(real**2 + imag**2, eps, np.inf) ** 0.5
         | 
| 84 | 
            +
                    cos = real / mag
         | 
| 85 | 
            +
                    sin = imag / mag
         | 
| 86 | 
            +
                    return mag, cos, sin
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                def wav_to_spectrogram_phase(self, input, eps=1e-8):
         | 
| 89 | 
            +
                    """Waveform to spectrogram.
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    Args:
         | 
| 92 | 
            +
                      input: (batch_size, channels_num, segment_samples)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    Outputs:
         | 
| 95 | 
            +
                      output: (batch_size, channels_num, time_steps, freq_bins)
         | 
| 96 | 
            +
                    """
         | 
| 97 | 
            +
                    sp_list = []
         | 
| 98 | 
            +
                    cos_list = []
         | 
| 99 | 
            +
                    sin_list = []
         | 
| 100 | 
            +
                    channels_num = input.shape[1]
         | 
| 101 | 
            +
                    for channel in range(channels_num):
         | 
| 102 | 
            +
                        mag, cos, sin = self.spectrogram_phase(input[:, channel, :], eps=eps)
         | 
| 103 | 
            +
                        sp_list.append(mag)
         | 
| 104 | 
            +
                        cos_list.append(cos)
         | 
| 105 | 
            +
                        sin_list.append(sin)
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    sps = torch.cat(sp_list, dim=1)
         | 
| 108 | 
            +
                    coss = torch.cat(cos_list, dim=1)
         | 
| 109 | 
            +
                    sins = torch.cat(sin_list, dim=1)
         | 
| 110 | 
            +
                    return sps, coss, sins
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                def spectrogram_phase_to_wav(self, sps, coss, sins, length):
         | 
| 113 | 
            +
                    channels_num = sps.size()[1]
         | 
| 114 | 
            +
                    res = []
         | 
| 115 | 
            +
                    for i in range(channels_num):
         | 
| 116 | 
            +
                        res.append(
         | 
| 117 | 
            +
                            self.istft(
         | 
| 118 | 
            +
                                sps[:, i : i + 1, ...] * coss[:, i : i + 1, ...],
         | 
| 119 | 
            +
                                sps[:, i : i + 1, ...] * sins[:, i : i + 1, ...],
         | 
| 120 | 
            +
                                length,
         | 
| 121 | 
            +
                            )
         | 
| 122 | 
            +
                        )
         | 
| 123 | 
            +
                        res[-1] = res[-1].unsqueeze(1)
         | 
| 124 | 
            +
                    return torch.cat(res, dim=1)
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                def wav_to_spectrogram(self, input, eps=1e-8):
         | 
| 127 | 
            +
                    """Waveform to spectrogram.
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    Args:
         | 
| 130 | 
            +
                      input: (batch_size,channels_num, segment_samples)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    Outputs:
         | 
| 133 | 
            +
                      output: (batch_size, channels_num, time_steps, freq_bins)
         | 
| 134 | 
            +
                    """
         | 
| 135 | 
            +
                    sp_list = []
         | 
| 136 | 
            +
                    channels_num = input.shape[1]
         | 
| 137 | 
            +
                    for channel in range(channels_num):
         | 
| 138 | 
            +
                        sp_list.append(self.spectrogram(input[:, channel, :], eps=eps))
         | 
| 139 | 
            +
                    output = torch.cat(sp_list, dim=1)
         | 
| 140 | 
            +
                    return output
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                def spectrogram_to_wav(self, input, spectrogram, length=None):
         | 
| 143 | 
            +
                    """Spectrogram to waveform.
         | 
| 144 | 
            +
                    Args:
         | 
| 145 | 
            +
                      input: (batch_size, segment_samples, channels_num)
         | 
| 146 | 
            +
                      spectrogram: (batch_size, channels_num, time_steps, freq_bins)
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    Outputs:
         | 
| 149 | 
            +
                      output: (batch_size, segment_samples, channels_num)
         | 
| 150 | 
            +
                    """
         | 
| 151 | 
            +
                    channels_num = input.shape[1]
         | 
| 152 | 
            +
                    wav_list = []
         | 
| 153 | 
            +
                    for channel in range(channels_num):
         | 
| 154 | 
            +
                        (real, imag) = self.stft(input[:, channel, :])
         | 
| 155 | 
            +
                        (_, cos, sin) = magphase(real, imag)
         | 
| 156 | 
            +
                        wav_list.append(
         | 
| 157 | 
            +
                            self.istft(
         | 
| 158 | 
            +
                                spectrogram[:, channel : channel + 1, :, :] * cos,
         | 
| 159 | 
            +
                                spectrogram[:, channel : channel + 1, :, :] * sin,
         | 
| 160 | 
            +
                                length,
         | 
| 161 | 
            +
                            )
         | 
| 162 | 
            +
                        )
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    output = torch.stack(wav_list, dim=1)
         | 
| 165 | 
            +
                    return output
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                # todo the following code is not bug free!
         | 
| 168 | 
            +
                def wav_to_complex_spectrogram(self, input, eps=0.0):
         | 
| 169 | 
            +
                    # [batchsize , channels, samples]
         | 
| 170 | 
            +
                    # [batchsize, 2[real,imag]*channels, t-steps, f-bins]
         | 
| 171 | 
            +
                    res = []
         | 
| 172 | 
            +
                    channels_num = input.shape[1]
         | 
| 173 | 
            +
                    for channel in range(channels_num):
         | 
| 174 | 
            +
                        res.append(self.complex_spectrogram(input[:, channel, :], eps=eps))
         | 
| 175 | 
            +
                    return torch.cat(res, dim=1)
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                def complex_spectrogram_to_wav(self, input, eps=0.0, length=None):
         | 
| 178 | 
            +
                    # [batchsize, 2[real,imag]*channels, t-steps, f-bins]
         | 
| 179 | 
            +
                    # return  [batchsize, channels, samples]
         | 
| 180 | 
            +
                    channels = input.size()[1] // 2
         | 
| 181 | 
            +
                    wavs = []
         | 
| 182 | 
            +
                    for i in range(channels):
         | 
| 183 | 
            +
                        wavs.append(
         | 
| 184 | 
            +
                            self.reverse_complex_spectrogram(
         | 
| 185 | 
            +
                                input[:, 2 * i : 2 * i + 2, ...], eps=eps, length=length
         | 
| 186 | 
            +
                            )
         | 
| 187 | 
            +
                        )
         | 
| 188 | 
            +
                        wavs[-1] = wavs[-1].unsqueeze(1)
         | 
| 189 | 
            +
                    return torch.cat(wavs, dim=1)
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                def wav_to_complex_subband_spectrogram(self, input, eps=0.0):
         | 
| 192 | 
            +
                    # [batchsize, channels, samples]
         | 
| 193 | 
            +
                    # [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins]
         | 
| 194 | 
            +
                    subwav = self.qmf.analysis(input)  # [batchsize, subband*channels, samples]
         | 
| 195 | 
            +
                    subspec = self.wav_to_complex_spectrogram(subwav)
         | 
| 196 | 
            +
                    return subspec
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                def complex_subband_spectrogram_to_wav(self, input, eps=0.0):
         | 
| 199 | 
            +
                    # [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins]
         | 
| 200 | 
            +
                    # [batchsize, channels, samples]
         | 
| 201 | 
            +
                    subwav = self.complex_spectrogram_to_wav(input)
         | 
| 202 | 
            +
                    data = self.qmf.synthesis(subwav)
         | 
| 203 | 
            +
                    return data
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                def wav_to_mag_phase_subband_spectrogram(self, input, eps=1e-8):
         | 
| 206 | 
            +
                    """
         | 
| 207 | 
            +
                    :param input:
         | 
| 208 | 
            +
                    :param eps:
         | 
| 209 | 
            +
                    :return:
         | 
| 210 | 
            +
                        loss = torch.nn.L1Loss()
         | 
| 211 | 
            +
                        models = FDomainHelper(subband=4)
         | 
| 212 | 
            +
                        data = torch.randn((3,1, 44100*3))
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                        sps, coss, sins = models.wav_to_mag_phase_subband_spectrogram(data)
         | 
| 215 | 
            +
                        wav = models.mag_phase_subband_spectrogram_to_wav(sps,coss,sins,44100*3//4)
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                        print(loss(data,wav))
         | 
| 218 | 
            +
                        print(torch.max(torch.abs(data-wav)))
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                    """
         | 
| 221 | 
            +
                    # [batchsize, channels, samples]
         | 
| 222 | 
            +
                    # [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins]
         | 
| 223 | 
            +
                    subwav = self.qmf.analysis(input)  # [batchsize, subband*channels, samples]
         | 
| 224 | 
            +
                    sps, coss, sins = self.wav_to_spectrogram_phase(subwav, eps=eps)
         | 
| 225 | 
            +
                    return sps, coss, sins
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                def mag_phase_subband_spectrogram_to_wav(self, sps, coss, sins, length, eps=0.0):
         | 
| 228 | 
            +
                    # [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins]
         | 
| 229 | 
            +
                    # [batchsize, channels, samples]
         | 
| 230 | 
            +
                    subwav = self.spectrogram_phase_to_wav(
         | 
| 231 | 
            +
                        sps, coss, sins, length + self.qmf.pad_samples // self.qmf.N
         | 
| 232 | 
            +
                    )
         | 
| 233 | 
            +
                    data = self.qmf.synthesis(subwav)
         | 
| 234 | 
            +
                    return data
         | 
    	
        voicefixer/tools/modules/filters/f_2_64.mat
    ADDED
    
    | 
            File without changes
         | 
    	
        voicefixer/tools/modules/filters/f_4_64.mat
    ADDED
    
    | 
            File without changes
         | 
    	
        voicefixer/tools/modules/filters/f_8_64.mat
    ADDED
    
    | 
            File without changes
         | 
    	
        voicefixer/tools/modules/filters/h_2_64.mat
    ADDED
    
    | 
            File without changes
         | 
    	
        voicefixer/tools/modules/filters/h_4_64.mat
    ADDED
    
    | 
            File without changes
         | 
    	
        voicefixer/tools/modules/filters/h_8_64.mat
    ADDED
    
    | 
            File without changes
         | 
    	
        voicefixer/tools/modules/pqmf.py
    ADDED
    
    | @@ -0,0 +1,116 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            @File    :   subband_util.py
         | 
| 3 | 
            +
            @Contact :   [email protected]
         | 
| 4 | 
            +
            @License :   (C)Copyright 2020-2021
         | 
| 5 | 
            +
            @Modify Time      @Author    @Version    @Desciption
         | 
| 6 | 
            +
            ------------      -------    --------    -----------
         | 
| 7 | 
            +
            2020/4/3 4:54 PM   Haohe Liu      1.0         None
         | 
| 8 | 
            +
            """
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
            import torch.nn.functional as F
         | 
| 12 | 
            +
            import torch.nn as nn
         | 
| 13 | 
            +
            import numpy as np
         | 
| 14 | 
            +
            import os.path as op
         | 
| 15 | 
            +
            from scipy.io import loadmat
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def load_mat2numpy(fname=""):
         | 
| 19 | 
            +
                if len(fname) == 0:
         | 
| 20 | 
            +
                    return None
         | 
| 21 | 
            +
                else:
         | 
| 22 | 
            +
                    return loadmat(fname)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            class PQMF(nn.Module):
         | 
| 26 | 
            +
                def __init__(self, N, M, project_root):
         | 
| 27 | 
            +
                    super().__init__()
         | 
| 28 | 
            +
                    self.N = N  # nsubband
         | 
| 29 | 
            +
                    self.M = M  # nfilter
         | 
| 30 | 
            +
                    try:
         | 
| 31 | 
            +
                        assert (N, M) in [(8, 64), (4, 64), (2, 64)]
         | 
| 32 | 
            +
                    except:
         | 
| 33 | 
            +
                        print("Warning:", N, "subbandand ", M, " filter is not supported")
         | 
| 34 | 
            +
                    self.pad_samples = 64
         | 
| 35 | 
            +
                    self.name = str(N) + "_" + str(M) + ".mat"
         | 
| 36 | 
            +
                    self.ana_conv_filter = nn.Conv1d(
         | 
| 37 | 
            +
                        1, out_channels=N, kernel_size=M, stride=N, bias=False
         | 
| 38 | 
            +
                    )
         | 
| 39 | 
            +
                    data = load_mat2numpy(
         | 
| 40 | 
            +
                        op.join(
         | 
| 41 | 
            +
                            project_root,
         | 
| 42 | 
            +
                            "arnold_workspace/restorer/tools/pytorch/modules/filters/f_"
         | 
| 43 | 
            +
                            + self.name,
         | 
| 44 | 
            +
                        )
         | 
| 45 | 
            +
                    )
         | 
| 46 | 
            +
                    data = data["f"].astype(np.float32) / N
         | 
| 47 | 
            +
                    data = np.flipud(data.T).T
         | 
| 48 | 
            +
                    data = np.reshape(data, (N, 1, M)).copy()
         | 
| 49 | 
            +
                    dict_new = self.ana_conv_filter.state_dict().copy()
         | 
| 50 | 
            +
                    dict_new["weight"] = torch.from_numpy(data)
         | 
| 51 | 
            +
                    self.ana_pad = nn.ConstantPad1d((M - N, 0), 0)
         | 
| 52 | 
            +
                    self.ana_conv_filter.load_state_dict(dict_new)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    self.syn_pad = nn.ConstantPad1d((0, M // N - 1), 0)
         | 
| 55 | 
            +
                    self.syn_conv_filter = nn.Conv1d(
         | 
| 56 | 
            +
                        N, out_channels=N, kernel_size=M // N, stride=1, bias=False
         | 
| 57 | 
            +
                    )
         | 
| 58 | 
            +
                    gk = load_mat2numpy(
         | 
| 59 | 
            +
                        op.join(
         | 
| 60 | 
            +
                            project_root,
         | 
| 61 | 
            +
                            "arnold_workspace/restorer/tools/pytorch/modules/filters/h_"
         | 
| 62 | 
            +
                            + self.name,
         | 
| 63 | 
            +
                        )
         | 
| 64 | 
            +
                    )
         | 
| 65 | 
            +
                    gk = gk["h"].astype(np.float32)
         | 
| 66 | 
            +
                    gk = np.transpose(np.reshape(gk, (N, M // N, N)), (1, 0, 2)) * N
         | 
| 67 | 
            +
                    gk = np.transpose(gk[::-1, :, :], (2, 1, 0)).copy()
         | 
| 68 | 
            +
                    dict_new = self.syn_conv_filter.state_dict().copy()
         | 
| 69 | 
            +
                    dict_new["weight"] = torch.from_numpy(gk)
         | 
| 70 | 
            +
                    self.syn_conv_filter.load_state_dict(dict_new)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    for param in self.parameters():
         | 
| 73 | 
            +
                        param.requires_grad = False
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                def __analysis_channel(self, inputs):
         | 
| 76 | 
            +
                    return self.ana_conv_filter(self.ana_pad(inputs))
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                def __systhesis_channel(self, inputs):
         | 
| 79 | 
            +
                    ret = self.syn_conv_filter(self.syn_pad(inputs)).permute(0, 2, 1)
         | 
| 80 | 
            +
                    return torch.reshape(ret, (ret.shape[0], 1, -1))
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                def analysis(self, inputs):
         | 
| 83 | 
            +
                    """
         | 
| 84 | 
            +
                    :param inputs: [batchsize,channel,raw_wav],value:[0,1]
         | 
| 85 | 
            +
                    :return:
         | 
| 86 | 
            +
                    """
         | 
| 87 | 
            +
                    inputs = F.pad(inputs, ((0, self.pad_samples)))
         | 
| 88 | 
            +
                    ret = None
         | 
| 89 | 
            +
                    for i in range(inputs.size()[1]):  # channels
         | 
| 90 | 
            +
                        if ret is None:
         | 
| 91 | 
            +
                            ret = self.__analysis_channel(inputs[:, i : i + 1, :])
         | 
| 92 | 
            +
                        else:
         | 
| 93 | 
            +
                            ret = torch.cat(
         | 
| 94 | 
            +
                                (ret, self.__analysis_channel(inputs[:, i : i + 1, :])), dim=1
         | 
| 95 | 
            +
                            )
         | 
| 96 | 
            +
                    return ret
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                def synthesis(self, data):
         | 
| 99 | 
            +
                    """
         | 
| 100 | 
            +
                    :param data: [batchsize,self.N*K,raw_wav_sub],value:[0,1]
         | 
| 101 | 
            +
                    :return:
         | 
| 102 | 
            +
                    """
         | 
| 103 | 
            +
                    ret = None
         | 
| 104 | 
            +
                    # data = F.pad(data,((0,self.pad_samples//self.N)))
         | 
| 105 | 
            +
                    for i in range(data.size()[1]):  # channels
         | 
| 106 | 
            +
                        if i % self.N == 0:
         | 
| 107 | 
            +
                            if ret is None:
         | 
| 108 | 
            +
                                ret = self.__systhesis_channel(data[:, i : i + self.N, :])
         | 
| 109 | 
            +
                            else:
         | 
| 110 | 
            +
                                new = self.__systhesis_channel(data[:, i : i + self.N, :])
         | 
| 111 | 
            +
                                ret = torch.cat((ret, new), dim=1)
         | 
| 112 | 
            +
                    ret = ret[..., : -self.pad_samples]
         | 
| 113 | 
            +
                    return ret
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                def forward(self, inputs):
         | 
| 116 | 
            +
                    return self.ana_conv_filter(self.ana_pad(inputs))
         | 
    	
        voicefixer/tools/path.py
    ADDED
    
    | @@ -0,0 +1,13 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            def find_and_build(root, path):
         | 
| 5 | 
            +
                path = os.path.join(root, path)
         | 
| 6 | 
            +
                if not os.path.exists(path):
         | 
| 7 | 
            +
                    os.makedirs(path, exist_ok=True)
         | 
| 8 | 
            +
                return path
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            def root_path(repo_name="voicefixer"):
         | 
| 12 | 
            +
                path = os.path.abspath(__file__)
         | 
| 13 | 
            +
                return path.split(repo_name)[0]
         | 
    	
        voicefixer/tools/pytorch_util.py
    ADDED
    
    | @@ -0,0 +1,180 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            def check_cuda_availability(cuda):
         | 
| 7 | 
            +
                if cuda and not torch.cuda.is_available():
         | 
| 8 | 
            +
                    raise RuntimeError("Error: You set cuda=True but no cuda device found.")
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            def try_tensor_cuda(tensor, cuda):
         | 
| 12 | 
            +
                if cuda and torch.cuda.is_available():
         | 
| 13 | 
            +
                    return tensor.cuda()
         | 
| 14 | 
            +
                else:
         | 
| 15 | 
            +
                    return tensor.cpu()
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def to_log(input):
         | 
| 19 | 
            +
                assert torch.sum(input < 0) == 0, (
         | 
| 20 | 
            +
                    str(input) + " has negative values counts " + str(torch.sum(input < 0))
         | 
| 21 | 
            +
                )
         | 
| 22 | 
            +
                return torch.log10(torch.clip(input, min=1e-8))
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            def from_log(input):
         | 
| 26 | 
            +
                input = torch.clip(input, min=-np.inf, max=5)
         | 
| 27 | 
            +
                return 10**input
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            def move_data_to_device(x, device):
         | 
| 31 | 
            +
                if "float" in str(x.dtype):
         | 
| 32 | 
            +
                    x = torch.Tensor(x)
         | 
| 33 | 
            +
                elif "int" in str(x.dtype):
         | 
| 34 | 
            +
                    x = torch.LongTensor(x)
         | 
| 35 | 
            +
                else:
         | 
| 36 | 
            +
                    return x
         | 
| 37 | 
            +
                return x.to(device)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
             | 
| 40 | 
            +
            def tensor2numpy(tensor):
         | 
| 41 | 
            +
                if "cuda" in str(tensor.device):
         | 
| 42 | 
            +
                    return tensor.detach().cpu().numpy()
         | 
| 43 | 
            +
                else:
         | 
| 44 | 
            +
                    return tensor.detach().numpy()
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            def count_parameters(model):
         | 
| 48 | 
            +
                for p in model.parameters():
         | 
| 49 | 
            +
                    if p.requires_grad:
         | 
| 50 | 
            +
                        print(p.shape)
         | 
| 51 | 
            +
                return sum(p.numel() for p in model.parameters() if p.requires_grad)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            def count_flops(model, audio_length):
         | 
| 55 | 
            +
                multiply_adds = False
         | 
| 56 | 
            +
                list_conv2d = []
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                def conv2d_hook(self, input, output):
         | 
| 59 | 
            +
                    batch_size, input_channels, input_height, input_width = input[0].size()
         | 
| 60 | 
            +
                    output_channels, output_height, output_width = output[0].size()
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    kernel_ops = (
         | 
| 63 | 
            +
                        self.kernel_size[0]
         | 
| 64 | 
            +
                        * self.kernel_size[1]
         | 
| 65 | 
            +
                        * (self.in_channels / self.groups)
         | 
| 66 | 
            +
                        * (2 if multiply_adds else 1)
         | 
| 67 | 
            +
                    )
         | 
| 68 | 
            +
                    bias_ops = 1 if self.bias is not None else 0
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    params = output_channels * (kernel_ops + bias_ops)
         | 
| 71 | 
            +
                    flops = batch_size * params * output_height * output_width
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    list_conv2d.append(flops)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                list_conv1d = []
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                def conv1d_hook(self, input, output):
         | 
| 78 | 
            +
                    batch_size, input_channels, input_length = input[0].size()
         | 
| 79 | 
            +
                    output_channels, output_length = output[0].size()
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    kernel_ops = (
         | 
| 82 | 
            +
                        self.kernel_size[0]
         | 
| 83 | 
            +
                        * (self.in_channels / self.groups)
         | 
| 84 | 
            +
                        * (2 if multiply_adds else 1)
         | 
| 85 | 
            +
                    )
         | 
| 86 | 
            +
                    bias_ops = 1 if self.bias is not None else 0
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    params = output_channels * (kernel_ops + bias_ops)
         | 
| 89 | 
            +
                    flops = batch_size * params * output_length
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    list_conv1d.append(flops)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                list_linear = []
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                def linear_hook(self, input, output):
         | 
| 96 | 
            +
                    batch_size = input[0].size(0) if input[0].dim() == 2 else 1
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    weight_ops = self.weight.nelement() * (2 if multiply_adds else 1)
         | 
| 99 | 
            +
                    bias_ops = self.bias.nelement()
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    flops = batch_size * (weight_ops + bias_ops)
         | 
| 102 | 
            +
                    list_linear.append(flops)
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                list_bn = []
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                def bn_hook(self, input, output):
         | 
| 107 | 
            +
                    list_bn.append(input[0].nelement())
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                list_relu = []
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                def relu_hook(self, input, output):
         | 
| 112 | 
            +
                    list_relu.append(input[0].nelement())
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                list_pooling2d = []
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                def pooling2d_hook(self, input, output):
         | 
| 117 | 
            +
                    batch_size, input_channels, input_height, input_width = input[0].size()
         | 
| 118 | 
            +
                    output_channels, output_height, output_width = output[0].size()
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    kernel_ops = self.kernel_size * self.kernel_size
         | 
| 121 | 
            +
                    bias_ops = 0
         | 
| 122 | 
            +
                    params = output_channels * (kernel_ops + bias_ops)
         | 
| 123 | 
            +
                    flops = batch_size * params * output_height * output_width
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    list_pooling2d.append(flops)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                list_pooling1d = []
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                def pooling1d_hook(self, input, output):
         | 
| 130 | 
            +
                    batch_size, input_channels, input_length = input[0].size()
         | 
| 131 | 
            +
                    output_channels, output_length = output[0].size()
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    kernel_ops = self.kernel_size
         | 
| 134 | 
            +
                    bias_ops = 0
         | 
| 135 | 
            +
                    params = output_channels * (kernel_ops + bias_ops)
         | 
| 136 | 
            +
                    flops = batch_size * params * output_length
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    list_pooling2d.append(flops)
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                def foo(net):
         | 
| 141 | 
            +
                    childrens = list(net.children())
         | 
| 142 | 
            +
                    if not childrens:
         | 
| 143 | 
            +
                        if isinstance(net, nn.Conv2d):
         | 
| 144 | 
            +
                            net.register_forward_hook(conv2d_hook)
         | 
| 145 | 
            +
                        elif isinstance(net, nn.ConvTranspose2d):
         | 
| 146 | 
            +
                            net.register_forward_hook(conv2d_hook)
         | 
| 147 | 
            +
                        elif isinstance(net, nn.Conv1d):
         | 
| 148 | 
            +
                            net.register_forward_hook(conv1d_hook)
         | 
| 149 | 
            +
                        elif isinstance(net, nn.Linear):
         | 
| 150 | 
            +
                            net.register_forward_hook(linear_hook)
         | 
| 151 | 
            +
                        elif isinstance(net, nn.BatchNorm2d) or isinstance(net, nn.BatchNorm1d):
         | 
| 152 | 
            +
                            net.register_forward_hook(bn_hook)
         | 
| 153 | 
            +
                        elif isinstance(net, nn.ReLU):
         | 
| 154 | 
            +
                            net.register_forward_hook(relu_hook)
         | 
| 155 | 
            +
                        elif isinstance(net, nn.AvgPool2d) or isinstance(net, nn.MaxPool2d):
         | 
| 156 | 
            +
                            net.register_forward_hook(pooling2d_hook)
         | 
| 157 | 
            +
                        elif isinstance(net, nn.AvgPool1d) or isinstance(net, nn.MaxPool1d):
         | 
| 158 | 
            +
                            net.register_forward_hook(pooling1d_hook)
         | 
| 159 | 
            +
                        else:
         | 
| 160 | 
            +
                            print("Warning: flop of module {} is not counted!".format(net))
         | 
| 161 | 
            +
                        return
         | 
| 162 | 
            +
                    for c in childrens:
         | 
| 163 | 
            +
                        foo(c)
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                foo(model)
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                input = torch.rand(1, audio_length, 2)
         | 
| 168 | 
            +
                out = model(input)
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                total_flops = (
         | 
| 171 | 
            +
                    sum(list_conv2d)
         | 
| 172 | 
            +
                    + sum(list_conv1d)
         | 
| 173 | 
            +
                    + sum(list_linear)
         | 
| 174 | 
            +
                    + sum(list_bn)
         | 
| 175 | 
            +
                    + sum(list_relu)
         | 
| 176 | 
            +
                    + sum(list_pooling2d)
         | 
| 177 | 
            +
                    + sum(list_pooling1d)
         | 
| 178 | 
            +
                )
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                return total_flops
         | 
    	
        voicefixer/tools/random_.py
    ADDED
    
    | @@ -0,0 +1,52 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import random
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            RANDOM_RESOLUTION = 2**31
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            def random_torch(high, to_int=True):
         | 
| 8 | 
            +
                if to_int:
         | 
| 9 | 
            +
                    return int((torch.rand(1)) * high)  # do not use numpy.random.random
         | 
| 10 | 
            +
                else:
         | 
| 11 | 
            +
                    return (torch.rand(1)) * high  # do not use numpy.random.random
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            def shuffle_torch(list):
         | 
| 15 | 
            +
                length = len(list)
         | 
| 16 | 
            +
                res = []
         | 
| 17 | 
            +
                order = torch.randperm(length)
         | 
| 18 | 
            +
                for each in order:
         | 
| 19 | 
            +
                    res.append(list[each])
         | 
| 20 | 
            +
                assert len(list) == len(res)
         | 
| 21 | 
            +
                return res
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            def random_choose_list(list):
         | 
| 25 | 
            +
                num = int(uniform_torch(0, len(list)))
         | 
| 26 | 
            +
                return list[num]
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            def normal_torch(mean=0, segma=1):
         | 
| 30 | 
            +
                return float(torch.normal(mean=mean, std=torch.Tensor([segma]))[0])
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            def uniform_torch(lower, upper):
         | 
| 34 | 
            +
                if abs(lower - upper) < 1e-5:
         | 
| 35 | 
            +
                    return upper
         | 
| 36 | 
            +
                return (upper - lower) * torch.rand(1) + lower
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            def random_key(keys: list, weights: list):
         | 
| 40 | 
            +
                return random.choices(keys, weights=weights)[0]
         | 
| 41 | 
            +
             | 
| 42 | 
            +
             | 
| 43 | 
            +
            def random_select(probs):
         | 
| 44 | 
            +
                res = []
         | 
| 45 | 
            +
                chance = random_torch(RANDOM_RESOLUTION)
         | 
| 46 | 
            +
                threshold = None
         | 
| 47 | 
            +
                for prob in probs:
         | 
| 48 | 
            +
                    # if(threshold is None):threshold=prob
         | 
| 49 | 
            +
                    # else:threshold*=prob
         | 
| 50 | 
            +
                    threshold = prob
         | 
| 51 | 
            +
                    res.append(chance < threshold * RANDOM_RESOLUTION)
         | 
| 52 | 
            +
                return res, chance
         | 
    	
        voicefixer/tools/wav.py
    ADDED
    
    | @@ -0,0 +1,242 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import wave
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            import scipy.signal as signal
         | 
| 5 | 
            +
            import soundfile as sf
         | 
| 6 | 
            +
            import librosa
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            def save_wave(frames: np.ndarray, fname, sample_rate=44100):
         | 
| 10 | 
            +
                shape = list(frames.shape)
         | 
| 11 | 
            +
                if len(shape) == 1:
         | 
| 12 | 
            +
                    frames = frames[..., None]
         | 
| 13 | 
            +
                in_samples, in_channels = shape[-2], shape[-1]
         | 
| 14 | 
            +
                if in_channels >= 3:
         | 
| 15 | 
            +
                    if len(shape) == 2:
         | 
| 16 | 
            +
                        frames = np.transpose(frames, (1, 0))
         | 
| 17 | 
            +
                    elif len(shape) == 3:
         | 
| 18 | 
            +
                        frames = np.transpose(frames, (0, 2, 1))
         | 
| 19 | 
            +
                    msg = (
         | 
| 20 | 
            +
                        "Warning: Save audio with "
         | 
| 21 | 
            +
                        + str(in_channels)
         | 
| 22 | 
            +
                        + " channels, save permute audio with shape "
         | 
| 23 | 
            +
                        + str(list(frames.shape))
         | 
| 24 | 
            +
                        + " please check if it's correct."
         | 
| 25 | 
            +
                    )
         | 
| 26 | 
            +
                    # print(msg)
         | 
| 27 | 
            +
                if (
         | 
| 28 | 
            +
                    np.max(frames) <= 1
         | 
| 29 | 
            +
                    and frames.dtype == np.float32
         | 
| 30 | 
            +
                    or frames.dtype == np.float16
         | 
| 31 | 
            +
                    or frames.dtype == np.float64
         | 
| 32 | 
            +
                ):
         | 
| 33 | 
            +
                    frames *= 2**15
         | 
| 34 | 
            +
                frames = frames.astype(np.short)
         | 
| 35 | 
            +
                if len(frames.shape) >= 3:
         | 
| 36 | 
            +
                    frames = frames[0, ...]
         | 
| 37 | 
            +
                sf.write(fname, frames, samplerate=sample_rate)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
             | 
| 40 | 
            +
            def constrain_length(chunk, length):
         | 
| 41 | 
            +
                frames_length = chunk.shape[0]
         | 
| 42 | 
            +
                if frames_length == length:
         | 
| 43 | 
            +
                    return chunk
         | 
| 44 | 
            +
                elif frames_length < length:
         | 
| 45 | 
            +
                    return np.pad(chunk, ((0, int(length - frames_length)), (0, 0)), "constant")
         | 
| 46 | 
            +
                else:
         | 
| 47 | 
            +
                    return chunk[:length, ...]
         | 
| 48 | 
            +
             | 
| 49 | 
            +
             | 
| 50 | 
            +
            def random_chunk_wav_file(fname, chunk_length):
         | 
| 51 | 
            +
                """
         | 
| 52 | 
            +
                fname: path to wav file
         | 
| 53 | 
            +
                chunk_length: frame length in seconds
         | 
| 54 | 
            +
                """
         | 
| 55 | 
            +
                with wave.open(fname) as f:
         | 
| 56 | 
            +
                    params = f.getparams()
         | 
| 57 | 
            +
                    duration = params[3] / params[2]
         | 
| 58 | 
            +
                    sample_rate = params[2]
         | 
| 59 | 
            +
                    sample_length = params[3]
         | 
| 60 | 
            +
                    if duration < chunk_length or abs(duration - chunk_length) < 1e-4:
         | 
| 61 | 
            +
                        frames = read_wave(fname, sample_rate)
         | 
| 62 | 
            +
                        return frames, duration, sample_rate  # [-1,1]
         | 
| 63 | 
            +
                    else:
         | 
| 64 | 
            +
                        # Random trunk
         | 
| 65 | 
            +
                        random_starts = np.random.randint(
         | 
| 66 | 
            +
                            0, sample_length - sample_rate * chunk_length
         | 
| 67 | 
            +
                        )
         | 
| 68 | 
            +
                        random_end = random_starts + sample_rate * chunk_length
         | 
| 69 | 
            +
                        random_starts, random_end = (
         | 
| 70 | 
            +
                            random_starts / sample_rate,
         | 
| 71 | 
            +
                            random_end / sample_rate,
         | 
| 72 | 
            +
                        )
         | 
| 73 | 
            +
                        random_starts, random_end = random_starts / duration, random_end / duration
         | 
| 74 | 
            +
                        frames = read_wave(
         | 
| 75 | 
            +
                            fname, sample_rate, portion_start=random_starts, portion_end=random_end
         | 
| 76 | 
            +
                        )
         | 
| 77 | 
            +
                        frames = constrain_length(frames, length=int(chunk_length * sample_rate))
         | 
| 78 | 
            +
                        return frames, chunk_length, sample_rate
         | 
| 79 | 
            +
             | 
| 80 | 
            +
             | 
| 81 | 
            +
            def random_chunk_wav_file_v2(fname, chunk_length, random_starts=None, random_end=None):
         | 
| 82 | 
            +
                """
         | 
| 83 | 
            +
                fname: path to wav file
         | 
| 84 | 
            +
                chunk_length: frame length in seconds
         | 
| 85 | 
            +
                """
         | 
| 86 | 
            +
                with wave.open(fname) as f:
         | 
| 87 | 
            +
                    params = f.getparams()
         | 
| 88 | 
            +
                    duration = params[3] / params[2]
         | 
| 89 | 
            +
                    sample_rate = params[2]
         | 
| 90 | 
            +
                    sample_length = params[3]
         | 
| 91 | 
            +
                    if duration < chunk_length or abs(duration - chunk_length) < 1e-4:
         | 
| 92 | 
            +
                        frames = read_wave(fname, sample_rate)
         | 
| 93 | 
            +
                        return frames, duration, sample_rate  # [-1,1]
         | 
| 94 | 
            +
                    else:
         | 
| 95 | 
            +
                        # Random trunk
         | 
| 96 | 
            +
                        if random_starts is None and random_end is None:
         | 
| 97 | 
            +
                            random_starts = np.random.randint(
         | 
| 98 | 
            +
                                0, sample_length - sample_rate * chunk_length
         | 
| 99 | 
            +
                            )
         | 
| 100 | 
            +
                            random_end = random_starts + sample_rate * chunk_length
         | 
| 101 | 
            +
                            random_starts, random_end = (
         | 
| 102 | 
            +
                                random_starts / sample_rate,
         | 
| 103 | 
            +
                                random_end / sample_rate,
         | 
| 104 | 
            +
                            )
         | 
| 105 | 
            +
                            random_starts, random_end = (
         | 
| 106 | 
            +
                                random_starts / duration,
         | 
| 107 | 
            +
                                random_end / duration,
         | 
| 108 | 
            +
                            )
         | 
| 109 | 
            +
                        frames = read_wave(
         | 
| 110 | 
            +
                            fname, sample_rate, portion_start=random_starts, portion_end=random_end
         | 
| 111 | 
            +
                        )
         | 
| 112 | 
            +
                        frames = constrain_length(frames, length=int(chunk_length * sample_rate))
         | 
| 113 | 
            +
                        return frames, chunk_length, sample_rate, random_starts, random_end
         | 
| 114 | 
            +
             | 
| 115 | 
            +
             | 
| 116 | 
            +
            def read_wave(
         | 
| 117 | 
            +
                fname,
         | 
| 118 | 
            +
                sample_rate,
         | 
| 119 | 
            +
                portion_start=0,
         | 
| 120 | 
            +
                portion_end=1,
         | 
| 121 | 
            +
            ):  # Whether you want raw bytes
         | 
| 122 | 
            +
                """
         | 
| 123 | 
            +
                :param fname: wav file path
         | 
| 124 | 
            +
                :param sample_rate:
         | 
| 125 | 
            +
                :param portion_start:
         | 
| 126 | 
            +
                :param portion_end:
         | 
| 127 | 
            +
                :return: [sample, channels]
         | 
| 128 | 
            +
                """
         | 
| 129 | 
            +
                # sr = get_sample_rate(fname)
         | 
| 130 | 
            +
                # if(sr != sample_rate):
         | 
| 131 | 
            +
                #     print("Warning: Sample rate not match, may lead to unexpected behavior.")
         | 
| 132 | 
            +
                if portion_end > 1 and portion_end < 1.1:
         | 
| 133 | 
            +
                    portion_end = 1
         | 
| 134 | 
            +
                if portion_end != 1:
         | 
| 135 | 
            +
                    duration = get_duration(fname)
         | 
| 136 | 
            +
                    wav, _ = librosa.load(
         | 
| 137 | 
            +
                        fname,
         | 
| 138 | 
            +
                        sr=sample_rate,
         | 
| 139 | 
            +
                        offset=portion_start * duration,
         | 
| 140 | 
            +
                        duration=(portion_end - portion_start) * duration,
         | 
| 141 | 
            +
                        mono=False,
         | 
| 142 | 
            +
                    )
         | 
| 143 | 
            +
                else:
         | 
| 144 | 
            +
                    wav, _ = librosa.load(fname, sr=sample_rate, mono=False)
         | 
| 145 | 
            +
                if len(list(wav.shape)) == 1:
         | 
| 146 | 
            +
                    wav = wav[..., None]
         | 
| 147 | 
            +
                else:
         | 
| 148 | 
            +
                    wav = np.transpose(wav, (1, 0))
         | 
| 149 | 
            +
                return wav
         | 
| 150 | 
            +
             | 
| 151 | 
            +
             | 
| 152 | 
            +
            def get_channels_sampwidth_and_sample_rate(fname):
         | 
| 153 | 
            +
                with wave.open(fname) as f:
         | 
| 154 | 
            +
                    params = f.getparams()
         | 
| 155 | 
            +
                return (
         | 
| 156 | 
            +
                    params[0],
         | 
| 157 | 
            +
                    params[1],
         | 
| 158 | 
            +
                    params[2],
         | 
| 159 | 
            +
                )  # == (2,2,44100),(params[0],params[1],params[2])
         | 
| 160 | 
            +
             | 
| 161 | 
            +
             | 
| 162 | 
            +
            def get_channels(fname):
         | 
| 163 | 
            +
                with wave.open(fname) as f:
         | 
| 164 | 
            +
                    params = f.getparams()
         | 
| 165 | 
            +
                return params[0]
         | 
| 166 | 
            +
             | 
| 167 | 
            +
             | 
| 168 | 
            +
            def get_sample_rate(fname):
         | 
| 169 | 
            +
                with wave.open(fname) as f:
         | 
| 170 | 
            +
                    params = f.getparams()
         | 
| 171 | 
            +
                return params[2]
         | 
| 172 | 
            +
             | 
| 173 | 
            +
             | 
| 174 | 
            +
            def get_duration(fname):
         | 
| 175 | 
            +
                with wave.open(fname) as f:
         | 
| 176 | 
            +
                    params = f.getparams()
         | 
| 177 | 
            +
                return params[3] / params[2]
         | 
| 178 | 
            +
             | 
| 179 | 
            +
             | 
| 180 | 
            +
            def get_framesLength(fname):
         | 
| 181 | 
            +
                with wave.open(fname) as f:
         | 
| 182 | 
            +
                    params = f.getparams()
         | 
| 183 | 
            +
                return params[3]
         | 
| 184 | 
            +
             | 
| 185 | 
            +
             | 
| 186 | 
            +
            def restore_wave(zxx):
         | 
| 187 | 
            +
                _, w = signal.istft(zxx)
         | 
| 188 | 
            +
                return w
         | 
| 189 | 
            +
             | 
| 190 | 
            +
             | 
| 191 | 
            +
            def calculate_total_times(dir):
         | 
| 192 | 
            +
                total = 0
         | 
| 193 | 
            +
                for each in os.listdir(dir):
         | 
| 194 | 
            +
                    fname = os.path.join(dir, each)
         | 
| 195 | 
            +
                    try:
         | 
| 196 | 
            +
                        duration = get_duration(fname)
         | 
| 197 | 
            +
                    except:
         | 
| 198 | 
            +
                        print(fname)
         | 
| 199 | 
            +
                    total += duration
         | 
| 200 | 
            +
                return total
         | 
| 201 | 
            +
             | 
| 202 | 
            +
             | 
| 203 | 
            +
            def filter(pth):
         | 
| 204 | 
            +
                global dic
         | 
| 205 | 
            +
                temp = []
         | 
| 206 | 
            +
                for each in os.listdir(pth):
         | 
| 207 | 
            +
                    temp.append(os.path.join(pth, each))
         | 
| 208 | 
            +
                for each in temp:
         | 
| 209 | 
            +
                    sr = get_sample_rate(each)
         | 
| 210 | 
            +
                    if sr not in dic.keys():
         | 
| 211 | 
            +
                        dic[sr] = []
         | 
| 212 | 
            +
                    dic[sr].append(each)
         | 
| 213 | 
            +
                for each in dic[16000]:
         | 
| 214 | 
            +
                    # print(each)
         | 
| 215 | 
            +
                    pass
         | 
| 216 | 
            +
                print(dic.keys())
         | 
| 217 | 
            +
                for each in list(dic.keys()):
         | 
| 218 | 
            +
                    print(each, len(dic[each]))
         | 
| 219 | 
            +
             | 
| 220 | 
            +
             | 
| 221 | 
            +
            if __name__ == "__main__":
         | 
| 222 | 
            +
                path = "/Users/admin/Desktop/p376_025.wav"
         | 
| 223 | 
            +
                stereo = "/Users/admin/Desktop/vocals.wav"
         | 
| 224 | 
            +
                path_16 = "/Users/admin/Desktop/SI869.WAV.wav"
         | 
| 225 | 
            +
                import time
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                start = time.time()
         | 
| 228 | 
            +
                for i in range(1000):
         | 
| 229 | 
            +
                    frames, duration, sample_rate = random_chunk_wav_file(stereo, chunk_length=3.0)
         | 
| 230 | 
            +
                    print(frames.shape, np.max(frames))
         | 
| 231 | 
            +
                    save_wave(frames, "stero.wav", sample_rate=44100)
         | 
| 232 | 
            +
                    frames, duration, sample_rate = random_chunk_wav_file(path, chunk_length=3.0)
         | 
| 233 | 
            +
                    print(frames.shape, np.max(frames))
         | 
| 234 | 
            +
                    save_wave(frames, "mono.wav", sample_rate=44100)
         | 
| 235 | 
            +
                    frames, duration, sample_rate = random_chunk_wav_file(path_16, chunk_length=3.0)
         | 
| 236 | 
            +
                    print(frames.shape, np.max(frames))
         | 
| 237 | 
            +
                    save_wave(frames, "16.wav", sample_rate=16000)
         | 
| 238 | 
            +
                print(time.time() - start)
         | 
| 239 | 
            +
                # frames = read_wave(stereo,sample_rate=44100)
         | 
| 240 | 
            +
                print(frames.shape)
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                print(frames)
         | 
    	
        voicefixer/vocoder/__init__.py
    ADDED
    
    | @@ -0,0 +1,30 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/env python
         | 
| 2 | 
            +
            # -*- encoding: utf-8 -*-
         | 
| 3 | 
            +
            """
         | 
| 4 | 
            +
            @File    :   __init__.py.py    
         | 
| 5 | 
            +
            @Contact :   [email protected]
         | 
| 6 | 
            +
            @License :   (C)Copyright 2020-2100
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            @Modify Time      @Author    @Version    @Desciption
         | 
| 9 | 
            +
            ------------      -------    --------    -----------
         | 
| 10 | 
            +
            9/14/21 1:00 AM   Haohe Liu      1.0         None
         | 
| 11 | 
            +
            """
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            import os
         | 
| 14 | 
            +
            from voicefixer.vocoder.config import Config
         | 
| 15 | 
            +
            import urllib.request
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            if not os.path.exists(Config.ckpt):
         | 
| 18 | 
            +
                os.makedirs(os.path.dirname(Config.ckpt), exist_ok=True)
         | 
| 19 | 
            +
                print("Downloading the weight of neural vocoder: TFGAN")
         | 
| 20 | 
            +
                urllib.request.urlretrieve(
         | 
| 21 | 
            +
                    "https://zenodo.org/record/5469951/files/model.ckpt-1490000_trimed.pt?download=1",
         | 
| 22 | 
            +
                    Config.ckpt,
         | 
| 23 | 
            +
                )
         | 
| 24 | 
            +
                print(
         | 
| 25 | 
            +
                    "Weights downloaded in: {} Size: {}".format(
         | 
| 26 | 
            +
                        Config.ckpt, os.path.getsize(Config.ckpt)
         | 
| 27 | 
            +
                    )
         | 
| 28 | 
            +
                )
         | 
| 29 | 
            +
                # cmd = "wget https://zenodo.org/record/5469951/files/model.ckpt-1490000_trimed.pt?download=1 -O " + Config.ckpt
         | 
| 30 | 
            +
                # os.system(cmd)
         | 
    	
        voicefixer/vocoder/base.py
    ADDED
    
    | @@ -0,0 +1,86 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from voicefixer.vocoder.model.generator import Generator
         | 
| 2 | 
            +
            from voicefixer.tools.wav import read_wave, save_wave
         | 
| 3 | 
            +
            from voicefixer.tools.pytorch_util import *
         | 
| 4 | 
            +
            from voicefixer.vocoder.model.util import *
         | 
| 5 | 
            +
            from voicefixer.vocoder.config import Config
         | 
| 6 | 
            +
            import os
         | 
| 7 | 
            +
            import numpy as np
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            class Vocoder(nn.Module):
         | 
| 11 | 
            +
                def __init__(self, sample_rate):
         | 
| 12 | 
            +
                    super(Vocoder, self).__init__()
         | 
| 13 | 
            +
                    Config.refresh(sample_rate)
         | 
| 14 | 
            +
                    self.rate = sample_rate
         | 
| 15 | 
            +
                    if(not os.path.exists(Config.ckpt)):
         | 
| 16 | 
            +
                        raise RuntimeError("Error 1: The checkpoint for synthesis module / vocoder (model.ckpt-1490000_trimed) is not found in ~/.cache/voicefixer/synthesis_module/44100. \
         | 
| 17 | 
            +
                                            By default the checkpoint should be download automatically by this program. Something bad may happened. Apologies for the inconvenience.\
         | 
| 18 | 
            +
                                            But don't worry! Alternatively you can download it directly from Zenodo: https://zenodo.org/record/5600188/files/model.ckpt-1490000_trimed.pt?download=1")
         | 
| 19 | 
            +
                    self._load_pretrain(Config.ckpt)
         | 
| 20 | 
            +
                    self.weight_torch = Config.get_mel_weight_torch(percent=1.0)[
         | 
| 21 | 
            +
                        None, None, None, ...
         | 
| 22 | 
            +
                    ]
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                def _load_pretrain(self, pth):
         | 
| 25 | 
            +
                    self.model = Generator(Config.cin_channels)
         | 
| 26 | 
            +
                    checkpoint = load_checkpoint(pth, torch.device("cpu"))
         | 
| 27 | 
            +
                    load_try(checkpoint["generator"], self.model)
         | 
| 28 | 
            +
                    self.model.eval()
         | 
| 29 | 
            +
                    self.model.remove_weight_norm()
         | 
| 30 | 
            +
                    self.model.remove_weight_norm()
         | 
| 31 | 
            +
                    for p in self.model.parameters():
         | 
| 32 | 
            +
                        p.requires_grad = False
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                # def vocoder_mel_npy(self, mel, save_dir, sample_rate, gain):
         | 
| 35 | 
            +
                #     mel = mel / Config.get_mel_weight(percent=gain)[...,None]
         | 
| 36 | 
            +
                #     mel = normalize(amp_to_db(np.abs(mel)) - 20)
         | 
| 37 | 
            +
                #     mel = pre(np.transpose(mel, (1, 0)))
         | 
| 38 | 
            +
                #     with torch.no_grad():
         | 
| 39 | 
            +
                #         wav_re = self.model(mel) # torch.Size([1, 1, 104076])
         | 
| 40 | 
            +
                #         save_wave(tensor2numpy(wav_re)*2**15,save_dir,sample_rate=sample_rate)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                def forward(self, mel, cuda=False):
         | 
| 43 | 
            +
                    """
         | 
| 44 | 
            +
                    :param non normalized mel spectrogram: [batchsize, 1, t-steps, n_mel]
         | 
| 45 | 
            +
                    :return: [batchsize, 1, samples]
         | 
| 46 | 
            +
                    """
         | 
| 47 | 
            +
                    assert mel.size()[-1] == 128
         | 
| 48 | 
            +
                    check_cuda_availability(cuda=cuda)
         | 
| 49 | 
            +
                    self.model = try_tensor_cuda(self.model, cuda=cuda)
         | 
| 50 | 
            +
                    mel = try_tensor_cuda(mel, cuda=cuda)
         | 
| 51 | 
            +
                    self.weight_torch = self.weight_torch.type_as(mel)
         | 
| 52 | 
            +
                    mel = mel / self.weight_torch
         | 
| 53 | 
            +
                    mel = tr_normalize(tr_amp_to_db(torch.abs(mel)) - 20.0)
         | 
| 54 | 
            +
                    mel = tr_pre(mel[:, 0, ...])
         | 
| 55 | 
            +
                    wav_re = self.model(mel)
         | 
| 56 | 
            +
                    return wav_re
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                def oracle(self, fpath, out_path, cuda=False):
         | 
| 59 | 
            +
                    check_cuda_availability(cuda=cuda)
         | 
| 60 | 
            +
                    self.model = try_tensor_cuda(self.model, cuda=cuda)
         | 
| 61 | 
            +
                    wav = read_wave(fpath, sample_rate=self.rate)[..., 0]
         | 
| 62 | 
            +
                    wav = wav / np.max(np.abs(wav))
         | 
| 63 | 
            +
                    stft = np.abs(
         | 
| 64 | 
            +
                        librosa.stft(
         | 
| 65 | 
            +
                            wav,
         | 
| 66 | 
            +
                            hop_length=Config.hop_length,
         | 
| 67 | 
            +
                            win_length=Config.win_size,
         | 
| 68 | 
            +
                            n_fft=Config.n_fft,
         | 
| 69 | 
            +
                        )
         | 
| 70 | 
            +
                    )
         | 
| 71 | 
            +
                    mel = linear_to_mel(stft)
         | 
| 72 | 
            +
                    mel = normalize(amp_to_db(np.abs(mel)) - 20)
         | 
| 73 | 
            +
                    mel = pre(np.transpose(mel, (1, 0)))
         | 
| 74 | 
            +
                    mel = try_tensor_cuda(mel, cuda=cuda)
         | 
| 75 | 
            +
                    with torch.no_grad():
         | 
| 76 | 
            +
                        wav_re = self.model(mel)
         | 
| 77 | 
            +
                        save_wave(tensor2numpy(wav_re * 2**15), out_path, sample_rate=self.rate)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
             | 
| 80 | 
            +
            if __name__ == "__main__":
         | 
| 81 | 
            +
                model = Vocoder(sample_rate=44100)
         | 
| 82 | 
            +
                print(model.device)
         | 
| 83 | 
            +
                # model.load_pretrain(Config.ckpt)
         | 
| 84 | 
            +
                # model.oracle(path="/Users/liuhaohe/Desktop/test.wav",
         | 
| 85 | 
            +
                #         sample_rate=44100,
         | 
| 86 | 
            +
                #         save_dir="/Users/liuhaohe/Desktop/test_vocoder.wav")
         | 
    	
        voicefixer/vocoder/config.py
    ADDED
    
    | @@ -0,0 +1,316 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            from voicefixer.tools.path import root_path
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            class Config:
         | 
| 8 | 
            +
                @classmethod
         | 
| 9 | 
            +
                def refresh(cls, sr):
         | 
| 10 | 
            +
                    if sr == 44100:
         | 
| 11 | 
            +
                        Config.ckpt = os.path.join(
         | 
| 12 | 
            +
                            os.path.expanduser("~"),
         | 
| 13 | 
            +
                            ".cache/voicefixer/synthesis_module/44100/model.ckpt-1490000_trimed.pt",
         | 
| 14 | 
            +
                        )
         | 
| 15 | 
            +
                        Config.cond_channels = 512
         | 
| 16 | 
            +
                        Config.m_channels = 768
         | 
| 17 | 
            +
                        Config.resstack_depth = [8, 8, 8, 8]
         | 
| 18 | 
            +
                        Config.channels = 1024
         | 
| 19 | 
            +
                        Config.cin_channels = 128
         | 
| 20 | 
            +
                        Config.upsample_scales = [7, 7, 3, 3]
         | 
| 21 | 
            +
                        Config.num_mels = 128
         | 
| 22 | 
            +
                        Config.n_fft = 2048
         | 
| 23 | 
            +
                        Config.hop_length = 441
         | 
| 24 | 
            +
                        Config.sample_rate = 44100
         | 
| 25 | 
            +
                        Config.fmax = 22000
         | 
| 26 | 
            +
                        Config.mel_win = 128
         | 
| 27 | 
            +
                        Config.local_condition_dim = 128
         | 
| 28 | 
            +
                    else:
         | 
| 29 | 
            +
                        raise RuntimeError(
         | 
| 30 | 
            +
                            "Error: Vocoder currently only support 44100 samplerate."
         | 
| 31 | 
            +
                        )
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                ckpt = os.path.join(
         | 
| 34 | 
            +
                    os.path.expanduser("~"),
         | 
| 35 | 
            +
                    ".cache/voicefixer/synthesis_module/44100/model.ckpt-1490000_trimed.pt",
         | 
| 36 | 
            +
                )
         | 
| 37 | 
            +
                m_channels = 384
         | 
| 38 | 
            +
                bits = 10
         | 
| 39 | 
            +
                opt = "Ralamb"
         | 
| 40 | 
            +
                cond_channels = 256
         | 
| 41 | 
            +
                clip = 0.5
         | 
| 42 | 
            +
                num_bands = 1
         | 
| 43 | 
            +
                cin_channels = 128
         | 
| 44 | 
            +
                upsample_scales = [7, 7, 3, 3]
         | 
| 45 | 
            +
                filterbands = "test/filterbanks_4bands.dat"
         | 
| 46 | 
            +
                ##For inference
         | 
| 47 | 
            +
                tag = ""
         | 
| 48 | 
            +
                min_db = -115
         | 
| 49 | 
            +
                num_mels = 128
         | 
| 50 | 
            +
                n_fft = 2048
         | 
| 51 | 
            +
                hop_length = 441
         | 
| 52 | 
            +
                win_size = None
         | 
| 53 | 
            +
                sample_rate = 44100
         | 
| 54 | 
            +
                frame_shift_ms = None
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                trim_fft_size = 512
         | 
| 57 | 
            +
                trim_hop_size = 128
         | 
| 58 | 
            +
                trim_top_db = 23
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                signal_normalization = True
         | 
| 61 | 
            +
                allow_clipping_in_normalization = True
         | 
| 62 | 
            +
                symmetric_mels = True
         | 
| 63 | 
            +
                max_abs_value = 4.0
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                preemphasis = 0.85
         | 
| 66 | 
            +
                min_level_db = -100
         | 
| 67 | 
            +
                ref_level_db = 20
         | 
| 68 | 
            +
                fmin = 50
         | 
| 69 | 
            +
                fmax = 22000
         | 
| 70 | 
            +
                power = 1.5
         | 
| 71 | 
            +
                griffin_lim_iters = 60
         | 
| 72 | 
            +
                rescale = False
         | 
| 73 | 
            +
                rescaling_max = 0.95
         | 
| 74 | 
            +
                trim_silence = False
         | 
| 75 | 
            +
                clip_mels_length = True
         | 
| 76 | 
            +
                max_mel_frames = 2000
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                mel_win = 128
         | 
| 79 | 
            +
                batch_size = 24
         | 
| 80 | 
            +
                g_learning_rate = 0.001
         | 
| 81 | 
            +
                d_learning_rate = 0.001
         | 
| 82 | 
            +
                warmup_steps = 100000
         | 
| 83 | 
            +
                decay_learning_rate = 0.5
         | 
| 84 | 
            +
                exponential_moving_average = True
         | 
| 85 | 
            +
                ema_decay = 0.99
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                reset_opt = False
         | 
| 88 | 
            +
                reset_g_opt = False
         | 
| 89 | 
            +
                reset_d_opt = False
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                local_condition_dim = 128
         | 
| 92 | 
            +
                lambda_update_G = 1
         | 
| 93 | 
            +
                multiscale_D = 3
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                lambda_adv = 4.0
         | 
| 96 | 
            +
                lambda_fm_loss = 0.0
         | 
| 97 | 
            +
                lambda_sc_loss = 5.0
         | 
| 98 | 
            +
                lambda_mag_loss = 5.0
         | 
| 99 | 
            +
                lambda_mel_loss = 50.0
         | 
| 100 | 
            +
                use_mle_loss = False
         | 
| 101 | 
            +
                lambda_mle_loss = 5.0
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                lambda_freq_loss = 2.0
         | 
| 104 | 
            +
                lambda_energy_loss = 100.0
         | 
| 105 | 
            +
                lambda_t_loss = 200.0
         | 
| 106 | 
            +
                lambda_phase_loss = 100.0
         | 
| 107 | 
            +
                lambda_f0_loss = 1.0
         | 
| 108 | 
            +
                use_elu = False
         | 
| 109 | 
            +
                de_preem = False  # train
         | 
| 110 | 
            +
                up_org = False
         | 
| 111 | 
            +
                use_one = True
         | 
| 112 | 
            +
                use_small_D = False
         | 
| 113 | 
            +
                use_condnet = True
         | 
| 114 | 
            +
                use_depreem = False  # inference
         | 
| 115 | 
            +
                use_msd = False
         | 
| 116 | 
            +
                model_type = "tfgan"  # or bytewave, frame level vocoder using istft
         | 
| 117 | 
            +
                use_hjcud = False
         | 
| 118 | 
            +
                no_skip = False
         | 
| 119 | 
            +
                out_channels = 1
         | 
| 120 | 
            +
                use_postnet = False  # wn in postnet
         | 
| 121 | 
            +
                use_wn = False  # wn in resstack
         | 
| 122 | 
            +
                up_type = "transpose"
         | 
| 123 | 
            +
                use_smooth = False
         | 
| 124 | 
            +
                use_drop = False
         | 
| 125 | 
            +
                use_shift_scale = False
         | 
| 126 | 
            +
                use_gcnn = False
         | 
| 127 | 
            +
                resstack_depth = [6, 6, 6, 6]
         | 
| 128 | 
            +
                kernel_size = [3, 3, 3, 3]
         | 
| 129 | 
            +
                channels = 512
         | 
| 130 | 
            +
                use_f0_loss = False
         | 
| 131 | 
            +
                use_sine = False
         | 
| 132 | 
            +
                use_cond_rnn = False
         | 
| 133 | 
            +
                use_rnn = False
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                f0_step = 120
         | 
| 136 | 
            +
                use_lowfreq_loss = False
         | 
| 137 | 
            +
                lambda_lowfreq_loss = 1.0
         | 
| 138 | 
            +
                use_film = False
         | 
| 139 | 
            +
                use_mb_mr_gan = False
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                use_mssl = False
         | 
| 142 | 
            +
                use_ml_gan = False
         | 
| 143 | 
            +
                use_mb_gan = True
         | 
| 144 | 
            +
                use_mpd = False
         | 
| 145 | 
            +
                use_spec_gan = True
         | 
| 146 | 
            +
                use_rwd = False
         | 
| 147 | 
            +
                use_mr_gan = True
         | 
| 148 | 
            +
                use_pqmf_rwd = False
         | 
| 149 | 
            +
                no_sine = False
         | 
| 150 | 
            +
                use_frame_mask = False
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                lambda_var_loss = 0.0
         | 
| 153 | 
            +
                discriminator_train_start_steps = 40000  # 80k
         | 
| 154 | 
            +
                aux_d_train_start_steps = 40000  # 100k
         | 
| 155 | 
            +
                rescale_out = 0.40
         | 
| 156 | 
            +
                use_dist = True
         | 
| 157 | 
            +
                dist_backend = "nccl"
         | 
| 158 | 
            +
                dist_url = "tcp://localhost:12345"
         | 
| 159 | 
            +
                world_size = 1
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                mel_weight_torch = torch.tensor(
         | 
| 162 | 
            +
                    [
         | 
| 163 | 
            +
                        19.40951426,
         | 
| 164 | 
            +
                        19.94047336,
         | 
| 165 | 
            +
                        20.4859038,
         | 
| 166 | 
            +
                        21.04629067,
         | 
| 167 | 
            +
                        21.62194148,
         | 
| 168 | 
            +
                        22.21335214,
         | 
| 169 | 
            +
                        22.8210215,
         | 
| 170 | 
            +
                        23.44529231,
         | 
| 171 | 
            +
                        24.08660962,
         | 
| 172 | 
            +
                        24.74541882,
         | 
| 173 | 
            +
                        25.42234287,
         | 
| 174 | 
            +
                        26.11770576,
         | 
| 175 | 
            +
                        26.83212784,
         | 
| 176 | 
            +
                        27.56615283,
         | 
| 177 | 
            +
                        28.32007747,
         | 
| 178 | 
            +
                        29.0947679,
         | 
| 179 | 
            +
                        29.89060111,
         | 
| 180 | 
            +
                        30.70832636,
         | 
| 181 | 
            +
                        31.54828121,
         | 
| 182 | 
            +
                        32.41121487,
         | 
| 183 | 
            +
                        33.29780773,
         | 
| 184 | 
            +
                        34.20865341,
         | 
| 185 | 
            +
                        35.14437675,
         | 
| 186 | 
            +
                        36.1056621,
         | 
| 187 | 
            +
                        37.09332763,
         | 
| 188 | 
            +
                        38.10795802,
         | 
| 189 | 
            +
                        39.15039691,
         | 
| 190 | 
            +
                        40.22119881,
         | 
| 191 | 
            +
                        41.32154931,
         | 
| 192 | 
            +
                        42.45172373,
         | 
| 193 | 
            +
                        43.61293329,
         | 
| 194 | 
            +
                        44.80609379,
         | 
| 195 | 
            +
                        46.031602,
         | 
| 196 | 
            +
                        47.29070223,
         | 
| 197 | 
            +
                        48.58427549,
         | 
| 198 | 
            +
                        49.91327905,
         | 
| 199 | 
            +
                        51.27863232,
         | 
| 200 | 
            +
                        52.68119708,
         | 
| 201 | 
            +
                        54.1222372,
         | 
| 202 | 
            +
                        55.60274206,
         | 
| 203 | 
            +
                        57.12364703,
         | 
| 204 | 
            +
                        58.68617876,
         | 
| 205 | 
            +
                        60.29148652,
         | 
| 206 | 
            +
                        61.94081306,
         | 
| 207 | 
            +
                        63.63501986,
         | 
| 208 | 
            +
                        65.37562658,
         | 
| 209 | 
            +
                        67.16408954,
         | 
| 210 | 
            +
                        69.00109084,
         | 
| 211 | 
            +
                        70.88850318,
         | 
| 212 | 
            +
                        72.82736101,
         | 
| 213 | 
            +
                        74.81985537,
         | 
| 214 | 
            +
                        76.86654792,
         | 
| 215 | 
            +
                        78.96885475,
         | 
| 216 | 
            +
                        81.12900906,
         | 
| 217 | 
            +
                        83.34840929,
         | 
| 218 | 
            +
                        85.62810662,
         | 
| 219 | 
            +
                        87.97005418,
         | 
| 220 | 
            +
                        90.37689804,
         | 
| 221 | 
            +
                        92.84887686,
         | 
| 222 | 
            +
                        95.38872881,
         | 
| 223 | 
            +
                        97.99777002,
         | 
| 224 | 
            +
                        100.67862715,
         | 
| 225 | 
            +
                        103.43232942,
         | 
| 226 | 
            +
                        106.26140638,
         | 
| 227 | 
            +
                        109.16827015,
         | 
| 228 | 
            +
                        112.15470471,
         | 
| 229 | 
            +
                        115.22184756,
         | 
| 230 | 
            +
                        118.37439245,
         | 
| 231 | 
            +
                        121.6122689,
         | 
| 232 | 
            +
                        124.93877158,
         | 
| 233 | 
            +
                        128.35661454,
         | 
| 234 | 
            +
                        131.86761321,
         | 
| 235 | 
            +
                        135.47417938,
         | 
| 236 | 
            +
                        139.18059494,
         | 
| 237 | 
            +
                        142.98713744,
         | 
| 238 | 
            +
                        146.89771854,
         | 
| 239 | 
            +
                        150.91684347,
         | 
| 240 | 
            +
                        155.0446638,
         | 
| 241 | 
            +
                        159.28614648,
         | 
| 242 | 
            +
                        163.64270198,
         | 
| 243 | 
            +
                        168.12035831,
         | 
| 244 | 
            +
                        172.71749158,
         | 
| 245 | 
            +
                        177.44220154,
         | 
| 246 | 
            +
                        182.29556933,
         | 
| 247 | 
            +
                        187.28286676,
         | 
| 248 | 
            +
                        192.40502126,
         | 
| 249 | 
            +
                        197.6682721,
         | 
| 250 | 
            +
                        203.07516896,
         | 
| 251 | 
            +
                        208.63088733,
         | 
| 252 | 
            +
                        214.33770931,
         | 
| 253 | 
            +
                        220.19910108,
         | 
| 254 | 
            +
                        226.22363072,
         | 
| 255 | 
            +
                        232.41087124,
         | 
| 256 | 
            +
                        238.76803591,
         | 
| 257 | 
            +
                        245.30079083,
         | 
| 258 | 
            +
                        252.01064464,
         | 
| 259 | 
            +
                        258.90261676,
         | 
| 260 | 
            +
                        265.98474,
         | 
| 261 | 
            +
                        273.26010248,
         | 
| 262 | 
            +
                        280.73496362,
         | 
| 263 | 
            +
                        288.41440094,
         | 
| 264 | 
            +
                        296.30489752,
         | 
| 265 | 
            +
                        304.41180337,
         | 
| 266 | 
            +
                        312.7377183,
         | 
| 267 | 
            +
                        321.28877878,
         | 
| 268 | 
            +
                        330.07870237,
         | 
| 269 | 
            +
                        339.10812951,
         | 
| 270 | 
            +
                        348.38276173,
         | 
| 271 | 
            +
                        357.91393924,
         | 
| 272 | 
            +
                        367.70513992,
         | 
| 273 | 
            +
                        377.76413924,
         | 
| 274 | 
            +
                        388.09467408,
         | 
| 275 | 
            +
                        398.70920178,
         | 
| 276 | 
            +
                        409.61813793,
         | 
| 277 | 
            +
                        420.81980127,
         | 
| 278 | 
            +
                        432.33215467,
         | 
| 279 | 
            +
                        444.16083117,
         | 
| 280 | 
            +
                        456.30919947,
         | 
| 281 | 
            +
                        468.78589276,
         | 
| 282 | 
            +
                        481.61325588,
         | 
| 283 | 
            +
                        494.78824596,
         | 
| 284 | 
            +
                        508.31969844,
         | 
| 285 | 
            +
                        522.2238331,
         | 
| 286 | 
            +
                        536.51163441,
         | 
| 287 | 
            +
                        551.18859414,
         | 
| 288 | 
            +
                        566.26142988,
         | 
| 289 | 
            +
                        581.75006061,
         | 
| 290 | 
            +
                        597.66210737,
         | 
| 291 | 
            +
                    ]
         | 
| 292 | 
            +
                )
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                x_orig = np.linspace(1, mel_weight_torch.shape[0], num=mel_weight_torch.shape[0])
         | 
| 295 | 
            +
             | 
| 296 | 
            +
                x_orig_torch = torch.linspace(
         | 
| 297 | 
            +
                    1, mel_weight_torch.shape[0], steps=mel_weight_torch.shape[0]
         | 
| 298 | 
            +
                )
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                @classmethod
         | 
| 301 | 
            +
                def get_mel_weight(cls, percent=1, a=18.8927416350036, b=0.0269863588184314):
         | 
| 302 | 
            +
                    b = percent * b
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                    def func(a, b, x):
         | 
| 305 | 
            +
                        return a * np.exp(b * x)
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                    return func(a, b, Config.x_orig)
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                @classmethod
         | 
| 310 | 
            +
                def get_mel_weight_torch(cls, percent=1, a=18.8927416350036, b=0.0269863588184314):
         | 
| 311 | 
            +
                    b = percent * b
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                    def func(a, b, x):
         | 
| 314 | 
            +
                        return a * torch.exp(b * x)
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                    return func(a, b, Config.x_orig_torch)
         | 
    	
        voicefixer/vocoder/model/__init__.py
    ADDED
    
    | @@ -0,0 +1,11 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/env python
         | 
| 2 | 
            +
            # -*- encoding: utf-8 -*-
         | 
| 3 | 
            +
            """
         | 
| 4 | 
            +
            @File    :   __init__.py.py    
         | 
| 5 | 
            +
            @Contact :   [email protected]
         | 
| 6 | 
            +
            @License :   (C)Copyright 2020-2100
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            @Modify Time      @Author    @Version    @Desciption
         | 
| 9 | 
            +
            ------------      -------    --------    -----------
         | 
| 10 | 
            +
            9/14/21 1:00 AM   Haohe Liu      1.0         None
         | 
| 11 | 
            +
            """
         | 
    	
        voicefixer/vocoder/model/generator.py
    ADDED
    
    | @@ -0,0 +1,168 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            from voicefixer.vocoder.model.modules import UpsampleNet, ResStack
         | 
| 5 | 
            +
            from voicefixer.vocoder.config import Config
         | 
| 6 | 
            +
            from voicefixer.vocoder.model.pqmf import PQMF
         | 
| 7 | 
            +
            import os
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            class Generator(nn.Module):
         | 
| 13 | 
            +
                def __init__(
         | 
| 14 | 
            +
                    self,
         | 
| 15 | 
            +
                    in_channels=128,
         | 
| 16 | 
            +
                    use_elu=False,
         | 
| 17 | 
            +
                    use_gcnn=False,
         | 
| 18 | 
            +
                    up_org=False,
         | 
| 19 | 
            +
                    group=1,
         | 
| 20 | 
            +
                    hp=None,
         | 
| 21 | 
            +
                ):
         | 
| 22 | 
            +
                    super(Generator, self).__init__()
         | 
| 23 | 
            +
                    self.hp = hp
         | 
| 24 | 
            +
                    channels = Config.channels
         | 
| 25 | 
            +
                    self.upsample_scales = Config.upsample_scales
         | 
| 26 | 
            +
                    self.use_condnet = Config.use_condnet
         | 
| 27 | 
            +
                    self.out_channels = Config.out_channels
         | 
| 28 | 
            +
                    self.resstack_depth = Config.resstack_depth
         | 
| 29 | 
            +
                    self.use_postnet = Config.use_postnet
         | 
| 30 | 
            +
                    self.use_cond_rnn = Config.use_cond_rnn
         | 
| 31 | 
            +
                    if self.use_condnet:
         | 
| 32 | 
            +
                        cond_channels = Config.cond_channels
         | 
| 33 | 
            +
                        self.condnet = nn.Sequential(
         | 
| 34 | 
            +
                            nn.utils.weight_norm(
         | 
| 35 | 
            +
                                nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
         | 
| 36 | 
            +
                            ),
         | 
| 37 | 
            +
                            nn.ELU(),
         | 
| 38 | 
            +
                            nn.utils.weight_norm(
         | 
| 39 | 
            +
                                nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
         | 
| 40 | 
            +
                            ),
         | 
| 41 | 
            +
                            nn.ELU(),
         | 
| 42 | 
            +
                            nn.utils.weight_norm(
         | 
| 43 | 
            +
                                nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
         | 
| 44 | 
            +
                            ),
         | 
| 45 | 
            +
                            nn.ELU(),
         | 
| 46 | 
            +
                            nn.utils.weight_norm(
         | 
| 47 | 
            +
                                nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
         | 
| 48 | 
            +
                            ),
         | 
| 49 | 
            +
                            nn.ELU(),
         | 
| 50 | 
            +
                            nn.utils.weight_norm(
         | 
| 51 | 
            +
                                nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
         | 
| 52 | 
            +
                            ),
         | 
| 53 | 
            +
                            nn.ELU(),
         | 
| 54 | 
            +
                        )
         | 
| 55 | 
            +
                        in_channels = cond_channels
         | 
| 56 | 
            +
                    if self.use_cond_rnn:
         | 
| 57 | 
            +
                        self.rnn = nn.GRU(
         | 
| 58 | 
            +
                            cond_channels,
         | 
| 59 | 
            +
                            cond_channels // 2,
         | 
| 60 | 
            +
                            num_layers=1,
         | 
| 61 | 
            +
                            batch_first=True,
         | 
| 62 | 
            +
                            bidirectional=True,
         | 
| 63 | 
            +
                        )
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    if use_elu:
         | 
| 66 | 
            +
                        act = nn.ELU()
         | 
| 67 | 
            +
                    else:
         | 
| 68 | 
            +
                        act = nn.LeakyReLU(0.2, True)
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    kernel_size = Config.kernel_size
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    if self.out_channels == 1:
         | 
| 73 | 
            +
                        self.generator = nn.Sequential(
         | 
| 74 | 
            +
                            nn.ReflectionPad1d(3),
         | 
| 75 | 
            +
                            nn.utils.weight_norm(nn.Conv1d(in_channels, channels, kernel_size=7)),
         | 
| 76 | 
            +
                            act,
         | 
| 77 | 
            +
                            UpsampleNet(channels, channels // 2, self.upsample_scales[0], hp, 0),
         | 
| 78 | 
            +
                            ResStack(channels // 2, kernel_size[0], self.resstack_depth[0], hp),
         | 
| 79 | 
            +
                            act,
         | 
| 80 | 
            +
                            UpsampleNet(
         | 
| 81 | 
            +
                                channels // 2, channels // 4, self.upsample_scales[1], hp, 1
         | 
| 82 | 
            +
                            ),
         | 
| 83 | 
            +
                            ResStack(channels // 4, kernel_size[1], self.resstack_depth[1], hp),
         | 
| 84 | 
            +
                            act,
         | 
| 85 | 
            +
                            UpsampleNet(
         | 
| 86 | 
            +
                                channels // 4, channels // 8, self.upsample_scales[2], hp, 2
         | 
| 87 | 
            +
                            ),
         | 
| 88 | 
            +
                            ResStack(channels // 8, kernel_size[2], self.resstack_depth[2], hp),
         | 
| 89 | 
            +
                            act,
         | 
| 90 | 
            +
                            UpsampleNet(
         | 
| 91 | 
            +
                                channels // 8, channels // 16, self.upsample_scales[3], hp, 3
         | 
| 92 | 
            +
                            ),
         | 
| 93 | 
            +
                            ResStack(channels // 16, kernel_size[3], self.resstack_depth[3], hp),
         | 
| 94 | 
            +
                            act,
         | 
| 95 | 
            +
                            nn.ReflectionPad1d(3),
         | 
| 96 | 
            +
                            nn.utils.weight_norm(
         | 
| 97 | 
            +
                                nn.Conv1d(channels // 16, self.out_channels, kernel_size=7)
         | 
| 98 | 
            +
                            ),
         | 
| 99 | 
            +
                            nn.Tanh(),
         | 
| 100 | 
            +
                        )
         | 
| 101 | 
            +
                    else:
         | 
| 102 | 
            +
                        channels = Config.m_channels
         | 
| 103 | 
            +
                        self.generator = nn.Sequential(
         | 
| 104 | 
            +
                            nn.ReflectionPad1d(3),
         | 
| 105 | 
            +
                            nn.utils.weight_norm(nn.Conv1d(in_channels, channels, kernel_size=7)),
         | 
| 106 | 
            +
                            act,
         | 
| 107 | 
            +
                            UpsampleNet(channels, channels // 2, self.upsample_scales[0], hp),
         | 
| 108 | 
            +
                            ResStack(channels // 2, kernel_size[0], self.resstack_depth[0], hp),
         | 
| 109 | 
            +
                            act,
         | 
| 110 | 
            +
                            UpsampleNet(channels // 2, channels // 4, self.upsample_scales[1], hp),
         | 
| 111 | 
            +
                            ResStack(channels // 4, kernel_size[1], self.resstack_depth[1], hp),
         | 
| 112 | 
            +
                            act,
         | 
| 113 | 
            +
                            UpsampleNet(channels // 4, channels // 8, self.upsample_scales[3], hp),
         | 
| 114 | 
            +
                            ResStack(channels // 8, kernel_size[3], self.resstack_depth[2], hp),
         | 
| 115 | 
            +
                            act,
         | 
| 116 | 
            +
                            nn.ReflectionPad1d(3),
         | 
| 117 | 
            +
                            nn.utils.weight_norm(
         | 
| 118 | 
            +
                                nn.Conv1d(channels // 8, self.out_channels, kernel_size=7)
         | 
| 119 | 
            +
                            ),
         | 
| 120 | 
            +
                            nn.Tanh(),
         | 
| 121 | 
            +
                        )
         | 
| 122 | 
            +
                    if self.out_channels > 1:
         | 
| 123 | 
            +
                        self.pqmf = PQMF(4, 64)
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    self.num_params()
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                def forward(self, conditions, use_res=False, f0=None):
         | 
| 128 | 
            +
                    res = conditions
         | 
| 129 | 
            +
                    if self.use_condnet:
         | 
| 130 | 
            +
                        conditions = self.condnet(conditions)
         | 
| 131 | 
            +
                    if self.use_cond_rnn:
         | 
| 132 | 
            +
                        conditions, _ = self.rnn(conditions.transpose(1, 2))
         | 
| 133 | 
            +
                        conditions = conditions.transpose(1, 2)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    wav = self.generator(conditions)
         | 
| 136 | 
            +
                    if self.out_channels > 1:
         | 
| 137 | 
            +
                        B = wav.size(0)
         | 
| 138 | 
            +
                        f_wav = (
         | 
| 139 | 
            +
                            self.pqmf.synthesis(wav)
         | 
| 140 | 
            +
                            .transpose(1, 2)
         | 
| 141 | 
            +
                            .reshape(B, 1, -1)
         | 
| 142 | 
            +
                            .clamp(-0.99, 0.99)
         | 
| 143 | 
            +
                        )
         | 
| 144 | 
            +
                        return f_wav, wav
         | 
| 145 | 
            +
                    return wav
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                def num_params(self):
         | 
| 148 | 
            +
                    parameters = filter(lambda p: p.requires_grad, self.parameters())
         | 
| 149 | 
            +
                    parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
         | 
| 150 | 
            +
                    return parameters
         | 
| 151 | 
            +
                    # print('Trainable Parameters: %.3f million' % parameters)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                def remove_weight_norm(self):
         | 
| 154 | 
            +
                    def _remove_weight_norm(m):
         | 
| 155 | 
            +
                        try:
         | 
| 156 | 
            +
                            torch.nn.utils.remove_weight_norm(m)
         | 
| 157 | 
            +
                        except ValueError:  # this module didn't have weight norm
         | 
| 158 | 
            +
                            return
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    self.apply(_remove_weight_norm)
         | 
| 161 | 
            +
             | 
| 162 | 
            +
             | 
| 163 | 
            +
            if __name__ == "__main__":
         | 
| 164 | 
            +
                model = Generator(128)
         | 
| 165 | 
            +
                x = torch.randn(3, 128, 13)
         | 
| 166 | 
            +
                print(x.shape)
         | 
| 167 | 
            +
                y = model(x)
         | 
| 168 | 
            +
                print(y.shape)
         | 
    	
        voicefixer/vocoder/model/modules.py
    ADDED
    
    | @@ -0,0 +1,947 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn as nn
         | 
| 5 | 
            +
            import torch.nn.functional as F
         | 
| 6 | 
            +
            from voicefixer.vocoder.config import Config
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            # From xin wang of nii
         | 
| 9 | 
            +
            class SineGen(torch.nn.Module):
         | 
| 10 | 
            +
                """Definition of sine generator
         | 
| 11 | 
            +
                SineGen(samp_rate, harmonic_num = 0,
         | 
| 12 | 
            +
                        sine_amp = 0.1, noise_std = 0.003,
         | 
| 13 | 
            +
                        voiced_threshold = 0,
         | 
| 14 | 
            +
                        flag_for_pulse=False)
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                samp_rate: sampling rate in Hz
         | 
| 17 | 
            +
                harmonic_num: number of harmonic overtones (default 0)
         | 
| 18 | 
            +
                sine_amp: amplitude of sine-wavefrom (default 0.1)
         | 
| 19 | 
            +
                noise_std: std of Gaussian noise (default 0.003)
         | 
| 20 | 
            +
                voiced_thoreshold: F0 threshold for U/V classification (default 0)
         | 
| 21 | 
            +
                flag_for_pulse: this SinGen is used inside PulseGen (default False)
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                Note: when flag_for_pulse is True, the first time step of a voiced
         | 
| 24 | 
            +
                    segment is always sin(np.pi) or cos(0)
         | 
| 25 | 
            +
                """
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                def __init__(
         | 
| 28 | 
            +
                    self,
         | 
| 29 | 
            +
                    samp_rate=24000,
         | 
| 30 | 
            +
                    harmonic_num=0,
         | 
| 31 | 
            +
                    sine_amp=0.1,
         | 
| 32 | 
            +
                    noise_std=0.003,
         | 
| 33 | 
            +
                    voiced_threshold=0,
         | 
| 34 | 
            +
                    flag_for_pulse=False,
         | 
| 35 | 
            +
                ):
         | 
| 36 | 
            +
                    super(SineGen, self).__init__()
         | 
| 37 | 
            +
                    self.sine_amp = sine_amp
         | 
| 38 | 
            +
                    self.noise_std = noise_std
         | 
| 39 | 
            +
                    self.harmonic_num = harmonic_num
         | 
| 40 | 
            +
                    self.dim = self.harmonic_num + 1
         | 
| 41 | 
            +
                    self.sampling_rate = samp_rate
         | 
| 42 | 
            +
                    self.voiced_threshold = voiced_threshold
         | 
| 43 | 
            +
                    self.flag_for_pulse = flag_for_pulse
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                def _f02uv(self, f0):
         | 
| 46 | 
            +
                    # generate uv signal
         | 
| 47 | 
            +
                    uv = torch.ones_like(f0)
         | 
| 48 | 
            +
                    uv = uv * (f0 > self.voiced_threshold)
         | 
| 49 | 
            +
                    return uv
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                def _f02sine(self, f0_values):
         | 
| 52 | 
            +
                    """f0_values: (batchsize, length, dim)
         | 
| 53 | 
            +
                    where dim indicates fundamental tone and overtones
         | 
| 54 | 
            +
                    """
         | 
| 55 | 
            +
                    # convert to F0 in rad. The interger part n can be ignored
         | 
| 56 | 
            +
                    # because 2 * np.pi * n doesn't affect phase
         | 
| 57 | 
            +
                    rad_values = (f0_values / self.sampling_rate) % 1
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    # initial phase noise (no noise for fundamental component)
         | 
| 60 | 
            +
                    rand_ini = torch.rand(
         | 
| 61 | 
            +
                        f0_values.shape[0], f0_values.shape[2], device=f0_values.device
         | 
| 62 | 
            +
                    )
         | 
| 63 | 
            +
                    rand_ini[:, 0] = 0
         | 
| 64 | 
            +
                    rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
         | 
| 67 | 
            +
                    if not self.flag_for_pulse:
         | 
| 68 | 
            +
                        # for normal case
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                        # To prevent torch.cumsum numerical overflow,
         | 
| 71 | 
            +
                        # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
         | 
| 72 | 
            +
                        # Buffer tmp_over_one_idx indicates the time step to add -1.
         | 
| 73 | 
            +
                        # This will not change F0 of sine because (x-1) * 2*pi = x *2*pi
         | 
| 74 | 
            +
                        tmp_over_one = torch.cumsum(rad_values, 1) % 1
         | 
| 75 | 
            +
                        tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
         | 
| 76 | 
            +
                        cumsum_shift = torch.zeros_like(rad_values)
         | 
| 77 | 
            +
                        cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                        sines = torch.sin(
         | 
| 80 | 
            +
                            torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi
         | 
| 81 | 
            +
                        )
         | 
| 82 | 
            +
                    else:
         | 
| 83 | 
            +
                        # If necessary, make sure that the first time step of every
         | 
| 84 | 
            +
                        # voiced segments is sin(pi) or cos(0)
         | 
| 85 | 
            +
                        # This is used for pulse-train generation
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                        # identify the last time step in unvoiced segments
         | 
| 88 | 
            +
                        uv = self._f02uv(f0_values)
         | 
| 89 | 
            +
                        uv_1 = torch.roll(uv, shifts=-1, dims=1)
         | 
| 90 | 
            +
                        uv_1[:, -1, :] = 1
         | 
| 91 | 
            +
                        u_loc = (uv < 1) * (uv_1 > 0)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                        # get the instantanouse phase
         | 
| 94 | 
            +
                        tmp_cumsum = torch.cumsum(rad_values, dim=1)
         | 
| 95 | 
            +
                        # different batch needs to be processed differently
         | 
| 96 | 
            +
                        for idx in range(f0_values.shape[0]):
         | 
| 97 | 
            +
                            temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
         | 
| 98 | 
            +
                            temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
         | 
| 99 | 
            +
                            # stores the accumulation of i.phase within
         | 
| 100 | 
            +
                            # each voiced segments
         | 
| 101 | 
            +
                            tmp_cumsum[idx, :, :] = 0
         | 
| 102 | 
            +
                            tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                        # rad_values - tmp_cumsum: remove the accumulation of i.phase
         | 
| 105 | 
            +
                        # within the previous voiced segment.
         | 
| 106 | 
            +
                        i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                        # get the sines
         | 
| 109 | 
            +
                        sines = torch.cos(i_phase * 2 * np.pi)
         | 
| 110 | 
            +
                    return sines
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                def forward(self, f0):
         | 
| 113 | 
            +
                    """sine_tensor, uv = forward(f0)
         | 
| 114 | 
            +
                    input F0: tensor(batchsize=1, length, dim=1)
         | 
| 115 | 
            +
                              f0 for unvoiced steps should be 0
         | 
| 116 | 
            +
                    output sine_tensor: tensor(batchsize=1, length, dim)
         | 
| 117 | 
            +
                    output uv: tensor(batchsize=1, length, 1)
         | 
| 118 | 
            +
                    """
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    with torch.no_grad():
         | 
| 121 | 
            +
                        f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
         | 
| 122 | 
            +
                        # fundamental component
         | 
| 123 | 
            +
                        f0_buf[:, :, 0] = f0[:, :, 0]
         | 
| 124 | 
            +
                        for idx in np.arange(self.harmonic_num):
         | 
| 125 | 
            +
                            # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
         | 
| 126 | 
            +
                            f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                        # generate sine waveforms
         | 
| 129 | 
            +
                        sine_waves = self._f02sine(f0_buf) * self.sine_amp
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                        # generate uv signal
         | 
| 132 | 
            +
                        # uv = torch.ones(f0.shape)
         | 
| 133 | 
            +
                        # uv = uv * (f0 > self.voiced_threshold)
         | 
| 134 | 
            +
                        uv = self._f02uv(f0)
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                        # noise: for unvoiced should be similar to sine_amp
         | 
| 137 | 
            +
                        #        std = self.sine_amp/3 -> max value ~ self.sine_amp
         | 
| 138 | 
            +
                        # .       for voiced regions is self.noise_std
         | 
| 139 | 
            +
                        noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
         | 
| 140 | 
            +
                        noise = noise_amp * torch.randn_like(sine_waves)
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                        # first: set the unvoiced part to 0 by uv
         | 
| 143 | 
            +
                        # then: additive noise
         | 
| 144 | 
            +
                        sine_waves = sine_waves * uv + noise
         | 
| 145 | 
            +
                    return sine_waves, uv, noise
         | 
| 146 | 
            +
             | 
| 147 | 
            +
             | 
| 148 | 
            +
            class LowpassBlur(nn.Module):
         | 
| 149 | 
            +
                """perform low pass filter after upsampling for anti-aliasing"""
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                def __init__(self, channels=128, filt_size=3, pad_type="reflect", pad_off=0):
         | 
| 152 | 
            +
                    super(LowpassBlur, self).__init__()
         | 
| 153 | 
            +
                    self.filt_size = filt_size
         | 
| 154 | 
            +
                    self.pad_off = pad_off
         | 
| 155 | 
            +
                    self.pad_sizes = [
         | 
| 156 | 
            +
                        int(1.0 * (filt_size - 1) / 2),
         | 
| 157 | 
            +
                        int(np.ceil(1.0 * (filt_size - 1) / 2)),
         | 
| 158 | 
            +
                    ]
         | 
| 159 | 
            +
                    self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
         | 
| 160 | 
            +
                    self.off = 0
         | 
| 161 | 
            +
                    self.channels = channels
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                    if self.filt_size == 1:
         | 
| 164 | 
            +
                        a = np.array(
         | 
| 165 | 
            +
                            [
         | 
| 166 | 
            +
                                1.0,
         | 
| 167 | 
            +
                            ]
         | 
| 168 | 
            +
                        )
         | 
| 169 | 
            +
                    elif self.filt_size == 2:
         | 
| 170 | 
            +
                        a = np.array([1.0, 1.0])
         | 
| 171 | 
            +
                    elif self.filt_size == 3:
         | 
| 172 | 
            +
                        a = np.array([1.0, 2.0, 1.0])
         | 
| 173 | 
            +
                    elif self.filt_size == 4:
         | 
| 174 | 
            +
                        a = np.array([1.0, 3.0, 3.0, 1.0])
         | 
| 175 | 
            +
                    elif self.filt_size == 5:
         | 
| 176 | 
            +
                        a = np.array([1.0, 4.0, 6.0, 4.0, 1.0])
         | 
| 177 | 
            +
                    elif self.filt_size == 6:
         | 
| 178 | 
            +
                        a = np.array([1.0, 5.0, 10.0, 10.0, 5.0, 1.0])
         | 
| 179 | 
            +
                    elif self.filt_size == 7:
         | 
| 180 | 
            +
                        a = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0])
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    filt = torch.Tensor(a)
         | 
| 183 | 
            +
                    filt = filt / torch.sum(filt)
         | 
| 184 | 
            +
                    self.register_buffer("filt", filt[None, None, :].repeat((self.channels, 1, 1)))
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                    self.pad = get_pad_layer_1d(pad_type)(self.pad_sizes)
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                def forward(self, inp):
         | 
| 189 | 
            +
                    if self.filt_size == 1:
         | 
| 190 | 
            +
                        return inp
         | 
| 191 | 
            +
                    return F.conv1d(self.pad(inp), self.filt, groups=inp.shape[1])
         | 
| 192 | 
            +
             | 
| 193 | 
            +
             | 
| 194 | 
            +
            def get_pad_layer_1d(pad_type):
         | 
| 195 | 
            +
                if pad_type in ["refl", "reflect"]:
         | 
| 196 | 
            +
                    PadLayer = nn.ReflectionPad1d
         | 
| 197 | 
            +
                elif pad_type in ["repl", "replicate"]:
         | 
| 198 | 
            +
                    PadLayer = nn.ReplicationPad1d
         | 
| 199 | 
            +
                elif pad_type == "zero":
         | 
| 200 | 
            +
                    PadLayer = nn.ZeroPad1d
         | 
| 201 | 
            +
                else:
         | 
| 202 | 
            +
                    print("Pad type [%s] not recognized" % pad_type)
         | 
| 203 | 
            +
                return PadLayer
         | 
| 204 | 
            +
             | 
| 205 | 
            +
             | 
| 206 | 
            +
            class MovingAverageSmooth(torch.nn.Conv1d):
         | 
| 207 | 
            +
                def __init__(self, channels, window_len=3):
         | 
| 208 | 
            +
                    """Initialize Conv1d module."""
         | 
| 209 | 
            +
                    super(MovingAverageSmooth, self).__init__(
         | 
| 210 | 
            +
                        in_channels=channels,
         | 
| 211 | 
            +
                        out_channels=channels,
         | 
| 212 | 
            +
                        kernel_size=1,
         | 
| 213 | 
            +
                        groups=channels,
         | 
| 214 | 
            +
                        bias=False,
         | 
| 215 | 
            +
                    )
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                    torch.nn.init.constant_(self.weight, 1.0 / window_len)
         | 
| 218 | 
            +
                    for p in self.parameters():
         | 
| 219 | 
            +
                        p.requires_grad = False
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                def forward(self, data):
         | 
| 222 | 
            +
                    return super(MovingAverageSmooth, self).forward(data)
         | 
| 223 | 
            +
             | 
| 224 | 
            +
             | 
| 225 | 
            +
            class Conv1d(torch.nn.Conv1d):
         | 
| 226 | 
            +
                """Conv1d module with customized initialization."""
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                def __init__(self, *args, **kwargs):
         | 
| 229 | 
            +
                    """Initialize Conv1d module."""
         | 
| 230 | 
            +
                    super(Conv1d, self).__init__(*args, **kwargs)
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                def reset_parameters(self):
         | 
| 233 | 
            +
                    """Reset parameters."""
         | 
| 234 | 
            +
                    torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu")
         | 
| 235 | 
            +
                    if self.bias is not None:
         | 
| 236 | 
            +
                        torch.nn.init.constant_(self.bias, 0.0)
         | 
| 237 | 
            +
             | 
| 238 | 
            +
             | 
| 239 | 
            +
            class Stretch2d(torch.nn.Module):
         | 
| 240 | 
            +
                """Stretch2d module."""
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                def __init__(self, x_scale, y_scale, mode="nearest"):
         | 
| 243 | 
            +
                    """Initialize Stretch2d module.
         | 
| 244 | 
            +
                    Args:
         | 
| 245 | 
            +
                        x_scale (int): X scaling factor (Time axis in spectrogram).
         | 
| 246 | 
            +
                        y_scale (int): Y scaling factor (Frequency axis in spectrogram).
         | 
| 247 | 
            +
                        mode (str): Interpolation mode.
         | 
| 248 | 
            +
                    """
         | 
| 249 | 
            +
                    super(Stretch2d, self).__init__()
         | 
| 250 | 
            +
                    self.x_scale = x_scale
         | 
| 251 | 
            +
                    self.y_scale = y_scale
         | 
| 252 | 
            +
                    self.mode = mode
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                def forward(self, x):
         | 
| 255 | 
            +
                    """Calculate forward propagation.
         | 
| 256 | 
            +
                    Args:
         | 
| 257 | 
            +
                        x (Tensor): Input tensor (B, C, F, T).
         | 
| 258 | 
            +
                    Returns:
         | 
| 259 | 
            +
                        Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale),
         | 
| 260 | 
            +
                    """
         | 
| 261 | 
            +
                    return F.interpolate(
         | 
| 262 | 
            +
                        x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode
         | 
| 263 | 
            +
                    )
         | 
| 264 | 
            +
             | 
| 265 | 
            +
             | 
| 266 | 
            +
            class Conv2d(torch.nn.Conv2d):
         | 
| 267 | 
            +
                """Conv2d module with customized initialization."""
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                def __init__(self, *args, **kwargs):
         | 
| 270 | 
            +
                    """Initialize Conv2d module."""
         | 
| 271 | 
            +
                    super(Conv2d, self).__init__(*args, **kwargs)
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                def reset_parameters(self):
         | 
| 274 | 
            +
                    """Reset parameters."""
         | 
| 275 | 
            +
                    self.weight.data.fill_(1.0 / np.prod(self.kernel_size))
         | 
| 276 | 
            +
                    if self.bias is not None:
         | 
| 277 | 
            +
                        torch.nn.init.constant_(self.bias, 0.0)
         | 
| 278 | 
            +
             | 
| 279 | 
            +
             | 
| 280 | 
            +
            class UpsampleNetwork(torch.nn.Module):
         | 
| 281 | 
            +
                """Upsampling network module."""
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                def __init__(
         | 
| 284 | 
            +
                    self,
         | 
| 285 | 
            +
                    upsample_scales,
         | 
| 286 | 
            +
                    nonlinear_activation=None,
         | 
| 287 | 
            +
                    nonlinear_activation_params={},
         | 
| 288 | 
            +
                    interpolate_mode="nearest",
         | 
| 289 | 
            +
                    freq_axis_kernel_size=1,
         | 
| 290 | 
            +
                    use_causal_conv=False,
         | 
| 291 | 
            +
                ):
         | 
| 292 | 
            +
                    """Initialize upsampling network module.
         | 
| 293 | 
            +
                    Args:
         | 
| 294 | 
            +
                        upsample_scales (list): List of upsampling scales.
         | 
| 295 | 
            +
                        nonlinear_activation (str): Activation function name.
         | 
| 296 | 
            +
                        nonlinear_activation_params (dict): Arguments for specified activation function.
         | 
| 297 | 
            +
                        interpolate_mode (str): Interpolation mode.
         | 
| 298 | 
            +
                        freq_axis_kernel_size (int): Kernel size in the direction of frequency axis.
         | 
| 299 | 
            +
                    """
         | 
| 300 | 
            +
                    super(UpsampleNetwork, self).__init__()
         | 
| 301 | 
            +
                    self.use_causal_conv = use_causal_conv
         | 
| 302 | 
            +
                    self.up_layers = torch.nn.ModuleList()
         | 
| 303 | 
            +
                    for scale in upsample_scales:
         | 
| 304 | 
            +
                        # interpolation layer
         | 
| 305 | 
            +
                        stretch = Stretch2d(scale, 1, interpolate_mode)
         | 
| 306 | 
            +
                        self.up_layers += [stretch]
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                        # conv layer
         | 
| 309 | 
            +
                        assert (
         | 
| 310 | 
            +
                            freq_axis_kernel_size - 1
         | 
| 311 | 
            +
                        ) % 2 == 0, "Not support even number freq axis kernel size."
         | 
| 312 | 
            +
                        freq_axis_padding = (freq_axis_kernel_size - 1) // 2
         | 
| 313 | 
            +
                        kernel_size = (freq_axis_kernel_size, scale * 2 + 1)
         | 
| 314 | 
            +
                        if use_causal_conv:
         | 
| 315 | 
            +
                            padding = (freq_axis_padding, scale * 2)
         | 
| 316 | 
            +
                        else:
         | 
| 317 | 
            +
                            padding = (freq_axis_padding, scale)
         | 
| 318 | 
            +
                        conv = Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False)
         | 
| 319 | 
            +
                        self.up_layers += [conv]
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                        # nonlinear
         | 
| 322 | 
            +
                        if nonlinear_activation is not None:
         | 
| 323 | 
            +
                            nonlinear = getattr(torch.nn, nonlinear_activation)(
         | 
| 324 | 
            +
                                **nonlinear_activation_params
         | 
| 325 | 
            +
                            )
         | 
| 326 | 
            +
                            self.up_layers += [nonlinear]
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                def forward(self, c):
         | 
| 329 | 
            +
                    """Calculate forward propagation.
         | 
| 330 | 
            +
                    Args:
         | 
| 331 | 
            +
                        c : Input tensor (B, C, T).
         | 
| 332 | 
            +
                    Returns:
         | 
| 333 | 
            +
                        Tensor: Upsampled tensor (B, C, T'), where T' = T * prod(upsample_scales).
         | 
| 334 | 
            +
                    """
         | 
| 335 | 
            +
                    c = c.unsqueeze(1)  # (B, 1, C, T)
         | 
| 336 | 
            +
                    for f in self.up_layers:
         | 
| 337 | 
            +
                        if self.use_causal_conv and isinstance(f, Conv2d):
         | 
| 338 | 
            +
                            c = f(c)[..., : c.size(-1)]
         | 
| 339 | 
            +
                        else:
         | 
| 340 | 
            +
                            c = f(c)
         | 
| 341 | 
            +
                    return c.squeeze(1)  # (B, C, T')
         | 
| 342 | 
            +
             | 
| 343 | 
            +
             | 
| 344 | 
            +
            class ConvInUpsampleNetwork(torch.nn.Module):
         | 
| 345 | 
            +
                """Convolution + upsampling network module."""
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                def __init__(
         | 
| 348 | 
            +
                    self,
         | 
| 349 | 
            +
                    upsample_scales=[3, 4, 5, 5],
         | 
| 350 | 
            +
                    nonlinear_activation="ReLU",
         | 
| 351 | 
            +
                    nonlinear_activation_params={},
         | 
| 352 | 
            +
                    interpolate_mode="nearest",
         | 
| 353 | 
            +
                    freq_axis_kernel_size=1,
         | 
| 354 | 
            +
                    aux_channels=80,
         | 
| 355 | 
            +
                    aux_context_window=0,
         | 
| 356 | 
            +
                    use_causal_conv=False,
         | 
| 357 | 
            +
                ):
         | 
| 358 | 
            +
                    """Initialize convolution + upsampling network module.
         | 
| 359 | 
            +
                    Args:
         | 
| 360 | 
            +
                        upsample_scales (list): List of upsampling scales.
         | 
| 361 | 
            +
                        nonlinear_activation (str): Activation function name.
         | 
| 362 | 
            +
                        nonlinear_activation_params (dict): Arguments for specified activation function.
         | 
| 363 | 
            +
                        mode (str): Interpolation mode.
         | 
| 364 | 
            +
                        freq_axis_kernel_size (int): Kernel size in the direction of frequency axis.
         | 
| 365 | 
            +
                        aux_channels (int): Number of channels of pre-convolutional layer.
         | 
| 366 | 
            +
                        aux_context_window (int): Context window size of the pre-convolutional layer.
         | 
| 367 | 
            +
                        use_causal_conv (bool): Whether to use causal structure.
         | 
| 368 | 
            +
                    """
         | 
| 369 | 
            +
                    super(ConvInUpsampleNetwork, self).__init__()
         | 
| 370 | 
            +
                    self.aux_context_window = aux_context_window
         | 
| 371 | 
            +
                    self.use_causal_conv = use_causal_conv and aux_context_window > 0
         | 
| 372 | 
            +
                    # To capture wide-context information in conditional features
         | 
| 373 | 
            +
                    kernel_size = (
         | 
| 374 | 
            +
                        aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1
         | 
| 375 | 
            +
                    )
         | 
| 376 | 
            +
                    # NOTE(kan-bayashi): Here do not use padding because the input is already padded
         | 
| 377 | 
            +
                    self.conv_in = Conv1d(
         | 
| 378 | 
            +
                        aux_channels, aux_channels, kernel_size=kernel_size, bias=False
         | 
| 379 | 
            +
                    )
         | 
| 380 | 
            +
                    self.upsample = UpsampleNetwork(
         | 
| 381 | 
            +
                        upsample_scales=upsample_scales,
         | 
| 382 | 
            +
                        nonlinear_activation=nonlinear_activation,
         | 
| 383 | 
            +
                        nonlinear_activation_params=nonlinear_activation_params,
         | 
| 384 | 
            +
                        interpolate_mode=interpolate_mode,
         | 
| 385 | 
            +
                        freq_axis_kernel_size=freq_axis_kernel_size,
         | 
| 386 | 
            +
                        use_causal_conv=use_causal_conv,
         | 
| 387 | 
            +
                    )
         | 
| 388 | 
            +
             | 
| 389 | 
            +
                def forward(self, c):
         | 
| 390 | 
            +
                    """Calculate forward propagation.
         | 
| 391 | 
            +
                    Args:
         | 
| 392 | 
            +
                        c : Input tensor (B, C, T').
         | 
| 393 | 
            +
                    Returns:
         | 
| 394 | 
            +
                        Tensor: Upsampled tensor (B, C, T),
         | 
| 395 | 
            +
                            where T = (T' - aux_context_window * 2) * prod(upsample_scales).
         | 
| 396 | 
            +
                    Note:
         | 
| 397 | 
            +
                        The length of inputs considers the context window size.
         | 
| 398 | 
            +
                    """
         | 
| 399 | 
            +
                    c_ = self.conv_in(c)
         | 
| 400 | 
            +
                    c = c_[:, :, : -self.aux_context_window] if self.use_causal_conv else c_
         | 
| 401 | 
            +
                    return self.upsample(c)
         | 
| 402 | 
            +
             | 
| 403 | 
            +
             | 
| 404 | 
            +
            class DownsampleNet(nn.Module):
         | 
| 405 | 
            +
                def __init__(self, input_size, output_size, upsample_factor, hp=None, index=0):
         | 
| 406 | 
            +
                    super(DownsampleNet, self).__init__()
         | 
| 407 | 
            +
                    self.input_size = input_size
         | 
| 408 | 
            +
                    self.output_size = output_size
         | 
| 409 | 
            +
                    self.upsample_factor = upsample_factor
         | 
| 410 | 
            +
                    self.skip_conv = nn.Conv1d(input_size, output_size, kernel_size=1)
         | 
| 411 | 
            +
                    self.index = index
         | 
| 412 | 
            +
                    layer = nn.Conv1d(
         | 
| 413 | 
            +
                        input_size,
         | 
| 414 | 
            +
                        output_size,
         | 
| 415 | 
            +
                        kernel_size=upsample_factor * 2,
         | 
| 416 | 
            +
                        stride=upsample_factor,
         | 
| 417 | 
            +
                        padding=upsample_factor // 2 + upsample_factor % 2,
         | 
| 418 | 
            +
                    )
         | 
| 419 | 
            +
             | 
| 420 | 
            +
                    self.layer = nn.utils.weight_norm(layer)
         | 
| 421 | 
            +
             | 
| 422 | 
            +
                def forward(self, inputs):
         | 
| 423 | 
            +
                    B, C, T = inputs.size()
         | 
| 424 | 
            +
                    res = inputs[:, :, :: self.upsample_factor]
         | 
| 425 | 
            +
                    skip = self.skip_conv(res)
         | 
| 426 | 
            +
             | 
| 427 | 
            +
                    outputs = self.layer(inputs)
         | 
| 428 | 
            +
                    outputs = outputs + skip
         | 
| 429 | 
            +
             | 
| 430 | 
            +
                    return outputs
         | 
| 431 | 
            +
             | 
| 432 | 
            +
             | 
| 433 | 
            +
            class UpsampleNet(nn.Module):
         | 
| 434 | 
            +
                def __init__(self, input_size, output_size, upsample_factor, hp=None, index=0):
         | 
| 435 | 
            +
             | 
| 436 | 
            +
                    super(UpsampleNet, self).__init__()
         | 
| 437 | 
            +
                    self.up_type = Config.up_type
         | 
| 438 | 
            +
                    self.use_smooth = Config.use_smooth
         | 
| 439 | 
            +
                    self.use_drop = Config.use_drop
         | 
| 440 | 
            +
                    self.input_size = input_size
         | 
| 441 | 
            +
                    self.output_size = output_size
         | 
| 442 | 
            +
                    self.upsample_factor = upsample_factor
         | 
| 443 | 
            +
                    self.skip_conv = nn.Conv1d(input_size, output_size, kernel_size=1)
         | 
| 444 | 
            +
                    self.index = index
         | 
| 445 | 
            +
                    if self.use_smooth:
         | 
| 446 | 
            +
                        window_lens = [5, 5, 4, 3]
         | 
| 447 | 
            +
                        self.window_len = window_lens[index]
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                    if self.up_type != "pn" or self.index < 3:
         | 
| 450 | 
            +
                        # if self.up_type != "pn":
         | 
| 451 | 
            +
                        layer = nn.ConvTranspose1d(
         | 
| 452 | 
            +
                            input_size,
         | 
| 453 | 
            +
                            output_size,
         | 
| 454 | 
            +
                            upsample_factor * 2,
         | 
| 455 | 
            +
                            upsample_factor,
         | 
| 456 | 
            +
                            padding=upsample_factor // 2 + upsample_factor % 2,
         | 
| 457 | 
            +
                            output_padding=upsample_factor % 2,
         | 
| 458 | 
            +
                        )
         | 
| 459 | 
            +
                        self.layer = nn.utils.weight_norm(layer)
         | 
| 460 | 
            +
                    else:
         | 
| 461 | 
            +
                        self.layer = nn.Sequential(
         | 
| 462 | 
            +
                            nn.ReflectionPad1d(1),
         | 
| 463 | 
            +
                            nn.utils.weight_norm(
         | 
| 464 | 
            +
                                nn.Conv1d(input_size, output_size * upsample_factor, kernel_size=3)
         | 
| 465 | 
            +
                            ),
         | 
| 466 | 
            +
                            nn.LeakyReLU(),
         | 
| 467 | 
            +
                            nn.ReflectionPad1d(1),
         | 
| 468 | 
            +
                            nn.utils.weight_norm(
         | 
| 469 | 
            +
                                nn.Conv1d(
         | 
| 470 | 
            +
                                    output_size * upsample_factor,
         | 
| 471 | 
            +
                                    output_size * upsample_factor,
         | 
| 472 | 
            +
                                    kernel_size=3,
         | 
| 473 | 
            +
                                )
         | 
| 474 | 
            +
                            ),
         | 
| 475 | 
            +
                            nn.LeakyReLU(),
         | 
| 476 | 
            +
                            nn.ReflectionPad1d(1),
         | 
| 477 | 
            +
                            nn.utils.weight_norm(
         | 
| 478 | 
            +
                                nn.Conv1d(
         | 
| 479 | 
            +
                                    output_size * upsample_factor,
         | 
| 480 | 
            +
                                    output_size * upsample_factor,
         | 
| 481 | 
            +
                                    kernel_size=3,
         | 
| 482 | 
            +
                                )
         | 
| 483 | 
            +
                            ),
         | 
| 484 | 
            +
                            nn.LeakyReLU(),
         | 
| 485 | 
            +
                        )
         | 
| 486 | 
            +
             | 
| 487 | 
            +
                    if hp is not None:
         | 
| 488 | 
            +
                        self.org = Config.up_org
         | 
| 489 | 
            +
                        self.no_skip = Config.no_skip
         | 
| 490 | 
            +
                    else:
         | 
| 491 | 
            +
                        self.org = False
         | 
| 492 | 
            +
                        self.no_skip = True
         | 
| 493 | 
            +
             | 
| 494 | 
            +
                    if self.use_smooth:
         | 
| 495 | 
            +
                        self.mas = nn.Sequential(
         | 
| 496 | 
            +
                            # LowpassBlur(output_size, self.window_len),
         | 
| 497 | 
            +
                            MovingAverageSmooth(output_size, self.window_len),
         | 
| 498 | 
            +
                            # MovingAverageSmooth(output_size, self.window_len),
         | 
| 499 | 
            +
                        )
         | 
| 500 | 
            +
             | 
| 501 | 
            +
                def forward(self, inputs):
         | 
| 502 | 
            +
             | 
| 503 | 
            +
                    if not self.org:
         | 
| 504 | 
            +
                        inputs = inputs + torch.sin(inputs)
         | 
| 505 | 
            +
                        B, C, T = inputs.size()
         | 
| 506 | 
            +
                        res = inputs.repeat(1, self.upsample_factor, 1).view(B, C, -1)
         | 
| 507 | 
            +
                        skip = self.skip_conv(res)
         | 
| 508 | 
            +
                        if self.up_type == "repeat":
         | 
| 509 | 
            +
                            return skip
         | 
| 510 | 
            +
             | 
| 511 | 
            +
                    outputs = self.layer(inputs)
         | 
| 512 | 
            +
                    if self.up_type == "pn" and self.index > 2:
         | 
| 513 | 
            +
                        B, c, l = outputs.size()
         | 
| 514 | 
            +
                        outputs = outputs.view(B, -1, l * self.upsample_factor)
         | 
| 515 | 
            +
             | 
| 516 | 
            +
                    if self.no_skip:
         | 
| 517 | 
            +
                        return outputs
         | 
| 518 | 
            +
             | 
| 519 | 
            +
                    if not self.org:
         | 
| 520 | 
            +
                        outputs = outputs + skip
         | 
| 521 | 
            +
             | 
| 522 | 
            +
                    if self.use_smooth:
         | 
| 523 | 
            +
                        outputs = self.mas(outputs)
         | 
| 524 | 
            +
             | 
| 525 | 
            +
                    if self.use_drop:
         | 
| 526 | 
            +
                        outputs = F.dropout(outputs, p=0.05)
         | 
| 527 | 
            +
             | 
| 528 | 
            +
                    return outputs
         | 
| 529 | 
            +
             | 
| 530 | 
            +
             | 
| 531 | 
            +
            class ResStack(nn.Module):
         | 
| 532 | 
            +
                def __init__(self, channel, kernel_size=3, resstack_depth=4, hp=None):
         | 
| 533 | 
            +
                    super(ResStack, self).__init__()
         | 
| 534 | 
            +
             | 
| 535 | 
            +
                    self.use_wn = Config.use_wn
         | 
| 536 | 
            +
                    self.use_shift_scale = Config.use_shift_scale
         | 
| 537 | 
            +
                    self.channel = channel
         | 
| 538 | 
            +
             | 
| 539 | 
            +
                    def get_padding(kernel_size, dilation=1):
         | 
| 540 | 
            +
                        return int((kernel_size * dilation - dilation) / 2)
         | 
| 541 | 
            +
             | 
| 542 | 
            +
                    if self.use_shift_scale:
         | 
| 543 | 
            +
                        self.scale_conv = nn.utils.weight_norm(
         | 
| 544 | 
            +
                            nn.Conv1d(
         | 
| 545 | 
            +
                                channel, 2 * channel, kernel_size=kernel_size, dilation=1, padding=1
         | 
| 546 | 
            +
                            )
         | 
| 547 | 
            +
                        )
         | 
| 548 | 
            +
             | 
| 549 | 
            +
                    if not self.use_wn:
         | 
| 550 | 
            +
                        self.layers = nn.ModuleList(
         | 
| 551 | 
            +
                            [
         | 
| 552 | 
            +
                                nn.Sequential(
         | 
| 553 | 
            +
                                    nn.LeakyReLU(),
         | 
| 554 | 
            +
                                    nn.utils.weight_norm(
         | 
| 555 | 
            +
                                        nn.Conv1d(
         | 
| 556 | 
            +
                                            channel,
         | 
| 557 | 
            +
                                            channel,
         | 
| 558 | 
            +
                                            kernel_size=kernel_size,
         | 
| 559 | 
            +
                                            dilation=3 ** (i % 10),
         | 
| 560 | 
            +
                                            padding=get_padding(kernel_size, 3 ** (i % 10)),
         | 
| 561 | 
            +
                                        )
         | 
| 562 | 
            +
                                    ),
         | 
| 563 | 
            +
                                    nn.LeakyReLU(),
         | 
| 564 | 
            +
                                    nn.utils.weight_norm(
         | 
| 565 | 
            +
                                        nn.Conv1d(
         | 
| 566 | 
            +
                                            channel,
         | 
| 567 | 
            +
                                            channel,
         | 
| 568 | 
            +
                                            kernel_size=kernel_size,
         | 
| 569 | 
            +
                                            dilation=1,
         | 
| 570 | 
            +
                                            padding=get_padding(kernel_size, 1),
         | 
| 571 | 
            +
                                        )
         | 
| 572 | 
            +
                                    ),
         | 
| 573 | 
            +
                                )
         | 
| 574 | 
            +
                                for i in range(resstack_depth)
         | 
| 575 | 
            +
                            ]
         | 
| 576 | 
            +
                        )
         | 
| 577 | 
            +
                    else:
         | 
| 578 | 
            +
                        self.wn = WaveNet(
         | 
| 579 | 
            +
                            in_channels=channel,
         | 
| 580 | 
            +
                            out_channels=channel,
         | 
| 581 | 
            +
                            cin_channels=-1,
         | 
| 582 | 
            +
                            num_layers=resstack_depth,
         | 
| 583 | 
            +
                            residual_channels=channel,
         | 
| 584 | 
            +
                            gate_channels=channel,
         | 
| 585 | 
            +
                            skip_channels=channel,
         | 
| 586 | 
            +
                            # kernel_size=5,
         | 
| 587 | 
            +
                            # dilation_rate=3,
         | 
| 588 | 
            +
                            causal=False,
         | 
| 589 | 
            +
                            use_downup=False,
         | 
| 590 | 
            +
                        )
         | 
| 591 | 
            +
             | 
| 592 | 
            +
                def forward(self, x):
         | 
| 593 | 
            +
                    if not self.use_wn:
         | 
| 594 | 
            +
                        for layer in self.layers:
         | 
| 595 | 
            +
                            x = x + layer(x)
         | 
| 596 | 
            +
                    else:
         | 
| 597 | 
            +
                        x = self.wn(x)
         | 
| 598 | 
            +
             | 
| 599 | 
            +
                    if self.use_shift_scale:
         | 
| 600 | 
            +
                        m_s = self.scale_conv(x)
         | 
| 601 | 
            +
                        m_s = m_s[:, :, :-1]
         | 
| 602 | 
            +
             | 
| 603 | 
            +
                        m, s = torch.split(m_s, self.channel, dim=1)
         | 
| 604 | 
            +
                        s = F.softplus(s)
         | 
| 605 | 
            +
             | 
| 606 | 
            +
                        x = m + s * x[:, :, 1:]  # key!!!
         | 
| 607 | 
            +
                        x = F.pad(x, pad=(1, 0), mode="constant", value=0)
         | 
| 608 | 
            +
             | 
| 609 | 
            +
                    return x
         | 
| 610 | 
            +
             | 
| 611 | 
            +
             | 
| 612 | 
            +
            class WaveNet(nn.Module):
         | 
| 613 | 
            +
                def __init__(
         | 
| 614 | 
            +
                    self,
         | 
| 615 | 
            +
                    in_channels=1,
         | 
| 616 | 
            +
                    out_channels=1,
         | 
| 617 | 
            +
                    num_layers=10,
         | 
| 618 | 
            +
                    residual_channels=64,
         | 
| 619 | 
            +
                    gate_channels=64,
         | 
| 620 | 
            +
                    skip_channels=64,
         | 
| 621 | 
            +
                    kernel_size=3,
         | 
| 622 | 
            +
                    dilation_rate=2,
         | 
| 623 | 
            +
                    cin_channels=80,
         | 
| 624 | 
            +
                    hp=None,
         | 
| 625 | 
            +
                    causal=False,
         | 
| 626 | 
            +
                    use_downup=False,
         | 
| 627 | 
            +
                ):
         | 
| 628 | 
            +
                    super(WaveNet, self).__init__()
         | 
| 629 | 
            +
             | 
| 630 | 
            +
                    self.in_channels = in_channels
         | 
| 631 | 
            +
                    self.causal = causal
         | 
| 632 | 
            +
                    self.num_layers = num_layers
         | 
| 633 | 
            +
                    self.out_channels = out_channels
         | 
| 634 | 
            +
                    self.gate_channels = gate_channels
         | 
| 635 | 
            +
                    self.residual_channels = residual_channels
         | 
| 636 | 
            +
                    self.skip_channels = skip_channels
         | 
| 637 | 
            +
                    self.cin_channels = cin_channels
         | 
| 638 | 
            +
                    self.kernel_size = kernel_size
         | 
| 639 | 
            +
                    self.use_downup = use_downup
         | 
| 640 | 
            +
             | 
| 641 | 
            +
                    self.front_conv = nn.Sequential(
         | 
| 642 | 
            +
                        nn.Conv1d(
         | 
| 643 | 
            +
                            in_channels=self.in_channels,
         | 
| 644 | 
            +
                            out_channels=self.residual_channels,
         | 
| 645 | 
            +
                            kernel_size=3,
         | 
| 646 | 
            +
                            padding=1,
         | 
| 647 | 
            +
                        ),
         | 
| 648 | 
            +
                        nn.ReLU(),
         | 
| 649 | 
            +
                    )
         | 
| 650 | 
            +
                    if self.use_downup:
         | 
| 651 | 
            +
                        self.downup_conv = nn.Sequential(
         | 
| 652 | 
            +
                            nn.Conv1d(
         | 
| 653 | 
            +
                                in_channels=self.residual_channels,
         | 
| 654 | 
            +
                                out_channels=self.residual_channels,
         | 
| 655 | 
            +
                                kernel_size=3,
         | 
| 656 | 
            +
                                stride=2,
         | 
| 657 | 
            +
                                padding=1,
         | 
| 658 | 
            +
                            ),
         | 
| 659 | 
            +
                            nn.ReLU(),
         | 
| 660 | 
            +
                            nn.Conv1d(
         | 
| 661 | 
            +
                                in_channels=self.residual_channels,
         | 
| 662 | 
            +
                                out_channels=self.residual_channels,
         | 
| 663 | 
            +
                                kernel_size=3,
         | 
| 664 | 
            +
                                stride=2,
         | 
| 665 | 
            +
                                padding=1,
         | 
| 666 | 
            +
                            ),
         | 
| 667 | 
            +
                            nn.ReLU(),
         | 
| 668 | 
            +
                            UpsampleNet(self.residual_channels, self.residual_channels, 4, hp),
         | 
| 669 | 
            +
                        )
         | 
| 670 | 
            +
             | 
| 671 | 
            +
                    self.res_blocks = nn.ModuleList()
         | 
| 672 | 
            +
                    for n in range(self.num_layers):
         | 
| 673 | 
            +
                        self.res_blocks.append(
         | 
| 674 | 
            +
                            ResBlock(
         | 
| 675 | 
            +
                                self.residual_channels,
         | 
| 676 | 
            +
                                self.gate_channels,
         | 
| 677 | 
            +
                                self.skip_channels,
         | 
| 678 | 
            +
                                self.kernel_size,
         | 
| 679 | 
            +
                                dilation=dilation_rate**n,
         | 
| 680 | 
            +
                                cin_channels=self.cin_channels,
         | 
| 681 | 
            +
                                local_conditioning=(self.cin_channels > 0),
         | 
| 682 | 
            +
                                causal=self.causal,
         | 
| 683 | 
            +
                                mode="SAME",
         | 
| 684 | 
            +
                            )
         | 
| 685 | 
            +
                        )
         | 
| 686 | 
            +
                    self.final_conv = nn.Sequential(
         | 
| 687 | 
            +
                        nn.ReLU(),
         | 
| 688 | 
            +
                        Conv(self.skip_channels, self.skip_channels, 1, causal=self.causal),
         | 
| 689 | 
            +
                        nn.ReLU(),
         | 
| 690 | 
            +
                        Conv(self.skip_channels, self.out_channels, 1, causal=self.causal),
         | 
| 691 | 
            +
                    )
         | 
| 692 | 
            +
             | 
| 693 | 
            +
                def forward(self, x, c=None):
         | 
| 694 | 
            +
                    return self.wavenet(x, c)
         | 
| 695 | 
            +
             | 
| 696 | 
            +
                def wavenet(self, tensor, c=None):
         | 
| 697 | 
            +
             | 
| 698 | 
            +
                    h = self.front_conv(tensor)
         | 
| 699 | 
            +
                    if self.use_downup:
         | 
| 700 | 
            +
                        h = self.downup_conv(h)
         | 
| 701 | 
            +
                    skip = 0
         | 
| 702 | 
            +
                    for i, f in enumerate(self.res_blocks):
         | 
| 703 | 
            +
                        h, s = f(h, c)
         | 
| 704 | 
            +
                        skip += s
         | 
| 705 | 
            +
                    out = self.final_conv(skip)
         | 
| 706 | 
            +
                    return out
         | 
| 707 | 
            +
             | 
| 708 | 
            +
                def receptive_field_size(self):
         | 
| 709 | 
            +
                    num_dir = 1 if self.causal else 2
         | 
| 710 | 
            +
                    dilations = [2 ** (i % self.num_layers) for i in range(self.num_layers)]
         | 
| 711 | 
            +
                    return (
         | 
| 712 | 
            +
                        num_dir * (self.kernel_size - 1) * sum(dilations)
         | 
| 713 | 
            +
                        + 1
         | 
| 714 | 
            +
                        + (self.front_channels - 1)
         | 
| 715 | 
            +
                    )
         | 
| 716 | 
            +
             | 
| 717 | 
            +
                def remove_weight_norm(self):
         | 
| 718 | 
            +
                    for f in self.res_blocks:
         | 
| 719 | 
            +
                        f.remove_weight_norm()
         | 
| 720 | 
            +
             | 
| 721 | 
            +
             | 
| 722 | 
            +
            class Conv(nn.Module):
         | 
| 723 | 
            +
                def __init__(
         | 
| 724 | 
            +
                    self,
         | 
| 725 | 
            +
                    in_channels,
         | 
| 726 | 
            +
                    out_channels,
         | 
| 727 | 
            +
                    kernel_size,
         | 
| 728 | 
            +
                    dilation=1,
         | 
| 729 | 
            +
                    causal=False,
         | 
| 730 | 
            +
                    mode="SAME",
         | 
| 731 | 
            +
                ):
         | 
| 732 | 
            +
                    super(Conv, self).__init__()
         | 
| 733 | 
            +
             | 
| 734 | 
            +
                    self.causal = causal
         | 
| 735 | 
            +
                    self.mode = mode
         | 
| 736 | 
            +
                    if self.causal and self.mode == "SAME":
         | 
| 737 | 
            +
                        self.padding = dilation * (kernel_size - 1)
         | 
| 738 | 
            +
                    elif self.mode == "SAME":
         | 
| 739 | 
            +
                        self.padding = dilation * (kernel_size - 1) // 2
         | 
| 740 | 
            +
                    else:
         | 
| 741 | 
            +
                        self.padding = 0
         | 
| 742 | 
            +
                    self.conv = nn.Conv1d(
         | 
| 743 | 
            +
                        in_channels,
         | 
| 744 | 
            +
                        out_channels,
         | 
| 745 | 
            +
                        kernel_size,
         | 
| 746 | 
            +
                        dilation=dilation,
         | 
| 747 | 
            +
                        padding=self.padding,
         | 
| 748 | 
            +
                    )
         | 
| 749 | 
            +
                    self.conv = nn.utils.weight_norm(self.conv)
         | 
| 750 | 
            +
                    nn.init.kaiming_normal_(self.conv.weight)
         | 
| 751 | 
            +
             | 
| 752 | 
            +
                def forward(self, tensor):
         | 
| 753 | 
            +
                    out = self.conv(tensor)
         | 
| 754 | 
            +
                    if self.causal and self.padding is not 0:
         | 
| 755 | 
            +
                        out = out[:, :, : -self.padding]
         | 
| 756 | 
            +
                    return out
         | 
| 757 | 
            +
             | 
| 758 | 
            +
                def remove_weight_norm(self):
         | 
| 759 | 
            +
                    nn.utils.remove_weight_norm(self.conv)
         | 
| 760 | 
            +
             | 
| 761 | 
            +
             | 
| 762 | 
            +
            class ResBlock(nn.Module):
         | 
| 763 | 
            +
                def __init__(
         | 
| 764 | 
            +
                    self,
         | 
| 765 | 
            +
                    in_channels,
         | 
| 766 | 
            +
                    out_channels,
         | 
| 767 | 
            +
                    skip_channels,
         | 
| 768 | 
            +
                    kernel_size,
         | 
| 769 | 
            +
                    dilation,
         | 
| 770 | 
            +
                    cin_channels=None,
         | 
| 771 | 
            +
                    local_conditioning=True,
         | 
| 772 | 
            +
                    causal=False,
         | 
| 773 | 
            +
                    mode="SAME",
         | 
| 774 | 
            +
                ):
         | 
| 775 | 
            +
                    super(ResBlock, self).__init__()
         | 
| 776 | 
            +
                    self.causal = causal
         | 
| 777 | 
            +
                    self.local_conditioning = local_conditioning
         | 
| 778 | 
            +
                    self.cin_channels = cin_channels
         | 
| 779 | 
            +
                    self.mode = mode
         | 
| 780 | 
            +
             | 
| 781 | 
            +
                    self.filter_conv = Conv(
         | 
| 782 | 
            +
                        in_channels, out_channels, kernel_size, dilation, causal, mode
         | 
| 783 | 
            +
                    )
         | 
| 784 | 
            +
                    self.gate_conv = Conv(
         | 
| 785 | 
            +
                        in_channels, out_channels, kernel_size, dilation, causal, mode
         | 
| 786 | 
            +
                    )
         | 
| 787 | 
            +
                    self.res_conv = nn.Conv1d(out_channels, in_channels, kernel_size=1)
         | 
| 788 | 
            +
                    self.skip_conv = nn.Conv1d(out_channels, skip_channels, kernel_size=1)
         | 
| 789 | 
            +
                    self.res_conv = nn.utils.weight_norm(self.res_conv)
         | 
| 790 | 
            +
                    self.skip_conv = nn.utils.weight_norm(self.skip_conv)
         | 
| 791 | 
            +
             | 
| 792 | 
            +
                    if self.local_conditioning:
         | 
| 793 | 
            +
                        self.filter_conv_c = nn.Conv1d(cin_channels, out_channels, kernel_size=1)
         | 
| 794 | 
            +
                        self.gate_conv_c = nn.Conv1d(cin_channels, out_channels, kernel_size=1)
         | 
| 795 | 
            +
                        self.filter_conv_c = nn.utils.weight_norm(self.filter_conv_c)
         | 
| 796 | 
            +
                        self.gate_conv_c = nn.utils.weight_norm(self.gate_conv_c)
         | 
| 797 | 
            +
             | 
| 798 | 
            +
                def forward(self, tensor, c=None):
         | 
| 799 | 
            +
                    h_filter = self.filter_conv(tensor)
         | 
| 800 | 
            +
                    h_gate = self.gate_conv(tensor)
         | 
| 801 | 
            +
             | 
| 802 | 
            +
                    if self.local_conditioning:
         | 
| 803 | 
            +
                        h_filter += self.filter_conv_c(c)
         | 
| 804 | 
            +
                        h_gate += self.gate_conv_c(c)
         | 
| 805 | 
            +
             | 
| 806 | 
            +
                    out = torch.tanh(h_filter) * torch.sigmoid(h_gate)
         | 
| 807 | 
            +
             | 
| 808 | 
            +
                    res = self.res_conv(out)
         | 
| 809 | 
            +
                    skip = self.skip_conv(out)
         | 
| 810 | 
            +
                    if self.mode == "SAME":
         | 
| 811 | 
            +
                        return (tensor + res) * math.sqrt(0.5), skip
         | 
| 812 | 
            +
                    else:
         | 
| 813 | 
            +
                        return (tensor[:, :, 1:] + res) * math.sqrt(0.5), skip
         | 
| 814 | 
            +
             | 
| 815 | 
            +
                def remove_weight_norm(self):
         | 
| 816 | 
            +
                    self.filter_conv.remove_weight_norm()
         | 
| 817 | 
            +
                    self.gate_conv.remove_weight_norm()
         | 
| 818 | 
            +
                    nn.utils.remove_weight_norm(self.res_conv)
         | 
| 819 | 
            +
                    nn.utils.remove_weight_norm(self.skip_conv)
         | 
| 820 | 
            +
                    nn.utils.remove_weight_norm(self.filter_conv_c)
         | 
| 821 | 
            +
                    nn.utils.remove_weight_norm(self.gate_conv_c)
         | 
| 822 | 
            +
             | 
| 823 | 
            +
             | 
| 824 | 
            +
            @torch.jit.script
         | 
| 825 | 
            +
            def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
         | 
| 826 | 
            +
                n_channels_int = n_channels[0]
         | 
| 827 | 
            +
                in_act = input_a + input_b
         | 
| 828 | 
            +
                t_act = torch.tanh(in_act[:, :n_channels_int])
         | 
| 829 | 
            +
                s_act = torch.sigmoid(in_act[:, n_channels_int:])
         | 
| 830 | 
            +
                acts = t_act * s_act
         | 
| 831 | 
            +
                return acts
         | 
| 832 | 
            +
             | 
| 833 | 
            +
             | 
| 834 | 
            +
            @torch.jit.script
         | 
| 835 | 
            +
            def fused_res_skip(tensor, res_skip, n_channels):
         | 
| 836 | 
            +
                n_channels_int = n_channels[0]
         | 
| 837 | 
            +
                res = res_skip[:, :n_channels_int]
         | 
| 838 | 
            +
                skip = res_skip[:, n_channels_int:]
         | 
| 839 | 
            +
                return (tensor + res), skip
         | 
| 840 | 
            +
             | 
| 841 | 
            +
             | 
| 842 | 
            +
            class ResStack2D(nn.Module):
         | 
| 843 | 
            +
                def __init__(self, channels=16, kernel_size=3, resstack_depth=4, hp=None):
         | 
| 844 | 
            +
                    super(ResStack2D, self).__init__()
         | 
| 845 | 
            +
                    channels = 16
         | 
| 846 | 
            +
                    kernel_size = 3
         | 
| 847 | 
            +
                    resstack_depth = 2
         | 
| 848 | 
            +
                    self.channels = channels
         | 
| 849 | 
            +
             | 
| 850 | 
            +
                    def get_padding(kernel_size, dilation=1):
         | 
| 851 | 
            +
                        return int((kernel_size * dilation - dilation) / 2)
         | 
| 852 | 
            +
             | 
| 853 | 
            +
                    self.layers = nn.ModuleList(
         | 
| 854 | 
            +
                        [
         | 
| 855 | 
            +
                            nn.Sequential(
         | 
| 856 | 
            +
                                nn.LeakyReLU(),
         | 
| 857 | 
            +
                                nn.utils.weight_norm(
         | 
| 858 | 
            +
                                    nn.Conv2d(
         | 
| 859 | 
            +
                                        1,
         | 
| 860 | 
            +
                                        self.channels,
         | 
| 861 | 
            +
                                        kernel_size,
         | 
| 862 | 
            +
                                        dilation=(1, 3 ** (i)),
         | 
| 863 | 
            +
                                        padding=(1, get_padding(kernel_size, 3 ** (i))),
         | 
| 864 | 
            +
                                    )
         | 
| 865 | 
            +
                                ),
         | 
| 866 | 
            +
                                nn.LeakyReLU(),
         | 
| 867 | 
            +
                                nn.utils.weight_norm(
         | 
| 868 | 
            +
                                    nn.Conv2d(
         | 
| 869 | 
            +
                                        self.channels,
         | 
| 870 | 
            +
                                        self.channels,
         | 
| 871 | 
            +
                                        kernel_size,
         | 
| 872 | 
            +
                                        dilation=(1, 3 ** (i)),
         | 
| 873 | 
            +
                                        padding=(1, get_padding(kernel_size, 3 ** (i))),
         | 
| 874 | 
            +
                                    )
         | 
| 875 | 
            +
                                ),
         | 
| 876 | 
            +
                                nn.LeakyReLU(),
         | 
| 877 | 
            +
                                nn.utils.weight_norm(nn.Conv2d(self.channels, 1, kernel_size=1)),
         | 
| 878 | 
            +
                            )
         | 
| 879 | 
            +
                            for i in range(resstack_depth)
         | 
| 880 | 
            +
                        ]
         | 
| 881 | 
            +
                    )
         | 
| 882 | 
            +
             | 
| 883 | 
            +
                def forward(self, tensor):
         | 
| 884 | 
            +
                    x = tensor.unsqueeze(1)
         | 
| 885 | 
            +
                    for layer in self.layers:
         | 
| 886 | 
            +
                        x = x + layer(x)
         | 
| 887 | 
            +
                    x = x.squeeze(1)
         | 
| 888 | 
            +
             | 
| 889 | 
            +
                    return x
         | 
| 890 | 
            +
             | 
| 891 | 
            +
             | 
| 892 | 
            +
            class FiLM(nn.Module):
         | 
| 893 | 
            +
                """
         | 
| 894 | 
            +
                feature-wise linear modulation
         | 
| 895 | 
            +
                """
         | 
| 896 | 
            +
             | 
| 897 | 
            +
                def __init__(self, input_dim, attribute_dim):
         | 
| 898 | 
            +
                    super().__init__()
         | 
| 899 | 
            +
                    self.input_dim = input_dim
         | 
| 900 | 
            +
                    self.generator = nn.Conv1d(
         | 
| 901 | 
            +
                        attribute_dim, input_dim * 2, kernel_size=3, padding=1
         | 
| 902 | 
            +
                    )
         | 
| 903 | 
            +
             | 
| 904 | 
            +
                def forward(self, x, c):
         | 
| 905 | 
            +
                    """
         | 
| 906 | 
            +
                    x: (B, input_dim, seq)
         | 
| 907 | 
            +
                    c: (B, attribute_dim, seq)
         | 
| 908 | 
            +
                    """
         | 
| 909 | 
            +
                    c = self.generator(c)
         | 
| 910 | 
            +
                    m, s = torch.split(c, self.input_dim, dim=1)
         | 
| 911 | 
            +
             | 
| 912 | 
            +
                    return x * s + m
         | 
| 913 | 
            +
             | 
| 914 | 
            +
             | 
| 915 | 
            +
            class FiLMConv1d(nn.Module):
         | 
| 916 | 
            +
                """
         | 
| 917 | 
            +
                Conv1d with FiLMs in between
         | 
| 918 | 
            +
                """
         | 
| 919 | 
            +
             | 
| 920 | 
            +
                def __init__(self, in_size, out_size, attribute_dim, ins_norm=True, loop=1):
         | 
| 921 | 
            +
                    super().__init__()
         | 
| 922 | 
            +
                    self.loop = loop
         | 
| 923 | 
            +
                    self.mlps = nn.ModuleList(
         | 
| 924 | 
            +
                        [nn.Conv1d(in_size, out_size, kernel_size=3, padding=1)]
         | 
| 925 | 
            +
                        + [
         | 
| 926 | 
            +
                            nn.Conv1d(out_size, out_size, kernel_size=3, padding=1)
         | 
| 927 | 
            +
                            for i in range(loop - 1)
         | 
| 928 | 
            +
                        ]
         | 
| 929 | 
            +
                    )
         | 
| 930 | 
            +
                    self.films = nn.ModuleList([FiLM(out_size, attribute_dim) for i in range(loop)])
         | 
| 931 | 
            +
                    self.ins_norm = ins_norm
         | 
| 932 | 
            +
                    if self.ins_norm:
         | 
| 933 | 
            +
                        self.norm = nn.InstanceNorm1d(attribute_dim)
         | 
| 934 | 
            +
             | 
| 935 | 
            +
                def forward(self, x, c):
         | 
| 936 | 
            +
                    """
         | 
| 937 | 
            +
                    x: (B, input_dim, seq)
         | 
| 938 | 
            +
                    c: (B, attribute_dim, seq)
         | 
| 939 | 
            +
                    """
         | 
| 940 | 
            +
                    if self.ins_norm:
         | 
| 941 | 
            +
                        c = self.norm(c)
         | 
| 942 | 
            +
                    for i in range(self.loop):
         | 
| 943 | 
            +
                        x = self.mlps[i](x)
         | 
| 944 | 
            +
                        x = F.relu(x)
         | 
| 945 | 
            +
                        x = self.films[i](x, c)
         | 
| 946 | 
            +
             | 
| 947 | 
            +
                    return x
         | 
    	
        voicefixer/vocoder/model/pqmf.py
    ADDED
    
    | @@ -0,0 +1,61 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import sys
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn as nn
         | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
            import scipy.io.wavfile
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            class PQMF(nn.Module):
         | 
| 10 | 
            +
                def __init__(self, N, M, file_path="utils/pqmf_hk_4_64.dat"):
         | 
| 11 | 
            +
                    super().__init__()
         | 
| 12 | 
            +
                    self.N = N  # nsubband
         | 
| 13 | 
            +
                    self.M = M  # nfilter
         | 
| 14 | 
            +
                    self.ana_conv_filter = nn.Conv1d(
         | 
| 15 | 
            +
                        1, out_channels=N, kernel_size=M, stride=N, bias=False
         | 
| 16 | 
            +
                    )
         | 
| 17 | 
            +
                    data = np.reshape(np.fromfile(file_path, dtype=np.float32), (N, M))
         | 
| 18 | 
            +
                    data = np.flipud(data.T).T
         | 
| 19 | 
            +
                    gk = data.copy()
         | 
| 20 | 
            +
                    data = np.reshape(data, (N, 1, M)).copy()
         | 
| 21 | 
            +
                    dict_new = self.ana_conv_filter.state_dict().copy()
         | 
| 22 | 
            +
                    dict_new["weight"] = torch.from_numpy(data)
         | 
| 23 | 
            +
                    self.ana_pad = nn.ConstantPad1d((M - N, 0), 0)
         | 
| 24 | 
            +
                    self.ana_conv_filter.load_state_dict(dict_new)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                    self.syn_pad = nn.ConstantPad1d((0, M // N - 1), 0)
         | 
| 27 | 
            +
                    self.syn_conv_filter = nn.Conv1d(
         | 
| 28 | 
            +
                        N, out_channels=N, kernel_size=M // N, stride=1, bias=False
         | 
| 29 | 
            +
                    )
         | 
| 30 | 
            +
                    gk = np.transpose(np.reshape(gk, (4, 16, 4)), (1, 0, 2)) * N
         | 
| 31 | 
            +
                    gk = np.transpose(gk[::-1, :, :], (2, 1, 0)).copy()
         | 
| 32 | 
            +
                    dict_new = self.syn_conv_filter.state_dict().copy()
         | 
| 33 | 
            +
                    dict_new["weight"] = torch.from_numpy(gk)
         | 
| 34 | 
            +
                    self.syn_conv_filter.load_state_dict(dict_new)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                    for param in self.parameters():
         | 
| 37 | 
            +
                        param.requires_grad = False
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                def analysis(self, inputs):
         | 
| 40 | 
            +
                    return self.ana_conv_filter(self.ana_pad(inputs))
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                def synthesis(self, inputs):
         | 
| 43 | 
            +
                    return self.syn_conv_filter(self.syn_pad(inputs))
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                def forward(self, inputs):
         | 
| 46 | 
            +
                    return self.ana_conv_filter(self.ana_pad(inputs))
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            if __name__ == "__main__":
         | 
| 50 | 
            +
                a = PQMF(4, 64)
         | 
| 51 | 
            +
                # x = np.load('data/train/audio/010000.npy')
         | 
| 52 | 
            +
                x = np.zeros([8, 24000], np.float32)
         | 
| 53 | 
            +
                x = np.reshape(x, (8, 1, -1))
         | 
| 54 | 
            +
                x = torch.from_numpy(x)
         | 
| 55 | 
            +
                b = a.analysis(x)
         | 
| 56 | 
            +
                c = a.synthesis(b)
         | 
| 57 | 
            +
                print(x.shape, b.shape, c.shape)
         | 
| 58 | 
            +
                b = (b * 32768).numpy()
         | 
| 59 | 
            +
                b = np.reshape(np.transpose(b, (0, 2, 1)), (-1, 1)).astype(np.int16)
         | 
| 60 | 
            +
                # b.tofile('1.pcm')
         | 
| 61 | 
            +
                # np.reshape(np.transpose(c.numpy()*32768, (0, 2, 1)), (-1,1)).astype(np.int16).tofile('2.pcm')
         | 
    	
        voicefixer/vocoder/model/res_msd.py
    ADDED
    
    | @@ -0,0 +1,71 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn.functional as F
         | 
| 3 | 
            +
            import torch.nn as nn
         | 
| 4 | 
            +
            from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
         | 
| 5 | 
            +
            from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            LRELU_SLOPE = 0.1
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            def init_weights(m, mean=0.0, std=0.01):
         | 
| 11 | 
            +
                classname = m.__class__.__name__
         | 
| 12 | 
            +
                if classname.find("Conv") != -1:
         | 
| 13 | 
            +
                    m.weight.data.normal_(mean, std)
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            def get_padding(kernel_size, dilation=1):
         | 
| 17 | 
            +
                return int((kernel_size * dilation - dilation) / 2)
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            class ResStack(nn.Module):
         | 
| 21 | 
            +
                def __init__(self, channels=384, kernel_size=3, resstack_depth=3, hp=None):
         | 
| 22 | 
            +
                    super(ResStack, self).__init__()
         | 
| 23 | 
            +
                    dilation = [2 * i + 1 for i in range(resstack_depth)]  # [1, 3, 5]
         | 
| 24 | 
            +
                    self.convs1 = nn.ModuleList(
         | 
| 25 | 
            +
                        [
         | 
| 26 | 
            +
                            weight_norm(
         | 
| 27 | 
            +
                                Conv1d(
         | 
| 28 | 
            +
                                    channels,
         | 
| 29 | 
            +
                                    channels,
         | 
| 30 | 
            +
                                    kernel_size,
         | 
| 31 | 
            +
                                    1,
         | 
| 32 | 
            +
                                    dilation=dilation[i],
         | 
| 33 | 
            +
                                    padding=get_padding(kernel_size, dilation[i]),
         | 
| 34 | 
            +
                                )
         | 
| 35 | 
            +
                            )
         | 
| 36 | 
            +
                            for i in range(resstack_depth)
         | 
| 37 | 
            +
                        ]
         | 
| 38 | 
            +
                    )
         | 
| 39 | 
            +
                    self.convs1.apply(init_weights)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    self.convs2 = nn.ModuleList(
         | 
| 42 | 
            +
                        [
         | 
| 43 | 
            +
                            weight_norm(
         | 
| 44 | 
            +
                                Conv1d(
         | 
| 45 | 
            +
                                    channels,
         | 
| 46 | 
            +
                                    channels,
         | 
| 47 | 
            +
                                    kernel_size,
         | 
| 48 | 
            +
                                    1,
         | 
| 49 | 
            +
                                    dilation=1,
         | 
| 50 | 
            +
                                    padding=get_padding(kernel_size, 1),
         | 
| 51 | 
            +
                                )
         | 
| 52 | 
            +
                            )
         | 
| 53 | 
            +
                            for i in range(resstack_depth)
         | 
| 54 | 
            +
                        ]
         | 
| 55 | 
            +
                    )
         | 
| 56 | 
            +
                    self.convs2.apply(init_weights)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                def forward(self, x):
         | 
| 59 | 
            +
                    for c1, c2 in zip(self.convs1, self.convs2):
         | 
| 60 | 
            +
                        xt = F.leaky_relu(x, LRELU_SLOPE)
         | 
| 61 | 
            +
                        xt = c1(xt)
         | 
| 62 | 
            +
                        xt = F.leaky_relu(xt, LRELU_SLOPE)
         | 
| 63 | 
            +
                        xt = c2(xt)
         | 
| 64 | 
            +
                        x = xt + x
         | 
| 65 | 
            +
                    return x
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                def remove_weight_norm(self):
         | 
| 68 | 
            +
                    for l in self.convs1:
         | 
| 69 | 
            +
                        remove_weight_norm(l)
         | 
| 70 | 
            +
                    for l in self.convs2:
         | 
| 71 | 
            +
                        remove_weight_norm(l)
         | 
    	
        voicefixer/vocoder/model/util.py
    ADDED
    
    | @@ -0,0 +1,135 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from voicefixer.vocoder.config import Config
         | 
| 2 | 
            +
            from voicefixer.tools.pytorch_util import try_tensor_cuda, check_cuda_availability
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import librosa
         | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            def tr_normalize(S):
         | 
| 9 | 
            +
                if Config.allow_clipping_in_normalization:
         | 
| 10 | 
            +
                    if Config.symmetric_mels:
         | 
| 11 | 
            +
                        return torch.clip(
         | 
| 12 | 
            +
                            (2 * Config.max_abs_value) * ((S - Config.min_db) / (-Config.min_db))
         | 
| 13 | 
            +
                            - Config.max_abs_value,
         | 
| 14 | 
            +
                            -Config.max_abs_value,
         | 
| 15 | 
            +
                            Config.max_abs_value,
         | 
| 16 | 
            +
                        )
         | 
| 17 | 
            +
                    else:
         | 
| 18 | 
            +
                        return torch.clip(
         | 
| 19 | 
            +
                            Config.max_abs_value * ((S - Config.min_db) / (-Config.min_db)),
         | 
| 20 | 
            +
                            0,
         | 
| 21 | 
            +
                            Config.max_abs_value,
         | 
| 22 | 
            +
                        )
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                assert S.max() <= 0 and S.min() - Config.min_db >= 0
         | 
| 25 | 
            +
                if Config.symmetric_mels:
         | 
| 26 | 
            +
                    return (2 * Config.max_abs_value) * (
         | 
| 27 | 
            +
                        (S - Config.min_db) / (-Config.min_db)
         | 
| 28 | 
            +
                    ) - Config.max_abs_value
         | 
| 29 | 
            +
                else:
         | 
| 30 | 
            +
                    return Config.max_abs_value * ((S - Config.min_db) / (-Config.min_db))
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            def tr_amp_to_db(x):
         | 
| 34 | 
            +
                min_level = torch.exp(Config.min_level_db / 20 * torch.log(torch.tensor(10.0)))
         | 
| 35 | 
            +
                min_level = min_level.type_as(x)
         | 
| 36 | 
            +
                return 20 * torch.log10(torch.maximum(min_level, x))
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            def normalize(S):
         | 
| 40 | 
            +
                if Config.allow_clipping_in_normalization:
         | 
| 41 | 
            +
                    if Config.symmetric_mels:
         | 
| 42 | 
            +
                        return np.clip(
         | 
| 43 | 
            +
                            (2 * Config.max_abs_value) * ((S - Config.min_db) / (-Config.min_db))
         | 
| 44 | 
            +
                            - Config.max_abs_value,
         | 
| 45 | 
            +
                            -Config.max_abs_value,
         | 
| 46 | 
            +
                            Config.max_abs_value,
         | 
| 47 | 
            +
                        )
         | 
| 48 | 
            +
                    else:
         | 
| 49 | 
            +
                        return np.clip(
         | 
| 50 | 
            +
                            Config.max_abs_value * ((S - Config.min_db) / (-Config.min_db)),
         | 
| 51 | 
            +
                            0,
         | 
| 52 | 
            +
                            Config.max_abs_value,
         | 
| 53 | 
            +
                        )
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                assert S.max() <= 0 and S.min() - Config.min_db >= 0
         | 
| 56 | 
            +
                if Config.symmetric_mels:
         | 
| 57 | 
            +
                    return (2 * Config.max_abs_value) * (
         | 
| 58 | 
            +
                        (S - Config.min_db) / (-Config.min_db)
         | 
| 59 | 
            +
                    ) - Config.max_abs_value
         | 
| 60 | 
            +
                else:
         | 
| 61 | 
            +
                    return Config.max_abs_value * ((S - Config.min_db) / (-Config.min_db))
         | 
| 62 | 
            +
             | 
| 63 | 
            +
             | 
| 64 | 
            +
            def amp_to_db(x):
         | 
| 65 | 
            +
                min_level = np.exp(Config.min_level_db / 20 * np.log(10))
         | 
| 66 | 
            +
                return 20 * np.log10(np.maximum(min_level, x))
         | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
            def tr_pre(npy):
         | 
| 70 | 
            +
                # conditions = torch.FloatTensor(npy).type_as(npy) # to(device)
         | 
| 71 | 
            +
                conditions = npy.transpose(1, 2)
         | 
| 72 | 
            +
                l = conditions.size(-1)
         | 
| 73 | 
            +
                pad_tail = l % 2 + 4
         | 
| 74 | 
            +
                zeros = (
         | 
| 75 | 
            +
                    torch.zeros([conditions.size()[0], Config.num_mels, pad_tail]).type_as(
         | 
| 76 | 
            +
                        conditions
         | 
| 77 | 
            +
                    )
         | 
| 78 | 
            +
                    + -4.0
         | 
| 79 | 
            +
                )
         | 
| 80 | 
            +
                return torch.cat([conditions, zeros], dim=-1)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
             | 
| 83 | 
            +
            def pre(npy):
         | 
| 84 | 
            +
                conditions = npy
         | 
| 85 | 
            +
                ## padding tail
         | 
| 86 | 
            +
                if type(conditions) == np.ndarray:
         | 
| 87 | 
            +
                    conditions = torch.FloatTensor(conditions).unsqueeze(0)
         | 
| 88 | 
            +
                else:
         | 
| 89 | 
            +
                    conditions = torch.FloatTensor(conditions.float()).unsqueeze(0)
         | 
| 90 | 
            +
                conditions = conditions.transpose(1, 2)
         | 
| 91 | 
            +
                l = conditions.size(-1)
         | 
| 92 | 
            +
                pad_tail = l % 2 + 4
         | 
| 93 | 
            +
                zeros = torch.zeros([1, Config.num_mels, pad_tail]) + -4.0
         | 
| 94 | 
            +
                return torch.cat([conditions, zeros], dim=-1)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
             | 
| 97 | 
            +
            def load_try(state, model):
         | 
| 98 | 
            +
                model_dict = model.state_dict()
         | 
| 99 | 
            +
                try:
         | 
| 100 | 
            +
                    model_dict.update(state)
         | 
| 101 | 
            +
                    model.load_state_dict(model_dict)
         | 
| 102 | 
            +
                except RuntimeError as e:
         | 
| 103 | 
            +
                    print(str(e))
         | 
| 104 | 
            +
                    model_dict = model.state_dict()
         | 
| 105 | 
            +
                    for k, v in state.items():
         | 
| 106 | 
            +
                        model_dict[k] = v
         | 
| 107 | 
            +
                        model.load_state_dict(model_dict)
         | 
| 108 | 
            +
             | 
| 109 | 
            +
             | 
| 110 | 
            +
            def load_checkpoint(checkpoint_path, device):
         | 
| 111 | 
            +
                checkpoint = torch.load(checkpoint_path, map_location=device)
         | 
| 112 | 
            +
                return checkpoint
         | 
| 113 | 
            +
             | 
| 114 | 
            +
             | 
| 115 | 
            +
            def build_mel_basis():
         | 
| 116 | 
            +
                return librosa.filters.mel(
         | 
| 117 | 
            +
                    Config.sample_rate,
         | 
| 118 | 
            +
                    Config.n_fft,
         | 
| 119 | 
            +
                    htk=True,
         | 
| 120 | 
            +
                    n_mels=Config.num_mels,
         | 
| 121 | 
            +
                    fmin=0,
         | 
| 122 | 
            +
                    fmax=int(Config.sample_rate // 2),
         | 
| 123 | 
            +
                )
         | 
| 124 | 
            +
             | 
| 125 | 
            +
             | 
| 126 | 
            +
            def linear_to_mel(spectogram):
         | 
| 127 | 
            +
                _mel_basis = build_mel_basis()
         | 
| 128 | 
            +
                return np.dot(_mel_basis, spectogram)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
             | 
| 131 | 
            +
            if __name__ == "__main__":
         | 
| 132 | 
            +
                data = torch.randn((3, 5, 100))
         | 
| 133 | 
            +
                b = normalize(amp_to_db(data.numpy()))
         | 
| 134 | 
            +
                a = tr_normalize(tr_amp_to_db(data)).numpy()
         | 
| 135 | 
            +
                print(a - b)
         |