v0.5.0
Browse files- .gitignore +214 -0
- LICENSE +201 -0
- SoniTranslate_Colab.ipynb +124 -0
- app.py +2 -0
- app_rvc.py +0 -0
- assets/logo.jpeg +0 -0
- docs/windows_install.md +150 -0
- lib/audio.py +21 -0
- lib/infer_pack/attentions.py +417 -0
- lib/infer_pack/commons.py +166 -0
- lib/infer_pack/models.py +1142 -0
- lib/infer_pack/modules.py +522 -0
- lib/infer_pack/transforms.py +209 -0
- lib/rmvpe.py +422 -0
- mdx_models/data.json +354 -0
- packages.txt +3 -0
- pre-requirements.txt +15 -0
- requirements.txt +19 -0
- requirements_xtts.txt +58 -0
- soni_translate/audio_segments.py +141 -0
- soni_translate/language_configuration.py +551 -0
- soni_translate/languages_gui.py +0 -0
- soni_translate/logging_setup.py +68 -0
- soni_translate/mdx_net.py +582 -0
- soni_translate/postprocessor.py +229 -0
- soni_translate/preprocessor.py +308 -0
- soni_translate/speech_segmentation.py +499 -0
- soni_translate/text_multiformat_processor.py +987 -0
- soni_translate/text_to_speech.py +1574 -0
- soni_translate/translate_segments.py +457 -0
- soni_translate/utils.py +487 -0
- vci_pipeline.py +454 -0
- voice_main.py +732 -0
    	
        .gitignore
    ADDED
    
    | @@ -0,0 +1,214 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Byte-compiled / optimized / DLL files
         | 
| 2 | 
            +
            __pycache__/
         | 
| 3 | 
            +
            *.py[cod]
         | 
| 4 | 
            +
            *$py.class
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # C extensions
         | 
| 7 | 
            +
            *.so
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            # Distribution / packaging
         | 
| 10 | 
            +
            .Python
         | 
| 11 | 
            +
            build/
         | 
| 12 | 
            +
            develop-eggs/
         | 
| 13 | 
            +
            dist/
         | 
| 14 | 
            +
            downloads/
         | 
| 15 | 
            +
            eggs/
         | 
| 16 | 
            +
            .eggs/
         | 
| 17 | 
            +
            lib64/
         | 
| 18 | 
            +
            parts/
         | 
| 19 | 
            +
            sdist/
         | 
| 20 | 
            +
            var/
         | 
| 21 | 
            +
            wheels/
         | 
| 22 | 
            +
            share/python-wheels/
         | 
| 23 | 
            +
            *.egg-info/
         | 
| 24 | 
            +
            .installed.cfg
         | 
| 25 | 
            +
            *.egg
         | 
| 26 | 
            +
            MANIFEST
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            # PyInstaller
         | 
| 29 | 
            +
            #  Usually these files are written by a python script from a template
         | 
| 30 | 
            +
            #  before PyInstaller builds the exe, so as to inject date/other infos into it.
         | 
| 31 | 
            +
            *.manifest
         | 
| 32 | 
            +
            *.spec
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            # Installer logs
         | 
| 35 | 
            +
            pip-log.txt
         | 
| 36 | 
            +
            pip-delete-this-directory.txt
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            # Unit test / coverage reports
         | 
| 39 | 
            +
            htmlcov/
         | 
| 40 | 
            +
            .tox/
         | 
| 41 | 
            +
            .nox/
         | 
| 42 | 
            +
            .coverage
         | 
| 43 | 
            +
            .coverage.*
         | 
| 44 | 
            +
            .cache
         | 
| 45 | 
            +
            nosetests.xml
         | 
| 46 | 
            +
            coverage.xml
         | 
| 47 | 
            +
            *.cover
         | 
| 48 | 
            +
            *.py,cover
         | 
| 49 | 
            +
            .hypothesis/
         | 
| 50 | 
            +
            .pytest_cache/
         | 
| 51 | 
            +
            cover/
         | 
| 52 | 
            +
             | 
| 53 | 
            +
            # Translations
         | 
| 54 | 
            +
            *.mo
         | 
| 55 | 
            +
            *.pot
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            # Django stuff:
         | 
| 58 | 
            +
            *.log
         | 
| 59 | 
            +
            local_settings.py
         | 
| 60 | 
            +
            db.sqlite3
         | 
| 61 | 
            +
            db.sqlite3-journal
         | 
| 62 | 
            +
             | 
| 63 | 
            +
            # Flask stuff:
         | 
| 64 | 
            +
            instance/
         | 
| 65 | 
            +
            .webassets-cache
         | 
| 66 | 
            +
             | 
| 67 | 
            +
            # Scrapy stuff:
         | 
| 68 | 
            +
            .scrapy
         | 
| 69 | 
            +
             | 
| 70 | 
            +
            # Sphinx documentation
         | 
| 71 | 
            +
            docs/_build/
         | 
| 72 | 
            +
             | 
| 73 | 
            +
            # PyBuilder
         | 
| 74 | 
            +
            .pybuilder/
         | 
| 75 | 
            +
            target/
         | 
| 76 | 
            +
             | 
| 77 | 
            +
            # Jupyter Notebook
         | 
| 78 | 
            +
            .ipynb_checkpoints
         | 
| 79 | 
            +
             | 
| 80 | 
            +
            # IPython
         | 
| 81 | 
            +
            profile_default/
         | 
| 82 | 
            +
            ipython_config.py
         | 
| 83 | 
            +
             | 
| 84 | 
            +
            # pyenv
         | 
| 85 | 
            +
            #   For a library or package, you might want to ignore these files since the code is
         | 
| 86 | 
            +
            #   intended to run in multiple environments; otherwise, check them in:
         | 
| 87 | 
            +
            # .python-version
         | 
| 88 | 
            +
             | 
| 89 | 
            +
            # pipenv
         | 
| 90 | 
            +
            #   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
         | 
| 91 | 
            +
            #   However, in case of collaboration, if having platform-specific dependencies or dependencies
         | 
| 92 | 
            +
            #   having no cross-platform support, pipenv may install dependencies that don't work, or not
         | 
| 93 | 
            +
            #   install all needed dependencies.
         | 
| 94 | 
            +
            #Pipfile.lock
         | 
| 95 | 
            +
             | 
| 96 | 
            +
            # poetry
         | 
| 97 | 
            +
            #   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
         | 
| 98 | 
            +
            #   This is especially recommended for binary packages to ensure reproducibility, and is more
         | 
| 99 | 
            +
            #   commonly ignored for libraries.
         | 
| 100 | 
            +
            #   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
         | 
| 101 | 
            +
            #poetry.lock
         | 
| 102 | 
            +
             | 
| 103 | 
            +
            # pdm
         | 
| 104 | 
            +
            #   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
         | 
| 105 | 
            +
            #pdm.lock
         | 
| 106 | 
            +
            #   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
         | 
| 107 | 
            +
            #   in version control.
         | 
| 108 | 
            +
            #   https://pdm.fming.dev/#use-with-ide
         | 
| 109 | 
            +
            .pdm.toml
         | 
| 110 | 
            +
             | 
| 111 | 
            +
            # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
         | 
| 112 | 
            +
            __pypackages__/
         | 
| 113 | 
            +
             | 
| 114 | 
            +
            # Celery stuff
         | 
| 115 | 
            +
            celerybeat-schedule
         | 
| 116 | 
            +
            celerybeat.pid
         | 
| 117 | 
            +
             | 
| 118 | 
            +
            # SageMath parsed files
         | 
| 119 | 
            +
            *.sage.py
         | 
| 120 | 
            +
             | 
| 121 | 
            +
            # Environments
         | 
| 122 | 
            +
            .env
         | 
| 123 | 
            +
            .venv
         | 
| 124 | 
            +
            env/
         | 
| 125 | 
            +
            venv/
         | 
| 126 | 
            +
            ENV/
         | 
| 127 | 
            +
            env.bak/
         | 
| 128 | 
            +
            venv.bak/
         | 
| 129 | 
            +
             | 
| 130 | 
            +
            # Spyder project settings
         | 
| 131 | 
            +
            .spyderproject
         | 
| 132 | 
            +
            .spyproject
         | 
| 133 | 
            +
             | 
| 134 | 
            +
            # Rope project settings
         | 
| 135 | 
            +
            .ropeproject
         | 
| 136 | 
            +
             | 
| 137 | 
            +
            # mkdocs documentation
         | 
| 138 | 
            +
            /site
         | 
| 139 | 
            +
             | 
| 140 | 
            +
            # mypy
         | 
| 141 | 
            +
            .mypy_cache/
         | 
| 142 | 
            +
            .dmypy.json
         | 
| 143 | 
            +
            dmypy.json
         | 
| 144 | 
            +
             | 
| 145 | 
            +
            # Pyre type checker
         | 
| 146 | 
            +
            .pyre/
         | 
| 147 | 
            +
             | 
| 148 | 
            +
            # pytype static type analyzer
         | 
| 149 | 
            +
            .pytype/
         | 
| 150 | 
            +
             | 
| 151 | 
            +
            # Cython debug symbols
         | 
| 152 | 
            +
            cython_debug/
         | 
| 153 | 
            +
             | 
| 154 | 
            +
            # PyCharm
         | 
| 155 | 
            +
            #  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
         | 
| 156 | 
            +
            #  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
         | 
| 157 | 
            +
            #  and can be added to the global gitignore or merged into this file.  For a more nuclear
         | 
| 158 | 
            +
            #  option (not recommended) you can uncomment the following to ignore the entire idea folder.
         | 
| 159 | 
            +
            #.idea/
         | 
| 160 | 
            +
             | 
| 161 | 
            +
            # Ignore
         | 
| 162 | 
            +
            sub_tra.*
         | 
| 163 | 
            +
            sub_ori.*
         | 
| 164 | 
            +
            SPEAKER_00.*
         | 
| 165 | 
            +
            SPEAKER_01.*
         | 
| 166 | 
            +
            SPEAKER_02.*
         | 
| 167 | 
            +
            SPEAKER_03.*
         | 
| 168 | 
            +
            SPEAKER_04.*
         | 
| 169 | 
            +
            SPEAKER_05.*
         | 
| 170 | 
            +
            SPEAKER_06.*
         | 
| 171 | 
            +
            SPEAKER_07.*
         | 
| 172 | 
            +
            SPEAKER_08.*
         | 
| 173 | 
            +
            SPEAKER_09.*
         | 
| 174 | 
            +
            SPEAKER_10.*
         | 
| 175 | 
            +
            SPEAKER_11.*
         | 
| 176 | 
            +
            task_subtitle.*
         | 
| 177 | 
            +
            *.mp3
         | 
| 178 | 
            +
            *.mp4
         | 
| 179 | 
            +
            *.ogg
         | 
| 180 | 
            +
            *.wav
         | 
| 181 | 
            +
            *.mkv
         | 
| 182 | 
            +
            *.webm
         | 
| 183 | 
            +
            *.avi
         | 
| 184 | 
            +
            *.mpg
         | 
| 185 | 
            +
            *.mov
         | 
| 186 | 
            +
            *.ogv
         | 
| 187 | 
            +
            *.wmv
         | 
| 188 | 
            +
            test.py
         | 
| 189 | 
            +
            list.txt
         | 
| 190 | 
            +
            text_preprocessor.txt
         | 
| 191 | 
            +
            text_translation.txt
         | 
| 192 | 
            +
            *.srt
         | 
| 193 | 
            +
            *.vtt
         | 
| 194 | 
            +
            *.tsv
         | 
| 195 | 
            +
            *.aud
         | 
| 196 | 
            +
            *.ass
         | 
| 197 | 
            +
            *.pt
         | 
| 198 | 
            +
            .vscode/
         | 
| 199 | 
            +
            mdx_models/*.onnx
         | 
| 200 | 
            +
            _XTTS_/
         | 
| 201 | 
            +
            downloads/
         | 
| 202 | 
            +
            logs/
         | 
| 203 | 
            +
            weights/
         | 
| 204 | 
            +
            clean_song_output/
         | 
| 205 | 
            +
            audio2/
         | 
| 206 | 
            +
            audio/
         | 
| 207 | 
            +
            outputs/
         | 
| 208 | 
            +
            processed/
         | 
| 209 | 
            +
            OPENVOICE_MODELS/
         | 
| 210 | 
            +
            PIPER_MODELS/
         | 
| 211 | 
            +
            WHISPER_MODELS/
         | 
| 212 | 
            +
            whisper_api_audio_parts/
         | 
| 213 | 
            +
            uroman/
         | 
| 214 | 
            +
            pdf_images/
         | 
    	
        LICENSE
    ADDED
    
    | @@ -0,0 +1,201 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
                                             Apache License
         | 
| 2 | 
            +
                                       Version 2.0, January 2004
         | 
| 3 | 
            +
                                    http://www.apache.org/licenses/
         | 
| 4 | 
            +
             | 
| 5 | 
            +
               TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
         | 
| 6 | 
            +
             | 
| 7 | 
            +
               1. Definitions.
         | 
| 8 | 
            +
             | 
| 9 | 
            +
                  "License" shall mean the terms and conditions for use, reproduction,
         | 
| 10 | 
            +
                  and distribution as defined by Sections 1 through 9 of this document.
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                  "Licensor" shall mean the copyright owner or entity authorized by
         | 
| 13 | 
            +
                  the copyright owner that is granting the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                  "Legal Entity" shall mean the union of the acting entity and all
         | 
| 16 | 
            +
                  other entities that control, are controlled by, or are under common
         | 
| 17 | 
            +
                  control with that entity. For the purposes of this definition,
         | 
| 18 | 
            +
                  "control" means (i) the power, direct or indirect, to cause the
         | 
| 19 | 
            +
                  direction or management of such entity, whether by contract or
         | 
| 20 | 
            +
                  otherwise, or (ii) ownership of fifty percent (50%) or more of the
         | 
| 21 | 
            +
                  outstanding shares, or (iii) beneficial ownership of such entity.
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                  "You" (or "Your") shall mean an individual or Legal Entity
         | 
| 24 | 
            +
                  exercising permissions granted by this License.
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                  "Source" form shall mean the preferred form for making modifications,
         | 
| 27 | 
            +
                  including but not limited to software source code, documentation
         | 
| 28 | 
            +
                  source, and configuration files.
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                  "Object" form shall mean any form resulting from mechanical
         | 
| 31 | 
            +
                  transformation or translation of a Source form, including but
         | 
| 32 | 
            +
                  not limited to compiled object code, generated documentation,
         | 
| 33 | 
            +
                  and conversions to other media types.
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                  "Work" shall mean the work of authorship, whether in Source or
         | 
| 36 | 
            +
                  Object form, made available under the License, as indicated by a
         | 
| 37 | 
            +
                  copyright notice that is included in or attached to the work
         | 
| 38 | 
            +
                  (an example is provided in the Appendix below).
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                  "Derivative Works" shall mean any work, whether in Source or Object
         | 
| 41 | 
            +
                  form, that is based on (or derived from) the Work and for which the
         | 
| 42 | 
            +
                  editorial revisions, annotations, elaborations, or other modifications
         | 
| 43 | 
            +
                  represent, as a whole, an original work of authorship. For the purposes
         | 
| 44 | 
            +
                  of this License, Derivative Works shall not include works that remain
         | 
| 45 | 
            +
                  separable from, or merely link (or bind by name) to the interfaces of,
         | 
| 46 | 
            +
                  the Work and Derivative Works thereof.
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                  "Contribution" shall mean any work of authorship, including
         | 
| 49 | 
            +
                  the original version of the Work and any modifications or additions
         | 
| 50 | 
            +
                  to that Work or Derivative Works thereof, that is intentionally
         | 
| 51 | 
            +
                  submitted to Licensor for inclusion in the Work by the copyright owner
         | 
| 52 | 
            +
                  or by an individual or Legal Entity authorized to submit on behalf of
         | 
| 53 | 
            +
                  the copyright owner. For the purposes of this definition, "submitted"
         | 
| 54 | 
            +
                  means any form of electronic, verbal, or written communication sent
         | 
| 55 | 
            +
                  to the Licensor or its representatives, including but not limited to
         | 
| 56 | 
            +
                  communication on electronic mailing lists, source code control systems,
         | 
| 57 | 
            +
                  and issue tracking systems that are managed by, or on behalf of, the
         | 
| 58 | 
            +
                  Licensor for the purpose of discussing and improving the Work, but
         | 
| 59 | 
            +
                  excluding communication that is conspicuously marked or otherwise
         | 
| 60 | 
            +
                  designated in writing by the copyright owner as "Not a Contribution."
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                  "Contributor" shall mean Licensor and any individual or Legal Entity
         | 
| 63 | 
            +
                  on behalf of whom a Contribution has been received by Licensor and
         | 
| 64 | 
            +
                  subsequently incorporated within the Work.
         | 
| 65 | 
            +
             | 
| 66 | 
            +
               2. Grant of Copyright License. Subject to the terms and conditions of
         | 
| 67 | 
            +
                  this License, each Contributor hereby grants to You a perpetual,
         | 
| 68 | 
            +
                  worldwide, non-exclusive, no-charge, royalty-free, irrevocable
         | 
| 69 | 
            +
                  copyright license to reproduce, prepare Derivative Works of,
         | 
| 70 | 
            +
                  publicly display, publicly perform, sublicense, and distribute the
         | 
| 71 | 
            +
                  Work and such Derivative Works in Source or Object form.
         | 
| 72 | 
            +
             | 
| 73 | 
            +
               3. Grant of Patent License. Subject to the terms and conditions of
         | 
| 74 | 
            +
                  this License, each Contributor hereby grants to You a perpetual,
         | 
| 75 | 
            +
                  worldwide, non-exclusive, no-charge, royalty-free, irrevocable
         | 
| 76 | 
            +
                  (except as stated in this section) patent license to make, have made,
         | 
| 77 | 
            +
                  use, offer to sell, sell, import, and otherwise transfer the Work,
         | 
| 78 | 
            +
                  where such license applies only to those patent claims licensable
         | 
| 79 | 
            +
                  by such Contributor that are necessarily infringed by their
         | 
| 80 | 
            +
                  Contribution(s) alone or by combination of their Contribution(s)
         | 
| 81 | 
            +
                  with the Work to which such Contribution(s) was submitted. If You
         | 
| 82 | 
            +
                  institute patent litigation against any entity (including a
         | 
| 83 | 
            +
                  cross-claim or counterclaim in a lawsuit) alleging that the Work
         | 
| 84 | 
            +
                  or a Contribution incorporated within the Work constitutes direct
         | 
| 85 | 
            +
                  or contributory patent infringement, then any patent licenses
         | 
| 86 | 
            +
                  granted to You under this License for that Work shall terminate
         | 
| 87 | 
            +
                  as of the date such litigation is filed.
         | 
| 88 | 
            +
             | 
| 89 | 
            +
               4. Redistribution. You may reproduce and distribute copies of the
         | 
| 90 | 
            +
                  Work or Derivative Works thereof in any medium, with or without
         | 
| 91 | 
            +
                  modifications, and in Source or Object form, provided that You
         | 
| 92 | 
            +
                  meet the following conditions:
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                  (a) You must give any other recipients of the Work or
         | 
| 95 | 
            +
                      Derivative Works a copy of this License; and
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                  (b) You must cause any modified files to carry prominent notices
         | 
| 98 | 
            +
                      stating that You changed the files; and
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                  (c) You must retain, in the Source form of any Derivative Works
         | 
| 101 | 
            +
                      that You distribute, all copyright, patent, trademark, and
         | 
| 102 | 
            +
                      attribution notices from the Source form of the Work,
         | 
| 103 | 
            +
                      excluding those notices that do not pertain to any part of
         | 
| 104 | 
            +
                      the Derivative Works; and
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                  (d) If the Work includes a "NOTICE" text file as part of its
         | 
| 107 | 
            +
                      distribution, then any Derivative Works that You distribute must
         | 
| 108 | 
            +
                      include a readable copy of the attribution notices contained
         | 
| 109 | 
            +
                      within such NOTICE file, excluding those notices that do not
         | 
| 110 | 
            +
                      pertain to any part of the Derivative Works, in at least one
         | 
| 111 | 
            +
                      of the following places: within a NOTICE text file distributed
         | 
| 112 | 
            +
                      as part of the Derivative Works; within the Source form or
         | 
| 113 | 
            +
                      documentation, if provided along with the Derivative Works; or,
         | 
| 114 | 
            +
                      within a display generated by the Derivative Works, if and
         | 
| 115 | 
            +
                      wherever such third-party notices normally appear. The contents
         | 
| 116 | 
            +
                      of the NOTICE file are for informational purposes only and
         | 
| 117 | 
            +
                      do not modify the License. You may add Your own attribution
         | 
| 118 | 
            +
                      notices within Derivative Works that You distribute, alongside
         | 
| 119 | 
            +
                      or as an addendum to the NOTICE text from the Work, provided
         | 
| 120 | 
            +
                      that such additional attribution notices cannot be construed
         | 
| 121 | 
            +
                      as modifying the License.
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                  You may add Your own copyright statement to Your modifications and
         | 
| 124 | 
            +
                  may provide additional or different license terms and conditions
         | 
| 125 | 
            +
                  for use, reproduction, or distribution of Your modifications, or
         | 
| 126 | 
            +
                  for any such Derivative Works as a whole, provided Your use,
         | 
| 127 | 
            +
                  reproduction, and distribution of the Work otherwise complies with
         | 
| 128 | 
            +
                  the conditions stated in this License.
         | 
| 129 | 
            +
             | 
| 130 | 
            +
               5. Submission of Contributions. Unless You explicitly state otherwise,
         | 
| 131 | 
            +
                  any Contribution intentionally submitted for inclusion in the Work
         | 
| 132 | 
            +
                  by You to the Licensor shall be under the terms and conditions of
         | 
| 133 | 
            +
                  this License, without any additional terms or conditions.
         | 
| 134 | 
            +
                  Notwithstanding the above, nothing herein shall supersede or modify
         | 
| 135 | 
            +
                  the terms of any separate license agreement you may have executed
         | 
| 136 | 
            +
                  with Licensor regarding such Contributions.
         | 
| 137 | 
            +
             | 
| 138 | 
            +
               6. Trademarks. This License does not grant permission to use the trade
         | 
| 139 | 
            +
                  names, trademarks, service marks, or product names of the Licensor,
         | 
| 140 | 
            +
                  except as required for reasonable and customary use in describing the
         | 
| 141 | 
            +
                  origin of the Work and reproducing the content of the NOTICE file.
         | 
| 142 | 
            +
             | 
| 143 | 
            +
               7. Disclaimer of Warranty. Unless required by applicable law or
         | 
| 144 | 
            +
                  agreed to in writing, Licensor provides the Work (and each
         | 
| 145 | 
            +
                  Contributor provides its Contributions) on an "AS IS" BASIS,
         | 
| 146 | 
            +
                  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
         | 
| 147 | 
            +
                  implied, including, without limitation, any warranties or conditions
         | 
| 148 | 
            +
                  of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
         | 
| 149 | 
            +
                  PARTICULAR PURPOSE. You are solely responsible for determining the
         | 
| 150 | 
            +
                  appropriateness of using or redistributing the Work and assume any
         | 
| 151 | 
            +
                  risks associated with Your exercise of permissions under this License.
         | 
| 152 | 
            +
             | 
| 153 | 
            +
               8. Limitation of Liability. In no event and under no legal theory,
         | 
| 154 | 
            +
                  whether in tort (including negligence), contract, or otherwise,
         | 
| 155 | 
            +
                  unless required by applicable law (such as deliberate and grossly
         | 
| 156 | 
            +
                  negligent acts) or agreed to in writing, shall any Contributor be
         | 
| 157 | 
            +
                  liable to You for damages, including any direct, indirect, special,
         | 
| 158 | 
            +
                  incidental, or consequential damages of any character arising as a
         | 
| 159 | 
            +
                  result of this License or out of the use or inability to use the
         | 
| 160 | 
            +
                  Work (including but not limited to damages for loss of goodwill,
         | 
| 161 | 
            +
                  work stoppage, computer failure or malfunction, or any and all
         | 
| 162 | 
            +
                  other commercial damages or losses), even if such Contributor
         | 
| 163 | 
            +
                  has been advised of the possibility of such damages.
         | 
| 164 | 
            +
             | 
| 165 | 
            +
               9. Accepting Warranty or Additional Liability. While redistributing
         | 
| 166 | 
            +
                  the Work or Derivative Works thereof, You may choose to offer,
         | 
| 167 | 
            +
                  and charge a fee for, acceptance of support, warranty, indemnity,
         | 
| 168 | 
            +
                  or other liability obligations and/or rights consistent with this
         | 
| 169 | 
            +
                  License. However, in accepting such obligations, You may act only
         | 
| 170 | 
            +
                  on Your own behalf and on Your sole responsibility, not on behalf
         | 
| 171 | 
            +
                  of any other Contributor, and only if You agree to indemnify,
         | 
| 172 | 
            +
                  defend, and hold each Contributor harmless for any liability
         | 
| 173 | 
            +
                  incurred by, or claims asserted against, such Contributor by reason
         | 
| 174 | 
            +
                  of your accepting any such warranty or additional liability.
         | 
| 175 | 
            +
             | 
| 176 | 
            +
               END OF TERMS AND CONDITIONS
         | 
| 177 | 
            +
             | 
| 178 | 
            +
               APPENDIX: How to apply the Apache License to your work.
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                  To apply the Apache License to your work, attach the following
         | 
| 181 | 
            +
                  boilerplate notice, with the fields enclosed by brackets "[]"
         | 
| 182 | 
            +
                  replaced with your own identifying information. (Don't include
         | 
| 183 | 
            +
                  the brackets!)  The text should be enclosed in the appropriate
         | 
| 184 | 
            +
                  comment syntax for the file format. We also recommend that a
         | 
| 185 | 
            +
                  file or class name and description of purpose be included on the
         | 
| 186 | 
            +
                  same "printed page" as the copyright notice for easier
         | 
| 187 | 
            +
                  identification within third-party archives.
         | 
| 188 | 
            +
             | 
| 189 | 
            +
               Copyright [yyyy] [name of copyright owner]
         | 
| 190 | 
            +
             | 
| 191 | 
            +
               Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 192 | 
            +
               you may not use this file except in compliance with the License.
         | 
| 193 | 
            +
               You may obtain a copy of the License at
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                   http://www.apache.org/licenses/LICENSE-2.0
         | 
| 196 | 
            +
             | 
| 197 | 
            +
               Unless required by applicable law or agreed to in writing, software
         | 
| 198 | 
            +
               distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 199 | 
            +
               WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 200 | 
            +
               See the License for the specific language governing permissions and
         | 
| 201 | 
            +
               limitations under the License.
         | 
    	
        SoniTranslate_Colab.ipynb
    ADDED
    
    | @@ -0,0 +1,124 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "nbformat": 4,
         | 
| 3 | 
            +
              "nbformat_minor": 0,
         | 
| 4 | 
            +
              "metadata": {
         | 
| 5 | 
            +
                "colab": {
         | 
| 6 | 
            +
                  "provenance": [],
         | 
| 7 | 
            +
                  "gpuType": "T4",
         | 
| 8 | 
            +
                  "include_colab_link": true
         | 
| 9 | 
            +
                },
         | 
| 10 | 
            +
                "kernelspec": {
         | 
| 11 | 
            +
                  "name": "python3",
         | 
| 12 | 
            +
                  "display_name": "Python 3"
         | 
| 13 | 
            +
                },
         | 
| 14 | 
            +
                "language_info": {
         | 
| 15 | 
            +
                  "name": "python"
         | 
| 16 | 
            +
                },
         | 
| 17 | 
            +
                "accelerator": "GPU"
         | 
| 18 | 
            +
              },
         | 
| 19 | 
            +
              "cells": [
         | 
| 20 | 
            +
                {
         | 
| 21 | 
            +
                  "cell_type": "markdown",
         | 
| 22 | 
            +
                  "metadata": {
         | 
| 23 | 
            +
                    "id": "view-in-github",
         | 
| 24 | 
            +
                    "colab_type": "text"
         | 
| 25 | 
            +
                  },
         | 
| 26 | 
            +
                  "source": [
         | 
| 27 | 
            +
                    "<a href=\"https://colab.research.google.com/github/R3gm/SoniTranslate/blob/main/SoniTranslate_Colab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
         | 
| 28 | 
            +
                  ]
         | 
| 29 | 
            +
                },
         | 
| 30 | 
            +
                {
         | 
| 31 | 
            +
                  "cell_type": "markdown",
         | 
| 32 | 
            +
                  "source": [
         | 
| 33 | 
            +
                    "# SoniTranslate\n",
         | 
| 34 | 
            +
                    "\n",
         | 
| 35 | 
            +
                    "| Description | Link |\n",
         | 
| 36 | 
            +
                    "| ----------- | ---- |\n",
         | 
| 37 | 
            +
                    "| 🎉 Repository | [](https://github.com/R3gm/SoniTranslate/) |\n",
         | 
| 38 | 
            +
                    "| 🚀 Online Demo in HF | [](https://huggingface.co/spaces/r3gm/SoniTranslate_translate_audio_of_a_video_content) |\n",
         | 
| 39 | 
            +
                    "\n",
         | 
| 40 | 
            +
                    "\n"
         | 
| 41 | 
            +
                  ],
         | 
| 42 | 
            +
                  "metadata": {
         | 
| 43 | 
            +
                    "id": "8lw0EgLex-YZ"
         | 
| 44 | 
            +
                  }
         | 
| 45 | 
            +
                },
         | 
| 46 | 
            +
                {
         | 
| 47 | 
            +
                  "cell_type": "code",
         | 
| 48 | 
            +
                  "execution_count": null,
         | 
| 49 | 
            +
                  "metadata": {
         | 
| 50 | 
            +
                    "id": "LUgwm0rfx0_J",
         | 
| 51 | 
            +
                    "cellView": "form"
         | 
| 52 | 
            +
                  },
         | 
| 53 | 
            +
                  "outputs": [],
         | 
| 54 | 
            +
                  "source": [
         | 
| 55 | 
            +
                    "# @title Install requirements for SoniTranslate\n",
         | 
| 56 | 
            +
                    "!git clone https://github.com/r3gm/SoniTranslate.git\n",
         | 
| 57 | 
            +
                    "%cd SoniTranslate\n",
         | 
| 58 | 
            +
                    "\n",
         | 
| 59 | 
            +
                    "!apt install git-lfs\n",
         | 
| 60 | 
            +
                    "!git lfs install\n",
         | 
| 61 | 
            +
                    "\n",
         | 
| 62 | 
            +
                    "!sed -i 's|git+https://github.com/R3gm/whisperX.git@cuda_11_8|git+https://github.com/R3gm/whisperX.git@cuda_12_x|' requirements_base.txt\n",
         | 
| 63 | 
            +
                    "!pip install -q -r requirements_base.txt\n",
         | 
| 64 | 
            +
                    "!pip install -q -r requirements_extra.txt\n",
         | 
| 65 | 
            +
                    "!pip install -q ort-nightly-gpu --index-url=https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ort-cuda-12-nightly/pypi/simple/\n",
         | 
| 66 | 
            +
                    "\n",
         | 
| 67 | 
            +
                    "Install_PIPER_TTS = True # @param {type:\"boolean\"}\n",
         | 
| 68 | 
            +
                    "\n",
         | 
| 69 | 
            +
                    "if Install_PIPER_TTS:\n",
         | 
| 70 | 
            +
                    "    !pip install -q piper-tts==1.2.0\n",
         | 
| 71 | 
            +
                    "\n",
         | 
| 72 | 
            +
                    "Install_Coqui_XTTS = True # @param {type:\"boolean\"}\n",
         | 
| 73 | 
            +
                    "\n",
         | 
| 74 | 
            +
                    "if Install_Coqui_XTTS:\n",
         | 
| 75 | 
            +
                    "    !pip install -q -r requirements_xtts.txt\n",
         | 
| 76 | 
            +
                    "    !pip install -q TTS==0.21.1  --no-deps"
         | 
| 77 | 
            +
                  ]
         | 
| 78 | 
            +
                },
         | 
| 79 | 
            +
                {
         | 
| 80 | 
            +
                  "cell_type": "markdown",
         | 
| 81 | 
            +
                  "source": [
         | 
| 82 | 
            +
                    "One important step is to accept the license agreement for using Pyannote. You need to have an account on Hugging Face and `accept the license to use the models`: https://huggingface.co/pyannote/speaker-diarization and https://huggingface.co/pyannote/segmentation\n",
         | 
| 83 | 
            +
                    "\n",
         | 
| 84 | 
            +
                    "\n",
         | 
| 85 | 
            +
                    "\n",
         | 
| 86 | 
            +
                    "\n",
         | 
| 87 | 
            +
                    "Get your KEY TOKEN here: https://hf.co/settings/tokens"
         | 
| 88 | 
            +
                  ],
         | 
| 89 | 
            +
                  "metadata": {
         | 
| 90 | 
            +
                    "id": "LTaTstXPXNg2"
         | 
| 91 | 
            +
                  }
         | 
| 92 | 
            +
                },
         | 
| 93 | 
            +
                {
         | 
| 94 | 
            +
                  "cell_type": "code",
         | 
| 95 | 
            +
                  "source": [
         | 
| 96 | 
            +
                    "#@markdown # `RUN THE WEB APP`\n",
         | 
| 97 | 
            +
                    "YOUR_HF_TOKEN = \"\" #@param {type:'string'}\n",
         | 
| 98 | 
            +
                    "%env YOUR_HF_TOKEN={YOUR_HF_TOKEN}\n",
         | 
| 99 | 
            +
                    "theme = \"Taithrah/Minimal\" # @param [\"Taithrah/Minimal\", \"aliabid94/new-theme\", \"gstaff/xkcd\", \"ParityError/LimeFace\", \"abidlabs/pakistan\", \"rottenlittlecreature/Moon_Goblin\", \"ysharma/llamas\", \"gradio/dracula_revamped\"]\n",
         | 
| 100 | 
            +
                    "interface_language = \"english\" # @param ['arabic', 'azerbaijani', 'chinese_zh_cn', 'english', 'french', 'german', 'hindi', 'indonesian', 'italian', 'japanese', 'korean', 'marathi', 'polish', 'portuguese', 'russian', 'spanish', 'swedish', 'turkish', 'ukrainian', 'vietnamese']\n",
         | 
| 101 | 
            +
                    "verbosity_level = \"info\" # @param [\"debug\", \"info\", \"warning\", \"error\", \"critical\"]\n",
         | 
| 102 | 
            +
                    "\n",
         | 
| 103 | 
            +
                    "\n",
         | 
| 104 | 
            +
                    "%cd /content/SoniTranslate\n",
         | 
| 105 | 
            +
                    "!python app_rvc.py --theme {theme} --verbosity_level {verbosity_level} --language {interface_language} --public_url"
         | 
| 106 | 
            +
                  ],
         | 
| 107 | 
            +
                  "metadata": {
         | 
| 108 | 
            +
                    "id": "XkhXfaFw4R4J",
         | 
| 109 | 
            +
                    "cellView": "form"
         | 
| 110 | 
            +
                  },
         | 
| 111 | 
            +
                  "execution_count": null,
         | 
| 112 | 
            +
                  "outputs": []
         | 
| 113 | 
            +
                },
         | 
| 114 | 
            +
                {
         | 
| 115 | 
            +
                  "cell_type": "markdown",
         | 
| 116 | 
            +
                  "source": [
         | 
| 117 | 
            +
                    "Open the `public URL` when it appears"
         | 
| 118 | 
            +
                  ],
         | 
| 119 | 
            +
                  "metadata": {
         | 
| 120 | 
            +
                    "id": "KJW3KrhZJh0u"
         | 
| 121 | 
            +
                  }
         | 
| 122 | 
            +
                }
         | 
| 123 | 
            +
              ]
         | 
| 124 | 
            +
            }
         | 
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,2 @@ | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            os.system("python app_rvc.py --language french --theme aliabid94/new-theme")
         | 
    	
        app_rvc.py
    ADDED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        assets/logo.jpeg
    ADDED
    
    |   | 
    	
        docs/windows_install.md
    ADDED
    
    | @@ -0,0 +1,150 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ## Install Locally Windows
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            ### Before You Start
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            Before you start installing and using SoniTranslate, there are a few things you need to do:
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            1. Install Microsoft Visual C++ Build Tools, MSVC and Windows 10 SDK:
         | 
| 8 | 
            +
             | 
| 9 | 
            +
                * Go to the [Visual Studio downloads page](https://visualstudio.microsoft.com/visual-cpp-build-tools/); Or maybe you already have **Visual Studio Installer**? Open it. If you have it already click modify.
         | 
| 10 | 
            +
                * Download and install the "Build Tools for Visual Studio" if you don't have it.
         | 
| 11 | 
            +
                * During installation, under "Workloads", select "C++ build tools" and ensure the latest versions of "MSVCv142 - VS 2019 C++ x64/x86 build tools" and "Windows 10 SDK"  are selected ("Windows 11 SDK" if you are using Windows 11); OR go to individual components and find those two listed.
         | 
| 12 | 
            +
                * Complete the installation.
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            2. Verify the NVIDIA driver on Windows using the command line:
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                * **Open Command Prompt:** Press `Win + R`, type `cmd`, then press `Enter`.
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                * **Type the command:** `nvidia-smi` and press `Enter`.
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                * **Look for "CUDA Version"** in the output.
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            ```
         | 
| 23 | 
            +
            +-----------------------------------------------------------------------------+
         | 
| 24 | 
            +
            | NVIDIA-SMI 522.25       Driver Version: 522.25       CUDA Version: 11.8     |
         | 
| 25 | 
            +
            |-------------------------------+----------------------+----------------------+
         | 
| 26 | 
            +
            ```
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            3. If you see that your CUDA version is less than 11.8, you should update your NVIDIA driver. Visit the NVIDIA website's driver download page (https://www.nvidia.com/Download/index.aspx) and enter your graphics card information.
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            4. Accept the license agreement for using Pyannote. You need to have an account on Hugging Face and `accept the license to use the models`: https://huggingface.co/pyannote/speaker-diarization and https://huggingface.co/pyannote/segmentation
         | 
| 31 | 
            +
            5. Create a [huggingface token](https://huggingface.co/settings/tokens). Hugging Face is a natural language processing platform that provides access to state-of-the-art models and tools. You will need to create a token in order to use some of the automatic model download features in SoniTranslate. Follow the instructions on the Hugging Face website to create a token.
         | 
| 32 | 
            +
            6. Install [Anaconda](https://www.anaconda.com/) or [Miniconda](https://docs.anaconda.com/free/miniconda/miniconda-install/). Anaconda is a free and open-source distribution of Python and R. It includes a package manager called conda that makes it easy to install and manage Python environments and packages. Follow the instructions on the Anaconda website to download and install Anaconda on your system.
         | 
| 33 | 
            +
            7. Install Git for your system. Git is a version control system that helps you track changes to your code and collaborate with other developers. You can install Git with Anaconda by running `conda install -c anaconda git -y` in your terminal (Do this after step 1 in the following section.). If you have trouble installing Git via Anaconda, you can use the following link instead:
         | 
| 34 | 
            +
               - [Git for Windows](https://git-scm.com/download/win)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            Once you have completed these steps, you will be ready to install SoniTranslate.
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            ### Getting Started
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            To install SoniTranslate, follow these steps:
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            1. Create a suitable anaconda environment for SoniTranslate and activate it:
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            ```
         | 
| 45 | 
            +
            conda create -n sonitr python=3.10 -y
         | 
| 46 | 
            +
            conda activate sonitr
         | 
| 47 | 
            +
            ```
         | 
| 48 | 
            +
             | 
| 49 | 
            +
            2. Clone this github repository and navigate to it:
         | 
| 50 | 
            +
            ```
         | 
| 51 | 
            +
            git clone https://github.com/r3gm/SoniTranslate.git
         | 
| 52 | 
            +
            cd SoniTranslate
         | 
| 53 | 
            +
            ```
         | 
| 54 | 
            +
            3. Install CUDA Toolkit 11.8.0
         | 
| 55 | 
            +
             | 
| 56 | 
            +
            ```
         | 
| 57 | 
            +
            conda install -c "nvidia/label/cuda-11.8.0" cuda-toolkit -y
         | 
| 58 | 
            +
            ```
         | 
| 59 | 
            +
             | 
| 60 | 
            +
            4. Install PyTorch using conda
         | 
| 61 | 
            +
            ```
         | 
| 62 | 
            +
            conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia -y
         | 
| 63 | 
            +
            ```
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            5. Install required packages:
         | 
| 66 | 
            +
             | 
| 67 | 
            +
            ```
         | 
| 68 | 
            +
            pip install -r requirements_base.txt -v
         | 
| 69 | 
            +
            pip install -r requirements_extra.txt -v
         | 
| 70 | 
            +
            pip install onnxruntime-gpu
         | 
| 71 | 
            +
            ```
         | 
| 72 | 
            +
             | 
| 73 | 
            +
            6. Install [ffmpeg](https://ffmpeg.org/download.html). FFmpeg is a free software project that produces libraries and programs for handling multimedia data. You will need it to process audio and video files. You can install ffmpeg with Anaconda by running `conda install -y ffmpeg` in your terminal (recommended). If you have trouble installing ffmpeg via Anaconda, you can use the following link instead: (https://ffmpeg.org/ffmpeg.html). Once it is installed, make sure it is in your PATH by running `ffmpeg -h` in your terminal. If you don't get an error message, you're good to go.
         | 
| 74 | 
            +
             | 
| 75 | 
            +
            7. Optional install:
         | 
| 76 | 
            +
             | 
| 77 | 
            +
            After installing FFmpeg, you can install these optional packages.
         | 
| 78 | 
            +
             | 
| 79 | 
            +
            [Coqui XTTS](https://github.com/coqui-ai/TTS) is a text-to-speech (TTS) model that lets you generate realistic voices in different languages. It can clone voices with just a short audio clip, even speak in a different language! It's like having a personal voice mimic for any text you need spoken.
         | 
| 80 | 
            +
             | 
| 81 | 
            +
            ```
         | 
| 82 | 
            +
            pip install -q -r requirements_xtts.txt
         | 
| 83 | 
            +
            pip install -q TTS==0.21.1  --no-deps
         | 
| 84 | 
            +
            ```
         | 
| 85 | 
            +
             | 
| 86 | 
            +
            [Piper TTS](https://github.com/rhasspy/piper) is a fast, local neural text to speech system that sounds great and is optimized for the Raspberry Pi 4. Piper is used in a variety of projects. Voices are trained with VITS and exported to the onnxruntime.
         | 
| 87 | 
            +
             | 
| 88 | 
            +
            🚧 For Windows users, it's important to note that the Python module piper-tts is not fully supported on this operating system. While it works smoothly on Linux, Windows compatibility is currently experimental. If you still wish to install it on Windows, you can follow this experimental method:
         | 
| 89 | 
            +
             | 
| 90 | 
            +
            ```
         | 
| 91 | 
            +
            pip install https://github.com/R3gm/piper-phonemize/releases/download/1.2.0/piper_phonemize-1.2.0-cp310-cp310-win_amd64.whl
         | 
| 92 | 
            +
            pip install sherpa-onnx==1.9.12
         | 
| 93 | 
            +
            pip install piper-tts==1.2.0 --no-deps
         | 
| 94 | 
            +
            ```
         | 
| 95 | 
            +
             | 
| 96 | 
            +
            8. Setting your [Hugging Face token](https://huggingface.co/settings/tokens) as an environment variable in quotes:
         | 
| 97 | 
            +
             | 
| 98 | 
            +
            ```
         | 
| 99 | 
            +
            conda env config vars set YOUR_HF_TOKEN="YOUR_HUGGING_FACE_TOKEN_HERE"
         | 
| 100 | 
            +
            conda deactivate
         | 
| 101 | 
            +
            ```
         | 
| 102 | 
            +
             | 
| 103 | 
            +
             | 
| 104 | 
            +
            ### Running SoniTranslate
         | 
| 105 | 
            +
             | 
| 106 | 
            +
            To run SoniTranslate locally, make sure the `sonitr` conda environment is active:
         | 
| 107 | 
            +
             | 
| 108 | 
            +
            ```
         | 
| 109 | 
            +
            conda activate sonitr
         | 
| 110 | 
            +
            ```
         | 
| 111 | 
            +
             | 
| 112 | 
            +
            Then navigate to the `SoniTranslate` folder and run either the `app_rvc.py`
         | 
| 113 | 
            +
             | 
| 114 | 
            +
            ```
         | 
| 115 | 
            +
            python app_rvc.py
         | 
| 116 | 
            +
            ```
         | 
| 117 | 
            +
            When the `local URL` `http://127.0.0.1:7860` is displayed in the terminal, simply open this URL in your web browser to access the SoniTranslate interface.
         | 
| 118 | 
            +
             | 
| 119 | 
            +
            ### Stop and close SoniTranslate.
         | 
| 120 | 
            +
             | 
| 121 | 
            +
            In most environments, you can stop the execution by pressing Ctrl+C in the terminal where you launched the script `app_rvc.py`. This will interrupt the program and stop the Gradio app.
         | 
| 122 | 
            +
            To deactivate the Conda environment, you can use the following command:
         | 
| 123 | 
            +
             | 
| 124 | 
            +
            ```
         | 
| 125 | 
            +
            conda deactivate
         | 
| 126 | 
            +
            ```
         | 
| 127 | 
            +
             | 
| 128 | 
            +
            This will deactivate the currently active Conda environment sonitr, and you'll return to the base environment or the global Python environment.
         | 
| 129 | 
            +
             | 
| 130 | 
            +
            ### Starting Over
         | 
| 131 | 
            +
             | 
| 132 | 
            +
            If you need to start over from scratch, you can delete the `SoniTranslate` folder and remove the `sonitr` conda environment with the following set of commands:
         | 
| 133 | 
            +
             | 
| 134 | 
            +
            ```
         | 
| 135 | 
            +
            conda deactivate
         | 
| 136 | 
            +
            conda env remove -n sonitr
         | 
| 137 | 
            +
            ```
         | 
| 138 | 
            +
             | 
| 139 | 
            +
            With the `sonitr` environment removed, you can start over with a fresh installation.
         | 
| 140 | 
            +
             | 
| 141 | 
            +
            ### Notes
         | 
| 142 | 
            +
            -  To use OpenAI's GPT API for translation, set up your OpenAI API key as an environment variable in quotes:
         | 
| 143 | 
            +
             | 
| 144 | 
            +
            ```
         | 
| 145 | 
            +
            conda activate sonitr
         | 
| 146 | 
            +
            conda env config vars set OPENAI_API_KEY="your-api-key-here"
         | 
| 147 | 
            +
            conda deactivate
         | 
| 148 | 
            +
            ```
         | 
| 149 | 
            +
             | 
| 150 | 
            +
            - Alternatively, you can install the CUDA Toolkit 11.8.0  directly on your system [CUDA Toolkit 11.8.0](https://developer.nvidia.com/cuda-11-8-0-download-archive).
         | 
    	
        lib/audio.py
    ADDED
    
    | @@ -0,0 +1,21 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import ffmpeg
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            def load_audio(file, sr):
         | 
| 6 | 
            +
                try:
         | 
| 7 | 
            +
                    # https://github.com/openai/whisper/blob/main/whisper/audio.py#L26
         | 
| 8 | 
            +
                    # This launches a subprocess to decode audio while down-mixing and resampling as necessary.
         | 
| 9 | 
            +
                    # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
         | 
| 10 | 
            +
                    file = (
         | 
| 11 | 
            +
                        file.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
         | 
| 12 | 
            +
                    )  # To prevent beginners from copying paths with leading or trailing spaces, quotation marks, and line breaks.
         | 
| 13 | 
            +
                    out, _ = (
         | 
| 14 | 
            +
                        ffmpeg.input(file, threads=0)
         | 
| 15 | 
            +
                        .output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr)
         | 
| 16 | 
            +
                        .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
         | 
| 17 | 
            +
                    )
         | 
| 18 | 
            +
                except Exception as e:
         | 
| 19 | 
            +
                    raise RuntimeError(f"Failed to load audio: {e}")
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                return np.frombuffer(out, np.float32).flatten()
         | 
    	
        lib/infer_pack/attentions.py
    ADDED
    
    | @@ -0,0 +1,417 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import copy
         | 
| 2 | 
            +
            import math
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            from torch import nn
         | 
| 6 | 
            +
            from torch.nn import functional as F
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from lib.infer_pack import commons
         | 
| 9 | 
            +
            from lib.infer_pack import modules
         | 
| 10 | 
            +
            from lib.infer_pack.modules import LayerNorm
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            class Encoder(nn.Module):
         | 
| 14 | 
            +
                def __init__(
         | 
| 15 | 
            +
                    self,
         | 
| 16 | 
            +
                    hidden_channels,
         | 
| 17 | 
            +
                    filter_channels,
         | 
| 18 | 
            +
                    n_heads,
         | 
| 19 | 
            +
                    n_layers,
         | 
| 20 | 
            +
                    kernel_size=1,
         | 
| 21 | 
            +
                    p_dropout=0.0,
         | 
| 22 | 
            +
                    window_size=10,
         | 
| 23 | 
            +
                    **kwargs
         | 
| 24 | 
            +
                ):
         | 
| 25 | 
            +
                    super().__init__()
         | 
| 26 | 
            +
                    self.hidden_channels = hidden_channels
         | 
| 27 | 
            +
                    self.filter_channels = filter_channels
         | 
| 28 | 
            +
                    self.n_heads = n_heads
         | 
| 29 | 
            +
                    self.n_layers = n_layers
         | 
| 30 | 
            +
                    self.kernel_size = kernel_size
         | 
| 31 | 
            +
                    self.p_dropout = p_dropout
         | 
| 32 | 
            +
                    self.window_size = window_size
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    self.drop = nn.Dropout(p_dropout)
         | 
| 35 | 
            +
                    self.attn_layers = nn.ModuleList()
         | 
| 36 | 
            +
                    self.norm_layers_1 = nn.ModuleList()
         | 
| 37 | 
            +
                    self.ffn_layers = nn.ModuleList()
         | 
| 38 | 
            +
                    self.norm_layers_2 = nn.ModuleList()
         | 
| 39 | 
            +
                    for i in range(self.n_layers):
         | 
| 40 | 
            +
                        self.attn_layers.append(
         | 
| 41 | 
            +
                            MultiHeadAttention(
         | 
| 42 | 
            +
                                hidden_channels,
         | 
| 43 | 
            +
                                hidden_channels,
         | 
| 44 | 
            +
                                n_heads,
         | 
| 45 | 
            +
                                p_dropout=p_dropout,
         | 
| 46 | 
            +
                                window_size=window_size,
         | 
| 47 | 
            +
                            )
         | 
| 48 | 
            +
                        )
         | 
| 49 | 
            +
                        self.norm_layers_1.append(LayerNorm(hidden_channels))
         | 
| 50 | 
            +
                        self.ffn_layers.append(
         | 
| 51 | 
            +
                            FFN(
         | 
| 52 | 
            +
                                hidden_channels,
         | 
| 53 | 
            +
                                hidden_channels,
         | 
| 54 | 
            +
                                filter_channels,
         | 
| 55 | 
            +
                                kernel_size,
         | 
| 56 | 
            +
                                p_dropout=p_dropout,
         | 
| 57 | 
            +
                            )
         | 
| 58 | 
            +
                        )
         | 
| 59 | 
            +
                        self.norm_layers_2.append(LayerNorm(hidden_channels))
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                def forward(self, x, x_mask):
         | 
| 62 | 
            +
                    attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
         | 
| 63 | 
            +
                    x = x * x_mask
         | 
| 64 | 
            +
                    for i in range(self.n_layers):
         | 
| 65 | 
            +
                        y = self.attn_layers[i](x, x, attn_mask)
         | 
| 66 | 
            +
                        y = self.drop(y)
         | 
| 67 | 
            +
                        x = self.norm_layers_1[i](x + y)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                        y = self.ffn_layers[i](x, x_mask)
         | 
| 70 | 
            +
                        y = self.drop(y)
         | 
| 71 | 
            +
                        x = self.norm_layers_2[i](x + y)
         | 
| 72 | 
            +
                    x = x * x_mask
         | 
| 73 | 
            +
                    return x
         | 
| 74 | 
            +
             | 
| 75 | 
            +
             | 
| 76 | 
            +
            class Decoder(nn.Module):
         | 
| 77 | 
            +
                def __init__(
         | 
| 78 | 
            +
                    self,
         | 
| 79 | 
            +
                    hidden_channels,
         | 
| 80 | 
            +
                    filter_channels,
         | 
| 81 | 
            +
                    n_heads,
         | 
| 82 | 
            +
                    n_layers,
         | 
| 83 | 
            +
                    kernel_size=1,
         | 
| 84 | 
            +
                    p_dropout=0.0,
         | 
| 85 | 
            +
                    proximal_bias=False,
         | 
| 86 | 
            +
                    proximal_init=True,
         | 
| 87 | 
            +
                    **kwargs
         | 
| 88 | 
            +
                ):
         | 
| 89 | 
            +
                    super().__init__()
         | 
| 90 | 
            +
                    self.hidden_channels = hidden_channels
         | 
| 91 | 
            +
                    self.filter_channels = filter_channels
         | 
| 92 | 
            +
                    self.n_heads = n_heads
         | 
| 93 | 
            +
                    self.n_layers = n_layers
         | 
| 94 | 
            +
                    self.kernel_size = kernel_size
         | 
| 95 | 
            +
                    self.p_dropout = p_dropout
         | 
| 96 | 
            +
                    self.proximal_bias = proximal_bias
         | 
| 97 | 
            +
                    self.proximal_init = proximal_init
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    self.drop = nn.Dropout(p_dropout)
         | 
| 100 | 
            +
                    self.self_attn_layers = nn.ModuleList()
         | 
| 101 | 
            +
                    self.norm_layers_0 = nn.ModuleList()
         | 
| 102 | 
            +
                    self.encdec_attn_layers = nn.ModuleList()
         | 
| 103 | 
            +
                    self.norm_layers_1 = nn.ModuleList()
         | 
| 104 | 
            +
                    self.ffn_layers = nn.ModuleList()
         | 
| 105 | 
            +
                    self.norm_layers_2 = nn.ModuleList()
         | 
| 106 | 
            +
                    for i in range(self.n_layers):
         | 
| 107 | 
            +
                        self.self_attn_layers.append(
         | 
| 108 | 
            +
                            MultiHeadAttention(
         | 
| 109 | 
            +
                                hidden_channels,
         | 
| 110 | 
            +
                                hidden_channels,
         | 
| 111 | 
            +
                                n_heads,
         | 
| 112 | 
            +
                                p_dropout=p_dropout,
         | 
| 113 | 
            +
                                proximal_bias=proximal_bias,
         | 
| 114 | 
            +
                                proximal_init=proximal_init,
         | 
| 115 | 
            +
                            )
         | 
| 116 | 
            +
                        )
         | 
| 117 | 
            +
                        self.norm_layers_0.append(LayerNorm(hidden_channels))
         | 
| 118 | 
            +
                        self.encdec_attn_layers.append(
         | 
| 119 | 
            +
                            MultiHeadAttention(
         | 
| 120 | 
            +
                                hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
         | 
| 121 | 
            +
                            )
         | 
| 122 | 
            +
                        )
         | 
| 123 | 
            +
                        self.norm_layers_1.append(LayerNorm(hidden_channels))
         | 
| 124 | 
            +
                        self.ffn_layers.append(
         | 
| 125 | 
            +
                            FFN(
         | 
| 126 | 
            +
                                hidden_channels,
         | 
| 127 | 
            +
                                hidden_channels,
         | 
| 128 | 
            +
                                filter_channels,
         | 
| 129 | 
            +
                                kernel_size,
         | 
| 130 | 
            +
                                p_dropout=p_dropout,
         | 
| 131 | 
            +
                                causal=True,
         | 
| 132 | 
            +
                            )
         | 
| 133 | 
            +
                        )
         | 
| 134 | 
            +
                        self.norm_layers_2.append(LayerNorm(hidden_channels))
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                def forward(self, x, x_mask, h, h_mask):
         | 
| 137 | 
            +
                    """
         | 
| 138 | 
            +
                    x: decoder input
         | 
| 139 | 
            +
                    h: encoder output
         | 
| 140 | 
            +
                    """
         | 
| 141 | 
            +
                    self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
         | 
| 142 | 
            +
                        device=x.device, dtype=x.dtype
         | 
| 143 | 
            +
                    )
         | 
| 144 | 
            +
                    encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
         | 
| 145 | 
            +
                    x = x * x_mask
         | 
| 146 | 
            +
                    for i in range(self.n_layers):
         | 
| 147 | 
            +
                        y = self.self_attn_layers[i](x, x, self_attn_mask)
         | 
| 148 | 
            +
                        y = self.drop(y)
         | 
| 149 | 
            +
                        x = self.norm_layers_0[i](x + y)
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                        y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
         | 
| 152 | 
            +
                        y = self.drop(y)
         | 
| 153 | 
            +
                        x = self.norm_layers_1[i](x + y)
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                        y = self.ffn_layers[i](x, x_mask)
         | 
| 156 | 
            +
                        y = self.drop(y)
         | 
| 157 | 
            +
                        x = self.norm_layers_2[i](x + y)
         | 
| 158 | 
            +
                    x = x * x_mask
         | 
| 159 | 
            +
                    return x
         | 
| 160 | 
            +
             | 
| 161 | 
            +
             | 
| 162 | 
            +
            class MultiHeadAttention(nn.Module):
         | 
| 163 | 
            +
                def __init__(
         | 
| 164 | 
            +
                    self,
         | 
| 165 | 
            +
                    channels,
         | 
| 166 | 
            +
                    out_channels,
         | 
| 167 | 
            +
                    n_heads,
         | 
| 168 | 
            +
                    p_dropout=0.0,
         | 
| 169 | 
            +
                    window_size=None,
         | 
| 170 | 
            +
                    heads_share=True,
         | 
| 171 | 
            +
                    block_length=None,
         | 
| 172 | 
            +
                    proximal_bias=False,
         | 
| 173 | 
            +
                    proximal_init=False,
         | 
| 174 | 
            +
                ):
         | 
| 175 | 
            +
                    super().__init__()
         | 
| 176 | 
            +
                    assert channels % n_heads == 0
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                    self.channels = channels
         | 
| 179 | 
            +
                    self.out_channels = out_channels
         | 
| 180 | 
            +
                    self.n_heads = n_heads
         | 
| 181 | 
            +
                    self.p_dropout = p_dropout
         | 
| 182 | 
            +
                    self.window_size = window_size
         | 
| 183 | 
            +
                    self.heads_share = heads_share
         | 
| 184 | 
            +
                    self.block_length = block_length
         | 
| 185 | 
            +
                    self.proximal_bias = proximal_bias
         | 
| 186 | 
            +
                    self.proximal_init = proximal_init
         | 
| 187 | 
            +
                    self.attn = None
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                    self.k_channels = channels // n_heads
         | 
| 190 | 
            +
                    self.conv_q = nn.Conv1d(channels, channels, 1)
         | 
| 191 | 
            +
                    self.conv_k = nn.Conv1d(channels, channels, 1)
         | 
| 192 | 
            +
                    self.conv_v = nn.Conv1d(channels, channels, 1)
         | 
| 193 | 
            +
                    self.conv_o = nn.Conv1d(channels, out_channels, 1)
         | 
| 194 | 
            +
                    self.drop = nn.Dropout(p_dropout)
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                    if window_size is not None:
         | 
| 197 | 
            +
                        n_heads_rel = 1 if heads_share else n_heads
         | 
| 198 | 
            +
                        rel_stddev = self.k_channels**-0.5
         | 
| 199 | 
            +
                        self.emb_rel_k = nn.Parameter(
         | 
| 200 | 
            +
                            torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
         | 
| 201 | 
            +
                            * rel_stddev
         | 
| 202 | 
            +
                        )
         | 
| 203 | 
            +
                        self.emb_rel_v = nn.Parameter(
         | 
| 204 | 
            +
                            torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
         | 
| 205 | 
            +
                            * rel_stddev
         | 
| 206 | 
            +
                        )
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                    nn.init.xavier_uniform_(self.conv_q.weight)
         | 
| 209 | 
            +
                    nn.init.xavier_uniform_(self.conv_k.weight)
         | 
| 210 | 
            +
                    nn.init.xavier_uniform_(self.conv_v.weight)
         | 
| 211 | 
            +
                    if proximal_init:
         | 
| 212 | 
            +
                        with torch.no_grad():
         | 
| 213 | 
            +
                            self.conv_k.weight.copy_(self.conv_q.weight)
         | 
| 214 | 
            +
                            self.conv_k.bias.copy_(self.conv_q.bias)
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                def forward(self, x, c, attn_mask=None):
         | 
| 217 | 
            +
                    q = self.conv_q(x)
         | 
| 218 | 
            +
                    k = self.conv_k(c)
         | 
| 219 | 
            +
                    v = self.conv_v(c)
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                    x, self.attn = self.attention(q, k, v, mask=attn_mask)
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                    x = self.conv_o(x)
         | 
| 224 | 
            +
                    return x
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                def attention(self, query, key, value, mask=None):
         | 
| 227 | 
            +
                    # reshape [b, d, t] -> [b, n_h, t, d_k]
         | 
| 228 | 
            +
                    b, d, t_s, t_t = (*key.size(), query.size(2))
         | 
| 229 | 
            +
                    query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
         | 
| 230 | 
            +
                    key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
         | 
| 231 | 
            +
                    value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                    scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
         | 
| 234 | 
            +
                    if self.window_size is not None:
         | 
| 235 | 
            +
                        assert (
         | 
| 236 | 
            +
                            t_s == t_t
         | 
| 237 | 
            +
                        ), "Relative attention is only available for self-attention."
         | 
| 238 | 
            +
                        key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
         | 
| 239 | 
            +
                        rel_logits = self._matmul_with_relative_keys(
         | 
| 240 | 
            +
                            query / math.sqrt(self.k_channels), key_relative_embeddings
         | 
| 241 | 
            +
                        )
         | 
| 242 | 
            +
                        scores_local = self._relative_position_to_absolute_position(rel_logits)
         | 
| 243 | 
            +
                        scores = scores + scores_local
         | 
| 244 | 
            +
                    if self.proximal_bias:
         | 
| 245 | 
            +
                        assert t_s == t_t, "Proximal bias is only available for self-attention."
         | 
| 246 | 
            +
                        scores = scores + self._attention_bias_proximal(t_s).to(
         | 
| 247 | 
            +
                            device=scores.device, dtype=scores.dtype
         | 
| 248 | 
            +
                        )
         | 
| 249 | 
            +
                    if mask is not None:
         | 
| 250 | 
            +
                        scores = scores.masked_fill(mask == 0, -1e4)
         | 
| 251 | 
            +
                        if self.block_length is not None:
         | 
| 252 | 
            +
                            assert (
         | 
| 253 | 
            +
                                t_s == t_t
         | 
| 254 | 
            +
                            ), "Local attention is only available for self-attention."
         | 
| 255 | 
            +
                            block_mask = (
         | 
| 256 | 
            +
                                torch.ones_like(scores)
         | 
| 257 | 
            +
                                .triu(-self.block_length)
         | 
| 258 | 
            +
                                .tril(self.block_length)
         | 
| 259 | 
            +
                            )
         | 
| 260 | 
            +
                            scores = scores.masked_fill(block_mask == 0, -1e4)
         | 
| 261 | 
            +
                    p_attn = F.softmax(scores, dim=-1)  # [b, n_h, t_t, t_s]
         | 
| 262 | 
            +
                    p_attn = self.drop(p_attn)
         | 
| 263 | 
            +
                    output = torch.matmul(p_attn, value)
         | 
| 264 | 
            +
                    if self.window_size is not None:
         | 
| 265 | 
            +
                        relative_weights = self._absolute_position_to_relative_position(p_attn)
         | 
| 266 | 
            +
                        value_relative_embeddings = self._get_relative_embeddings(
         | 
| 267 | 
            +
                            self.emb_rel_v, t_s
         | 
| 268 | 
            +
                        )
         | 
| 269 | 
            +
                        output = output + self._matmul_with_relative_values(
         | 
| 270 | 
            +
                            relative_weights, value_relative_embeddings
         | 
| 271 | 
            +
                        )
         | 
| 272 | 
            +
                    output = (
         | 
| 273 | 
            +
                        output.transpose(2, 3).contiguous().view(b, d, t_t)
         | 
| 274 | 
            +
                    )  # [b, n_h, t_t, d_k] -> [b, d, t_t]
         | 
| 275 | 
            +
                    return output, p_attn
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                def _matmul_with_relative_values(self, x, y):
         | 
| 278 | 
            +
                    """
         | 
| 279 | 
            +
                    x: [b, h, l, m]
         | 
| 280 | 
            +
                    y: [h or 1, m, d]
         | 
| 281 | 
            +
                    ret: [b, h, l, d]
         | 
| 282 | 
            +
                    """
         | 
| 283 | 
            +
                    ret = torch.matmul(x, y.unsqueeze(0))
         | 
| 284 | 
            +
                    return ret
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                def _matmul_with_relative_keys(self, x, y):
         | 
| 287 | 
            +
                    """
         | 
| 288 | 
            +
                    x: [b, h, l, d]
         | 
| 289 | 
            +
                    y: [h or 1, m, d]
         | 
| 290 | 
            +
                    ret: [b, h, l, m]
         | 
| 291 | 
            +
                    """
         | 
| 292 | 
            +
                    ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
         | 
| 293 | 
            +
                    return ret
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                def _get_relative_embeddings(self, relative_embeddings, length):
         | 
| 296 | 
            +
                    max_relative_position = 2 * self.window_size + 1
         | 
| 297 | 
            +
                    # Pad first before slice to avoid using cond ops.
         | 
| 298 | 
            +
                    pad_length = max(length - (self.window_size + 1), 0)
         | 
| 299 | 
            +
                    slice_start_position = max((self.window_size + 1) - length, 0)
         | 
| 300 | 
            +
                    slice_end_position = slice_start_position + 2 * length - 1
         | 
| 301 | 
            +
                    if pad_length > 0:
         | 
| 302 | 
            +
                        padded_relative_embeddings = F.pad(
         | 
| 303 | 
            +
                            relative_embeddings,
         | 
| 304 | 
            +
                            commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
         | 
| 305 | 
            +
                        )
         | 
| 306 | 
            +
                    else:
         | 
| 307 | 
            +
                        padded_relative_embeddings = relative_embeddings
         | 
| 308 | 
            +
                    used_relative_embeddings = padded_relative_embeddings[
         | 
| 309 | 
            +
                        :, slice_start_position:slice_end_position
         | 
| 310 | 
            +
                    ]
         | 
| 311 | 
            +
                    return used_relative_embeddings
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                def _relative_position_to_absolute_position(self, x):
         | 
| 314 | 
            +
                    """
         | 
| 315 | 
            +
                    x: [b, h, l, 2*l-1]
         | 
| 316 | 
            +
                    ret: [b, h, l, l]
         | 
| 317 | 
            +
                    """
         | 
| 318 | 
            +
                    batch, heads, length, _ = x.size()
         | 
| 319 | 
            +
                    # Concat columns of pad to shift from relative to absolute indexing.
         | 
| 320 | 
            +
                    x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                    # Concat extra elements so to add up to shape (len+1, 2*len-1).
         | 
| 323 | 
            +
                    x_flat = x.view([batch, heads, length * 2 * length])
         | 
| 324 | 
            +
                    x_flat = F.pad(
         | 
| 325 | 
            +
                        x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
         | 
| 326 | 
            +
                    )
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                    # Reshape and slice out the padded elements.
         | 
| 329 | 
            +
                    x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
         | 
| 330 | 
            +
                        :, :, :length, length - 1 :
         | 
| 331 | 
            +
                    ]
         | 
| 332 | 
            +
                    return x_final
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                def _absolute_position_to_relative_position(self, x):
         | 
| 335 | 
            +
                    """
         | 
| 336 | 
            +
                    x: [b, h, l, l]
         | 
| 337 | 
            +
                    ret: [b, h, l, 2*l-1]
         | 
| 338 | 
            +
                    """
         | 
| 339 | 
            +
                    batch, heads, length, _ = x.size()
         | 
| 340 | 
            +
                    # padd along column
         | 
| 341 | 
            +
                    x = F.pad(
         | 
| 342 | 
            +
                        x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
         | 
| 343 | 
            +
                    )
         | 
| 344 | 
            +
                    x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
         | 
| 345 | 
            +
                    # add 0's in the beginning that will skew the elements after reshape
         | 
| 346 | 
            +
                    x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
         | 
| 347 | 
            +
                    x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
         | 
| 348 | 
            +
                    return x_final
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                def _attention_bias_proximal(self, length):
         | 
| 351 | 
            +
                    """Bias for self-attention to encourage attention to close positions.
         | 
| 352 | 
            +
                    Args:
         | 
| 353 | 
            +
                      length: an integer scalar.
         | 
| 354 | 
            +
                    Returns:
         | 
| 355 | 
            +
                      a Tensor with shape [1, 1, length, length]
         | 
| 356 | 
            +
                    """
         | 
| 357 | 
            +
                    r = torch.arange(length, dtype=torch.float32)
         | 
| 358 | 
            +
                    diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
         | 
| 359 | 
            +
                    return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
         | 
| 360 | 
            +
             | 
| 361 | 
            +
             | 
| 362 | 
            +
            class FFN(nn.Module):
         | 
| 363 | 
            +
                def __init__(
         | 
| 364 | 
            +
                    self,
         | 
| 365 | 
            +
                    in_channels,
         | 
| 366 | 
            +
                    out_channels,
         | 
| 367 | 
            +
                    filter_channels,
         | 
| 368 | 
            +
                    kernel_size,
         | 
| 369 | 
            +
                    p_dropout=0.0,
         | 
| 370 | 
            +
                    activation=None,
         | 
| 371 | 
            +
                    causal=False,
         | 
| 372 | 
            +
                ):
         | 
| 373 | 
            +
                    super().__init__()
         | 
| 374 | 
            +
                    self.in_channels = in_channels
         | 
| 375 | 
            +
                    self.out_channels = out_channels
         | 
| 376 | 
            +
                    self.filter_channels = filter_channels
         | 
| 377 | 
            +
                    self.kernel_size = kernel_size
         | 
| 378 | 
            +
                    self.p_dropout = p_dropout
         | 
| 379 | 
            +
                    self.activation = activation
         | 
| 380 | 
            +
                    self.causal = causal
         | 
| 381 | 
            +
             | 
| 382 | 
            +
                    if causal:
         | 
| 383 | 
            +
                        self.padding = self._causal_padding
         | 
| 384 | 
            +
                    else:
         | 
| 385 | 
            +
                        self.padding = self._same_padding
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                    self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
         | 
| 388 | 
            +
                    self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
         | 
| 389 | 
            +
                    self.drop = nn.Dropout(p_dropout)
         | 
| 390 | 
            +
             | 
| 391 | 
            +
                def forward(self, x, x_mask):
         | 
| 392 | 
            +
                    x = self.conv_1(self.padding(x * x_mask))
         | 
| 393 | 
            +
                    if self.activation == "gelu":
         | 
| 394 | 
            +
                        x = x * torch.sigmoid(1.702 * x)
         | 
| 395 | 
            +
                    else:
         | 
| 396 | 
            +
                        x = torch.relu(x)
         | 
| 397 | 
            +
                    x = self.drop(x)
         | 
| 398 | 
            +
                    x = self.conv_2(self.padding(x * x_mask))
         | 
| 399 | 
            +
                    return x * x_mask
         | 
| 400 | 
            +
             | 
| 401 | 
            +
                def _causal_padding(self, x):
         | 
| 402 | 
            +
                    if self.kernel_size == 1:
         | 
| 403 | 
            +
                        return x
         | 
| 404 | 
            +
                    pad_l = self.kernel_size - 1
         | 
| 405 | 
            +
                    pad_r = 0
         | 
| 406 | 
            +
                    padding = [[0, 0], [0, 0], [pad_l, pad_r]]
         | 
| 407 | 
            +
                    x = F.pad(x, commons.convert_pad_shape(padding))
         | 
| 408 | 
            +
                    return x
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                def _same_padding(self, x):
         | 
| 411 | 
            +
                    if self.kernel_size == 1:
         | 
| 412 | 
            +
                        return x
         | 
| 413 | 
            +
                    pad_l = (self.kernel_size - 1) // 2
         | 
| 414 | 
            +
                    pad_r = self.kernel_size // 2
         | 
| 415 | 
            +
                    padding = [[0, 0], [0, 0], [pad_l, pad_r]]
         | 
| 416 | 
            +
                    x = F.pad(x, commons.convert_pad_shape(padding))
         | 
| 417 | 
            +
                    return x
         | 
    	
        lib/infer_pack/commons.py
    ADDED
    
    | @@ -0,0 +1,166 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from torch import nn
         | 
| 5 | 
            +
            from torch.nn import functional as F
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            def init_weights(m, mean=0.0, std=0.01):
         | 
| 9 | 
            +
                classname = m.__class__.__name__
         | 
| 10 | 
            +
                if classname.find("Conv") != -1:
         | 
| 11 | 
            +
                    m.weight.data.normal_(mean, std)
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            def get_padding(kernel_size, dilation=1):
         | 
| 15 | 
            +
                return int((kernel_size * dilation - dilation) / 2)
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def convert_pad_shape(pad_shape):
         | 
| 19 | 
            +
                l = pad_shape[::-1]
         | 
| 20 | 
            +
                pad_shape = [item for sublist in l for item in sublist]
         | 
| 21 | 
            +
                return pad_shape
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            def kl_divergence(m_p, logs_p, m_q, logs_q):
         | 
| 25 | 
            +
                """KL(P||Q)"""
         | 
| 26 | 
            +
                kl = (logs_q - logs_p) - 0.5
         | 
| 27 | 
            +
                kl += (
         | 
| 28 | 
            +
                    0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
         | 
| 29 | 
            +
                )
         | 
| 30 | 
            +
                return kl
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            def rand_gumbel(shape):
         | 
| 34 | 
            +
                """Sample from the Gumbel distribution, protect from overflows."""
         | 
| 35 | 
            +
                uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
         | 
| 36 | 
            +
                return -torch.log(-torch.log(uniform_samples))
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            def rand_gumbel_like(x):
         | 
| 40 | 
            +
                g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
         | 
| 41 | 
            +
                return g
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            def slice_segments(x, ids_str, segment_size=4):
         | 
| 45 | 
            +
                ret = torch.zeros_like(x[:, :, :segment_size])
         | 
| 46 | 
            +
                for i in range(x.size(0)):
         | 
| 47 | 
            +
                    idx_str = ids_str[i]
         | 
| 48 | 
            +
                    idx_end = idx_str + segment_size
         | 
| 49 | 
            +
                    ret[i] = x[i, :, idx_str:idx_end]
         | 
| 50 | 
            +
                return ret
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            def slice_segments2(x, ids_str, segment_size=4):
         | 
| 54 | 
            +
                ret = torch.zeros_like(x[:, :segment_size])
         | 
| 55 | 
            +
                for i in range(x.size(0)):
         | 
| 56 | 
            +
                    idx_str = ids_str[i]
         | 
| 57 | 
            +
                    idx_end = idx_str + segment_size
         | 
| 58 | 
            +
                    ret[i] = x[i, idx_str:idx_end]
         | 
| 59 | 
            +
                return ret
         | 
| 60 | 
            +
             | 
| 61 | 
            +
             | 
| 62 | 
            +
            def rand_slice_segments(x, x_lengths=None, segment_size=4):
         | 
| 63 | 
            +
                b, d, t = x.size()
         | 
| 64 | 
            +
                if x_lengths is None:
         | 
| 65 | 
            +
                    x_lengths = t
         | 
| 66 | 
            +
                ids_str_max = x_lengths - segment_size + 1
         | 
| 67 | 
            +
                ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
         | 
| 68 | 
            +
                ret = slice_segments(x, ids_str, segment_size)
         | 
| 69 | 
            +
                return ret, ids_str
         | 
| 70 | 
            +
             | 
| 71 | 
            +
             | 
| 72 | 
            +
            def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
         | 
| 73 | 
            +
                position = torch.arange(length, dtype=torch.float)
         | 
| 74 | 
            +
                num_timescales = channels // 2
         | 
| 75 | 
            +
                log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
         | 
| 76 | 
            +
                    num_timescales - 1
         | 
| 77 | 
            +
                )
         | 
| 78 | 
            +
                inv_timescales = min_timescale * torch.exp(
         | 
| 79 | 
            +
                    torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
         | 
| 80 | 
            +
                )
         | 
| 81 | 
            +
                scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
         | 
| 82 | 
            +
                signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
         | 
| 83 | 
            +
                signal = F.pad(signal, [0, 0, 0, channels % 2])
         | 
| 84 | 
            +
                signal = signal.view(1, channels, length)
         | 
| 85 | 
            +
                return signal
         | 
| 86 | 
            +
             | 
| 87 | 
            +
             | 
| 88 | 
            +
            def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
         | 
| 89 | 
            +
                b, channels, length = x.size()
         | 
| 90 | 
            +
                signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
         | 
| 91 | 
            +
                return x + signal.to(dtype=x.dtype, device=x.device)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
             | 
| 94 | 
            +
            def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
         | 
| 95 | 
            +
                b, channels, length = x.size()
         | 
| 96 | 
            +
                signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
         | 
| 97 | 
            +
                return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
         | 
| 98 | 
            +
             | 
| 99 | 
            +
             | 
| 100 | 
            +
            def subsequent_mask(length):
         | 
| 101 | 
            +
                mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
         | 
| 102 | 
            +
                return mask
         | 
| 103 | 
            +
             | 
| 104 | 
            +
             | 
| 105 | 
            +
            @torch.jit.script
         | 
| 106 | 
            +
            def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
         | 
| 107 | 
            +
                n_channels_int = n_channels[0]
         | 
| 108 | 
            +
                in_act = input_a + input_b
         | 
| 109 | 
            +
                t_act = torch.tanh(in_act[:, :n_channels_int, :])
         | 
| 110 | 
            +
                s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
         | 
| 111 | 
            +
                acts = t_act * s_act
         | 
| 112 | 
            +
                return acts
         | 
| 113 | 
            +
             | 
| 114 | 
            +
             | 
| 115 | 
            +
            def convert_pad_shape(pad_shape):
         | 
| 116 | 
            +
                l = pad_shape[::-1]
         | 
| 117 | 
            +
                pad_shape = [item for sublist in l for item in sublist]
         | 
| 118 | 
            +
                return pad_shape
         | 
| 119 | 
            +
             | 
| 120 | 
            +
             | 
| 121 | 
            +
            def shift_1d(x):
         | 
| 122 | 
            +
                x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
         | 
| 123 | 
            +
                return x
         | 
| 124 | 
            +
             | 
| 125 | 
            +
             | 
| 126 | 
            +
            def sequence_mask(length, max_length=None):
         | 
| 127 | 
            +
                if max_length is None:
         | 
| 128 | 
            +
                    max_length = length.max()
         | 
| 129 | 
            +
                x = torch.arange(max_length, dtype=length.dtype, device=length.device)
         | 
| 130 | 
            +
                return x.unsqueeze(0) < length.unsqueeze(1)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
             | 
| 133 | 
            +
            def generate_path(duration, mask):
         | 
| 134 | 
            +
                """
         | 
| 135 | 
            +
                duration: [b, 1, t_x]
         | 
| 136 | 
            +
                mask: [b, 1, t_y, t_x]
         | 
| 137 | 
            +
                """
         | 
| 138 | 
            +
                device = duration.device
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                b, _, t_y, t_x = mask.shape
         | 
| 141 | 
            +
                cum_duration = torch.cumsum(duration, -1)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                cum_duration_flat = cum_duration.view(b * t_x)
         | 
| 144 | 
            +
                path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
         | 
| 145 | 
            +
                path = path.view(b, t_x, t_y)
         | 
| 146 | 
            +
                path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
         | 
| 147 | 
            +
                path = path.unsqueeze(1).transpose(2, 3) * mask
         | 
| 148 | 
            +
                return path
         | 
| 149 | 
            +
             | 
| 150 | 
            +
             | 
| 151 | 
            +
            def clip_grad_value_(parameters, clip_value, norm_type=2):
         | 
| 152 | 
            +
                if isinstance(parameters, torch.Tensor):
         | 
| 153 | 
            +
                    parameters = [parameters]
         | 
| 154 | 
            +
                parameters = list(filter(lambda p: p.grad is not None, parameters))
         | 
| 155 | 
            +
                norm_type = float(norm_type)
         | 
| 156 | 
            +
                if clip_value is not None:
         | 
| 157 | 
            +
                    clip_value = float(clip_value)
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                total_norm = 0
         | 
| 160 | 
            +
                for p in parameters:
         | 
| 161 | 
            +
                    param_norm = p.grad.data.norm(norm_type)
         | 
| 162 | 
            +
                    total_norm += param_norm.item() ** norm_type
         | 
| 163 | 
            +
                    if clip_value is not None:
         | 
| 164 | 
            +
                        p.grad.data.clamp_(min=-clip_value, max=clip_value)
         | 
| 165 | 
            +
                total_norm = total_norm ** (1.0 / norm_type)
         | 
| 166 | 
            +
                return total_norm
         | 
    	
        lib/infer_pack/models.py
    ADDED
    
    | @@ -0,0 +1,1142 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math, pdb, os
         | 
| 2 | 
            +
            from time import time as ttime
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from torch import nn
         | 
| 5 | 
            +
            from torch.nn import functional as F
         | 
| 6 | 
            +
            from lib.infer_pack import modules
         | 
| 7 | 
            +
            from lib.infer_pack import attentions
         | 
| 8 | 
            +
            from lib.infer_pack import commons
         | 
| 9 | 
            +
            from lib.infer_pack.commons import init_weights, get_padding
         | 
| 10 | 
            +
            from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
         | 
| 11 | 
            +
            from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
         | 
| 12 | 
            +
            from lib.infer_pack.commons import init_weights
         | 
| 13 | 
            +
            import numpy as np
         | 
| 14 | 
            +
            from lib.infer_pack import commons
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            class TextEncoder256(nn.Module):
         | 
| 18 | 
            +
                def __init__(
         | 
| 19 | 
            +
                    self,
         | 
| 20 | 
            +
                    out_channels,
         | 
| 21 | 
            +
                    hidden_channels,
         | 
| 22 | 
            +
                    filter_channels,
         | 
| 23 | 
            +
                    n_heads,
         | 
| 24 | 
            +
                    n_layers,
         | 
| 25 | 
            +
                    kernel_size,
         | 
| 26 | 
            +
                    p_dropout,
         | 
| 27 | 
            +
                    f0=True,
         | 
| 28 | 
            +
                ):
         | 
| 29 | 
            +
                    super().__init__()
         | 
| 30 | 
            +
                    self.out_channels = out_channels
         | 
| 31 | 
            +
                    self.hidden_channels = hidden_channels
         | 
| 32 | 
            +
                    self.filter_channels = filter_channels
         | 
| 33 | 
            +
                    self.n_heads = n_heads
         | 
| 34 | 
            +
                    self.n_layers = n_layers
         | 
| 35 | 
            +
                    self.kernel_size = kernel_size
         | 
| 36 | 
            +
                    self.p_dropout = p_dropout
         | 
| 37 | 
            +
                    self.emb_phone = nn.Linear(256, hidden_channels)
         | 
| 38 | 
            +
                    self.lrelu = nn.LeakyReLU(0.1, inplace=True)
         | 
| 39 | 
            +
                    if f0 == True:
         | 
| 40 | 
            +
                        self.emb_pitch = nn.Embedding(256, hidden_channels)  # pitch 256
         | 
| 41 | 
            +
                    self.encoder = attentions.Encoder(
         | 
| 42 | 
            +
                        hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
         | 
| 43 | 
            +
                    )
         | 
| 44 | 
            +
                    self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                def forward(self, phone, pitch, lengths):
         | 
| 47 | 
            +
                    if pitch == None:
         | 
| 48 | 
            +
                        x = self.emb_phone(phone)
         | 
| 49 | 
            +
                    else:
         | 
| 50 | 
            +
                        x = self.emb_phone(phone) + self.emb_pitch(pitch)
         | 
| 51 | 
            +
                    x = x * math.sqrt(self.hidden_channels)  # [b, t, h]
         | 
| 52 | 
            +
                    x = self.lrelu(x)
         | 
| 53 | 
            +
                    x = torch.transpose(x, 1, -1)  # [b, h, t]
         | 
| 54 | 
            +
                    x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
         | 
| 55 | 
            +
                        x.dtype
         | 
| 56 | 
            +
                    )
         | 
| 57 | 
            +
                    x = self.encoder(x * x_mask, x_mask)
         | 
| 58 | 
            +
                    stats = self.proj(x) * x_mask
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    m, logs = torch.split(stats, self.out_channels, dim=1)
         | 
| 61 | 
            +
                    return m, logs, x_mask
         | 
| 62 | 
            +
             | 
| 63 | 
            +
             | 
| 64 | 
            +
            class TextEncoder768(nn.Module):
         | 
| 65 | 
            +
                def __init__(
         | 
| 66 | 
            +
                    self,
         | 
| 67 | 
            +
                    out_channels,
         | 
| 68 | 
            +
                    hidden_channels,
         | 
| 69 | 
            +
                    filter_channels,
         | 
| 70 | 
            +
                    n_heads,
         | 
| 71 | 
            +
                    n_layers,
         | 
| 72 | 
            +
                    kernel_size,
         | 
| 73 | 
            +
                    p_dropout,
         | 
| 74 | 
            +
                    f0=True,
         | 
| 75 | 
            +
                ):
         | 
| 76 | 
            +
                    super().__init__()
         | 
| 77 | 
            +
                    self.out_channels = out_channels
         | 
| 78 | 
            +
                    self.hidden_channels = hidden_channels
         | 
| 79 | 
            +
                    self.filter_channels = filter_channels
         | 
| 80 | 
            +
                    self.n_heads = n_heads
         | 
| 81 | 
            +
                    self.n_layers = n_layers
         | 
| 82 | 
            +
                    self.kernel_size = kernel_size
         | 
| 83 | 
            +
                    self.p_dropout = p_dropout
         | 
| 84 | 
            +
                    self.emb_phone = nn.Linear(768, hidden_channels)
         | 
| 85 | 
            +
                    self.lrelu = nn.LeakyReLU(0.1, inplace=True)
         | 
| 86 | 
            +
                    if f0 == True:
         | 
| 87 | 
            +
                        self.emb_pitch = nn.Embedding(256, hidden_channels)  # pitch 256
         | 
| 88 | 
            +
                    self.encoder = attentions.Encoder(
         | 
| 89 | 
            +
                        hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
         | 
| 90 | 
            +
                    )
         | 
| 91 | 
            +
                    self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                def forward(self, phone, pitch, lengths):
         | 
| 94 | 
            +
                    if pitch == None:
         | 
| 95 | 
            +
                        x = self.emb_phone(phone)
         | 
| 96 | 
            +
                    else:
         | 
| 97 | 
            +
                        x = self.emb_phone(phone) + self.emb_pitch(pitch)
         | 
| 98 | 
            +
                    x = x * math.sqrt(self.hidden_channels)  # [b, t, h]
         | 
| 99 | 
            +
                    x = self.lrelu(x)
         | 
| 100 | 
            +
                    x = torch.transpose(x, 1, -1)  # [b, h, t]
         | 
| 101 | 
            +
                    x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
         | 
| 102 | 
            +
                        x.dtype
         | 
| 103 | 
            +
                    )
         | 
| 104 | 
            +
                    x = self.encoder(x * x_mask, x_mask)
         | 
| 105 | 
            +
                    stats = self.proj(x) * x_mask
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    m, logs = torch.split(stats, self.out_channels, dim=1)
         | 
| 108 | 
            +
                    return m, logs, x_mask
         | 
| 109 | 
            +
             | 
| 110 | 
            +
             | 
| 111 | 
            +
            class ResidualCouplingBlock(nn.Module):
         | 
| 112 | 
            +
                def __init__(
         | 
| 113 | 
            +
                    self,
         | 
| 114 | 
            +
                    channels,
         | 
| 115 | 
            +
                    hidden_channels,
         | 
| 116 | 
            +
                    kernel_size,
         | 
| 117 | 
            +
                    dilation_rate,
         | 
| 118 | 
            +
                    n_layers,
         | 
| 119 | 
            +
                    n_flows=4,
         | 
| 120 | 
            +
                    gin_channels=0,
         | 
| 121 | 
            +
                ):
         | 
| 122 | 
            +
                    super().__init__()
         | 
| 123 | 
            +
                    self.channels = channels
         | 
| 124 | 
            +
                    self.hidden_channels = hidden_channels
         | 
| 125 | 
            +
                    self.kernel_size = kernel_size
         | 
| 126 | 
            +
                    self.dilation_rate = dilation_rate
         | 
| 127 | 
            +
                    self.n_layers = n_layers
         | 
| 128 | 
            +
                    self.n_flows = n_flows
         | 
| 129 | 
            +
                    self.gin_channels = gin_channels
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    self.flows = nn.ModuleList()
         | 
| 132 | 
            +
                    for i in range(n_flows):
         | 
| 133 | 
            +
                        self.flows.append(
         | 
| 134 | 
            +
                            modules.ResidualCouplingLayer(
         | 
| 135 | 
            +
                                channels,
         | 
| 136 | 
            +
                                hidden_channels,
         | 
| 137 | 
            +
                                kernel_size,
         | 
| 138 | 
            +
                                dilation_rate,
         | 
| 139 | 
            +
                                n_layers,
         | 
| 140 | 
            +
                                gin_channels=gin_channels,
         | 
| 141 | 
            +
                                mean_only=True,
         | 
| 142 | 
            +
                            )
         | 
| 143 | 
            +
                        )
         | 
| 144 | 
            +
                        self.flows.append(modules.Flip())
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                def forward(self, x, x_mask, g=None, reverse=False):
         | 
| 147 | 
            +
                    if not reverse:
         | 
| 148 | 
            +
                        for flow in self.flows:
         | 
| 149 | 
            +
                            x, _ = flow(x, x_mask, g=g, reverse=reverse)
         | 
| 150 | 
            +
                    else:
         | 
| 151 | 
            +
                        for flow in reversed(self.flows):
         | 
| 152 | 
            +
                            x = flow(x, x_mask, g=g, reverse=reverse)
         | 
| 153 | 
            +
                    return x
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                def remove_weight_norm(self):
         | 
| 156 | 
            +
                    for i in range(self.n_flows):
         | 
| 157 | 
            +
                        self.flows[i * 2].remove_weight_norm()
         | 
| 158 | 
            +
             | 
| 159 | 
            +
             | 
| 160 | 
            +
            class PosteriorEncoder(nn.Module):
         | 
| 161 | 
            +
                def __init__(
         | 
| 162 | 
            +
                    self,
         | 
| 163 | 
            +
                    in_channels,
         | 
| 164 | 
            +
                    out_channels,
         | 
| 165 | 
            +
                    hidden_channels,
         | 
| 166 | 
            +
                    kernel_size,
         | 
| 167 | 
            +
                    dilation_rate,
         | 
| 168 | 
            +
                    n_layers,
         | 
| 169 | 
            +
                    gin_channels=0,
         | 
| 170 | 
            +
                ):
         | 
| 171 | 
            +
                    super().__init__()
         | 
| 172 | 
            +
                    self.in_channels = in_channels
         | 
| 173 | 
            +
                    self.out_channels = out_channels
         | 
| 174 | 
            +
                    self.hidden_channels = hidden_channels
         | 
| 175 | 
            +
                    self.kernel_size = kernel_size
         | 
| 176 | 
            +
                    self.dilation_rate = dilation_rate
         | 
| 177 | 
            +
                    self.n_layers = n_layers
         | 
| 178 | 
            +
                    self.gin_channels = gin_channels
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                    self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
         | 
| 181 | 
            +
                    self.enc = modules.WN(
         | 
| 182 | 
            +
                        hidden_channels,
         | 
| 183 | 
            +
                        kernel_size,
         | 
| 184 | 
            +
                        dilation_rate,
         | 
| 185 | 
            +
                        n_layers,
         | 
| 186 | 
            +
                        gin_channels=gin_channels,
         | 
| 187 | 
            +
                    )
         | 
| 188 | 
            +
                    self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                def forward(self, x, x_lengths, g=None):
         | 
| 191 | 
            +
                    x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
         | 
| 192 | 
            +
                        x.dtype
         | 
| 193 | 
            +
                    )
         | 
| 194 | 
            +
                    x = self.pre(x) * x_mask
         | 
| 195 | 
            +
                    x = self.enc(x, x_mask, g=g)
         | 
| 196 | 
            +
                    stats = self.proj(x) * x_mask
         | 
| 197 | 
            +
                    m, logs = torch.split(stats, self.out_channels, dim=1)
         | 
| 198 | 
            +
                    z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
         | 
| 199 | 
            +
                    return z, m, logs, x_mask
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                def remove_weight_norm(self):
         | 
| 202 | 
            +
                    self.enc.remove_weight_norm()
         | 
| 203 | 
            +
             | 
| 204 | 
            +
             | 
| 205 | 
            +
            class Generator(torch.nn.Module):
         | 
| 206 | 
            +
                def __init__(
         | 
| 207 | 
            +
                    self,
         | 
| 208 | 
            +
                    initial_channel,
         | 
| 209 | 
            +
                    resblock,
         | 
| 210 | 
            +
                    resblock_kernel_sizes,
         | 
| 211 | 
            +
                    resblock_dilation_sizes,
         | 
| 212 | 
            +
                    upsample_rates,
         | 
| 213 | 
            +
                    upsample_initial_channel,
         | 
| 214 | 
            +
                    upsample_kernel_sizes,
         | 
| 215 | 
            +
                    gin_channels=0,
         | 
| 216 | 
            +
                ):
         | 
| 217 | 
            +
                    super(Generator, self).__init__()
         | 
| 218 | 
            +
                    self.num_kernels = len(resblock_kernel_sizes)
         | 
| 219 | 
            +
                    self.num_upsamples = len(upsample_rates)
         | 
| 220 | 
            +
                    self.conv_pre = Conv1d(
         | 
| 221 | 
            +
                        initial_channel, upsample_initial_channel, 7, 1, padding=3
         | 
| 222 | 
            +
                    )
         | 
| 223 | 
            +
                    resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                    self.ups = nn.ModuleList()
         | 
| 226 | 
            +
                    for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
         | 
| 227 | 
            +
                        self.ups.append(
         | 
| 228 | 
            +
                            weight_norm(
         | 
| 229 | 
            +
                                ConvTranspose1d(
         | 
| 230 | 
            +
                                    upsample_initial_channel // (2**i),
         | 
| 231 | 
            +
                                    upsample_initial_channel // (2 ** (i + 1)),
         | 
| 232 | 
            +
                                    k,
         | 
| 233 | 
            +
                                    u,
         | 
| 234 | 
            +
                                    padding=(k - u) // 2,
         | 
| 235 | 
            +
                                )
         | 
| 236 | 
            +
                            )
         | 
| 237 | 
            +
                        )
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                    self.resblocks = nn.ModuleList()
         | 
| 240 | 
            +
                    for i in range(len(self.ups)):
         | 
| 241 | 
            +
                        ch = upsample_initial_channel // (2 ** (i + 1))
         | 
| 242 | 
            +
                        for j, (k, d) in enumerate(
         | 
| 243 | 
            +
                            zip(resblock_kernel_sizes, resblock_dilation_sizes)
         | 
| 244 | 
            +
                        ):
         | 
| 245 | 
            +
                            self.resblocks.append(resblock(ch, k, d))
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                    self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
         | 
| 248 | 
            +
                    self.ups.apply(init_weights)
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                    if gin_channels != 0:
         | 
| 251 | 
            +
                        self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                def forward(self, x, g=None):
         | 
| 254 | 
            +
                    x = self.conv_pre(x)
         | 
| 255 | 
            +
                    if g is not None:
         | 
| 256 | 
            +
                        x = x + self.cond(g)
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                    for i in range(self.num_upsamples):
         | 
| 259 | 
            +
                        x = F.leaky_relu(x, modules.LRELU_SLOPE)
         | 
| 260 | 
            +
                        x = self.ups[i](x)
         | 
| 261 | 
            +
                        xs = None
         | 
| 262 | 
            +
                        for j in range(self.num_kernels):
         | 
| 263 | 
            +
                            if xs is None:
         | 
| 264 | 
            +
                                xs = self.resblocks[i * self.num_kernels + j](x)
         | 
| 265 | 
            +
                            else:
         | 
| 266 | 
            +
                                xs += self.resblocks[i * self.num_kernels + j](x)
         | 
| 267 | 
            +
                        x = xs / self.num_kernels
         | 
| 268 | 
            +
                    x = F.leaky_relu(x)
         | 
| 269 | 
            +
                    x = self.conv_post(x)
         | 
| 270 | 
            +
                    x = torch.tanh(x)
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                    return x
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                def remove_weight_norm(self):
         | 
| 275 | 
            +
                    for l in self.ups:
         | 
| 276 | 
            +
                        remove_weight_norm(l)
         | 
| 277 | 
            +
                    for l in self.resblocks:
         | 
| 278 | 
            +
                        l.remove_weight_norm()
         | 
| 279 | 
            +
             | 
| 280 | 
            +
             | 
| 281 | 
            +
            class SineGen(torch.nn.Module):
         | 
| 282 | 
            +
                """Definition of sine generator
         | 
| 283 | 
            +
                SineGen(samp_rate, harmonic_num = 0,
         | 
| 284 | 
            +
                        sine_amp = 0.1, noise_std = 0.003,
         | 
| 285 | 
            +
                        voiced_threshold = 0,
         | 
| 286 | 
            +
                        flag_for_pulse=False)
         | 
| 287 | 
            +
                samp_rate: sampling rate in Hz
         | 
| 288 | 
            +
                harmonic_num: number of harmonic overtones (default 0)
         | 
| 289 | 
            +
                sine_amp: amplitude of sine-wavefrom (default 0.1)
         | 
| 290 | 
            +
                noise_std: std of Gaussian noise (default 0.003)
         | 
| 291 | 
            +
                voiced_thoreshold: F0 threshold for U/V classification (default 0)
         | 
| 292 | 
            +
                flag_for_pulse: this SinGen is used inside PulseGen (default False)
         | 
| 293 | 
            +
                Note: when flag_for_pulse is True, the first time step of a voiced
         | 
| 294 | 
            +
                    segment is always sin(np.pi) or cos(0)
         | 
| 295 | 
            +
                """
         | 
| 296 | 
            +
             | 
| 297 | 
            +
                def __init__(
         | 
| 298 | 
            +
                    self,
         | 
| 299 | 
            +
                    samp_rate,
         | 
| 300 | 
            +
                    harmonic_num=0,
         | 
| 301 | 
            +
                    sine_amp=0.1,
         | 
| 302 | 
            +
                    noise_std=0.003,
         | 
| 303 | 
            +
                    voiced_threshold=0,
         | 
| 304 | 
            +
                    flag_for_pulse=False,
         | 
| 305 | 
            +
                ):
         | 
| 306 | 
            +
                    super(SineGen, self).__init__()
         | 
| 307 | 
            +
                    self.sine_amp = sine_amp
         | 
| 308 | 
            +
                    self.noise_std = noise_std
         | 
| 309 | 
            +
                    self.harmonic_num = harmonic_num
         | 
| 310 | 
            +
                    self.dim = self.harmonic_num + 1
         | 
| 311 | 
            +
                    self.sampling_rate = samp_rate
         | 
| 312 | 
            +
                    self.voiced_threshold = voiced_threshold
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                def _f02uv(self, f0):
         | 
| 315 | 
            +
                    # generate uv signal
         | 
| 316 | 
            +
                    uv = torch.ones_like(f0)
         | 
| 317 | 
            +
                    uv = uv * (f0 > self.voiced_threshold)
         | 
| 318 | 
            +
                    return uv
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                def forward(self, f0, upp):
         | 
| 321 | 
            +
                    """sine_tensor, uv = forward(f0)
         | 
| 322 | 
            +
                    input F0: tensor(batchsize=1, length, dim=1)
         | 
| 323 | 
            +
                              f0 for unvoiced steps should be 0
         | 
| 324 | 
            +
                    output sine_tensor: tensor(batchsize=1, length, dim)
         | 
| 325 | 
            +
                    output uv: tensor(batchsize=1, length, 1)
         | 
| 326 | 
            +
                    """
         | 
| 327 | 
            +
                    with torch.no_grad():
         | 
| 328 | 
            +
                        f0 = f0[:, None].transpose(1, 2)
         | 
| 329 | 
            +
                        f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
         | 
| 330 | 
            +
                        # fundamental component
         | 
| 331 | 
            +
                        f0_buf[:, :, 0] = f0[:, :, 0]
         | 
| 332 | 
            +
                        for idx in np.arange(self.harmonic_num):
         | 
| 333 | 
            +
                            f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (
         | 
| 334 | 
            +
                                idx + 2
         | 
| 335 | 
            +
                            )  # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
         | 
| 336 | 
            +
                        rad_values = (f0_buf / self.sampling_rate) % 1  ###%1 means that the product of n_har cannot be post-processed and optimized
         | 
| 337 | 
            +
                        rand_ini = torch.rand(
         | 
| 338 | 
            +
                            f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device
         | 
| 339 | 
            +
                        )
         | 
| 340 | 
            +
                        rand_ini[:, 0] = 0
         | 
| 341 | 
            +
                        rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
         | 
| 342 | 
            +
                        tmp_over_one = torch.cumsum(rad_values, 1)  # % 1  #####%1 means that the following cumsum can no longer be optimized
         | 
| 343 | 
            +
                        tmp_over_one *= upp
         | 
| 344 | 
            +
                        tmp_over_one = F.interpolate(
         | 
| 345 | 
            +
                            tmp_over_one.transpose(2, 1),
         | 
| 346 | 
            +
                            scale_factor=upp,
         | 
| 347 | 
            +
                            mode="linear",
         | 
| 348 | 
            +
                            align_corners=True,
         | 
| 349 | 
            +
                        ).transpose(2, 1)
         | 
| 350 | 
            +
                        rad_values = F.interpolate(
         | 
| 351 | 
            +
                            rad_values.transpose(2, 1), scale_factor=upp, mode="nearest"
         | 
| 352 | 
            +
                        ).transpose(
         | 
| 353 | 
            +
                            2, 1
         | 
| 354 | 
            +
                        )  #######
         | 
| 355 | 
            +
                        tmp_over_one %= 1
         | 
| 356 | 
            +
                        tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
         | 
| 357 | 
            +
                        cumsum_shift = torch.zeros_like(rad_values)
         | 
| 358 | 
            +
                        cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
         | 
| 359 | 
            +
                        sine_waves = torch.sin(
         | 
| 360 | 
            +
                            torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi
         | 
| 361 | 
            +
                        )
         | 
| 362 | 
            +
                        sine_waves = sine_waves * self.sine_amp
         | 
| 363 | 
            +
                        uv = self._f02uv(f0)
         | 
| 364 | 
            +
                        uv = F.interpolate(
         | 
| 365 | 
            +
                            uv.transpose(2, 1), scale_factor=upp, mode="nearest"
         | 
| 366 | 
            +
                        ).transpose(2, 1)
         | 
| 367 | 
            +
                        noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
         | 
| 368 | 
            +
                        noise = noise_amp * torch.randn_like(sine_waves)
         | 
| 369 | 
            +
                        sine_waves = sine_waves * uv + noise
         | 
| 370 | 
            +
                    return sine_waves, uv, noise
         | 
| 371 | 
            +
             | 
| 372 | 
            +
             | 
| 373 | 
            +
            class SourceModuleHnNSF(torch.nn.Module):
         | 
| 374 | 
            +
                """SourceModule for hn-nsf
         | 
| 375 | 
            +
                SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
         | 
| 376 | 
            +
                             add_noise_std=0.003, voiced_threshod=0)
         | 
| 377 | 
            +
                sampling_rate: sampling_rate in Hz
         | 
| 378 | 
            +
                harmonic_num: number of harmonic above F0 (default: 0)
         | 
| 379 | 
            +
                sine_amp: amplitude of sine source signal (default: 0.1)
         | 
| 380 | 
            +
                add_noise_std: std of additive Gaussian noise (default: 0.003)
         | 
| 381 | 
            +
                    note that amplitude of noise in unvoiced is decided
         | 
| 382 | 
            +
                    by sine_amp
         | 
| 383 | 
            +
                voiced_threshold: threhold to set U/V given F0 (default: 0)
         | 
| 384 | 
            +
                Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
         | 
| 385 | 
            +
                F0_sampled (batchsize, length, 1)
         | 
| 386 | 
            +
                Sine_source (batchsize, length, 1)
         | 
| 387 | 
            +
                noise_source (batchsize, length 1)
         | 
| 388 | 
            +
                uv (batchsize, length, 1)
         | 
| 389 | 
            +
                """
         | 
| 390 | 
            +
             | 
| 391 | 
            +
                def __init__(
         | 
| 392 | 
            +
                    self,
         | 
| 393 | 
            +
                    sampling_rate,
         | 
| 394 | 
            +
                    harmonic_num=0,
         | 
| 395 | 
            +
                    sine_amp=0.1,
         | 
| 396 | 
            +
                    add_noise_std=0.003,
         | 
| 397 | 
            +
                    voiced_threshod=0,
         | 
| 398 | 
            +
                    is_half=True,
         | 
| 399 | 
            +
                ):
         | 
| 400 | 
            +
                    super(SourceModuleHnNSF, self).__init__()
         | 
| 401 | 
            +
             | 
| 402 | 
            +
                    self.sine_amp = sine_amp
         | 
| 403 | 
            +
                    self.noise_std = add_noise_std
         | 
| 404 | 
            +
                    self.is_half = is_half
         | 
| 405 | 
            +
                    # to produce sine waveforms
         | 
| 406 | 
            +
                    self.l_sin_gen = SineGen(
         | 
| 407 | 
            +
                        sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod
         | 
| 408 | 
            +
                    )
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                    # to merge source harmonics into a single excitation
         | 
| 411 | 
            +
                    self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
         | 
| 412 | 
            +
                    self.l_tanh = torch.nn.Tanh()
         | 
| 413 | 
            +
             | 
| 414 | 
            +
                def forward(self, x, upp=None):
         | 
| 415 | 
            +
                    sine_wavs, uv, _ = self.l_sin_gen(x, upp)
         | 
| 416 | 
            +
                    if self.is_half:
         | 
| 417 | 
            +
                        sine_wavs = sine_wavs.half()
         | 
| 418 | 
            +
                    sine_merge = self.l_tanh(self.l_linear(sine_wavs))
         | 
| 419 | 
            +
                    return sine_merge, None, None  # noise, uv
         | 
| 420 | 
            +
             | 
| 421 | 
            +
             | 
| 422 | 
            +
            class GeneratorNSF(torch.nn.Module):
         | 
| 423 | 
            +
                def __init__(
         | 
| 424 | 
            +
                    self,
         | 
| 425 | 
            +
                    initial_channel,
         | 
| 426 | 
            +
                    resblock,
         | 
| 427 | 
            +
                    resblock_kernel_sizes,
         | 
| 428 | 
            +
                    resblock_dilation_sizes,
         | 
| 429 | 
            +
                    upsample_rates,
         | 
| 430 | 
            +
                    upsample_initial_channel,
         | 
| 431 | 
            +
                    upsample_kernel_sizes,
         | 
| 432 | 
            +
                    gin_channels,
         | 
| 433 | 
            +
                    sr,
         | 
| 434 | 
            +
                    is_half=False,
         | 
| 435 | 
            +
                ):
         | 
| 436 | 
            +
                    super(GeneratorNSF, self).__init__()
         | 
| 437 | 
            +
                    self.num_kernels = len(resblock_kernel_sizes)
         | 
| 438 | 
            +
                    self.num_upsamples = len(upsample_rates)
         | 
| 439 | 
            +
             | 
| 440 | 
            +
                    self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
         | 
| 441 | 
            +
                    self.m_source = SourceModuleHnNSF(
         | 
| 442 | 
            +
                        sampling_rate=sr, harmonic_num=0, is_half=is_half
         | 
| 443 | 
            +
                    )
         | 
| 444 | 
            +
                    self.noise_convs = nn.ModuleList()
         | 
| 445 | 
            +
                    self.conv_pre = Conv1d(
         | 
| 446 | 
            +
                        initial_channel, upsample_initial_channel, 7, 1, padding=3
         | 
| 447 | 
            +
                    )
         | 
| 448 | 
            +
                    resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
         | 
| 449 | 
            +
             | 
| 450 | 
            +
                    self.ups = nn.ModuleList()
         | 
| 451 | 
            +
                    for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
         | 
| 452 | 
            +
                        c_cur = upsample_initial_channel // (2 ** (i + 1))
         | 
| 453 | 
            +
                        self.ups.append(
         | 
| 454 | 
            +
                            weight_norm(
         | 
| 455 | 
            +
                                ConvTranspose1d(
         | 
| 456 | 
            +
                                    upsample_initial_channel // (2**i),
         | 
| 457 | 
            +
                                    upsample_initial_channel // (2 ** (i + 1)),
         | 
| 458 | 
            +
                                    k,
         | 
| 459 | 
            +
                                    u,
         | 
| 460 | 
            +
                                    padding=(k - u) // 2,
         | 
| 461 | 
            +
                                )
         | 
| 462 | 
            +
                            )
         | 
| 463 | 
            +
                        )
         | 
| 464 | 
            +
                        if i + 1 < len(upsample_rates):
         | 
| 465 | 
            +
                            stride_f0 = np.prod(upsample_rates[i + 1 :])
         | 
| 466 | 
            +
                            self.noise_convs.append(
         | 
| 467 | 
            +
                                Conv1d(
         | 
| 468 | 
            +
                                    1,
         | 
| 469 | 
            +
                                    c_cur,
         | 
| 470 | 
            +
                                    kernel_size=stride_f0 * 2,
         | 
| 471 | 
            +
                                    stride=stride_f0,
         | 
| 472 | 
            +
                                    padding=stride_f0 // 2,
         | 
| 473 | 
            +
                                )
         | 
| 474 | 
            +
                            )
         | 
| 475 | 
            +
                        else:
         | 
| 476 | 
            +
                            self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
         | 
| 477 | 
            +
             | 
| 478 | 
            +
                    self.resblocks = nn.ModuleList()
         | 
| 479 | 
            +
                    for i in range(len(self.ups)):
         | 
| 480 | 
            +
                        ch = upsample_initial_channel // (2 ** (i + 1))
         | 
| 481 | 
            +
                        for j, (k, d) in enumerate(
         | 
| 482 | 
            +
                            zip(resblock_kernel_sizes, resblock_dilation_sizes)
         | 
| 483 | 
            +
                        ):
         | 
| 484 | 
            +
                            self.resblocks.append(resblock(ch, k, d))
         | 
| 485 | 
            +
             | 
| 486 | 
            +
                    self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
         | 
| 487 | 
            +
                    self.ups.apply(init_weights)
         | 
| 488 | 
            +
             | 
| 489 | 
            +
                    if gin_channels != 0:
         | 
| 490 | 
            +
                        self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
         | 
| 491 | 
            +
             | 
| 492 | 
            +
                    self.upp = np.prod(upsample_rates)
         | 
| 493 | 
            +
             | 
| 494 | 
            +
                def forward(self, x, f0, g=None):
         | 
| 495 | 
            +
                    har_source, noi_source, uv = self.m_source(f0, self.upp)
         | 
| 496 | 
            +
                    har_source = har_source.transpose(1, 2)
         | 
| 497 | 
            +
                    x = self.conv_pre(x)
         | 
| 498 | 
            +
                    if g is not None:
         | 
| 499 | 
            +
                        x = x + self.cond(g)
         | 
| 500 | 
            +
             | 
| 501 | 
            +
                    for i in range(self.num_upsamples):
         | 
| 502 | 
            +
                        x = F.leaky_relu(x, modules.LRELU_SLOPE)
         | 
| 503 | 
            +
                        x = self.ups[i](x)
         | 
| 504 | 
            +
                        x_source = self.noise_convs[i](har_source)
         | 
| 505 | 
            +
                        x = x + x_source
         | 
| 506 | 
            +
                        xs = None
         | 
| 507 | 
            +
                        for j in range(self.num_kernels):
         | 
| 508 | 
            +
                            if xs is None:
         | 
| 509 | 
            +
                                xs = self.resblocks[i * self.num_kernels + j](x)
         | 
| 510 | 
            +
                            else:
         | 
| 511 | 
            +
                                xs += self.resblocks[i * self.num_kernels + j](x)
         | 
| 512 | 
            +
                        x = xs / self.num_kernels
         | 
| 513 | 
            +
                    x = F.leaky_relu(x)
         | 
| 514 | 
            +
                    x = self.conv_post(x)
         | 
| 515 | 
            +
                    x = torch.tanh(x)
         | 
| 516 | 
            +
                    return x
         | 
| 517 | 
            +
             | 
| 518 | 
            +
                def remove_weight_norm(self):
         | 
| 519 | 
            +
                    for l in self.ups:
         | 
| 520 | 
            +
                        remove_weight_norm(l)
         | 
| 521 | 
            +
                    for l in self.resblocks:
         | 
| 522 | 
            +
                        l.remove_weight_norm()
         | 
| 523 | 
            +
             | 
| 524 | 
            +
             | 
| 525 | 
            +
            sr2sr = {
         | 
| 526 | 
            +
                "32k": 32000,
         | 
| 527 | 
            +
                "40k": 40000,
         | 
| 528 | 
            +
                "48k": 48000,
         | 
| 529 | 
            +
            }
         | 
| 530 | 
            +
             | 
| 531 | 
            +
             | 
| 532 | 
            +
            class SynthesizerTrnMs256NSFsid(nn.Module):
         | 
| 533 | 
            +
                def __init__(
         | 
| 534 | 
            +
                    self,
         | 
| 535 | 
            +
                    spec_channels,
         | 
| 536 | 
            +
                    segment_size,
         | 
| 537 | 
            +
                    inter_channels,
         | 
| 538 | 
            +
                    hidden_channels,
         | 
| 539 | 
            +
                    filter_channels,
         | 
| 540 | 
            +
                    n_heads,
         | 
| 541 | 
            +
                    n_layers,
         | 
| 542 | 
            +
                    kernel_size,
         | 
| 543 | 
            +
                    p_dropout,
         | 
| 544 | 
            +
                    resblock,
         | 
| 545 | 
            +
                    resblock_kernel_sizes,
         | 
| 546 | 
            +
                    resblock_dilation_sizes,
         | 
| 547 | 
            +
                    upsample_rates,
         | 
| 548 | 
            +
                    upsample_initial_channel,
         | 
| 549 | 
            +
                    upsample_kernel_sizes,
         | 
| 550 | 
            +
                    spk_embed_dim,
         | 
| 551 | 
            +
                    gin_channels,
         | 
| 552 | 
            +
                    sr,
         | 
| 553 | 
            +
                    **kwargs
         | 
| 554 | 
            +
                ):
         | 
| 555 | 
            +
                    super().__init__()
         | 
| 556 | 
            +
                    if type(sr) == type("strr"):
         | 
| 557 | 
            +
                        sr = sr2sr[sr]
         | 
| 558 | 
            +
                    self.spec_channels = spec_channels
         | 
| 559 | 
            +
                    self.inter_channels = inter_channels
         | 
| 560 | 
            +
                    self.hidden_channels = hidden_channels
         | 
| 561 | 
            +
                    self.filter_channels = filter_channels
         | 
| 562 | 
            +
                    self.n_heads = n_heads
         | 
| 563 | 
            +
                    self.n_layers = n_layers
         | 
| 564 | 
            +
                    self.kernel_size = kernel_size
         | 
| 565 | 
            +
                    self.p_dropout = p_dropout
         | 
| 566 | 
            +
                    self.resblock = resblock
         | 
| 567 | 
            +
                    self.resblock_kernel_sizes = resblock_kernel_sizes
         | 
| 568 | 
            +
                    self.resblock_dilation_sizes = resblock_dilation_sizes
         | 
| 569 | 
            +
                    self.upsample_rates = upsample_rates
         | 
| 570 | 
            +
                    self.upsample_initial_channel = upsample_initial_channel
         | 
| 571 | 
            +
                    self.upsample_kernel_sizes = upsample_kernel_sizes
         | 
| 572 | 
            +
                    self.segment_size = segment_size
         | 
| 573 | 
            +
                    self.gin_channels = gin_channels
         | 
| 574 | 
            +
                    # self.hop_length = hop_length#
         | 
| 575 | 
            +
                    self.spk_embed_dim = spk_embed_dim
         | 
| 576 | 
            +
                    self.enc_p = TextEncoder256(
         | 
| 577 | 
            +
                        inter_channels,
         | 
| 578 | 
            +
                        hidden_channels,
         | 
| 579 | 
            +
                        filter_channels,
         | 
| 580 | 
            +
                        n_heads,
         | 
| 581 | 
            +
                        n_layers,
         | 
| 582 | 
            +
                        kernel_size,
         | 
| 583 | 
            +
                        p_dropout,
         | 
| 584 | 
            +
                    )
         | 
| 585 | 
            +
                    self.dec = GeneratorNSF(
         | 
| 586 | 
            +
                        inter_channels,
         | 
| 587 | 
            +
                        resblock,
         | 
| 588 | 
            +
                        resblock_kernel_sizes,
         | 
| 589 | 
            +
                        resblock_dilation_sizes,
         | 
| 590 | 
            +
                        upsample_rates,
         | 
| 591 | 
            +
                        upsample_initial_channel,
         | 
| 592 | 
            +
                        upsample_kernel_sizes,
         | 
| 593 | 
            +
                        gin_channels=gin_channels,
         | 
| 594 | 
            +
                        sr=sr,
         | 
| 595 | 
            +
                        is_half=kwargs["is_half"],
         | 
| 596 | 
            +
                    )
         | 
| 597 | 
            +
                    self.enc_q = PosteriorEncoder(
         | 
| 598 | 
            +
                        spec_channels,
         | 
| 599 | 
            +
                        inter_channels,
         | 
| 600 | 
            +
                        hidden_channels,
         | 
| 601 | 
            +
                        5,
         | 
| 602 | 
            +
                        1,
         | 
| 603 | 
            +
                        16,
         | 
| 604 | 
            +
                        gin_channels=gin_channels,
         | 
| 605 | 
            +
                    )
         | 
| 606 | 
            +
                    self.flow = ResidualCouplingBlock(
         | 
| 607 | 
            +
                        inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
         | 
| 608 | 
            +
                    )
         | 
| 609 | 
            +
                    self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
         | 
| 610 | 
            +
                    print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
         | 
| 611 | 
            +
             | 
| 612 | 
            +
                def remove_weight_norm(self):
         | 
| 613 | 
            +
                    self.dec.remove_weight_norm()
         | 
| 614 | 
            +
                    self.flow.remove_weight_norm()
         | 
| 615 | 
            +
                    self.enc_q.remove_weight_norm()
         | 
| 616 | 
            +
             | 
| 617 | 
            +
                def forward(
         | 
| 618 | 
            +
                    self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds
         | 
| 619 | 
            +
                ):  # Here ds is id, [bs,1]
         | 
| 620 | 
            +
                    # print(1,pitch.shape)#[bs,t]
         | 
| 621 | 
            +
                    g = self.emb_g(ds).unsqueeze(-1)  # [b, 256, 1]##1 is t, broadcast
         | 
| 622 | 
            +
                    m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
         | 
| 623 | 
            +
                    z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
         | 
| 624 | 
            +
                    z_p = self.flow(z, y_mask, g=g)
         | 
| 625 | 
            +
                    z_slice, ids_slice = commons.rand_slice_segments(
         | 
| 626 | 
            +
                        z, y_lengths, self.segment_size
         | 
| 627 | 
            +
                    )
         | 
| 628 | 
            +
                    # print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length)
         | 
| 629 | 
            +
                    pitchf = commons.slice_segments2(pitchf, ids_slice, self.segment_size)
         | 
| 630 | 
            +
                    # print(-2,pitchf.shape,z_slice.shape)
         | 
| 631 | 
            +
                    o = self.dec(z_slice, pitchf, g=g)
         | 
| 632 | 
            +
                    return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
         | 
| 633 | 
            +
             | 
| 634 | 
            +
                def infer(self, phone, phone_lengths, pitch, nsff0, sid, rate=None):
         | 
| 635 | 
            +
                    g = self.emb_g(sid).unsqueeze(-1)
         | 
| 636 | 
            +
                    m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
         | 
| 637 | 
            +
                    z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
         | 
| 638 | 
            +
                    if rate:
         | 
| 639 | 
            +
                        head = int(z_p.shape[2] * rate)
         | 
| 640 | 
            +
                        z_p = z_p[:, :, -head:]
         | 
| 641 | 
            +
                        x_mask = x_mask[:, :, -head:]
         | 
| 642 | 
            +
                        nsff0 = nsff0[:, -head:]
         | 
| 643 | 
            +
                    z = self.flow(z_p, x_mask, g=g, reverse=True)
         | 
| 644 | 
            +
                    o = self.dec(z * x_mask, nsff0, g=g)
         | 
| 645 | 
            +
                    return o, x_mask, (z, z_p, m_p, logs_p)
         | 
| 646 | 
            +
             | 
| 647 | 
            +
             | 
| 648 | 
            +
            class SynthesizerTrnMs768NSFsid(nn.Module):
         | 
| 649 | 
            +
                def __init__(
         | 
| 650 | 
            +
                    self,
         | 
| 651 | 
            +
                    spec_channels,
         | 
| 652 | 
            +
                    segment_size,
         | 
| 653 | 
            +
                    inter_channels,
         | 
| 654 | 
            +
                    hidden_channels,
         | 
| 655 | 
            +
                    filter_channels,
         | 
| 656 | 
            +
                    n_heads,
         | 
| 657 | 
            +
                    n_layers,
         | 
| 658 | 
            +
                    kernel_size,
         | 
| 659 | 
            +
                    p_dropout,
         | 
| 660 | 
            +
                    resblock,
         | 
| 661 | 
            +
                    resblock_kernel_sizes,
         | 
| 662 | 
            +
                    resblock_dilation_sizes,
         | 
| 663 | 
            +
                    upsample_rates,
         | 
| 664 | 
            +
                    upsample_initial_channel,
         | 
| 665 | 
            +
                    upsample_kernel_sizes,
         | 
| 666 | 
            +
                    spk_embed_dim,
         | 
| 667 | 
            +
                    gin_channels,
         | 
| 668 | 
            +
                    sr,
         | 
| 669 | 
            +
                    **kwargs
         | 
| 670 | 
            +
                ):
         | 
| 671 | 
            +
                    super().__init__()
         | 
| 672 | 
            +
                    if type(sr) == type("strr"):
         | 
| 673 | 
            +
                        sr = sr2sr[sr]
         | 
| 674 | 
            +
                    self.spec_channels = spec_channels
         | 
| 675 | 
            +
                    self.inter_channels = inter_channels
         | 
| 676 | 
            +
                    self.hidden_channels = hidden_channels
         | 
| 677 | 
            +
                    self.filter_channels = filter_channels
         | 
| 678 | 
            +
                    self.n_heads = n_heads
         | 
| 679 | 
            +
                    self.n_layers = n_layers
         | 
| 680 | 
            +
                    self.kernel_size = kernel_size
         | 
| 681 | 
            +
                    self.p_dropout = p_dropout
         | 
| 682 | 
            +
                    self.resblock = resblock
         | 
| 683 | 
            +
                    self.resblock_kernel_sizes = resblock_kernel_sizes
         | 
| 684 | 
            +
                    self.resblock_dilation_sizes = resblock_dilation_sizes
         | 
| 685 | 
            +
                    self.upsample_rates = upsample_rates
         | 
| 686 | 
            +
                    self.upsample_initial_channel = upsample_initial_channel
         | 
| 687 | 
            +
                    self.upsample_kernel_sizes = upsample_kernel_sizes
         | 
| 688 | 
            +
                    self.segment_size = segment_size
         | 
| 689 | 
            +
                    self.gin_channels = gin_channels
         | 
| 690 | 
            +
                    # self.hop_length = hop_length#
         | 
| 691 | 
            +
                    self.spk_embed_dim = spk_embed_dim
         | 
| 692 | 
            +
                    self.enc_p = TextEncoder768(
         | 
| 693 | 
            +
                        inter_channels,
         | 
| 694 | 
            +
                        hidden_channels,
         | 
| 695 | 
            +
                        filter_channels,
         | 
| 696 | 
            +
                        n_heads,
         | 
| 697 | 
            +
                        n_layers,
         | 
| 698 | 
            +
                        kernel_size,
         | 
| 699 | 
            +
                        p_dropout,
         | 
| 700 | 
            +
                    )
         | 
| 701 | 
            +
                    self.dec = GeneratorNSF(
         | 
| 702 | 
            +
                        inter_channels,
         | 
| 703 | 
            +
                        resblock,
         | 
| 704 | 
            +
                        resblock_kernel_sizes,
         | 
| 705 | 
            +
                        resblock_dilation_sizes,
         | 
| 706 | 
            +
                        upsample_rates,
         | 
| 707 | 
            +
                        upsample_initial_channel,
         | 
| 708 | 
            +
                        upsample_kernel_sizes,
         | 
| 709 | 
            +
                        gin_channels=gin_channels,
         | 
| 710 | 
            +
                        sr=sr,
         | 
| 711 | 
            +
                        is_half=kwargs["is_half"],
         | 
| 712 | 
            +
                    )
         | 
| 713 | 
            +
                    self.enc_q = PosteriorEncoder(
         | 
| 714 | 
            +
                        spec_channels,
         | 
| 715 | 
            +
                        inter_channels,
         | 
| 716 | 
            +
                        hidden_channels,
         | 
| 717 | 
            +
                        5,
         | 
| 718 | 
            +
                        1,
         | 
| 719 | 
            +
                        16,
         | 
| 720 | 
            +
                        gin_channels=gin_channels,
         | 
| 721 | 
            +
                    )
         | 
| 722 | 
            +
                    self.flow = ResidualCouplingBlock(
         | 
| 723 | 
            +
                        inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
         | 
| 724 | 
            +
                    )
         | 
| 725 | 
            +
                    self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
         | 
| 726 | 
            +
                    print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
         | 
| 727 | 
            +
             | 
| 728 | 
            +
                def remove_weight_norm(self):
         | 
| 729 | 
            +
                    self.dec.remove_weight_norm()
         | 
| 730 | 
            +
                    self.flow.remove_weight_norm()
         | 
| 731 | 
            +
                    self.enc_q.remove_weight_norm()
         | 
| 732 | 
            +
             | 
| 733 | 
            +
                def forward(
         | 
| 734 | 
            +
                    self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds
         | 
| 735 | 
            +
                ):  # Here ds is id,[bs,1]
         | 
| 736 | 
            +
                    # print(1,pitch.shape)#[bs,t]
         | 
| 737 | 
            +
                    g = self.emb_g(ds).unsqueeze(-1)  # [b, 256, 1]##1 is t, broadcast
         | 
| 738 | 
            +
                    m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
         | 
| 739 | 
            +
                    z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
         | 
| 740 | 
            +
                    z_p = self.flow(z, y_mask, g=g)
         | 
| 741 | 
            +
                    z_slice, ids_slice = commons.rand_slice_segments(
         | 
| 742 | 
            +
                        z, y_lengths, self.segment_size
         | 
| 743 | 
            +
                    )
         | 
| 744 | 
            +
                    # print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length)
         | 
| 745 | 
            +
                    pitchf = commons.slice_segments2(pitchf, ids_slice, self.segment_size)
         | 
| 746 | 
            +
                    # print(-2,pitchf.shape,z_slice.shape)
         | 
| 747 | 
            +
                    o = self.dec(z_slice, pitchf, g=g)
         | 
| 748 | 
            +
                    return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
         | 
| 749 | 
            +
             | 
| 750 | 
            +
                def infer(self, phone, phone_lengths, pitch, nsff0, sid, rate=None):
         | 
| 751 | 
            +
                    g = self.emb_g(sid).unsqueeze(-1)
         | 
| 752 | 
            +
                    m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
         | 
| 753 | 
            +
                    z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
         | 
| 754 | 
            +
                    if rate:
         | 
| 755 | 
            +
                        head = int(z_p.shape[2] * rate)
         | 
| 756 | 
            +
                        z_p = z_p[:, :, -head:]
         | 
| 757 | 
            +
                        x_mask = x_mask[:, :, -head:]
         | 
| 758 | 
            +
                        nsff0 = nsff0[:, -head:]
         | 
| 759 | 
            +
                    z = self.flow(z_p, x_mask, g=g, reverse=True)
         | 
| 760 | 
            +
                    o = self.dec(z * x_mask, nsff0, g=g)
         | 
| 761 | 
            +
                    return o, x_mask, (z, z_p, m_p, logs_p)
         | 
| 762 | 
            +
             | 
| 763 | 
            +
             | 
| 764 | 
            +
            class SynthesizerTrnMs256NSFsid_nono(nn.Module):
         | 
| 765 | 
            +
                def __init__(
         | 
| 766 | 
            +
                    self,
         | 
| 767 | 
            +
                    spec_channels,
         | 
| 768 | 
            +
                    segment_size,
         | 
| 769 | 
            +
                    inter_channels,
         | 
| 770 | 
            +
                    hidden_channels,
         | 
| 771 | 
            +
                    filter_channels,
         | 
| 772 | 
            +
                    n_heads,
         | 
| 773 | 
            +
                    n_layers,
         | 
| 774 | 
            +
                    kernel_size,
         | 
| 775 | 
            +
                    p_dropout,
         | 
| 776 | 
            +
                    resblock,
         | 
| 777 | 
            +
                    resblock_kernel_sizes,
         | 
| 778 | 
            +
                    resblock_dilation_sizes,
         | 
| 779 | 
            +
                    upsample_rates,
         | 
| 780 | 
            +
                    upsample_initial_channel,
         | 
| 781 | 
            +
                    upsample_kernel_sizes,
         | 
| 782 | 
            +
                    spk_embed_dim,
         | 
| 783 | 
            +
                    gin_channels,
         | 
| 784 | 
            +
                    sr=None,
         | 
| 785 | 
            +
                    **kwargs
         | 
| 786 | 
            +
                ):
         | 
| 787 | 
            +
                    super().__init__()
         | 
| 788 | 
            +
                    self.spec_channels = spec_channels
         | 
| 789 | 
            +
                    self.inter_channels = inter_channels
         | 
| 790 | 
            +
                    self.hidden_channels = hidden_channels
         | 
| 791 | 
            +
                    self.filter_channels = filter_channels
         | 
| 792 | 
            +
                    self.n_heads = n_heads
         | 
| 793 | 
            +
                    self.n_layers = n_layers
         | 
| 794 | 
            +
                    self.kernel_size = kernel_size
         | 
| 795 | 
            +
                    self.p_dropout = p_dropout
         | 
| 796 | 
            +
                    self.resblock = resblock
         | 
| 797 | 
            +
                    self.resblock_kernel_sizes = resblock_kernel_sizes
         | 
| 798 | 
            +
                    self.resblock_dilation_sizes = resblock_dilation_sizes
         | 
| 799 | 
            +
                    self.upsample_rates = upsample_rates
         | 
| 800 | 
            +
                    self.upsample_initial_channel = upsample_initial_channel
         | 
| 801 | 
            +
                    self.upsample_kernel_sizes = upsample_kernel_sizes
         | 
| 802 | 
            +
                    self.segment_size = segment_size
         | 
| 803 | 
            +
                    self.gin_channels = gin_channels
         | 
| 804 | 
            +
                    # self.hop_length = hop_length#
         | 
| 805 | 
            +
                    self.spk_embed_dim = spk_embed_dim
         | 
| 806 | 
            +
                    self.enc_p = TextEncoder256(
         | 
| 807 | 
            +
                        inter_channels,
         | 
| 808 | 
            +
                        hidden_channels,
         | 
| 809 | 
            +
                        filter_channels,
         | 
| 810 | 
            +
                        n_heads,
         | 
| 811 | 
            +
                        n_layers,
         | 
| 812 | 
            +
                        kernel_size,
         | 
| 813 | 
            +
                        p_dropout,
         | 
| 814 | 
            +
                        f0=False,
         | 
| 815 | 
            +
                    )
         | 
| 816 | 
            +
                    self.dec = Generator(
         | 
| 817 | 
            +
                        inter_channels,
         | 
| 818 | 
            +
                        resblock,
         | 
| 819 | 
            +
                        resblock_kernel_sizes,
         | 
| 820 | 
            +
                        resblock_dilation_sizes,
         | 
| 821 | 
            +
                        upsample_rates,
         | 
| 822 | 
            +
                        upsample_initial_channel,
         | 
| 823 | 
            +
                        upsample_kernel_sizes,
         | 
| 824 | 
            +
                        gin_channels=gin_channels,
         | 
| 825 | 
            +
                    )
         | 
| 826 | 
            +
                    self.enc_q = PosteriorEncoder(
         | 
| 827 | 
            +
                        spec_channels,
         | 
| 828 | 
            +
                        inter_channels,
         | 
| 829 | 
            +
                        hidden_channels,
         | 
| 830 | 
            +
                        5,
         | 
| 831 | 
            +
                        1,
         | 
| 832 | 
            +
                        16,
         | 
| 833 | 
            +
                        gin_channels=gin_channels,
         | 
| 834 | 
            +
                    )
         | 
| 835 | 
            +
                    self.flow = ResidualCouplingBlock(
         | 
| 836 | 
            +
                        inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
         | 
| 837 | 
            +
                    )
         | 
| 838 | 
            +
                    self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
         | 
| 839 | 
            +
                    print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
         | 
| 840 | 
            +
             | 
| 841 | 
            +
                def remove_weight_norm(self):
         | 
| 842 | 
            +
                    self.dec.remove_weight_norm()
         | 
| 843 | 
            +
                    self.flow.remove_weight_norm()
         | 
| 844 | 
            +
                    self.enc_q.remove_weight_norm()
         | 
| 845 | 
            +
             | 
| 846 | 
            +
                def forward(self, phone, phone_lengths, y, y_lengths, ds):  # Here ds is id,[bs,1]
         | 
| 847 | 
            +
                    g = self.emb_g(ds).unsqueeze(-1)  # [b, 256, 1]##1 is t, broadcast
         | 
| 848 | 
            +
                    m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
         | 
| 849 | 
            +
                    z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
         | 
| 850 | 
            +
                    z_p = self.flow(z, y_mask, g=g)
         | 
| 851 | 
            +
                    z_slice, ids_slice = commons.rand_slice_segments(
         | 
| 852 | 
            +
                        z, y_lengths, self.segment_size
         | 
| 853 | 
            +
                    )
         | 
| 854 | 
            +
                    o = self.dec(z_slice, g=g)
         | 
| 855 | 
            +
                    return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
         | 
| 856 | 
            +
             | 
| 857 | 
            +
                def infer(self, phone, phone_lengths, sid, rate=None):
         | 
| 858 | 
            +
                    g = self.emb_g(sid).unsqueeze(-1)
         | 
| 859 | 
            +
                    m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
         | 
| 860 | 
            +
                    z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
         | 
| 861 | 
            +
                    if rate:
         | 
| 862 | 
            +
                        head = int(z_p.shape[2] * rate)
         | 
| 863 | 
            +
                        z_p = z_p[:, :, -head:]
         | 
| 864 | 
            +
                        x_mask = x_mask[:, :, -head:]
         | 
| 865 | 
            +
                    z = self.flow(z_p, x_mask, g=g, reverse=True)
         | 
| 866 | 
            +
                    o = self.dec(z * x_mask, g=g)
         | 
| 867 | 
            +
                    return o, x_mask, (z, z_p, m_p, logs_p)
         | 
| 868 | 
            +
             | 
| 869 | 
            +
             | 
| 870 | 
            +
            class SynthesizerTrnMs768NSFsid_nono(nn.Module):
         | 
| 871 | 
            +
                def __init__(
         | 
| 872 | 
            +
                    self,
         | 
| 873 | 
            +
                    spec_channels,
         | 
| 874 | 
            +
                    segment_size,
         | 
| 875 | 
            +
                    inter_channels,
         | 
| 876 | 
            +
                    hidden_channels,
         | 
| 877 | 
            +
                    filter_channels,
         | 
| 878 | 
            +
                    n_heads,
         | 
| 879 | 
            +
                    n_layers,
         | 
| 880 | 
            +
                    kernel_size,
         | 
| 881 | 
            +
                    p_dropout,
         | 
| 882 | 
            +
                    resblock,
         | 
| 883 | 
            +
                    resblock_kernel_sizes,
         | 
| 884 | 
            +
                    resblock_dilation_sizes,
         | 
| 885 | 
            +
                    upsample_rates,
         | 
| 886 | 
            +
                    upsample_initial_channel,
         | 
| 887 | 
            +
                    upsample_kernel_sizes,
         | 
| 888 | 
            +
                    spk_embed_dim,
         | 
| 889 | 
            +
                    gin_channels,
         | 
| 890 | 
            +
                    sr=None,
         | 
| 891 | 
            +
                    **kwargs
         | 
| 892 | 
            +
                ):
         | 
| 893 | 
            +
                    super().__init__()
         | 
| 894 | 
            +
                    self.spec_channels = spec_channels
         | 
| 895 | 
            +
                    self.inter_channels = inter_channels
         | 
| 896 | 
            +
                    self.hidden_channels = hidden_channels
         | 
| 897 | 
            +
                    self.filter_channels = filter_channels
         | 
| 898 | 
            +
                    self.n_heads = n_heads
         | 
| 899 | 
            +
                    self.n_layers = n_layers
         | 
| 900 | 
            +
                    self.kernel_size = kernel_size
         | 
| 901 | 
            +
                    self.p_dropout = p_dropout
         | 
| 902 | 
            +
                    self.resblock = resblock
         | 
| 903 | 
            +
                    self.resblock_kernel_sizes = resblock_kernel_sizes
         | 
| 904 | 
            +
                    self.resblock_dilation_sizes = resblock_dilation_sizes
         | 
| 905 | 
            +
                    self.upsample_rates = upsample_rates
         | 
| 906 | 
            +
                    self.upsample_initial_channel = upsample_initial_channel
         | 
| 907 | 
            +
                    self.upsample_kernel_sizes = upsample_kernel_sizes
         | 
| 908 | 
            +
                    self.segment_size = segment_size
         | 
| 909 | 
            +
                    self.gin_channels = gin_channels
         | 
| 910 | 
            +
                    # self.hop_length = hop_length#
         | 
| 911 | 
            +
                    self.spk_embed_dim = spk_embed_dim
         | 
| 912 | 
            +
                    self.enc_p = TextEncoder768(
         | 
| 913 | 
            +
                        inter_channels,
         | 
| 914 | 
            +
                        hidden_channels,
         | 
| 915 | 
            +
                        filter_channels,
         | 
| 916 | 
            +
                        n_heads,
         | 
| 917 | 
            +
                        n_layers,
         | 
| 918 | 
            +
                        kernel_size,
         | 
| 919 | 
            +
                        p_dropout,
         | 
| 920 | 
            +
                        f0=False,
         | 
| 921 | 
            +
                    )
         | 
| 922 | 
            +
                    self.dec = Generator(
         | 
| 923 | 
            +
                        inter_channels,
         | 
| 924 | 
            +
                        resblock,
         | 
| 925 | 
            +
                        resblock_kernel_sizes,
         | 
| 926 | 
            +
                        resblock_dilation_sizes,
         | 
| 927 | 
            +
                        upsample_rates,
         | 
| 928 | 
            +
                        upsample_initial_channel,
         | 
| 929 | 
            +
                        upsample_kernel_sizes,
         | 
| 930 | 
            +
                        gin_channels=gin_channels,
         | 
| 931 | 
            +
                    )
         | 
| 932 | 
            +
                    self.enc_q = PosteriorEncoder(
         | 
| 933 | 
            +
                        spec_channels,
         | 
| 934 | 
            +
                        inter_channels,
         | 
| 935 | 
            +
                        hidden_channels,
         | 
| 936 | 
            +
                        5,
         | 
| 937 | 
            +
                        1,
         | 
| 938 | 
            +
                        16,
         | 
| 939 | 
            +
                        gin_channels=gin_channels,
         | 
| 940 | 
            +
                    )
         | 
| 941 | 
            +
                    self.flow = ResidualCouplingBlock(
         | 
| 942 | 
            +
                        inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
         | 
| 943 | 
            +
                    )
         | 
| 944 | 
            +
                    self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
         | 
| 945 | 
            +
                    print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
         | 
| 946 | 
            +
             | 
| 947 | 
            +
                def remove_weight_norm(self):
         | 
| 948 | 
            +
                    self.dec.remove_weight_norm()
         | 
| 949 | 
            +
                    self.flow.remove_weight_norm()
         | 
| 950 | 
            +
                    self.enc_q.remove_weight_norm()
         | 
| 951 | 
            +
             | 
| 952 | 
            +
                def forward(self, phone, phone_lengths, y, y_lengths, ds):  # Here ds is id,[bs,1]
         | 
| 953 | 
            +
                    g = self.emb_g(ds).unsqueeze(-1)  # [b, 256, 1]##1 is t, broadcast
         | 
| 954 | 
            +
                    m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
         | 
| 955 | 
            +
                    z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
         | 
| 956 | 
            +
                    z_p = self.flow(z, y_mask, g=g)
         | 
| 957 | 
            +
                    z_slice, ids_slice = commons.rand_slice_segments(
         | 
| 958 | 
            +
                        z, y_lengths, self.segment_size
         | 
| 959 | 
            +
                    )
         | 
| 960 | 
            +
                    o = self.dec(z_slice, g=g)
         | 
| 961 | 
            +
                    return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
         | 
| 962 | 
            +
             | 
| 963 | 
            +
                def infer(self, phone, phone_lengths, sid, rate=None):
         | 
| 964 | 
            +
                    g = self.emb_g(sid).unsqueeze(-1)
         | 
| 965 | 
            +
                    m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
         | 
| 966 | 
            +
                    z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
         | 
| 967 | 
            +
                    if rate:
         | 
| 968 | 
            +
                        head = int(z_p.shape[2] * rate)
         | 
| 969 | 
            +
                        z_p = z_p[:, :, -head:]
         | 
| 970 | 
            +
                        x_mask = x_mask[:, :, -head:]
         | 
| 971 | 
            +
                    z = self.flow(z_p, x_mask, g=g, reverse=True)
         | 
| 972 | 
            +
                    o = self.dec(z * x_mask, g=g)
         | 
| 973 | 
            +
                    return o, x_mask, (z, z_p, m_p, logs_p)
         | 
| 974 | 
            +
             | 
| 975 | 
            +
             | 
| 976 | 
            +
            class MultiPeriodDiscriminator(torch.nn.Module):
         | 
| 977 | 
            +
                def __init__(self, use_spectral_norm=False):
         | 
| 978 | 
            +
                    super(MultiPeriodDiscriminator, self).__init__()
         | 
| 979 | 
            +
                    periods = [2, 3, 5, 7, 11, 17]
         | 
| 980 | 
            +
                    # periods = [3, 5, 7, 11, 17, 23, 37]
         | 
| 981 | 
            +
             | 
| 982 | 
            +
                    discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
         | 
| 983 | 
            +
                    discs = discs + [
         | 
| 984 | 
            +
                        DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
         | 
| 985 | 
            +
                    ]
         | 
| 986 | 
            +
                    self.discriminators = nn.ModuleList(discs)
         | 
| 987 | 
            +
             | 
| 988 | 
            +
                def forward(self, y, y_hat):
         | 
| 989 | 
            +
                    y_d_rs = []  #
         | 
| 990 | 
            +
                    y_d_gs = []
         | 
| 991 | 
            +
                    fmap_rs = []
         | 
| 992 | 
            +
                    fmap_gs = []
         | 
| 993 | 
            +
                    for i, d in enumerate(self.discriminators):
         | 
| 994 | 
            +
                        y_d_r, fmap_r = d(y)
         | 
| 995 | 
            +
                        y_d_g, fmap_g = d(y_hat)
         | 
| 996 | 
            +
                        # for j in range(len(fmap_r)):
         | 
| 997 | 
            +
                        #     print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
         | 
| 998 | 
            +
                        y_d_rs.append(y_d_r)
         | 
| 999 | 
            +
                        y_d_gs.append(y_d_g)
         | 
| 1000 | 
            +
                        fmap_rs.append(fmap_r)
         | 
| 1001 | 
            +
                        fmap_gs.append(fmap_g)
         | 
| 1002 | 
            +
             | 
| 1003 | 
            +
                    return y_d_rs, y_d_gs, fmap_rs, fmap_gs
         | 
| 1004 | 
            +
             | 
| 1005 | 
            +
             | 
| 1006 | 
            +
            class MultiPeriodDiscriminatorV2(torch.nn.Module):
         | 
| 1007 | 
            +
                def __init__(self, use_spectral_norm=False):
         | 
| 1008 | 
            +
                    super(MultiPeriodDiscriminatorV2, self).__init__()
         | 
| 1009 | 
            +
                    # periods = [2, 3, 5, 7, 11, 17]
         | 
| 1010 | 
            +
                    periods = [2, 3, 5, 7, 11, 17, 23, 37]
         | 
| 1011 | 
            +
             | 
| 1012 | 
            +
                    discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
         | 
| 1013 | 
            +
                    discs = discs + [
         | 
| 1014 | 
            +
                        DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
         | 
| 1015 | 
            +
                    ]
         | 
| 1016 | 
            +
                    self.discriminators = nn.ModuleList(discs)
         | 
| 1017 | 
            +
             | 
| 1018 | 
            +
                def forward(self, y, y_hat):
         | 
| 1019 | 
            +
                    y_d_rs = []  #
         | 
| 1020 | 
            +
                    y_d_gs = []
         | 
| 1021 | 
            +
                    fmap_rs = []
         | 
| 1022 | 
            +
                    fmap_gs = []
         | 
| 1023 | 
            +
                    for i, d in enumerate(self.discriminators):
         | 
| 1024 | 
            +
                        y_d_r, fmap_r = d(y)
         | 
| 1025 | 
            +
                        y_d_g, fmap_g = d(y_hat)
         | 
| 1026 | 
            +
                        # for j in range(len(fmap_r)):
         | 
| 1027 | 
            +
                        #     print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
         | 
| 1028 | 
            +
                        y_d_rs.append(y_d_r)
         | 
| 1029 | 
            +
                        y_d_gs.append(y_d_g)
         | 
| 1030 | 
            +
                        fmap_rs.append(fmap_r)
         | 
| 1031 | 
            +
                        fmap_gs.append(fmap_g)
         | 
| 1032 | 
            +
             | 
| 1033 | 
            +
                    return y_d_rs, y_d_gs, fmap_rs, fmap_gs
         | 
| 1034 | 
            +
             | 
| 1035 | 
            +
             | 
| 1036 | 
            +
            class DiscriminatorS(torch.nn.Module):
         | 
| 1037 | 
            +
                def __init__(self, use_spectral_norm=False):
         | 
| 1038 | 
            +
                    super(DiscriminatorS, self).__init__()
         | 
| 1039 | 
            +
                    norm_f = weight_norm if use_spectral_norm == False else spectral_norm
         | 
| 1040 | 
            +
                    self.convs = nn.ModuleList(
         | 
| 1041 | 
            +
                        [
         | 
| 1042 | 
            +
                            norm_f(Conv1d(1, 16, 15, 1, padding=7)),
         | 
| 1043 | 
            +
                            norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
         | 
| 1044 | 
            +
                            norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
         | 
| 1045 | 
            +
                            norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
         | 
| 1046 | 
            +
                            norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
         | 
| 1047 | 
            +
                            norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
         | 
| 1048 | 
            +
                        ]
         | 
| 1049 | 
            +
                    )
         | 
| 1050 | 
            +
                    self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
         | 
| 1051 | 
            +
             | 
| 1052 | 
            +
                def forward(self, x):
         | 
| 1053 | 
            +
                    fmap = []
         | 
| 1054 | 
            +
             | 
| 1055 | 
            +
                    for l in self.convs:
         | 
| 1056 | 
            +
                        x = l(x)
         | 
| 1057 | 
            +
                        x = F.leaky_relu(x, modules.LRELU_SLOPE)
         | 
| 1058 | 
            +
                        fmap.append(x)
         | 
| 1059 | 
            +
                    x = self.conv_post(x)
         | 
| 1060 | 
            +
                    fmap.append(x)
         | 
| 1061 | 
            +
                    x = torch.flatten(x, 1, -1)
         | 
| 1062 | 
            +
             | 
| 1063 | 
            +
                    return x, fmap
         | 
| 1064 | 
            +
             | 
| 1065 | 
            +
             | 
| 1066 | 
            +
            class DiscriminatorP(torch.nn.Module):
         | 
| 1067 | 
            +
                def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
         | 
| 1068 | 
            +
                    super(DiscriminatorP, self).__init__()
         | 
| 1069 | 
            +
                    self.period = period
         | 
| 1070 | 
            +
                    self.use_spectral_norm = use_spectral_norm
         | 
| 1071 | 
            +
                    norm_f = weight_norm if use_spectral_norm == False else spectral_norm
         | 
| 1072 | 
            +
                    self.convs = nn.ModuleList(
         | 
| 1073 | 
            +
                        [
         | 
| 1074 | 
            +
                            norm_f(
         | 
| 1075 | 
            +
                                Conv2d(
         | 
| 1076 | 
            +
                                    1,
         | 
| 1077 | 
            +
                                    32,
         | 
| 1078 | 
            +
                                    (kernel_size, 1),
         | 
| 1079 | 
            +
                                    (stride, 1),
         | 
| 1080 | 
            +
                                    padding=(get_padding(kernel_size, 1), 0),
         | 
| 1081 | 
            +
                                )
         | 
| 1082 | 
            +
                            ),
         | 
| 1083 | 
            +
                            norm_f(
         | 
| 1084 | 
            +
                                Conv2d(
         | 
| 1085 | 
            +
                                    32,
         | 
| 1086 | 
            +
                                    128,
         | 
| 1087 | 
            +
                                    (kernel_size, 1),
         | 
| 1088 | 
            +
                                    (stride, 1),
         | 
| 1089 | 
            +
                                    padding=(get_padding(kernel_size, 1), 0),
         | 
| 1090 | 
            +
                                )
         | 
| 1091 | 
            +
                            ),
         | 
| 1092 | 
            +
                            norm_f(
         | 
| 1093 | 
            +
                                Conv2d(
         | 
| 1094 | 
            +
                                    128,
         | 
| 1095 | 
            +
                                    512,
         | 
| 1096 | 
            +
                                    (kernel_size, 1),
         | 
| 1097 | 
            +
                                    (stride, 1),
         | 
| 1098 | 
            +
                                    padding=(get_padding(kernel_size, 1), 0),
         | 
| 1099 | 
            +
                                )
         | 
| 1100 | 
            +
                            ),
         | 
| 1101 | 
            +
                            norm_f(
         | 
| 1102 | 
            +
                                Conv2d(
         | 
| 1103 | 
            +
                                    512,
         | 
| 1104 | 
            +
                                    1024,
         | 
| 1105 | 
            +
                                    (kernel_size, 1),
         | 
| 1106 | 
            +
                                    (stride, 1),
         | 
| 1107 | 
            +
                                    padding=(get_padding(kernel_size, 1), 0),
         | 
| 1108 | 
            +
                                )
         | 
| 1109 | 
            +
                            ),
         | 
| 1110 | 
            +
                            norm_f(
         | 
| 1111 | 
            +
                                Conv2d(
         | 
| 1112 | 
            +
                                    1024,
         | 
| 1113 | 
            +
                                    1024,
         | 
| 1114 | 
            +
                                    (kernel_size, 1),
         | 
| 1115 | 
            +
                                    1,
         | 
| 1116 | 
            +
                                    padding=(get_padding(kernel_size, 1), 0),
         | 
| 1117 | 
            +
                                )
         | 
| 1118 | 
            +
                            ),
         | 
| 1119 | 
            +
                        ]
         | 
| 1120 | 
            +
                    )
         | 
| 1121 | 
            +
                    self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
         | 
| 1122 | 
            +
             | 
| 1123 | 
            +
                def forward(self, x):
         | 
| 1124 | 
            +
                    fmap = []
         | 
| 1125 | 
            +
             | 
| 1126 | 
            +
                    # 1d to 2d
         | 
| 1127 | 
            +
                    b, c, t = x.shape
         | 
| 1128 | 
            +
                    if t % self.period != 0:  # pad first
         | 
| 1129 | 
            +
                        n_pad = self.period - (t % self.period)
         | 
| 1130 | 
            +
                        x = F.pad(x, (0, n_pad), "reflect")
         | 
| 1131 | 
            +
                        t = t + n_pad
         | 
| 1132 | 
            +
                    x = x.view(b, c, t // self.period, self.period)
         | 
| 1133 | 
            +
             | 
| 1134 | 
            +
                    for l in self.convs:
         | 
| 1135 | 
            +
                        x = l(x)
         | 
| 1136 | 
            +
                        x = F.leaky_relu(x, modules.LRELU_SLOPE)
         | 
| 1137 | 
            +
                        fmap.append(x)
         | 
| 1138 | 
            +
                    x = self.conv_post(x)
         | 
| 1139 | 
            +
                    fmap.append(x)
         | 
| 1140 | 
            +
                    x = torch.flatten(x, 1, -1)
         | 
| 1141 | 
            +
             | 
| 1142 | 
            +
                    return x, fmap
         | 
    	
        lib/infer_pack/modules.py
    ADDED
    
    | @@ -0,0 +1,522 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import copy
         | 
| 2 | 
            +
            import math
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            import scipy
         | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            from torch import nn
         | 
| 7 | 
            +
            from torch.nn import functional as F
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
         | 
| 10 | 
            +
            from torch.nn.utils import weight_norm, remove_weight_norm
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from lib.infer_pack import commons
         | 
| 13 | 
            +
            from lib.infer_pack.commons import init_weights, get_padding
         | 
| 14 | 
            +
            from lib.infer_pack.transforms import piecewise_rational_quadratic_transform
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            LRELU_SLOPE = 0.1
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            class LayerNorm(nn.Module):
         | 
| 21 | 
            +
                def __init__(self, channels, eps=1e-5):
         | 
| 22 | 
            +
                    super().__init__()
         | 
| 23 | 
            +
                    self.channels = channels
         | 
| 24 | 
            +
                    self.eps = eps
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                    self.gamma = nn.Parameter(torch.ones(channels))
         | 
| 27 | 
            +
                    self.beta = nn.Parameter(torch.zeros(channels))
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                def forward(self, x):
         | 
| 30 | 
            +
                    x = x.transpose(1, -1)
         | 
| 31 | 
            +
                    x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
         | 
| 32 | 
            +
                    return x.transpose(1, -1)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            class ConvReluNorm(nn.Module):
         | 
| 36 | 
            +
                def __init__(
         | 
| 37 | 
            +
                    self,
         | 
| 38 | 
            +
                    in_channels,
         | 
| 39 | 
            +
                    hidden_channels,
         | 
| 40 | 
            +
                    out_channels,
         | 
| 41 | 
            +
                    kernel_size,
         | 
| 42 | 
            +
                    n_layers,
         | 
| 43 | 
            +
                    p_dropout,
         | 
| 44 | 
            +
                ):
         | 
| 45 | 
            +
                    super().__init__()
         | 
| 46 | 
            +
                    self.in_channels = in_channels
         | 
| 47 | 
            +
                    self.hidden_channels = hidden_channels
         | 
| 48 | 
            +
                    self.out_channels = out_channels
         | 
| 49 | 
            +
                    self.kernel_size = kernel_size
         | 
| 50 | 
            +
                    self.n_layers = n_layers
         | 
| 51 | 
            +
                    self.p_dropout = p_dropout
         | 
| 52 | 
            +
                    assert n_layers > 1, "Number of layers should be larger than 0."
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    self.conv_layers = nn.ModuleList()
         | 
| 55 | 
            +
                    self.norm_layers = nn.ModuleList()
         | 
| 56 | 
            +
                    self.conv_layers.append(
         | 
| 57 | 
            +
                        nn.Conv1d(
         | 
| 58 | 
            +
                            in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
         | 
| 59 | 
            +
                        )
         | 
| 60 | 
            +
                    )
         | 
| 61 | 
            +
                    self.norm_layers.append(LayerNorm(hidden_channels))
         | 
| 62 | 
            +
                    self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
         | 
| 63 | 
            +
                    for _ in range(n_layers - 1):
         | 
| 64 | 
            +
                        self.conv_layers.append(
         | 
| 65 | 
            +
                            nn.Conv1d(
         | 
| 66 | 
            +
                                hidden_channels,
         | 
| 67 | 
            +
                                hidden_channels,
         | 
| 68 | 
            +
                                kernel_size,
         | 
| 69 | 
            +
                                padding=kernel_size // 2,
         | 
| 70 | 
            +
                            )
         | 
| 71 | 
            +
                        )
         | 
| 72 | 
            +
                        self.norm_layers.append(LayerNorm(hidden_channels))
         | 
| 73 | 
            +
                    self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
         | 
| 74 | 
            +
                    self.proj.weight.data.zero_()
         | 
| 75 | 
            +
                    self.proj.bias.data.zero_()
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                def forward(self, x, x_mask):
         | 
| 78 | 
            +
                    x_org = x
         | 
| 79 | 
            +
                    for i in range(self.n_layers):
         | 
| 80 | 
            +
                        x = self.conv_layers[i](x * x_mask)
         | 
| 81 | 
            +
                        x = self.norm_layers[i](x)
         | 
| 82 | 
            +
                        x = self.relu_drop(x)
         | 
| 83 | 
            +
                    x = x_org + self.proj(x)
         | 
| 84 | 
            +
                    return x * x_mask
         | 
| 85 | 
            +
             | 
| 86 | 
            +
             | 
| 87 | 
            +
            class DDSConv(nn.Module):
         | 
| 88 | 
            +
                """
         | 
| 89 | 
            +
                Dialted and Depth-Separable Convolution
         | 
| 90 | 
            +
                """
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
         | 
| 93 | 
            +
                    super().__init__()
         | 
| 94 | 
            +
                    self.channels = channels
         | 
| 95 | 
            +
                    self.kernel_size = kernel_size
         | 
| 96 | 
            +
                    self.n_layers = n_layers
         | 
| 97 | 
            +
                    self.p_dropout = p_dropout
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    self.drop = nn.Dropout(p_dropout)
         | 
| 100 | 
            +
                    self.convs_sep = nn.ModuleList()
         | 
| 101 | 
            +
                    self.convs_1x1 = nn.ModuleList()
         | 
| 102 | 
            +
                    self.norms_1 = nn.ModuleList()
         | 
| 103 | 
            +
                    self.norms_2 = nn.ModuleList()
         | 
| 104 | 
            +
                    for i in range(n_layers):
         | 
| 105 | 
            +
                        dilation = kernel_size**i
         | 
| 106 | 
            +
                        padding = (kernel_size * dilation - dilation) // 2
         | 
| 107 | 
            +
                        self.convs_sep.append(
         | 
| 108 | 
            +
                            nn.Conv1d(
         | 
| 109 | 
            +
                                channels,
         | 
| 110 | 
            +
                                channels,
         | 
| 111 | 
            +
                                kernel_size,
         | 
| 112 | 
            +
                                groups=channels,
         | 
| 113 | 
            +
                                dilation=dilation,
         | 
| 114 | 
            +
                                padding=padding,
         | 
| 115 | 
            +
                            )
         | 
| 116 | 
            +
                        )
         | 
| 117 | 
            +
                        self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
         | 
| 118 | 
            +
                        self.norms_1.append(LayerNorm(channels))
         | 
| 119 | 
            +
                        self.norms_2.append(LayerNorm(channels))
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                def forward(self, x, x_mask, g=None):
         | 
| 122 | 
            +
                    if g is not None:
         | 
| 123 | 
            +
                        x = x + g
         | 
| 124 | 
            +
                    for i in range(self.n_layers):
         | 
| 125 | 
            +
                        y = self.convs_sep[i](x * x_mask)
         | 
| 126 | 
            +
                        y = self.norms_1[i](y)
         | 
| 127 | 
            +
                        y = F.gelu(y)
         | 
| 128 | 
            +
                        y = self.convs_1x1[i](y)
         | 
| 129 | 
            +
                        y = self.norms_2[i](y)
         | 
| 130 | 
            +
                        y = F.gelu(y)
         | 
| 131 | 
            +
                        y = self.drop(y)
         | 
| 132 | 
            +
                        x = x + y
         | 
| 133 | 
            +
                    return x * x_mask
         | 
| 134 | 
            +
             | 
| 135 | 
            +
             | 
| 136 | 
            +
            class WN(torch.nn.Module):
         | 
| 137 | 
            +
                def __init__(
         | 
| 138 | 
            +
                    self,
         | 
| 139 | 
            +
                    hidden_channels,
         | 
| 140 | 
            +
                    kernel_size,
         | 
| 141 | 
            +
                    dilation_rate,
         | 
| 142 | 
            +
                    n_layers,
         | 
| 143 | 
            +
                    gin_channels=0,
         | 
| 144 | 
            +
                    p_dropout=0,
         | 
| 145 | 
            +
                ):
         | 
| 146 | 
            +
                    super(WN, self).__init__()
         | 
| 147 | 
            +
                    assert kernel_size % 2 == 1
         | 
| 148 | 
            +
                    self.hidden_channels = hidden_channels
         | 
| 149 | 
            +
                    self.kernel_size = (kernel_size,)
         | 
| 150 | 
            +
                    self.dilation_rate = dilation_rate
         | 
| 151 | 
            +
                    self.n_layers = n_layers
         | 
| 152 | 
            +
                    self.gin_channels = gin_channels
         | 
| 153 | 
            +
                    self.p_dropout = p_dropout
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                    self.in_layers = torch.nn.ModuleList()
         | 
| 156 | 
            +
                    self.res_skip_layers = torch.nn.ModuleList()
         | 
| 157 | 
            +
                    self.drop = nn.Dropout(p_dropout)
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                    if gin_channels != 0:
         | 
| 160 | 
            +
                        cond_layer = torch.nn.Conv1d(
         | 
| 161 | 
            +
                            gin_channels, 2 * hidden_channels * n_layers, 1
         | 
| 162 | 
            +
                        )
         | 
| 163 | 
            +
                        self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                    for i in range(n_layers):
         | 
| 166 | 
            +
                        dilation = dilation_rate**i
         | 
| 167 | 
            +
                        padding = int((kernel_size * dilation - dilation) / 2)
         | 
| 168 | 
            +
                        in_layer = torch.nn.Conv1d(
         | 
| 169 | 
            +
                            hidden_channels,
         | 
| 170 | 
            +
                            2 * hidden_channels,
         | 
| 171 | 
            +
                            kernel_size,
         | 
| 172 | 
            +
                            dilation=dilation,
         | 
| 173 | 
            +
                            padding=padding,
         | 
| 174 | 
            +
                        )
         | 
| 175 | 
            +
                        in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
         | 
| 176 | 
            +
                        self.in_layers.append(in_layer)
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                        # last one is not necessary
         | 
| 179 | 
            +
                        if i < n_layers - 1:
         | 
| 180 | 
            +
                            res_skip_channels = 2 * hidden_channels
         | 
| 181 | 
            +
                        else:
         | 
| 182 | 
            +
                            res_skip_channels = hidden_channels
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                        res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
         | 
| 185 | 
            +
                        res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
         | 
| 186 | 
            +
                        self.res_skip_layers.append(res_skip_layer)
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                def forward(self, x, x_mask, g=None, **kwargs):
         | 
| 189 | 
            +
                    output = torch.zeros_like(x)
         | 
| 190 | 
            +
                    n_channels_tensor = torch.IntTensor([self.hidden_channels])
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                    if g is not None:
         | 
| 193 | 
            +
                        g = self.cond_layer(g)
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    for i in range(self.n_layers):
         | 
| 196 | 
            +
                        x_in = self.in_layers[i](x)
         | 
| 197 | 
            +
                        if g is not None:
         | 
| 198 | 
            +
                            cond_offset = i * 2 * self.hidden_channels
         | 
| 199 | 
            +
                            g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
         | 
| 200 | 
            +
                        else:
         | 
| 201 | 
            +
                            g_l = torch.zeros_like(x_in)
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                        acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
         | 
| 204 | 
            +
                        acts = self.drop(acts)
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                        res_skip_acts = self.res_skip_layers[i](acts)
         | 
| 207 | 
            +
                        if i < self.n_layers - 1:
         | 
| 208 | 
            +
                            res_acts = res_skip_acts[:, : self.hidden_channels, :]
         | 
| 209 | 
            +
                            x = (x + res_acts) * x_mask
         | 
| 210 | 
            +
                            output = output + res_skip_acts[:, self.hidden_channels :, :]
         | 
| 211 | 
            +
                        else:
         | 
| 212 | 
            +
                            output = output + res_skip_acts
         | 
| 213 | 
            +
                    return output * x_mask
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                def remove_weight_norm(self):
         | 
| 216 | 
            +
                    if self.gin_channels != 0:
         | 
| 217 | 
            +
                        torch.nn.utils.remove_weight_norm(self.cond_layer)
         | 
| 218 | 
            +
                    for l in self.in_layers:
         | 
| 219 | 
            +
                        torch.nn.utils.remove_weight_norm(l)
         | 
| 220 | 
            +
                    for l in self.res_skip_layers:
         | 
| 221 | 
            +
                        torch.nn.utils.remove_weight_norm(l)
         | 
| 222 | 
            +
             | 
| 223 | 
            +
             | 
| 224 | 
            +
            class ResBlock1(torch.nn.Module):
         | 
| 225 | 
            +
                def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
         | 
| 226 | 
            +
                    super(ResBlock1, self).__init__()
         | 
| 227 | 
            +
                    self.convs1 = nn.ModuleList(
         | 
| 228 | 
            +
                        [
         | 
| 229 | 
            +
                            weight_norm(
         | 
| 230 | 
            +
                                Conv1d(
         | 
| 231 | 
            +
                                    channels,
         | 
| 232 | 
            +
                                    channels,
         | 
| 233 | 
            +
                                    kernel_size,
         | 
| 234 | 
            +
                                    1,
         | 
| 235 | 
            +
                                    dilation=dilation[0],
         | 
| 236 | 
            +
                                    padding=get_padding(kernel_size, dilation[0]),
         | 
| 237 | 
            +
                                )
         | 
| 238 | 
            +
                            ),
         | 
| 239 | 
            +
                            weight_norm(
         | 
| 240 | 
            +
                                Conv1d(
         | 
| 241 | 
            +
                                    channels,
         | 
| 242 | 
            +
                                    channels,
         | 
| 243 | 
            +
                                    kernel_size,
         | 
| 244 | 
            +
                                    1,
         | 
| 245 | 
            +
                                    dilation=dilation[1],
         | 
| 246 | 
            +
                                    padding=get_padding(kernel_size, dilation[1]),
         | 
| 247 | 
            +
                                )
         | 
| 248 | 
            +
                            ),
         | 
| 249 | 
            +
                            weight_norm(
         | 
| 250 | 
            +
                                Conv1d(
         | 
| 251 | 
            +
                                    channels,
         | 
| 252 | 
            +
                                    channels,
         | 
| 253 | 
            +
                                    kernel_size,
         | 
| 254 | 
            +
                                    1,
         | 
| 255 | 
            +
                                    dilation=dilation[2],
         | 
| 256 | 
            +
                                    padding=get_padding(kernel_size, dilation[2]),
         | 
| 257 | 
            +
                                )
         | 
| 258 | 
            +
                            ),
         | 
| 259 | 
            +
                        ]
         | 
| 260 | 
            +
                    )
         | 
| 261 | 
            +
                    self.convs1.apply(init_weights)
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                    self.convs2 = nn.ModuleList(
         | 
| 264 | 
            +
                        [
         | 
| 265 | 
            +
                            weight_norm(
         | 
| 266 | 
            +
                                Conv1d(
         | 
| 267 | 
            +
                                    channels,
         | 
| 268 | 
            +
                                    channels,
         | 
| 269 | 
            +
                                    kernel_size,
         | 
| 270 | 
            +
                                    1,
         | 
| 271 | 
            +
                                    dilation=1,
         | 
| 272 | 
            +
                                    padding=get_padding(kernel_size, 1),
         | 
| 273 | 
            +
                                )
         | 
| 274 | 
            +
                            ),
         | 
| 275 | 
            +
                            weight_norm(
         | 
| 276 | 
            +
                                Conv1d(
         | 
| 277 | 
            +
                                    channels,
         | 
| 278 | 
            +
                                    channels,
         | 
| 279 | 
            +
                                    kernel_size,
         | 
| 280 | 
            +
                                    1,
         | 
| 281 | 
            +
                                    dilation=1,
         | 
| 282 | 
            +
                                    padding=get_padding(kernel_size, 1),
         | 
| 283 | 
            +
                                )
         | 
| 284 | 
            +
                            ),
         | 
| 285 | 
            +
                            weight_norm(
         | 
| 286 | 
            +
                                Conv1d(
         | 
| 287 | 
            +
                                    channels,
         | 
| 288 | 
            +
                                    channels,
         | 
| 289 | 
            +
                                    kernel_size,
         | 
| 290 | 
            +
                                    1,
         | 
| 291 | 
            +
                                    dilation=1,
         | 
| 292 | 
            +
                                    padding=get_padding(kernel_size, 1),
         | 
| 293 | 
            +
                                )
         | 
| 294 | 
            +
                            ),
         | 
| 295 | 
            +
                        ]
         | 
| 296 | 
            +
                    )
         | 
| 297 | 
            +
                    self.convs2.apply(init_weights)
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                def forward(self, x, x_mask=None):
         | 
| 300 | 
            +
                    for c1, c2 in zip(self.convs1, self.convs2):
         | 
| 301 | 
            +
                        xt = F.leaky_relu(x, LRELU_SLOPE)
         | 
| 302 | 
            +
                        if x_mask is not None:
         | 
| 303 | 
            +
                            xt = xt * x_mask
         | 
| 304 | 
            +
                        xt = c1(xt)
         | 
| 305 | 
            +
                        xt = F.leaky_relu(xt, LRELU_SLOPE)
         | 
| 306 | 
            +
                        if x_mask is not None:
         | 
| 307 | 
            +
                            xt = xt * x_mask
         | 
| 308 | 
            +
                        xt = c2(xt)
         | 
| 309 | 
            +
                        x = xt + x
         | 
| 310 | 
            +
                    if x_mask is not None:
         | 
| 311 | 
            +
                        x = x * x_mask
         | 
| 312 | 
            +
                    return x
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                def remove_weight_norm(self):
         | 
| 315 | 
            +
                    for l in self.convs1:
         | 
| 316 | 
            +
                        remove_weight_norm(l)
         | 
| 317 | 
            +
                    for l in self.convs2:
         | 
| 318 | 
            +
                        remove_weight_norm(l)
         | 
| 319 | 
            +
             | 
| 320 | 
            +
             | 
| 321 | 
            +
            class ResBlock2(torch.nn.Module):
         | 
| 322 | 
            +
                def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
         | 
| 323 | 
            +
                    super(ResBlock2, self).__init__()
         | 
| 324 | 
            +
                    self.convs = nn.ModuleList(
         | 
| 325 | 
            +
                        [
         | 
| 326 | 
            +
                            weight_norm(
         | 
| 327 | 
            +
                                Conv1d(
         | 
| 328 | 
            +
                                    channels,
         | 
| 329 | 
            +
                                    channels,
         | 
| 330 | 
            +
                                    kernel_size,
         | 
| 331 | 
            +
                                    1,
         | 
| 332 | 
            +
                                    dilation=dilation[0],
         | 
| 333 | 
            +
                                    padding=get_padding(kernel_size, dilation[0]),
         | 
| 334 | 
            +
                                )
         | 
| 335 | 
            +
                            ),
         | 
| 336 | 
            +
                            weight_norm(
         | 
| 337 | 
            +
                                Conv1d(
         | 
| 338 | 
            +
                                    channels,
         | 
| 339 | 
            +
                                    channels,
         | 
| 340 | 
            +
                                    kernel_size,
         | 
| 341 | 
            +
                                    1,
         | 
| 342 | 
            +
                                    dilation=dilation[1],
         | 
| 343 | 
            +
                                    padding=get_padding(kernel_size, dilation[1]),
         | 
| 344 | 
            +
                                )
         | 
| 345 | 
            +
                            ),
         | 
| 346 | 
            +
                        ]
         | 
| 347 | 
            +
                    )
         | 
| 348 | 
            +
                    self.convs.apply(init_weights)
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                def forward(self, x, x_mask=None):
         | 
| 351 | 
            +
                    for c in self.convs:
         | 
| 352 | 
            +
                        xt = F.leaky_relu(x, LRELU_SLOPE)
         | 
| 353 | 
            +
                        if x_mask is not None:
         | 
| 354 | 
            +
                            xt = xt * x_mask
         | 
| 355 | 
            +
                        xt = c(xt)
         | 
| 356 | 
            +
                        x = xt + x
         | 
| 357 | 
            +
                    if x_mask is not None:
         | 
| 358 | 
            +
                        x = x * x_mask
         | 
| 359 | 
            +
                    return x
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                def remove_weight_norm(self):
         | 
| 362 | 
            +
                    for l in self.convs:
         | 
| 363 | 
            +
                        remove_weight_norm(l)
         | 
| 364 | 
            +
             | 
| 365 | 
            +
             | 
| 366 | 
            +
            class Log(nn.Module):
         | 
| 367 | 
            +
                def forward(self, x, x_mask, reverse=False, **kwargs):
         | 
| 368 | 
            +
                    if not reverse:
         | 
| 369 | 
            +
                        y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
         | 
| 370 | 
            +
                        logdet = torch.sum(-y, [1, 2])
         | 
| 371 | 
            +
                        return y, logdet
         | 
| 372 | 
            +
                    else:
         | 
| 373 | 
            +
                        x = torch.exp(x) * x_mask
         | 
| 374 | 
            +
                        return x
         | 
| 375 | 
            +
             | 
| 376 | 
            +
             | 
| 377 | 
            +
            class Flip(nn.Module):
         | 
| 378 | 
            +
                def forward(self, x, *args, reverse=False, **kwargs):
         | 
| 379 | 
            +
                    x = torch.flip(x, [1])
         | 
| 380 | 
            +
                    if not reverse:
         | 
| 381 | 
            +
                        logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
         | 
| 382 | 
            +
                        return x, logdet
         | 
| 383 | 
            +
                    else:
         | 
| 384 | 
            +
                        return x
         | 
| 385 | 
            +
             | 
| 386 | 
            +
             | 
| 387 | 
            +
            class ElementwiseAffine(nn.Module):
         | 
| 388 | 
            +
                def __init__(self, channels):
         | 
| 389 | 
            +
                    super().__init__()
         | 
| 390 | 
            +
                    self.channels = channels
         | 
| 391 | 
            +
                    self.m = nn.Parameter(torch.zeros(channels, 1))
         | 
| 392 | 
            +
                    self.logs = nn.Parameter(torch.zeros(channels, 1))
         | 
| 393 | 
            +
             | 
| 394 | 
            +
                def forward(self, x, x_mask, reverse=False, **kwargs):
         | 
| 395 | 
            +
                    if not reverse:
         | 
| 396 | 
            +
                        y = self.m + torch.exp(self.logs) * x
         | 
| 397 | 
            +
                        y = y * x_mask
         | 
| 398 | 
            +
                        logdet = torch.sum(self.logs * x_mask, [1, 2])
         | 
| 399 | 
            +
                        return y, logdet
         | 
| 400 | 
            +
                    else:
         | 
| 401 | 
            +
                        x = (x - self.m) * torch.exp(-self.logs) * x_mask
         | 
| 402 | 
            +
                        return x
         | 
| 403 | 
            +
             | 
| 404 | 
            +
             | 
| 405 | 
            +
            class ResidualCouplingLayer(nn.Module):
         | 
| 406 | 
            +
                def __init__(
         | 
| 407 | 
            +
                    self,
         | 
| 408 | 
            +
                    channels,
         | 
| 409 | 
            +
                    hidden_channels,
         | 
| 410 | 
            +
                    kernel_size,
         | 
| 411 | 
            +
                    dilation_rate,
         | 
| 412 | 
            +
                    n_layers,
         | 
| 413 | 
            +
                    p_dropout=0,
         | 
| 414 | 
            +
                    gin_channels=0,
         | 
| 415 | 
            +
                    mean_only=False,
         | 
| 416 | 
            +
                ):
         | 
| 417 | 
            +
                    assert channels % 2 == 0, "channels should be divisible by 2"
         | 
| 418 | 
            +
                    super().__init__()
         | 
| 419 | 
            +
                    self.channels = channels
         | 
| 420 | 
            +
                    self.hidden_channels = hidden_channels
         | 
| 421 | 
            +
                    self.kernel_size = kernel_size
         | 
| 422 | 
            +
                    self.dilation_rate = dilation_rate
         | 
| 423 | 
            +
                    self.n_layers = n_layers
         | 
| 424 | 
            +
                    self.half_channels = channels // 2
         | 
| 425 | 
            +
                    self.mean_only = mean_only
         | 
| 426 | 
            +
             | 
| 427 | 
            +
                    self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
         | 
| 428 | 
            +
                    self.enc = WN(
         | 
| 429 | 
            +
                        hidden_channels,
         | 
| 430 | 
            +
                        kernel_size,
         | 
| 431 | 
            +
                        dilation_rate,
         | 
| 432 | 
            +
                        n_layers,
         | 
| 433 | 
            +
                        p_dropout=p_dropout,
         | 
| 434 | 
            +
                        gin_channels=gin_channels,
         | 
| 435 | 
            +
                    )
         | 
| 436 | 
            +
                    self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
         | 
| 437 | 
            +
                    self.post.weight.data.zero_()
         | 
| 438 | 
            +
                    self.post.bias.data.zero_()
         | 
| 439 | 
            +
             | 
| 440 | 
            +
                def forward(self, x, x_mask, g=None, reverse=False):
         | 
| 441 | 
            +
                    x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
         | 
| 442 | 
            +
                    h = self.pre(x0) * x_mask
         | 
| 443 | 
            +
                    h = self.enc(h, x_mask, g=g)
         | 
| 444 | 
            +
                    stats = self.post(h) * x_mask
         | 
| 445 | 
            +
                    if not self.mean_only:
         | 
| 446 | 
            +
                        m, logs = torch.split(stats, [self.half_channels] * 2, 1)
         | 
| 447 | 
            +
                    else:
         | 
| 448 | 
            +
                        m = stats
         | 
| 449 | 
            +
                        logs = torch.zeros_like(m)
         | 
| 450 | 
            +
             | 
| 451 | 
            +
                    if not reverse:
         | 
| 452 | 
            +
                        x1 = m + x1 * torch.exp(logs) * x_mask
         | 
| 453 | 
            +
                        x = torch.cat([x0, x1], 1)
         | 
| 454 | 
            +
                        logdet = torch.sum(logs, [1, 2])
         | 
| 455 | 
            +
                        return x, logdet
         | 
| 456 | 
            +
                    else:
         | 
| 457 | 
            +
                        x1 = (x1 - m) * torch.exp(-logs) * x_mask
         | 
| 458 | 
            +
                        x = torch.cat([x0, x1], 1)
         | 
| 459 | 
            +
                        return x
         | 
| 460 | 
            +
             | 
| 461 | 
            +
                def remove_weight_norm(self):
         | 
| 462 | 
            +
                    self.enc.remove_weight_norm()
         | 
| 463 | 
            +
             | 
| 464 | 
            +
             | 
| 465 | 
            +
            class ConvFlow(nn.Module):
         | 
| 466 | 
            +
                def __init__(
         | 
| 467 | 
            +
                    self,
         | 
| 468 | 
            +
                    in_channels,
         | 
| 469 | 
            +
                    filter_channels,
         | 
| 470 | 
            +
                    kernel_size,
         | 
| 471 | 
            +
                    n_layers,
         | 
| 472 | 
            +
                    num_bins=10,
         | 
| 473 | 
            +
                    tail_bound=5.0,
         | 
| 474 | 
            +
                ):
         | 
| 475 | 
            +
                    super().__init__()
         | 
| 476 | 
            +
                    self.in_channels = in_channels
         | 
| 477 | 
            +
                    self.filter_channels = filter_channels
         | 
| 478 | 
            +
                    self.kernel_size = kernel_size
         | 
| 479 | 
            +
                    self.n_layers = n_layers
         | 
| 480 | 
            +
                    self.num_bins = num_bins
         | 
| 481 | 
            +
                    self.tail_bound = tail_bound
         | 
| 482 | 
            +
                    self.half_channels = in_channels // 2
         | 
| 483 | 
            +
             | 
| 484 | 
            +
                    self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
         | 
| 485 | 
            +
                    self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
         | 
| 486 | 
            +
                    self.proj = nn.Conv1d(
         | 
| 487 | 
            +
                        filter_channels, self.half_channels * (num_bins * 3 - 1), 1
         | 
| 488 | 
            +
                    )
         | 
| 489 | 
            +
                    self.proj.weight.data.zero_()
         | 
| 490 | 
            +
                    self.proj.bias.data.zero_()
         | 
| 491 | 
            +
             | 
| 492 | 
            +
                def forward(self, x, x_mask, g=None, reverse=False):
         | 
| 493 | 
            +
                    x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
         | 
| 494 | 
            +
                    h = self.pre(x0)
         | 
| 495 | 
            +
                    h = self.convs(h, x_mask, g=g)
         | 
| 496 | 
            +
                    h = self.proj(h) * x_mask
         | 
| 497 | 
            +
             | 
| 498 | 
            +
                    b, c, t = x0.shape
         | 
| 499 | 
            +
                    h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2)  # [b, cx?, t] -> [b, c, t, ?]
         | 
| 500 | 
            +
             | 
| 501 | 
            +
                    unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
         | 
| 502 | 
            +
                    unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
         | 
| 503 | 
            +
                        self.filter_channels
         | 
| 504 | 
            +
                    )
         | 
| 505 | 
            +
                    unnormalized_derivatives = h[..., 2 * self.num_bins :]
         | 
| 506 | 
            +
             | 
| 507 | 
            +
                    x1, logabsdet = piecewise_rational_quadratic_transform(
         | 
| 508 | 
            +
                        x1,
         | 
| 509 | 
            +
                        unnormalized_widths,
         | 
| 510 | 
            +
                        unnormalized_heights,
         | 
| 511 | 
            +
                        unnormalized_derivatives,
         | 
| 512 | 
            +
                        inverse=reverse,
         | 
| 513 | 
            +
                        tails="linear",
         | 
| 514 | 
            +
                        tail_bound=self.tail_bound,
         | 
| 515 | 
            +
                    )
         | 
| 516 | 
            +
             | 
| 517 | 
            +
                    x = torch.cat([x0, x1], 1) * x_mask
         | 
| 518 | 
            +
                    logdet = torch.sum(logabsdet * x_mask, [1, 2])
         | 
| 519 | 
            +
                    if not reverse:
         | 
| 520 | 
            +
                        return x, logdet
         | 
| 521 | 
            +
                    else:
         | 
| 522 | 
            +
                        return x
         | 
    	
        lib/infer_pack/transforms.py
    ADDED
    
    | @@ -0,0 +1,209 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            from torch.nn import functional as F
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            DEFAULT_MIN_BIN_WIDTH = 1e-3
         | 
| 8 | 
            +
            DEFAULT_MIN_BIN_HEIGHT = 1e-3
         | 
| 9 | 
            +
            DEFAULT_MIN_DERIVATIVE = 1e-3
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            def piecewise_rational_quadratic_transform(
         | 
| 13 | 
            +
                inputs,
         | 
| 14 | 
            +
                unnormalized_widths,
         | 
| 15 | 
            +
                unnormalized_heights,
         | 
| 16 | 
            +
                unnormalized_derivatives,
         | 
| 17 | 
            +
                inverse=False,
         | 
| 18 | 
            +
                tails=None,
         | 
| 19 | 
            +
                tail_bound=1.0,
         | 
| 20 | 
            +
                min_bin_width=DEFAULT_MIN_BIN_WIDTH,
         | 
| 21 | 
            +
                min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
         | 
| 22 | 
            +
                min_derivative=DEFAULT_MIN_DERIVATIVE,
         | 
| 23 | 
            +
            ):
         | 
| 24 | 
            +
                if tails is None:
         | 
| 25 | 
            +
                    spline_fn = rational_quadratic_spline
         | 
| 26 | 
            +
                    spline_kwargs = {}
         | 
| 27 | 
            +
                else:
         | 
| 28 | 
            +
                    spline_fn = unconstrained_rational_quadratic_spline
         | 
| 29 | 
            +
                    spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                outputs, logabsdet = spline_fn(
         | 
| 32 | 
            +
                    inputs=inputs,
         | 
| 33 | 
            +
                    unnormalized_widths=unnormalized_widths,
         | 
| 34 | 
            +
                    unnormalized_heights=unnormalized_heights,
         | 
| 35 | 
            +
                    unnormalized_derivatives=unnormalized_derivatives,
         | 
| 36 | 
            +
                    inverse=inverse,
         | 
| 37 | 
            +
                    min_bin_width=min_bin_width,
         | 
| 38 | 
            +
                    min_bin_height=min_bin_height,
         | 
| 39 | 
            +
                    min_derivative=min_derivative,
         | 
| 40 | 
            +
                    **spline_kwargs
         | 
| 41 | 
            +
                )
         | 
| 42 | 
            +
                return outputs, logabsdet
         | 
| 43 | 
            +
             | 
| 44 | 
            +
             | 
| 45 | 
            +
            def searchsorted(bin_locations, inputs, eps=1e-6):
         | 
| 46 | 
            +
                bin_locations[..., -1] += eps
         | 
| 47 | 
            +
                return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
         | 
| 48 | 
            +
             | 
| 49 | 
            +
             | 
| 50 | 
            +
            def unconstrained_rational_quadratic_spline(
         | 
| 51 | 
            +
                inputs,
         | 
| 52 | 
            +
                unnormalized_widths,
         | 
| 53 | 
            +
                unnormalized_heights,
         | 
| 54 | 
            +
                unnormalized_derivatives,
         | 
| 55 | 
            +
                inverse=False,
         | 
| 56 | 
            +
                tails="linear",
         | 
| 57 | 
            +
                tail_bound=1.0,
         | 
| 58 | 
            +
                min_bin_width=DEFAULT_MIN_BIN_WIDTH,
         | 
| 59 | 
            +
                min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
         | 
| 60 | 
            +
                min_derivative=DEFAULT_MIN_DERIVATIVE,
         | 
| 61 | 
            +
            ):
         | 
| 62 | 
            +
                inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
         | 
| 63 | 
            +
                outside_interval_mask = ~inside_interval_mask
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                outputs = torch.zeros_like(inputs)
         | 
| 66 | 
            +
                logabsdet = torch.zeros_like(inputs)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                if tails == "linear":
         | 
| 69 | 
            +
                    unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
         | 
| 70 | 
            +
                    constant = np.log(np.exp(1 - min_derivative) - 1)
         | 
| 71 | 
            +
                    unnormalized_derivatives[..., 0] = constant
         | 
| 72 | 
            +
                    unnormalized_derivatives[..., -1] = constant
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    outputs[outside_interval_mask] = inputs[outside_interval_mask]
         | 
| 75 | 
            +
                    logabsdet[outside_interval_mask] = 0
         | 
| 76 | 
            +
                else:
         | 
| 77 | 
            +
                    raise RuntimeError("{} tails are not implemented.".format(tails))
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                (
         | 
| 80 | 
            +
                    outputs[inside_interval_mask],
         | 
| 81 | 
            +
                    logabsdet[inside_interval_mask],
         | 
| 82 | 
            +
                ) = rational_quadratic_spline(
         | 
| 83 | 
            +
                    inputs=inputs[inside_interval_mask],
         | 
| 84 | 
            +
                    unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
         | 
| 85 | 
            +
                    unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
         | 
| 86 | 
            +
                    unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
         | 
| 87 | 
            +
                    inverse=inverse,
         | 
| 88 | 
            +
                    left=-tail_bound,
         | 
| 89 | 
            +
                    right=tail_bound,
         | 
| 90 | 
            +
                    bottom=-tail_bound,
         | 
| 91 | 
            +
                    top=tail_bound,
         | 
| 92 | 
            +
                    min_bin_width=min_bin_width,
         | 
| 93 | 
            +
                    min_bin_height=min_bin_height,
         | 
| 94 | 
            +
                    min_derivative=min_derivative,
         | 
| 95 | 
            +
                )
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                return outputs, logabsdet
         | 
| 98 | 
            +
             | 
| 99 | 
            +
             | 
| 100 | 
            +
            def rational_quadratic_spline(
         | 
| 101 | 
            +
                inputs,
         | 
| 102 | 
            +
                unnormalized_widths,
         | 
| 103 | 
            +
                unnormalized_heights,
         | 
| 104 | 
            +
                unnormalized_derivatives,
         | 
| 105 | 
            +
                inverse=False,
         | 
| 106 | 
            +
                left=0.0,
         | 
| 107 | 
            +
                right=1.0,
         | 
| 108 | 
            +
                bottom=0.0,
         | 
| 109 | 
            +
                top=1.0,
         | 
| 110 | 
            +
                min_bin_width=DEFAULT_MIN_BIN_WIDTH,
         | 
| 111 | 
            +
                min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
         | 
| 112 | 
            +
                min_derivative=DEFAULT_MIN_DERIVATIVE,
         | 
| 113 | 
            +
            ):
         | 
| 114 | 
            +
                if torch.min(inputs) < left or torch.max(inputs) > right:
         | 
| 115 | 
            +
                    raise ValueError("Input to a transform is not within its domain")
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                num_bins = unnormalized_widths.shape[-1]
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                if min_bin_width * num_bins > 1.0:
         | 
| 120 | 
            +
                    raise ValueError("Minimal bin width too large for the number of bins")
         | 
| 121 | 
            +
                if min_bin_height * num_bins > 1.0:
         | 
| 122 | 
            +
                    raise ValueError("Minimal bin height too large for the number of bins")
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                widths = F.softmax(unnormalized_widths, dim=-1)
         | 
| 125 | 
            +
                widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
         | 
| 126 | 
            +
                cumwidths = torch.cumsum(widths, dim=-1)
         | 
| 127 | 
            +
                cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
         | 
| 128 | 
            +
                cumwidths = (right - left) * cumwidths + left
         | 
| 129 | 
            +
                cumwidths[..., 0] = left
         | 
| 130 | 
            +
                cumwidths[..., -1] = right
         | 
| 131 | 
            +
                widths = cumwidths[..., 1:] - cumwidths[..., :-1]
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                derivatives = min_derivative + F.softplus(unnormalized_derivatives)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                heights = F.softmax(unnormalized_heights, dim=-1)
         | 
| 136 | 
            +
                heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
         | 
| 137 | 
            +
                cumheights = torch.cumsum(heights, dim=-1)
         | 
| 138 | 
            +
                cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
         | 
| 139 | 
            +
                cumheights = (top - bottom) * cumheights + bottom
         | 
| 140 | 
            +
                cumheights[..., 0] = bottom
         | 
| 141 | 
            +
                cumheights[..., -1] = top
         | 
| 142 | 
            +
                heights = cumheights[..., 1:] - cumheights[..., :-1]
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                if inverse:
         | 
| 145 | 
            +
                    bin_idx = searchsorted(cumheights, inputs)[..., None]
         | 
| 146 | 
            +
                else:
         | 
| 147 | 
            +
                    bin_idx = searchsorted(cumwidths, inputs)[..., None]
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
         | 
| 150 | 
            +
                input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
         | 
| 153 | 
            +
                delta = heights / widths
         | 
| 154 | 
            +
                input_delta = delta.gather(-1, bin_idx)[..., 0]
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
         | 
| 157 | 
            +
                input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                input_heights = heights.gather(-1, bin_idx)[..., 0]
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                if inverse:
         | 
| 162 | 
            +
                    a = (inputs - input_cumheights) * (
         | 
| 163 | 
            +
                        input_derivatives + input_derivatives_plus_one - 2 * input_delta
         | 
| 164 | 
            +
                    ) + input_heights * (input_delta - input_derivatives)
         | 
| 165 | 
            +
                    b = input_heights * input_derivatives - (inputs - input_cumheights) * (
         | 
| 166 | 
            +
                        input_derivatives + input_derivatives_plus_one - 2 * input_delta
         | 
| 167 | 
            +
                    )
         | 
| 168 | 
            +
                    c = -input_delta * (inputs - input_cumheights)
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                    discriminant = b.pow(2) - 4 * a * c
         | 
| 171 | 
            +
                    assert (discriminant >= 0).all()
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                    root = (2 * c) / (-b - torch.sqrt(discriminant))
         | 
| 174 | 
            +
                    outputs = root * input_bin_widths + input_cumwidths
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    theta_one_minus_theta = root * (1 - root)
         | 
| 177 | 
            +
                    denominator = input_delta + (
         | 
| 178 | 
            +
                        (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
         | 
| 179 | 
            +
                        * theta_one_minus_theta
         | 
| 180 | 
            +
                    )
         | 
| 181 | 
            +
                    derivative_numerator = input_delta.pow(2) * (
         | 
| 182 | 
            +
                        input_derivatives_plus_one * root.pow(2)
         | 
| 183 | 
            +
                        + 2 * input_delta * theta_one_minus_theta
         | 
| 184 | 
            +
                        + input_derivatives * (1 - root).pow(2)
         | 
| 185 | 
            +
                    )
         | 
| 186 | 
            +
                    logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                    return outputs, -logabsdet
         | 
| 189 | 
            +
                else:
         | 
| 190 | 
            +
                    theta = (inputs - input_cumwidths) / input_bin_widths
         | 
| 191 | 
            +
                    theta_one_minus_theta = theta * (1 - theta)
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                    numerator = input_heights * (
         | 
| 194 | 
            +
                        input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
         | 
| 195 | 
            +
                    )
         | 
| 196 | 
            +
                    denominator = input_delta + (
         | 
| 197 | 
            +
                        (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
         | 
| 198 | 
            +
                        * theta_one_minus_theta
         | 
| 199 | 
            +
                    )
         | 
| 200 | 
            +
                    outputs = input_cumheights + numerator / denominator
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                    derivative_numerator = input_delta.pow(2) * (
         | 
| 203 | 
            +
                        input_derivatives_plus_one * theta.pow(2)
         | 
| 204 | 
            +
                        + 2 * input_delta * theta_one_minus_theta
         | 
| 205 | 
            +
                        + input_derivatives * (1 - theta).pow(2)
         | 
| 206 | 
            +
                    )
         | 
| 207 | 
            +
                    logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                    return outputs, logabsdet
         | 
    	
        lib/rmvpe.py
    ADDED
    
    | @@ -0,0 +1,422 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch, numpy as np
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            class BiGRU(nn.Module):
         | 
| 8 | 
            +
                def __init__(self, input_features, hidden_features, num_layers):
         | 
| 9 | 
            +
                    super(BiGRU, self).__init__()
         | 
| 10 | 
            +
                    self.gru = nn.GRU(
         | 
| 11 | 
            +
                        input_features,
         | 
| 12 | 
            +
                        hidden_features,
         | 
| 13 | 
            +
                        num_layers=num_layers,
         | 
| 14 | 
            +
                        batch_first=True,
         | 
| 15 | 
            +
                        bidirectional=True,
         | 
| 16 | 
            +
                    )
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                def forward(self, x):
         | 
| 19 | 
            +
                    return self.gru(x)[0]
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            class ConvBlockRes(nn.Module):
         | 
| 23 | 
            +
                def __init__(self, in_channels, out_channels, momentum=0.01):
         | 
| 24 | 
            +
                    super(ConvBlockRes, self).__init__()
         | 
| 25 | 
            +
                    self.conv = nn.Sequential(
         | 
| 26 | 
            +
                        nn.Conv2d(
         | 
| 27 | 
            +
                            in_channels=in_channels,
         | 
| 28 | 
            +
                            out_channels=out_channels,
         | 
| 29 | 
            +
                            kernel_size=(3, 3),
         | 
| 30 | 
            +
                            stride=(1, 1),
         | 
| 31 | 
            +
                            padding=(1, 1),
         | 
| 32 | 
            +
                            bias=False,
         | 
| 33 | 
            +
                        ),
         | 
| 34 | 
            +
                        nn.BatchNorm2d(out_channels, momentum=momentum),
         | 
| 35 | 
            +
                        nn.ReLU(),
         | 
| 36 | 
            +
                        nn.Conv2d(
         | 
| 37 | 
            +
                            in_channels=out_channels,
         | 
| 38 | 
            +
                            out_channels=out_channels,
         | 
| 39 | 
            +
                            kernel_size=(3, 3),
         | 
| 40 | 
            +
                            stride=(1, 1),
         | 
| 41 | 
            +
                            padding=(1, 1),
         | 
| 42 | 
            +
                            bias=False,
         | 
| 43 | 
            +
                        ),
         | 
| 44 | 
            +
                        nn.BatchNorm2d(out_channels, momentum=momentum),
         | 
| 45 | 
            +
                        nn.ReLU(),
         | 
| 46 | 
            +
                    )
         | 
| 47 | 
            +
                    if in_channels != out_channels:
         | 
| 48 | 
            +
                        self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
         | 
| 49 | 
            +
                        self.is_shortcut = True
         | 
| 50 | 
            +
                    else:
         | 
| 51 | 
            +
                        self.is_shortcut = False
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                def forward(self, x):
         | 
| 54 | 
            +
                    if self.is_shortcut:
         | 
| 55 | 
            +
                        return self.conv(x) + self.shortcut(x)
         | 
| 56 | 
            +
                    else:
         | 
| 57 | 
            +
                        return self.conv(x) + x
         | 
| 58 | 
            +
             | 
| 59 | 
            +
             | 
| 60 | 
            +
            class Encoder(nn.Module):
         | 
| 61 | 
            +
                def __init__(
         | 
| 62 | 
            +
                    self,
         | 
| 63 | 
            +
                    in_channels,
         | 
| 64 | 
            +
                    in_size,
         | 
| 65 | 
            +
                    n_encoders,
         | 
| 66 | 
            +
                    kernel_size,
         | 
| 67 | 
            +
                    n_blocks,
         | 
| 68 | 
            +
                    out_channels=16,
         | 
| 69 | 
            +
                    momentum=0.01,
         | 
| 70 | 
            +
                ):
         | 
| 71 | 
            +
                    super(Encoder, self).__init__()
         | 
| 72 | 
            +
                    self.n_encoders = n_encoders
         | 
| 73 | 
            +
                    self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
         | 
| 74 | 
            +
                    self.layers = nn.ModuleList()
         | 
| 75 | 
            +
                    self.latent_channels = []
         | 
| 76 | 
            +
                    for i in range(self.n_encoders):
         | 
| 77 | 
            +
                        self.layers.append(
         | 
| 78 | 
            +
                            ResEncoderBlock(
         | 
| 79 | 
            +
                                in_channels, out_channels, kernel_size, n_blocks, momentum=momentum
         | 
| 80 | 
            +
                            )
         | 
| 81 | 
            +
                        )
         | 
| 82 | 
            +
                        self.latent_channels.append([out_channels, in_size])
         | 
| 83 | 
            +
                        in_channels = out_channels
         | 
| 84 | 
            +
                        out_channels *= 2
         | 
| 85 | 
            +
                        in_size //= 2
         | 
| 86 | 
            +
                    self.out_size = in_size
         | 
| 87 | 
            +
                    self.out_channel = out_channels
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                def forward(self, x):
         | 
| 90 | 
            +
                    concat_tensors = []
         | 
| 91 | 
            +
                    x = self.bn(x)
         | 
| 92 | 
            +
                    for i in range(self.n_encoders):
         | 
| 93 | 
            +
                        _, x = self.layers[i](x)
         | 
| 94 | 
            +
                        concat_tensors.append(_)
         | 
| 95 | 
            +
                    return x, concat_tensors
         | 
| 96 | 
            +
             | 
| 97 | 
            +
             | 
| 98 | 
            +
            class ResEncoderBlock(nn.Module):
         | 
| 99 | 
            +
                def __init__(
         | 
| 100 | 
            +
                    self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01
         | 
| 101 | 
            +
                ):
         | 
| 102 | 
            +
                    super(ResEncoderBlock, self).__init__()
         | 
| 103 | 
            +
                    self.n_blocks = n_blocks
         | 
| 104 | 
            +
                    self.conv = nn.ModuleList()
         | 
| 105 | 
            +
                    self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
         | 
| 106 | 
            +
                    for i in range(n_blocks - 1):
         | 
| 107 | 
            +
                        self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
         | 
| 108 | 
            +
                    self.kernel_size = kernel_size
         | 
| 109 | 
            +
                    if self.kernel_size is not None:
         | 
| 110 | 
            +
                        self.pool = nn.AvgPool2d(kernel_size=kernel_size)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                def forward(self, x):
         | 
| 113 | 
            +
                    for i in range(self.n_blocks):
         | 
| 114 | 
            +
                        x = self.conv[i](x)
         | 
| 115 | 
            +
                    if self.kernel_size is not None:
         | 
| 116 | 
            +
                        return x, self.pool(x)
         | 
| 117 | 
            +
                    else:
         | 
| 118 | 
            +
                        return x
         | 
| 119 | 
            +
             | 
| 120 | 
            +
             | 
| 121 | 
            +
            class Intermediate(nn.Module):  #
         | 
| 122 | 
            +
                def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
         | 
| 123 | 
            +
                    super(Intermediate, self).__init__()
         | 
| 124 | 
            +
                    self.n_inters = n_inters
         | 
| 125 | 
            +
                    self.layers = nn.ModuleList()
         | 
| 126 | 
            +
                    self.layers.append(
         | 
| 127 | 
            +
                        ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum)
         | 
| 128 | 
            +
                    )
         | 
| 129 | 
            +
                    for i in range(self.n_inters - 1):
         | 
| 130 | 
            +
                        self.layers.append(
         | 
| 131 | 
            +
                            ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum)
         | 
| 132 | 
            +
                        )
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                def forward(self, x):
         | 
| 135 | 
            +
                    for i in range(self.n_inters):
         | 
| 136 | 
            +
                        x = self.layers[i](x)
         | 
| 137 | 
            +
                    return x
         | 
| 138 | 
            +
             | 
| 139 | 
            +
             | 
| 140 | 
            +
            class ResDecoderBlock(nn.Module):
         | 
| 141 | 
            +
                def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
         | 
| 142 | 
            +
                    super(ResDecoderBlock, self).__init__()
         | 
| 143 | 
            +
                    out_padding = (0, 1) if stride == (1, 2) else (1, 1)
         | 
| 144 | 
            +
                    self.n_blocks = n_blocks
         | 
| 145 | 
            +
                    self.conv1 = nn.Sequential(
         | 
| 146 | 
            +
                        nn.ConvTranspose2d(
         | 
| 147 | 
            +
                            in_channels=in_channels,
         | 
| 148 | 
            +
                            out_channels=out_channels,
         | 
| 149 | 
            +
                            kernel_size=(3, 3),
         | 
| 150 | 
            +
                            stride=stride,
         | 
| 151 | 
            +
                            padding=(1, 1),
         | 
| 152 | 
            +
                            output_padding=out_padding,
         | 
| 153 | 
            +
                            bias=False,
         | 
| 154 | 
            +
                        ),
         | 
| 155 | 
            +
                        nn.BatchNorm2d(out_channels, momentum=momentum),
         | 
| 156 | 
            +
                        nn.ReLU(),
         | 
| 157 | 
            +
                    )
         | 
| 158 | 
            +
                    self.conv2 = nn.ModuleList()
         | 
| 159 | 
            +
                    self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
         | 
| 160 | 
            +
                    for i in range(n_blocks - 1):
         | 
| 161 | 
            +
                        self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                def forward(self, x, concat_tensor):
         | 
| 164 | 
            +
                    x = self.conv1(x)
         | 
| 165 | 
            +
                    x = torch.cat((x, concat_tensor), dim=1)
         | 
| 166 | 
            +
                    for i in range(self.n_blocks):
         | 
| 167 | 
            +
                        x = self.conv2[i](x)
         | 
| 168 | 
            +
                    return x
         | 
| 169 | 
            +
             | 
| 170 | 
            +
             | 
| 171 | 
            +
            class Decoder(nn.Module):
         | 
| 172 | 
            +
                def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
         | 
| 173 | 
            +
                    super(Decoder, self).__init__()
         | 
| 174 | 
            +
                    self.layers = nn.ModuleList()
         | 
| 175 | 
            +
                    self.n_decoders = n_decoders
         | 
| 176 | 
            +
                    for i in range(self.n_decoders):
         | 
| 177 | 
            +
                        out_channels = in_channels // 2
         | 
| 178 | 
            +
                        self.layers.append(
         | 
| 179 | 
            +
                            ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)
         | 
| 180 | 
            +
                        )
         | 
| 181 | 
            +
                        in_channels = out_channels
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                def forward(self, x, concat_tensors):
         | 
| 184 | 
            +
                    for i in range(self.n_decoders):
         | 
| 185 | 
            +
                        x = self.layers[i](x, concat_tensors[-1 - i])
         | 
| 186 | 
            +
                    return x
         | 
| 187 | 
            +
             | 
| 188 | 
            +
             | 
| 189 | 
            +
            class DeepUnet(nn.Module):
         | 
| 190 | 
            +
                def __init__(
         | 
| 191 | 
            +
                    self,
         | 
| 192 | 
            +
                    kernel_size,
         | 
| 193 | 
            +
                    n_blocks,
         | 
| 194 | 
            +
                    en_de_layers=5,
         | 
| 195 | 
            +
                    inter_layers=4,
         | 
| 196 | 
            +
                    in_channels=1,
         | 
| 197 | 
            +
                    en_out_channels=16,
         | 
| 198 | 
            +
                ):
         | 
| 199 | 
            +
                    super(DeepUnet, self).__init__()
         | 
| 200 | 
            +
                    self.encoder = Encoder(
         | 
| 201 | 
            +
                        in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels
         | 
| 202 | 
            +
                    )
         | 
| 203 | 
            +
                    self.intermediate = Intermediate(
         | 
| 204 | 
            +
                        self.encoder.out_channel // 2,
         | 
| 205 | 
            +
                        self.encoder.out_channel,
         | 
| 206 | 
            +
                        inter_layers,
         | 
| 207 | 
            +
                        n_blocks,
         | 
| 208 | 
            +
                    )
         | 
| 209 | 
            +
                    self.decoder = Decoder(
         | 
| 210 | 
            +
                        self.encoder.out_channel, en_de_layers, kernel_size, n_blocks
         | 
| 211 | 
            +
                    )
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                def forward(self, x):
         | 
| 214 | 
            +
                    x, concat_tensors = self.encoder(x)
         | 
| 215 | 
            +
                    x = self.intermediate(x)
         | 
| 216 | 
            +
                    x = self.decoder(x, concat_tensors)
         | 
| 217 | 
            +
                    return x
         | 
| 218 | 
            +
             | 
| 219 | 
            +
             | 
| 220 | 
            +
            class E2E(nn.Module):
         | 
| 221 | 
            +
                def __init__(
         | 
| 222 | 
            +
                    self,
         | 
| 223 | 
            +
                    n_blocks,
         | 
| 224 | 
            +
                    n_gru,
         | 
| 225 | 
            +
                    kernel_size,
         | 
| 226 | 
            +
                    en_de_layers=5,
         | 
| 227 | 
            +
                    inter_layers=4,
         | 
| 228 | 
            +
                    in_channels=1,
         | 
| 229 | 
            +
                    en_out_channels=16,
         | 
| 230 | 
            +
                ):
         | 
| 231 | 
            +
                    super(E2E, self).__init__()
         | 
| 232 | 
            +
                    self.unet = DeepUnet(
         | 
| 233 | 
            +
                        kernel_size,
         | 
| 234 | 
            +
                        n_blocks,
         | 
| 235 | 
            +
                        en_de_layers,
         | 
| 236 | 
            +
                        inter_layers,
         | 
| 237 | 
            +
                        in_channels,
         | 
| 238 | 
            +
                        en_out_channels,
         | 
| 239 | 
            +
                    )
         | 
| 240 | 
            +
                    self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
         | 
| 241 | 
            +
                    if n_gru:
         | 
| 242 | 
            +
                        self.fc = nn.Sequential(
         | 
| 243 | 
            +
                            BiGRU(3 * 128, 256, n_gru),
         | 
| 244 | 
            +
                            nn.Linear(512, 360),
         | 
| 245 | 
            +
                            nn.Dropout(0.25),
         | 
| 246 | 
            +
                            nn.Sigmoid(),
         | 
| 247 | 
            +
                        )
         | 
| 248 | 
            +
                    else:
         | 
| 249 | 
            +
                        self.fc = nn.Sequential(
         | 
| 250 | 
            +
                            nn.Linear(3 * nn.N_MELS, nn.N_CLASS), nn.Dropout(0.25), nn.Sigmoid()
         | 
| 251 | 
            +
                        )
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                def forward(self, mel):
         | 
| 254 | 
            +
                    mel = mel.transpose(-1, -2).unsqueeze(1)
         | 
| 255 | 
            +
                    x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
         | 
| 256 | 
            +
                    x = self.fc(x)
         | 
| 257 | 
            +
                    return x
         | 
| 258 | 
            +
             | 
| 259 | 
            +
             | 
| 260 | 
            +
            from librosa.filters import mel
         | 
| 261 | 
            +
             | 
| 262 | 
            +
             | 
| 263 | 
            +
            class MelSpectrogram(torch.nn.Module):
         | 
| 264 | 
            +
                def __init__(
         | 
| 265 | 
            +
                    self,
         | 
| 266 | 
            +
                    is_half,
         | 
| 267 | 
            +
                    n_mel_channels,
         | 
| 268 | 
            +
                    sampling_rate,
         | 
| 269 | 
            +
                    win_length,
         | 
| 270 | 
            +
                    hop_length,
         | 
| 271 | 
            +
                    n_fft=None,
         | 
| 272 | 
            +
                    mel_fmin=0,
         | 
| 273 | 
            +
                    mel_fmax=None,
         | 
| 274 | 
            +
                    clamp=1e-5,
         | 
| 275 | 
            +
                ):
         | 
| 276 | 
            +
                    super().__init__()
         | 
| 277 | 
            +
                    n_fft = win_length if n_fft is None else n_fft
         | 
| 278 | 
            +
                    self.hann_window = {}
         | 
| 279 | 
            +
                    mel_basis = mel(
         | 
| 280 | 
            +
                        sr=sampling_rate,
         | 
| 281 | 
            +
                        n_fft=n_fft,
         | 
| 282 | 
            +
                        n_mels=n_mel_channels,
         | 
| 283 | 
            +
                        fmin=mel_fmin,
         | 
| 284 | 
            +
                        fmax=mel_fmax,
         | 
| 285 | 
            +
                        htk=True,
         | 
| 286 | 
            +
                    )
         | 
| 287 | 
            +
                    mel_basis = torch.from_numpy(mel_basis).float()
         | 
| 288 | 
            +
                    self.register_buffer("mel_basis", mel_basis)
         | 
| 289 | 
            +
                    self.n_fft = win_length if n_fft is None else n_fft
         | 
| 290 | 
            +
                    self.hop_length = hop_length
         | 
| 291 | 
            +
                    self.win_length = win_length
         | 
| 292 | 
            +
                    self.sampling_rate = sampling_rate
         | 
| 293 | 
            +
                    self.n_mel_channels = n_mel_channels
         | 
| 294 | 
            +
                    self.clamp = clamp
         | 
| 295 | 
            +
                    self.is_half = is_half
         | 
| 296 | 
            +
             | 
| 297 | 
            +
                def forward(self, audio, keyshift=0, speed=1, center=True):
         | 
| 298 | 
            +
                    factor = 2 ** (keyshift / 12)
         | 
| 299 | 
            +
                    n_fft_new = int(np.round(self.n_fft * factor))
         | 
| 300 | 
            +
                    win_length_new = int(np.round(self.win_length * factor))
         | 
| 301 | 
            +
                    hop_length_new = int(np.round(self.hop_length * speed))
         | 
| 302 | 
            +
                    keyshift_key = str(keyshift) + "_" + str(audio.device)
         | 
| 303 | 
            +
                    if keyshift_key not in self.hann_window:
         | 
| 304 | 
            +
                        self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(
         | 
| 305 | 
            +
                            audio.device
         | 
| 306 | 
            +
                        )
         | 
| 307 | 
            +
                    fft = torch.stft(
         | 
| 308 | 
            +
                        audio,
         | 
| 309 | 
            +
                        n_fft=n_fft_new,
         | 
| 310 | 
            +
                        hop_length=hop_length_new,
         | 
| 311 | 
            +
                        win_length=win_length_new,
         | 
| 312 | 
            +
                        window=self.hann_window[keyshift_key],
         | 
| 313 | 
            +
                        center=center,
         | 
| 314 | 
            +
                        return_complex=True,
         | 
| 315 | 
            +
                    )
         | 
| 316 | 
            +
                    magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
         | 
| 317 | 
            +
                    if keyshift != 0:
         | 
| 318 | 
            +
                        size = self.n_fft // 2 + 1
         | 
| 319 | 
            +
                        resize = magnitude.size(1)
         | 
| 320 | 
            +
                        if resize < size:
         | 
| 321 | 
            +
                            magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
         | 
| 322 | 
            +
                        magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
         | 
| 323 | 
            +
                    mel_output = torch.matmul(self.mel_basis, magnitude)
         | 
| 324 | 
            +
                    if self.is_half == True:
         | 
| 325 | 
            +
                        mel_output = mel_output.half()
         | 
| 326 | 
            +
                    log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
         | 
| 327 | 
            +
                    return log_mel_spec
         | 
| 328 | 
            +
             | 
| 329 | 
            +
             | 
| 330 | 
            +
            class RMVPE:
         | 
| 331 | 
            +
                def __init__(self, model_path, is_half, device=None):
         | 
| 332 | 
            +
                    self.resample_kernel = {}
         | 
| 333 | 
            +
                    model = E2E(4, 1, (2, 2))
         | 
| 334 | 
            +
                    ckpt = torch.load(model_path, map_location="cpu")
         | 
| 335 | 
            +
                    model.load_state_dict(ckpt)
         | 
| 336 | 
            +
                    model.eval()
         | 
| 337 | 
            +
                    if is_half == True:
         | 
| 338 | 
            +
                        model = model.half()
         | 
| 339 | 
            +
                    self.model = model
         | 
| 340 | 
            +
                    self.resample_kernel = {}
         | 
| 341 | 
            +
                    self.is_half = is_half
         | 
| 342 | 
            +
                    if device is None:
         | 
| 343 | 
            +
                        device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 344 | 
            +
                    self.device = device
         | 
| 345 | 
            +
                    self.mel_extractor = MelSpectrogram(
         | 
| 346 | 
            +
                        is_half, 128, 16000, 1024, 160, None, 30, 8000
         | 
| 347 | 
            +
                    ).to(device)
         | 
| 348 | 
            +
                    self.model = self.model.to(device)
         | 
| 349 | 
            +
                    cents_mapping = 20 * np.arange(360) + 1997.3794084376191
         | 
| 350 | 
            +
                    self.cents_mapping = np.pad(cents_mapping, (4, 4))  # 368
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                def mel2hidden(self, mel):
         | 
| 353 | 
            +
                    with torch.no_grad():
         | 
| 354 | 
            +
                        n_frames = mel.shape[-1]
         | 
| 355 | 
            +
                        mel = F.pad(
         | 
| 356 | 
            +
                            mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode="reflect"
         | 
| 357 | 
            +
                        )
         | 
| 358 | 
            +
                        hidden = self.model(mel)
         | 
| 359 | 
            +
                        return hidden[:, :n_frames]
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                def decode(self, hidden, thred=0.03):
         | 
| 362 | 
            +
                    cents_pred = self.to_local_average_cents(hidden, thred=thred)
         | 
| 363 | 
            +
                    f0 = 10 * (2 ** (cents_pred / 1200))
         | 
| 364 | 
            +
                    f0[f0 == 10] = 0
         | 
| 365 | 
            +
                    # f0 = np.array([10 * (2 ** (cent_pred / 1200)) if cent_pred else 0 for cent_pred in cents_pred])
         | 
| 366 | 
            +
                    return f0
         | 
| 367 | 
            +
             | 
| 368 | 
            +
                def infer_from_audio(self, audio, thred=0.03):
         | 
| 369 | 
            +
                    audio = torch.from_numpy(audio).float().to(self.device).unsqueeze(0)
         | 
| 370 | 
            +
                    # torch.cuda.synchronize()
         | 
| 371 | 
            +
                    # t0=ttime()
         | 
| 372 | 
            +
                    mel = self.mel_extractor(audio, center=True)
         | 
| 373 | 
            +
                    # torch.cuda.synchronize()
         | 
| 374 | 
            +
                    # t1=ttime()
         | 
| 375 | 
            +
                    hidden = self.mel2hidden(mel)
         | 
| 376 | 
            +
                    # torch.cuda.synchronize()
         | 
| 377 | 
            +
                    # t2=ttime()
         | 
| 378 | 
            +
                    hidden = hidden.squeeze(0).cpu().numpy()
         | 
| 379 | 
            +
                    if self.is_half == True:
         | 
| 380 | 
            +
                        hidden = hidden.astype("float32")
         | 
| 381 | 
            +
                    f0 = self.decode(hidden, thred=thred)
         | 
| 382 | 
            +
                    # torch.cuda.synchronize()
         | 
| 383 | 
            +
                    # t3=ttime()
         | 
| 384 | 
            +
                    # print("hmvpe:%s\t%s\t%s\t%s"%(t1-t0,t2-t1,t3-t2,t3-t0))
         | 
| 385 | 
            +
                    return f0
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                def pitch_based_audio_inference(self, audio, thred=0.03, f0_min=50, f0_max=1100):
         | 
| 388 | 
            +
                    audio = torch.from_numpy(audio).float().to(self.device).unsqueeze(0)
         | 
| 389 | 
            +
                    mel = self.mel_extractor(audio, center=True)
         | 
| 390 | 
            +
                    hidden = self.mel2hidden(mel)
         | 
| 391 | 
            +
                    hidden = hidden.squeeze(0).cpu().numpy()
         | 
| 392 | 
            +
                    if self.is_half == True:
         | 
| 393 | 
            +
                        hidden = hidden.astype("float32")
         | 
| 394 | 
            +
                    f0 = self.decode(hidden, thred=thred)
         | 
| 395 | 
            +
                    f0[(f0 < f0_min) | (f0 > f0_max)] = 0
         | 
| 396 | 
            +
                    return f0
         | 
| 397 | 
            +
             | 
| 398 | 
            +
                def to_local_average_cents(self, salience, thred=0.05):
         | 
| 399 | 
            +
                    # t0 = ttime()
         | 
| 400 | 
            +
                    center = np.argmax(salience, axis=1)  # frame length#index
         | 
| 401 | 
            +
                    salience = np.pad(salience, ((0, 0), (4, 4)))  # frame length,368
         | 
| 402 | 
            +
                    # t1 = ttime()
         | 
| 403 | 
            +
                    center += 4
         | 
| 404 | 
            +
                    todo_salience = []
         | 
| 405 | 
            +
                    todo_cents_mapping = []
         | 
| 406 | 
            +
                    starts = center - 4
         | 
| 407 | 
            +
                    ends = center + 5
         | 
| 408 | 
            +
                    for idx in range(salience.shape[0]):
         | 
| 409 | 
            +
                        todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
         | 
| 410 | 
            +
                        todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
         | 
| 411 | 
            +
                    # t2 = ttime()
         | 
| 412 | 
            +
                    todo_salience = np.array(todo_salience)  # frame length,9
         | 
| 413 | 
            +
                    todo_cents_mapping = np.array(todo_cents_mapping)  # frame length,9
         | 
| 414 | 
            +
                    product_sum = np.sum(todo_salience * todo_cents_mapping, 1)
         | 
| 415 | 
            +
                    weight_sum = np.sum(todo_salience, 1)  # frame length
         | 
| 416 | 
            +
                    devided = product_sum / weight_sum  # frame length
         | 
| 417 | 
            +
                    # t3 = ttime()
         | 
| 418 | 
            +
                    maxx = np.max(salience, axis=1)  # frame length
         | 
| 419 | 
            +
                    devided[maxx <= thred] = 0
         | 
| 420 | 
            +
                    # t4 = ttime()
         | 
| 421 | 
            +
                    # print("decode:%s\t%s\t%s\t%s" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
         | 
| 422 | 
            +
                    return devided
         | 
    	
        mdx_models/data.json
    ADDED
    
    | @@ -0,0 +1,354 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
                "0ddfc0eb5792638ad5dc27850236c246": {
         | 
| 3 | 
            +
                    "compensate": 1.035,
         | 
| 4 | 
            +
                    "mdx_dim_f_set": 2048,
         | 
| 5 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 6 | 
            +
                    "mdx_n_fft_scale_set": 6144,
         | 
| 7 | 
            +
                    "primary_stem": "Vocals"
         | 
| 8 | 
            +
                },
         | 
| 9 | 
            +
                "26d308f91f3423a67dc69a6d12a8793d": {
         | 
| 10 | 
            +
                    "compensate": 1.035,
         | 
| 11 | 
            +
                    "mdx_dim_f_set": 2048,
         | 
| 12 | 
            +
                    "mdx_dim_t_set": 9,
         | 
| 13 | 
            +
                    "mdx_n_fft_scale_set": 8192,
         | 
| 14 | 
            +
                    "primary_stem": "Other"
         | 
| 15 | 
            +
                },
         | 
| 16 | 
            +
                "2cdd429caac38f0194b133884160f2c6": {
         | 
| 17 | 
            +
                    "compensate": 1.045,
         | 
| 18 | 
            +
                    "mdx_dim_f_set": 3072,
         | 
| 19 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 20 | 
            +
                    "mdx_n_fft_scale_set": 7680,
         | 
| 21 | 
            +
                    "primary_stem": "Instrumental"
         | 
| 22 | 
            +
                },
         | 
| 23 | 
            +
                "2f5501189a2f6db6349916fabe8c90de": {
         | 
| 24 | 
            +
                    "compensate": 1.035,
         | 
| 25 | 
            +
                    "mdx_dim_f_set": 2048,
         | 
| 26 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 27 | 
            +
                    "mdx_n_fft_scale_set": 6144,
         | 
| 28 | 
            +
                    "primary_stem": "Vocals"
         | 
| 29 | 
            +
                },
         | 
| 30 | 
            +
                "398580b6d5d973af3120df54cee6759d": {
         | 
| 31 | 
            +
                    "compensate": 1.75,
         | 
| 32 | 
            +
                    "mdx_dim_f_set": 3072,
         | 
| 33 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 34 | 
            +
                    "mdx_n_fft_scale_set": 7680,
         | 
| 35 | 
            +
                    "primary_stem": "Vocals"
         | 
| 36 | 
            +
                },
         | 
| 37 | 
            +
                "488b3e6f8bd3717d9d7c428476be2d75": {
         | 
| 38 | 
            +
                    "compensate": 1.035,
         | 
| 39 | 
            +
                    "mdx_dim_f_set": 3072,
         | 
| 40 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 41 | 
            +
                    "mdx_n_fft_scale_set": 7680,
         | 
| 42 | 
            +
                    "primary_stem": "Instrumental"
         | 
| 43 | 
            +
                },
         | 
| 44 | 
            +
                "4910e7827f335048bdac11fa967772f9": {
         | 
| 45 | 
            +
                    "compensate": 1.035,
         | 
| 46 | 
            +
                    "mdx_dim_f_set": 2048,
         | 
| 47 | 
            +
                    "mdx_dim_t_set": 7,
         | 
| 48 | 
            +
                    "mdx_n_fft_scale_set": 4096,
         | 
| 49 | 
            +
                    "primary_stem": "Drums"
         | 
| 50 | 
            +
                },
         | 
| 51 | 
            +
                "53c4baf4d12c3e6c3831bb8f5b532b93": {
         | 
| 52 | 
            +
                    "compensate": 1.043,
         | 
| 53 | 
            +
                    "mdx_dim_f_set": 3072,
         | 
| 54 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 55 | 
            +
                    "mdx_n_fft_scale_set": 7680,
         | 
| 56 | 
            +
                    "primary_stem": "Vocals"
         | 
| 57 | 
            +
                },
         | 
| 58 | 
            +
                "5d343409ef0df48c7d78cce9f0106781": {
         | 
| 59 | 
            +
                    "compensate": 1.075,
         | 
| 60 | 
            +
                    "mdx_dim_f_set": 3072,
         | 
| 61 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 62 | 
            +
                    "mdx_n_fft_scale_set": 7680,
         | 
| 63 | 
            +
                    "primary_stem": "Vocals"
         | 
| 64 | 
            +
                },
         | 
| 65 | 
            +
                "5f6483271e1efb9bfb59e4a3e6d4d098": {
         | 
| 66 | 
            +
                    "compensate": 1.035,
         | 
| 67 | 
            +
                    "mdx_dim_f_set": 2048,
         | 
| 68 | 
            +
                    "mdx_dim_t_set": 9,
         | 
| 69 | 
            +
                    "mdx_n_fft_scale_set": 6144,
         | 
| 70 | 
            +
                    "primary_stem": "Vocals"
         | 
| 71 | 
            +
                },
         | 
| 72 | 
            +
                "65ab5919372a128e4167f5e01a8fda85": {
         | 
| 73 | 
            +
                    "compensate": 1.035,
         | 
| 74 | 
            +
                    "mdx_dim_f_set": 2048,
         | 
| 75 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 76 | 
            +
                    "mdx_n_fft_scale_set": 8192,
         | 
| 77 | 
            +
                    "primary_stem": "Other"
         | 
| 78 | 
            +
                },
         | 
| 79 | 
            +
                "6703e39f36f18aa7855ee1047765621d": {
         | 
| 80 | 
            +
                    "compensate": 1.035,
         | 
| 81 | 
            +
                    "mdx_dim_f_set": 2048,
         | 
| 82 | 
            +
                    "mdx_dim_t_set": 9,
         | 
| 83 | 
            +
                    "mdx_n_fft_scale_set": 16384,
         | 
| 84 | 
            +
                    "primary_stem": "Bass"
         | 
| 85 | 
            +
                },
         | 
| 86 | 
            +
                "6b31de20e84392859a3d09d43f089515": {
         | 
| 87 | 
            +
                    "compensate": 1.035,
         | 
| 88 | 
            +
                    "mdx_dim_f_set": 2048,
         | 
| 89 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 90 | 
            +
                    "mdx_n_fft_scale_set": 6144,
         | 
| 91 | 
            +
                    "primary_stem": "Vocals"
         | 
| 92 | 
            +
                },
         | 
| 93 | 
            +
                "867595e9de46f6ab699008295df62798": {
         | 
| 94 | 
            +
                    "compensate": 1.03,
         | 
| 95 | 
            +
                    "mdx_dim_f_set": 3072,
         | 
| 96 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 97 | 
            +
                    "mdx_n_fft_scale_set": 7680,
         | 
| 98 | 
            +
                    "primary_stem": "Vocals"
         | 
| 99 | 
            +
                },
         | 
| 100 | 
            +
                "a3cd63058945e777505c01d2507daf37": {
         | 
| 101 | 
            +
                    "compensate": 1.03,
         | 
| 102 | 
            +
                    "mdx_dim_f_set": 2048,
         | 
| 103 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 104 | 
            +
                    "mdx_n_fft_scale_set": 6144,
         | 
| 105 | 
            +
                    "primary_stem": "Vocals"
         | 
| 106 | 
            +
                },
         | 
| 107 | 
            +
                "b33d9b3950b6cbf5fe90a32608924700": {
         | 
| 108 | 
            +
                    "compensate": 1.03,
         | 
| 109 | 
            +
                    "mdx_dim_f_set": 3072,
         | 
| 110 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 111 | 
            +
                    "mdx_n_fft_scale_set": 7680,
         | 
| 112 | 
            +
                    "primary_stem": "Vocals"
         | 
| 113 | 
            +
                },
         | 
| 114 | 
            +
                "c3b29bdce8c4fa17ec609e16220330ab": {
         | 
| 115 | 
            +
                    "compensate": 1.035,
         | 
| 116 | 
            +
                    "mdx_dim_f_set": 2048,
         | 
| 117 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 118 | 
            +
                    "mdx_n_fft_scale_set": 16384,
         | 
| 119 | 
            +
                    "primary_stem": "Bass"
         | 
| 120 | 
            +
                },
         | 
| 121 | 
            +
                "ceed671467c1f64ebdfac8a2490d0d52": {
         | 
| 122 | 
            +
                    "compensate": 1.035,
         | 
| 123 | 
            +
                    "mdx_dim_f_set": 3072,
         | 
| 124 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 125 | 
            +
                    "mdx_n_fft_scale_set": 7680,
         | 
| 126 | 
            +
                    "primary_stem": "Instrumental"
         | 
| 127 | 
            +
                },
         | 
| 128 | 
            +
                "d2a1376f310e4f7fa37fb9b5774eb701": {
         | 
| 129 | 
            +
                    "compensate": 1.035,
         | 
| 130 | 
            +
                    "mdx_dim_f_set": 3072,
         | 
| 131 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 132 | 
            +
                    "mdx_n_fft_scale_set": 7680,
         | 
| 133 | 
            +
                    "primary_stem": "Instrumental"
         | 
| 134 | 
            +
                },
         | 
| 135 | 
            +
                "d7bff498db9324db933d913388cba6be": {
         | 
| 136 | 
            +
                    "compensate": 1.035,
         | 
| 137 | 
            +
                    "mdx_dim_f_set": 2048,
         | 
| 138 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 139 | 
            +
                    "mdx_n_fft_scale_set": 6144,
         | 
| 140 | 
            +
                    "primary_stem": "Vocals"
         | 
| 141 | 
            +
                },
         | 
| 142 | 
            +
                "d94058f8c7f1fae4164868ae8ae66b20": {
         | 
| 143 | 
            +
                    "compensate": 1.035,
         | 
| 144 | 
            +
                    "mdx_dim_f_set": 2048,
         | 
| 145 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 146 | 
            +
                    "mdx_n_fft_scale_set": 6144,
         | 
| 147 | 
            +
                    "primary_stem": "Vocals"
         | 
| 148 | 
            +
                },
         | 
| 149 | 
            +
                "dc41ede5961d50f277eb846db17f5319": {
         | 
| 150 | 
            +
                    "compensate": 1.035,
         | 
| 151 | 
            +
                    "mdx_dim_f_set": 2048,
         | 
| 152 | 
            +
                    "mdx_dim_t_set": 9,
         | 
| 153 | 
            +
                    "mdx_n_fft_scale_set": 4096,
         | 
| 154 | 
            +
                    "primary_stem": "Drums"
         | 
| 155 | 
            +
                },
         | 
| 156 | 
            +
                "e5572e58abf111f80d8241d2e44e7fa4": {
         | 
| 157 | 
            +
                    "compensate": 1.028,
         | 
| 158 | 
            +
                    "mdx_dim_f_set": 3072,
         | 
| 159 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 160 | 
            +
                    "mdx_n_fft_scale_set": 7680,
         | 
| 161 | 
            +
                    "primary_stem": "Instrumental"
         | 
| 162 | 
            +
                },
         | 
| 163 | 
            +
                "e7324c873b1f615c35c1967f912db92a": {
         | 
| 164 | 
            +
                    "compensate": 1.03,
         | 
| 165 | 
            +
                    "mdx_dim_f_set": 3072,
         | 
| 166 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 167 | 
            +
                    "mdx_n_fft_scale_set": 7680,
         | 
| 168 | 
            +
                    "primary_stem": "Vocals"
         | 
| 169 | 
            +
                },
         | 
| 170 | 
            +
                "1c56ec0224f1d559c42fd6fd2a67b154": {
         | 
| 171 | 
            +
                    "compensate": 1.025,
         | 
| 172 | 
            +
                    "mdx_dim_f_set": 2048,
         | 
| 173 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 174 | 
            +
                    "mdx_n_fft_scale_set": 5120,
         | 
| 175 | 
            +
                    "primary_stem": "Instrumental"
         | 
| 176 | 
            +
                },
         | 
| 177 | 
            +
                "f2df6d6863d8f435436d8b561594ff49": {
         | 
| 178 | 
            +
                    "compensate": 1.035,
         | 
| 179 | 
            +
                    "mdx_dim_f_set": 3072,
         | 
| 180 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 181 | 
            +
                    "mdx_n_fft_scale_set": 7680,
         | 
| 182 | 
            +
                    "primary_stem": "Instrumental"
         | 
| 183 | 
            +
                },
         | 
| 184 | 
            +
                "b06327a00d5e5fbc7d96e1781bbdb596": {
         | 
| 185 | 
            +
                    "compensate": 1.035,
         | 
| 186 | 
            +
                    "mdx_dim_f_set": 3072,
         | 
| 187 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 188 | 
            +
                    "mdx_n_fft_scale_set": 6144,
         | 
| 189 | 
            +
                    "primary_stem": "Instrumental"
         | 
| 190 | 
            +
                },
         | 
| 191 | 
            +
                "94ff780b977d3ca07c7a343dab2e25dd": {
         | 
| 192 | 
            +
                    "compensate": 1.039,
         | 
| 193 | 
            +
                    "mdx_dim_f_set": 3072,
         | 
| 194 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 195 | 
            +
                    "mdx_n_fft_scale_set": 6144,
         | 
| 196 | 
            +
                    "primary_stem": "Instrumental"
         | 
| 197 | 
            +
                },
         | 
| 198 | 
            +
                "73492b58195c3b52d34590d5474452f6": {
         | 
| 199 | 
            +
                    "compensate": 1.043,
         | 
| 200 | 
            +
                    "mdx_dim_f_set": 3072,
         | 
| 201 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 202 | 
            +
                    "mdx_n_fft_scale_set": 7680,
         | 
| 203 | 
            +
                    "primary_stem": "Vocals"
         | 
| 204 | 
            +
                },
         | 
| 205 | 
            +
                "970b3f9492014d18fefeedfe4773cb42": {
         | 
| 206 | 
            +
                    "compensate": 1.009,
         | 
| 207 | 
            +
                    "mdx_dim_f_set": 3072,
         | 
| 208 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 209 | 
            +
                    "mdx_n_fft_scale_set": 7680,
         | 
| 210 | 
            +
                    "primary_stem": "Vocals"
         | 
| 211 | 
            +
                },
         | 
| 212 | 
            +
                "1d64a6d2c30f709b8c9b4ce1366d96ee": {
         | 
| 213 | 
            +
                    "compensate": 1.035,
         | 
| 214 | 
            +
                    "mdx_dim_f_set": 2048,
         | 
| 215 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 216 | 
            +
                    "mdx_n_fft_scale_set": 5120,
         | 
| 217 | 
            +
                    "primary_stem": "Instrumental"
         | 
| 218 | 
            +
                },
         | 
| 219 | 
            +
                "203f2a3955221b64df85a41af87cf8f0": {
         | 
| 220 | 
            +
                    "compensate": 1.035,
         | 
| 221 | 
            +
                    "mdx_dim_f_set": 3072,
         | 
| 222 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 223 | 
            +
                    "mdx_n_fft_scale_set": 6144,
         | 
| 224 | 
            +
                    "primary_stem": "Instrumental"
         | 
| 225 | 
            +
                },
         | 
| 226 | 
            +
                "291c2049608edb52648b96e27eb80e95": {
         | 
| 227 | 
            +
                    "compensate": 1.035,
         | 
| 228 | 
            +
                    "mdx_dim_f_set": 3072,
         | 
| 229 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 230 | 
            +
                    "mdx_n_fft_scale_set": 6144,
         | 
| 231 | 
            +
                    "primary_stem": "Instrumental"
         | 
| 232 | 
            +
                },
         | 
| 233 | 
            +
                "ead8d05dab12ec571d67549b3aab03fc": {
         | 
| 234 | 
            +
                    "compensate": 1.035,
         | 
| 235 | 
            +
                    "mdx_dim_f_set": 3072,
         | 
| 236 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 237 | 
            +
                    "mdx_n_fft_scale_set": 6144,
         | 
| 238 | 
            +
                    "primary_stem": "Instrumental"
         | 
| 239 | 
            +
                },
         | 
| 240 | 
            +
                "cc63408db3d80b4d85b0287d1d7c9632": {
         | 
| 241 | 
            +
                    "compensate": 1.033,
         | 
| 242 | 
            +
                    "mdx_dim_f_set": 3072,
         | 
| 243 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 244 | 
            +
                    "mdx_n_fft_scale_set": 6144,
         | 
| 245 | 
            +
                    "primary_stem": "Instrumental"
         | 
| 246 | 
            +
                },
         | 
| 247 | 
            +
                "cd5b2989ad863f116c855db1dfe24e39": {
         | 
| 248 | 
            +
                    "compensate": 1.035,
         | 
| 249 | 
            +
                    "mdx_dim_f_set": 3072,
         | 
| 250 | 
            +
                    "mdx_dim_t_set": 9,
         | 
| 251 | 
            +
                    "mdx_n_fft_scale_set": 6144,
         | 
| 252 | 
            +
                    "primary_stem": "Other"
         | 
| 253 | 
            +
                },
         | 
| 254 | 
            +
                "55657dd70583b0fedfba5f67df11d711": {
         | 
| 255 | 
            +
                    "compensate": 1.022,
         | 
| 256 | 
            +
                    "mdx_dim_f_set": 3072,
         | 
| 257 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 258 | 
            +
                    "mdx_n_fft_scale_set": 6144,
         | 
| 259 | 
            +
                    "primary_stem": "Instrumental"
         | 
| 260 | 
            +
                },
         | 
| 261 | 
            +
                "b6bccda408a436db8500083ef3491e8b": {
         | 
| 262 | 
            +
                    "compensate": 1.02,
         | 
| 263 | 
            +
                    "mdx_dim_f_set": 3072,
         | 
| 264 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 265 | 
            +
                    "mdx_n_fft_scale_set": 7680,
         | 
| 266 | 
            +
                    "primary_stem": "Instrumental"
         | 
| 267 | 
            +
                },
         | 
| 268 | 
            +
                "8a88db95c7fb5dbe6a095ff2ffb428b1": {
         | 
| 269 | 
            +
                    "compensate": 1.026,
         | 
| 270 | 
            +
                    "mdx_dim_f_set": 2048,
         | 
| 271 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 272 | 
            +
                    "mdx_n_fft_scale_set": 5120,
         | 
| 273 | 
            +
                    "primary_stem": "Instrumental"
         | 
| 274 | 
            +
                },
         | 
| 275 | 
            +
                "b78da4afc6512f98e4756f5977f5c6b9": {
         | 
| 276 | 
            +
                    "compensate": 1.021,
         | 
| 277 | 
            +
                    "mdx_dim_f_set": 3072,
         | 
| 278 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 279 | 
            +
                    "mdx_n_fft_scale_set": 7680,
         | 
| 280 | 
            +
                    "primary_stem": "Instrumental"
         | 
| 281 | 
            +
                },
         | 
| 282 | 
            +
                "77d07b2667ddf05b9e3175941b4454a0": {
         | 
| 283 | 
            +
                    "compensate": 1.021,
         | 
| 284 | 
            +
                    "mdx_dim_f_set": 3072,
         | 
| 285 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 286 | 
            +
                    "mdx_n_fft_scale_set": 7680,
         | 
| 287 | 
            +
                    "primary_stem": "Vocals"
         | 
| 288 | 
            +
                },
         | 
| 289 | 
            +
                "0f2a6bc5b49d87d64728ee40e23bceb1": {
         | 
| 290 | 
            +
                    "compensate": 1.019,
         | 
| 291 | 
            +
                    "mdx_dim_f_set": 2560,
         | 
| 292 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 293 | 
            +
                    "mdx_n_fft_scale_set": 5120,
         | 
| 294 | 
            +
                    "primary_stem": "Instrumental"
         | 
| 295 | 
            +
                },
         | 
| 296 | 
            +
                "b02be2d198d4968a121030cf8950b492": {
         | 
| 297 | 
            +
                    "compensate": 1.020,
         | 
| 298 | 
            +
                    "mdx_dim_f_set": 2560,
         | 
| 299 | 
            +
                    "mdx_dim_t_set": 8,
         | 
| 300 | 
            +
                    "mdx_n_fft_scale_set": 5120,
         | 
| 301 | 
            +
                    "primary_stem": "No Crowd"
         | 
| 302 | 
            +
                },
         | 
| 303 | 
            +
                "2154254ee89b2945b97a7efed6e88820": {
         | 
| 304 | 
            +
                    "config_yaml": "model_2_stem_061321.yaml"
         | 
| 305 | 
            +
                },
         | 
| 306 | 
            +
                "063aadd735d58150722926dcbf5852a9": {
         | 
| 307 | 
            +
                    "config_yaml": "model_2_stem_061321.yaml"
         | 
| 308 | 
            +
                },
         | 
| 309 | 
            +
                "fe96801369f6a148df2720f5ced88c19": {
         | 
| 310 | 
            +
                    "config_yaml": "model3.yaml"
         | 
| 311 | 
            +
                },
         | 
| 312 | 
            +
                "02e8b226f85fb566e5db894b9931c640": {
         | 
| 313 | 
            +
                    "config_yaml": "model2.yaml"
         | 
| 314 | 
            +
                },
         | 
| 315 | 
            +
                "e3de6d861635ab9c1d766149edd680d6": {
         | 
| 316 | 
            +
                    "config_yaml": "model1.yaml"
         | 
| 317 | 
            +
                },
         | 
| 318 | 
            +
                "3f2936c554ab73ce2e396d54636bd373": {
         | 
| 319 | 
            +
                    "config_yaml": "modelB.yaml"
         | 
| 320 | 
            +
                },
         | 
| 321 | 
            +
                "890d0f6f82d7574bca741a9e8bcb8168": {
         | 
| 322 | 
            +
                    "config_yaml": "modelB.yaml"
         | 
| 323 | 
            +
                },
         | 
| 324 | 
            +
                "63a3cb8c37c474681049be4ad1ba8815": {
         | 
| 325 | 
            +
                    "config_yaml": "modelB.yaml"
         | 
| 326 | 
            +
                },
         | 
| 327 | 
            +
                "a7fc5d719743c7fd6b61bd2b4d48b9f0": {
         | 
| 328 | 
            +
                    "config_yaml": "modelA.yaml"
         | 
| 329 | 
            +
                },
         | 
| 330 | 
            +
                "3567f3dee6e77bf366fcb1c7b8bc3745": {
         | 
| 331 | 
            +
                    "config_yaml": "modelA.yaml"
         | 
| 332 | 
            +
                },
         | 
| 333 | 
            +
                "a28f4d717bd0d34cd2ff7a3b0a3d065e": {
         | 
| 334 | 
            +
                    "config_yaml": "modelA.yaml"
         | 
| 335 | 
            +
                },
         | 
| 336 | 
            +
                "c9971a18da20911822593dc81caa8be9": {
         | 
| 337 | 
            +
                    "config_yaml": "sndfx.yaml"
         | 
| 338 | 
            +
                },
         | 
| 339 | 
            +
                "57d94d5ed705460d21c75a5ac829a605": {
         | 
| 340 | 
            +
                    "config_yaml": "sndfx.yaml"
         | 
| 341 | 
            +
                },
         | 
| 342 | 
            +
                "e7a25f8764f25a52c1b96c4946e66ba2": {
         | 
| 343 | 
            +
                    "config_yaml": "sndfx.yaml"
         | 
| 344 | 
            +
                },
         | 
| 345 | 
            +
                "104081d24e37217086ce5fde09147ee1": {
         | 
| 346 | 
            +
                    "config_yaml": "model_2_stem_061321.yaml"
         | 
| 347 | 
            +
                },
         | 
| 348 | 
            +
                "1e6165b601539f38d0a9330f3facffeb": {
         | 
| 349 | 
            +
                    "config_yaml": "model_2_stem_061321.yaml"
         | 
| 350 | 
            +
                },
         | 
| 351 | 
            +
                "fe0108464ce0d8271be5ab810891bd7c": {
         | 
| 352 | 
            +
                    "config_yaml": "model_2_stem_full_band.yaml"
         | 
| 353 | 
            +
                }
         | 
| 354 | 
            +
            }
         | 
    	
        packages.txt
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            git-lfs
         | 
| 2 | 
            +
            aria2 -y
         | 
| 3 | 
            +
            ffmpeg
         | 
    	
        pre-requirements.txt
    ADDED
    
    | @@ -0,0 +1,15 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            --extra-index-url https://download.pytorch.org/whl/cu118
         | 
| 2 | 
            +
            torch>=2.1.0+cu118
         | 
| 3 | 
            +
            torchvision>=0.16.0+cu118
         | 
| 4 | 
            +
            torchaudio>=2.1.0+cu118
         | 
| 5 | 
            +
            yt-dlp
         | 
| 6 | 
            +
            gradio==4.19.2
         | 
| 7 | 
            +
            pydub==0.25.1
         | 
| 8 | 
            +
            edge_tts==6.1.7
         | 
| 9 | 
            +
            deep_translator==1.11.4
         | 
| 10 | 
            +
            git+https://github.com/R3gm/[email protected]
         | 
| 11 | 
            +
            git+https://github.com/R3gm/whisperX.git@cuda_11_8
         | 
| 12 | 
            +
            nest_asyncio
         | 
| 13 | 
            +
            gTTS
         | 
| 14 | 
            +
            gradio_client==0.10.1
         | 
| 15 | 
            +
            IPython
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,19 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            praat-parselmouth>=0.4.3
         | 
| 2 | 
            +
            pyworld==0.3.2
         | 
| 3 | 
            +
            faiss-cpu==1.7.3
         | 
| 4 | 
            +
            torchcrepe==0.0.20
         | 
| 5 | 
            +
            ffmpeg-python>=0.2.0
         | 
| 6 | 
            +
            fairseq==0.12.2 
         | 
| 7 | 
            +
            gdown
         | 
| 8 | 
            +
            rarfile
         | 
| 9 | 
            +
            transformers
         | 
| 10 | 
            +
            accelerate
         | 
| 11 | 
            +
            optimum
         | 
| 12 | 
            +
            sentencepiece
         | 
| 13 | 
            +
            srt
         | 
| 14 | 
            +
            git+https://github.com/R3gm/openvoice_package.git@lite
         | 
| 15 | 
            +
            openai==1.14.3
         | 
| 16 | 
            +
            tiktoken==0.6.0
         | 
| 17 | 
            +
            # Documents
         | 
| 18 | 
            +
            pypdf==4.2.0
         | 
| 19 | 
            +
            python-docx
         | 
    	
        requirements_xtts.txt
    ADDED
    
    | @@ -0,0 +1,58 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # core deps
         | 
| 2 | 
            +
            numpy==1.23.5
         | 
| 3 | 
            +
            cython>=0.29.30
         | 
| 4 | 
            +
            scipy>=1.11.2
         | 
| 5 | 
            +
            torch
         | 
| 6 | 
            +
            torchaudio
         | 
| 7 | 
            +
            soundfile
         | 
| 8 | 
            +
            librosa
         | 
| 9 | 
            +
            scikit-learn
         | 
| 10 | 
            +
            numba
         | 
| 11 | 
            +
            inflect>=5.6.0
         | 
| 12 | 
            +
            tqdm>=4.64.1
         | 
| 13 | 
            +
            anyascii>=0.3.0
         | 
| 14 | 
            +
            pyyaml>=6.0
         | 
| 15 | 
            +
            fsspec>=2023.6.0 # <= 2023.9.1 makes aux tests fail
         | 
| 16 | 
            +
            aiohttp>=3.8.1
         | 
| 17 | 
            +
            packaging>=23.1
         | 
| 18 | 
            +
            # deps for examples
         | 
| 19 | 
            +
            flask>=2.0.1
         | 
| 20 | 
            +
            # deps for inference
         | 
| 21 | 
            +
            pysbd>=0.3.4
         | 
| 22 | 
            +
            # deps for notebooks
         | 
| 23 | 
            +
            umap-learn>=0.5.1
         | 
| 24 | 
            +
            pandas
         | 
| 25 | 
            +
            # deps for training
         | 
| 26 | 
            +
            matplotlib
         | 
| 27 | 
            +
            # coqui stack
         | 
| 28 | 
            +
            trainer>=0.0.32
         | 
| 29 | 
            +
            # config management
         | 
| 30 | 
            +
            coqpit>=0.0.16
         | 
| 31 | 
            +
            # chinese g2p deps
         | 
| 32 | 
            +
            jieba
         | 
| 33 | 
            +
            pypinyin
         | 
| 34 | 
            +
            # korean
         | 
| 35 | 
            +
            hangul_romanize
         | 
| 36 | 
            +
            # gruut+supported langs
         | 
| 37 | 
            +
            gruut[de,es,fr]==2.2.3
         | 
| 38 | 
            +
            # deps for korean
         | 
| 39 | 
            +
            jamo
         | 
| 40 | 
            +
            nltk
         | 
| 41 | 
            +
            g2pkk>=0.1.1
         | 
| 42 | 
            +
            # deps for bangla
         | 
| 43 | 
            +
            bangla
         | 
| 44 | 
            +
            bnnumerizer
         | 
| 45 | 
            +
            bnunicodenormalizer
         | 
| 46 | 
            +
            #deps for tortoise
         | 
| 47 | 
            +
            einops>=0.6.0
         | 
| 48 | 
            +
            transformers
         | 
| 49 | 
            +
            #deps for bark
         | 
| 50 | 
            +
            encodec>=0.1.1
         | 
| 51 | 
            +
            # deps for XTTS
         | 
| 52 | 
            +
            unidecode>=1.3.2
         | 
| 53 | 
            +
            num2words
         | 
| 54 | 
            +
            spacy[ja]>=3
         | 
| 55 | 
            +
             | 
| 56 | 
            +
            # after this 
         | 
| 57 | 
            +
            # pip install -r requirements_xtts.txt
         | 
| 58 | 
            +
            # pip install TTS==0.21.1  --no-deps
         | 
    	
        soni_translate/audio_segments.py
    ADDED
    
    | @@ -0,0 +1,141 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from pydub import AudioSegment
         | 
| 2 | 
            +
            from tqdm import tqdm
         | 
| 3 | 
            +
            from .utils import run_command
         | 
| 4 | 
            +
            from .logging_setup import logger
         | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            class Mixer:
         | 
| 9 | 
            +
                def __init__(self):
         | 
| 10 | 
            +
                    self.parts = []
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                def __len__(self):
         | 
| 13 | 
            +
                    parts = self._sync()
         | 
| 14 | 
            +
                    seg = parts[0][1]
         | 
| 15 | 
            +
                    frame_count = max(offset + seg.frame_count() for offset, seg in parts)
         | 
| 16 | 
            +
                    return int(1000.0 * frame_count / seg.frame_rate)
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                def overlay(self, sound, position=0):
         | 
| 19 | 
            +
                    self.parts.append((position, sound))
         | 
| 20 | 
            +
                    return self
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                def _sync(self):
         | 
| 23 | 
            +
                    positions, segs = zip(*self.parts)
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                    frame_rate = segs[0].frame_rate
         | 
| 26 | 
            +
                    array_type = segs[0].array_type # noqa
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                    offsets = [int(frame_rate * pos / 1000.0) for pos in positions]
         | 
| 29 | 
            +
                    segs = AudioSegment.empty()._sync(*segs)
         | 
| 30 | 
            +
                    return list(zip(offsets, segs))
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                def append(self, sound):
         | 
| 33 | 
            +
                    self.overlay(sound, position=len(self))
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                def to_audio_segment(self):
         | 
| 36 | 
            +
                    parts = self._sync()
         | 
| 37 | 
            +
                    seg = parts[0][1]
         | 
| 38 | 
            +
                    channels = seg.channels
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    frame_count = max(offset + seg.frame_count() for offset, seg in parts)
         | 
| 41 | 
            +
                    sample_count = int(frame_count * seg.channels)
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                    output = np.zeros(sample_count, dtype="int32")
         | 
| 44 | 
            +
                    for offset, seg in parts:
         | 
| 45 | 
            +
                        sample_offset = offset * channels
         | 
| 46 | 
            +
                        samples = np.frombuffer(seg.get_array_of_samples(), dtype="int32")
         | 
| 47 | 
            +
                        samples = np.int16(samples/np.max(np.abs(samples)) * 32767)
         | 
| 48 | 
            +
                        start = sample_offset
         | 
| 49 | 
            +
                        end = start + len(samples)
         | 
| 50 | 
            +
                        output[start:end] += samples
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    return seg._spawn(
         | 
| 53 | 
            +
                        output, overrides={"sample_width": 4}).normalize(headroom=0.0)
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            def create_translated_audio(
         | 
| 57 | 
            +
                result_diarize, audio_files, final_file, concat=False, avoid_overlap=False,
         | 
| 58 | 
            +
            ):
         | 
| 59 | 
            +
                total_duration = result_diarize["segments"][-1]["end"]  # in seconds
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                if concat:
         | 
| 62 | 
            +
                    """
         | 
| 63 | 
            +
                    file .\audio\1.ogg
         | 
| 64 | 
            +
                    file .\audio\2.ogg
         | 
| 65 | 
            +
                    file .\audio\3.ogg
         | 
| 66 | 
            +
                    file .\audio\4.ogg
         | 
| 67 | 
            +
                    ...
         | 
| 68 | 
            +
                    """
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    # Write the file paths to list.txt
         | 
| 71 | 
            +
                    with open("list.txt", "w") as file:
         | 
| 72 | 
            +
                        for i, audio_file in enumerate(audio_files):
         | 
| 73 | 
            +
                            if i == len(audio_files) - 1:  # Check if it's the last item
         | 
| 74 | 
            +
                                file.write(f"file {audio_file}")
         | 
| 75 | 
            +
                            else:
         | 
| 76 | 
            +
                                file.write(f"file {audio_file}\n")
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    # command = f"ffmpeg -f concat -safe 0 -i list.txt {final_file}"
         | 
| 79 | 
            +
                    command = (
         | 
| 80 | 
            +
                        f"ffmpeg -f concat -safe 0 -i list.txt -c:a pcm_s16le {final_file}"
         | 
| 81 | 
            +
                    )
         | 
| 82 | 
            +
                    run_command(command)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                else:
         | 
| 85 | 
            +
                    # silent audio with total_duration
         | 
| 86 | 
            +
                    base_audio = AudioSegment.silent(
         | 
| 87 | 
            +
                        duration=int(total_duration * 1000), frame_rate=41000
         | 
| 88 | 
            +
                    )
         | 
| 89 | 
            +
                    combined_audio = Mixer()
         | 
| 90 | 
            +
                    combined_audio.overlay(base_audio)
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    logger.debug(
         | 
| 93 | 
            +
                        f"Audio duration: {total_duration // 60} "
         | 
| 94 | 
            +
                        f"minutes and {int(total_duration % 60)} seconds"
         | 
| 95 | 
            +
                    )
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    last_end_time = 0
         | 
| 98 | 
            +
                    previous_speaker = ""
         | 
| 99 | 
            +
                    for line, audio_file in tqdm(
         | 
| 100 | 
            +
                        zip(result_diarize["segments"], audio_files)
         | 
| 101 | 
            +
                    ):
         | 
| 102 | 
            +
                        start = float(line["start"])
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                        # Overlay each audio at the corresponding time
         | 
| 105 | 
            +
                        try:
         | 
| 106 | 
            +
                            audio = AudioSegment.from_file(audio_file)
         | 
| 107 | 
            +
                            # audio_a = audio.speedup(playback_speed=1.5)
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                            if avoid_overlap:
         | 
| 110 | 
            +
                                speaker = line["speaker"]
         | 
| 111 | 
            +
                                if (last_end_time - 0.500) > start:
         | 
| 112 | 
            +
                                    overlap_time = last_end_time - start
         | 
| 113 | 
            +
                                    if previous_speaker and previous_speaker != speaker:
         | 
| 114 | 
            +
                                        start = (last_end_time - 0.500)
         | 
| 115 | 
            +
                                    else:
         | 
| 116 | 
            +
                                        start = (last_end_time - 0.200)
         | 
| 117 | 
            +
                                    if overlap_time > 2.5:
         | 
| 118 | 
            +
                                        start = start - 0.3
         | 
| 119 | 
            +
                                    logger.info(
         | 
| 120 | 
            +
                                          f"Avoid overlap for {str(audio_file)} "
         | 
| 121 | 
            +
                                          f"with {str(start)}"
         | 
| 122 | 
            +
                                    )
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                                previous_speaker = speaker
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                                duration_tts_seconds = len(audio) / 1000.0  # to sec
         | 
| 127 | 
            +
                                last_end_time = (start + duration_tts_seconds)
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                            start_time = start * 1000  # to ms
         | 
| 130 | 
            +
                            combined_audio = combined_audio.overlay(
         | 
| 131 | 
            +
                                audio, position=start_time
         | 
| 132 | 
            +
                            )
         | 
| 133 | 
            +
                        except Exception as error:
         | 
| 134 | 
            +
                            logger.debug(str(error))
         | 
| 135 | 
            +
                            logger.error(f"Error audio file {audio_file}")
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    # combined audio as a file
         | 
| 138 | 
            +
                    combined_audio_data = combined_audio.to_audio_segment()
         | 
| 139 | 
            +
                    combined_audio_data.export(
         | 
| 140 | 
            +
                        final_file, format="wav"
         | 
| 141 | 
            +
                    )  # best than ogg, change if the audio is anomalous
         | 
    	
        soni_translate/language_configuration.py
    ADDED
    
    | @@ -0,0 +1,551 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .logging_setup import logger
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            LANGUAGES_UNIDIRECTIONAL = {
         | 
| 4 | 
            +
                "Aymara (ay)": "ay",
         | 
| 5 | 
            +
                "Bambara (bm)": "bm",
         | 
| 6 | 
            +
                "Cebuano (ceb)": "ceb",
         | 
| 7 | 
            +
                "Chichewa (ny)": "ny",
         | 
| 8 | 
            +
                "Divehi (dv)": "dv",
         | 
| 9 | 
            +
                "Dogri (doi)": "doi",
         | 
| 10 | 
            +
                "Ewe (ee)": "ee",
         | 
| 11 | 
            +
                "Guarani (gn)": "gn",
         | 
| 12 | 
            +
                "Iloko (ilo)": "ilo",
         | 
| 13 | 
            +
                "Kinyarwanda (rw)": "rw",
         | 
| 14 | 
            +
                "Krio (kri)": "kri",
         | 
| 15 | 
            +
                "Kurdish (ku)": "ku",
         | 
| 16 | 
            +
                "Kirghiz (ky)": "ky",
         | 
| 17 | 
            +
                "Ganda (lg)": "lg",
         | 
| 18 | 
            +
                "Maithili (mai)": "mai",
         | 
| 19 | 
            +
                "Oriya (or)": "or",
         | 
| 20 | 
            +
                "Oromo (om)": "om",
         | 
| 21 | 
            +
                "Quechua (qu)": "qu",
         | 
| 22 | 
            +
                "Samoan (sm)": "sm",
         | 
| 23 | 
            +
                "Tigrinya (ti)": "ti",
         | 
| 24 | 
            +
                "Tsonga (ts)": "ts",
         | 
| 25 | 
            +
                "Akan (ak)": "ak",
         | 
| 26 | 
            +
                "Uighur (ug)": "ug"
         | 
| 27 | 
            +
            }
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            UNIDIRECTIONAL_L_LIST = LANGUAGES_UNIDIRECTIONAL.keys()
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            LANGUAGES = {
         | 
| 32 | 
            +
                "Automatic detection": "Automatic detection",
         | 
| 33 | 
            +
                "Arabic (ar)": "ar",
         | 
| 34 | 
            +
                "Chinese - Simplified (zh-CN)": "zh",
         | 
| 35 | 
            +
                "Czech (cs)": "cs",
         | 
| 36 | 
            +
                "Danish (da)": "da",
         | 
| 37 | 
            +
                "Dutch (nl)": "nl",
         | 
| 38 | 
            +
                "English (en)": "en",
         | 
| 39 | 
            +
                "Finnish (fi)": "fi",
         | 
| 40 | 
            +
                "French (fr)": "fr",
         | 
| 41 | 
            +
                "German (de)": "de",
         | 
| 42 | 
            +
                "Greek (el)": "el",
         | 
| 43 | 
            +
                "Hebrew (he)": "he",
         | 
| 44 | 
            +
                "Hungarian (hu)": "hu",
         | 
| 45 | 
            +
                "Italian (it)": "it",
         | 
| 46 | 
            +
                "Japanese (ja)": "ja",
         | 
| 47 | 
            +
                "Korean (ko)": "ko",
         | 
| 48 | 
            +
                "Persian (fa)": "fa",  # no aux gTTS
         | 
| 49 | 
            +
                "Polish (pl)": "pl",
         | 
| 50 | 
            +
                "Portuguese (pt)": "pt",
         | 
| 51 | 
            +
                "Russian (ru)": "ru",
         | 
| 52 | 
            +
                "Spanish (es)": "es",
         | 
| 53 | 
            +
                "Turkish (tr)": "tr",
         | 
| 54 | 
            +
                "Ukrainian (uk)": "uk",
         | 
| 55 | 
            +
                "Urdu (ur)": "ur",
         | 
| 56 | 
            +
                "Vietnamese (vi)": "vi",
         | 
| 57 | 
            +
                "Hindi (hi)": "hi",
         | 
| 58 | 
            +
                "Indonesian (id)": "id",
         | 
| 59 | 
            +
                "Bengali (bn)": "bn",
         | 
| 60 | 
            +
                "Telugu (te)": "te",
         | 
| 61 | 
            +
                "Marathi (mr)": "mr",
         | 
| 62 | 
            +
                "Tamil (ta)": "ta",
         | 
| 63 | 
            +
                "Javanese (jw|jv)": "jw",
         | 
| 64 | 
            +
                "Catalan (ca)": "ca",
         | 
| 65 | 
            +
                "Nepali (ne)": "ne",
         | 
| 66 | 
            +
                "Thai (th)": "th",
         | 
| 67 | 
            +
                "Swedish (sv)": "sv",
         | 
| 68 | 
            +
                "Amharic (am)": "am",
         | 
| 69 | 
            +
                "Welsh (cy)": "cy",  # no aux gTTS
         | 
| 70 | 
            +
                "Estonian (et)": "et",
         | 
| 71 | 
            +
                "Croatian (hr)": "hr",
         | 
| 72 | 
            +
                "Icelandic (is)": "is",
         | 
| 73 | 
            +
                "Georgian (ka)": "ka",  # no aux gTTS
         | 
| 74 | 
            +
                "Khmer (km)": "km",
         | 
| 75 | 
            +
                "Slovak (sk)": "sk",
         | 
| 76 | 
            +
                "Albanian (sq)": "sq",
         | 
| 77 | 
            +
                "Serbian (sr)": "sr",
         | 
| 78 | 
            +
                "Azerbaijani (az)": "az",  # no aux gTTS
         | 
| 79 | 
            +
                "Bulgarian (bg)": "bg",
         | 
| 80 | 
            +
                "Galician (gl)": "gl",  # no aux gTTS
         | 
| 81 | 
            +
                "Gujarati (gu)": "gu",
         | 
| 82 | 
            +
                "Kazakh (kk)": "kk",  # no aux gTTS
         | 
| 83 | 
            +
                "Kannada (kn)": "kn",
         | 
| 84 | 
            +
                "Lithuanian (lt)": "lt",  # no aux gTTS
         | 
| 85 | 
            +
                "Latvian (lv)": "lv",
         | 
| 86 | 
            +
                "Macedonian (mk)": "mk",  # no aux gTTS # error get align model
         | 
| 87 | 
            +
                "Malayalam (ml)": "ml",
         | 
| 88 | 
            +
                "Malay (ms)": "ms",  # error get align model
         | 
| 89 | 
            +
                "Romanian (ro)": "ro",
         | 
| 90 | 
            +
                "Sinhala (si)": "si",
         | 
| 91 | 
            +
                "Sundanese (su)": "su",
         | 
| 92 | 
            +
                "Swahili (sw)": "sw",  # error aling
         | 
| 93 | 
            +
                "Afrikaans (af)": "af",
         | 
| 94 | 
            +
                "Bosnian (bs)": "bs",
         | 
| 95 | 
            +
                "Latin (la)": "la",
         | 
| 96 | 
            +
                "Myanmar Burmese (my)": "my",
         | 
| 97 | 
            +
                "Norwegian (no|nb)": "no",
         | 
| 98 | 
            +
                "Chinese - Traditional (zh-TW)": "zh-TW",
         | 
| 99 | 
            +
                "Assamese (as)": "as",
         | 
| 100 | 
            +
                "Basque (eu)": "eu",
         | 
| 101 | 
            +
                "Hausa (ha)": "ha",
         | 
| 102 | 
            +
                "Haitian Creole (ht)": "ht",
         | 
| 103 | 
            +
                "Armenian (hy)": "hy",
         | 
| 104 | 
            +
                "Lao (lo)": "lo",
         | 
| 105 | 
            +
                "Malagasy (mg)": "mg",
         | 
| 106 | 
            +
                "Mongolian (mn)": "mn",
         | 
| 107 | 
            +
                "Maltese (mt)": "mt",
         | 
| 108 | 
            +
                "Punjabi (pa)": "pa",
         | 
| 109 | 
            +
                "Pashto (ps)": "ps",
         | 
| 110 | 
            +
                "Slovenian (sl)": "sl",
         | 
| 111 | 
            +
                "Shona (sn)": "sn",
         | 
| 112 | 
            +
                "Somali (so)": "so",
         | 
| 113 | 
            +
                "Tajik (tg)": "tg",
         | 
| 114 | 
            +
                "Turkmen (tk)": "tk",
         | 
| 115 | 
            +
                "Tatar (tt)": "tt",
         | 
| 116 | 
            +
                "Uzbek (uz)": "uz",
         | 
| 117 | 
            +
                "Yoruba (yo)": "yo",
         | 
| 118 | 
            +
                **LANGUAGES_UNIDIRECTIONAL
         | 
| 119 | 
            +
            }
         | 
| 120 | 
            +
             | 
| 121 | 
            +
            BASE_L_LIST = LANGUAGES.keys()
         | 
| 122 | 
            +
            LANGUAGES_LIST = [list(BASE_L_LIST)[0]] + sorted(list(BASE_L_LIST)[1:])
         | 
| 123 | 
            +
            INVERTED_LANGUAGES = {value: key for key, value in LANGUAGES.items()}
         | 
| 124 | 
            +
             | 
| 125 | 
            +
            EXTRA_ALIGN = {
         | 
| 126 | 
            +
                "id": "indonesian-nlp/wav2vec2-large-xlsr-indonesian",
         | 
| 127 | 
            +
                "bn": "arijitx/wav2vec2-large-xlsr-bengali",
         | 
| 128 | 
            +
                "mr": "sumedh/wav2vec2-large-xlsr-marathi",
         | 
| 129 | 
            +
                "ta": "Amrrs/wav2vec2-large-xlsr-53-tamil",
         | 
| 130 | 
            +
                "jw": "cahya/wav2vec2-large-xlsr-javanese",
         | 
| 131 | 
            +
                "ne": "shniranjan/wav2vec2-large-xlsr-300m-nepali",
         | 
| 132 | 
            +
                "th": "sakares/wav2vec2-large-xlsr-thai-demo",
         | 
| 133 | 
            +
                "sv": "KBLab/wav2vec2-large-voxrex-swedish",
         | 
| 134 | 
            +
                "am": "agkphysics/wav2vec2-large-xlsr-53-amharic",
         | 
| 135 | 
            +
                "cy": "Srulikbdd/Wav2Vec2-large-xlsr-welsh",
         | 
| 136 | 
            +
                "et": "anton-l/wav2vec2-large-xlsr-53-estonian",
         | 
| 137 | 
            +
                "hr": "classla/wav2vec2-xls-r-parlaspeech-hr",
         | 
| 138 | 
            +
                "is": "carlosdanielhernandezmena/wav2vec2-large-xlsr-53-icelandic-ep10-1000h",
         | 
| 139 | 
            +
                "ka": "MehdiHosseiniMoghadam/wav2vec2-large-xlsr-53-Georgian",
         | 
| 140 | 
            +
                "km": "vitouphy/wav2vec2-xls-r-300m-khmer",
         | 
| 141 | 
            +
                "sk": "infinitejoy/wav2vec2-large-xls-r-300m-slovak",
         | 
| 142 | 
            +
                "sq": "Alimzhan/wav2vec2-large-xls-r-300m-albanian-colab",
         | 
| 143 | 
            +
                "sr": "dnikolic/wav2vec2-xlsr-530-serbian-colab",
         | 
| 144 | 
            +
                "az": "nijatzeynalov/wav2vec2-large-mms-1b-azerbaijani-common_voice15.0",
         | 
| 145 | 
            +
                "bg": "infinitejoy/wav2vec2-large-xls-r-300m-bulgarian",
         | 
| 146 | 
            +
                "gl": "ifrz/wav2vec2-large-xlsr-galician",
         | 
| 147 | 
            +
                "gu": "Harveenchadha/vakyansh-wav2vec2-gujarati-gnm-100",
         | 
| 148 | 
            +
                "kk": "aismlv/wav2vec2-large-xlsr-kazakh",
         | 
| 149 | 
            +
                "kn": "Harveenchadha/vakyansh-wav2vec2-kannada-knm-560",
         | 
| 150 | 
            +
                "lt": "DeividasM/wav2vec2-large-xlsr-53-lithuanian",
         | 
| 151 | 
            +
                "lv": "anton-l/wav2vec2-large-xlsr-53-latvian",
         | 
| 152 | 
            +
                "mk": "",  # Konstantin-Bogdanoski/wav2vec2-macedonian-base
         | 
| 153 | 
            +
                "ml": "gvs/wav2vec2-large-xlsr-malayalam",
         | 
| 154 | 
            +
                "ms": "",  # Duy/wav2vec2_malay
         | 
| 155 | 
            +
                "ro": "anton-l/wav2vec2-large-xlsr-53-romanian",
         | 
| 156 | 
            +
                "si": "IAmNotAnanth/wav2vec2-large-xls-r-300m-sinhala",
         | 
| 157 | 
            +
                "su": "cahya/wav2vec2-large-xlsr-sundanese",
         | 
| 158 | 
            +
                "sw": "",  # Lians/fine-tune-wav2vec2-large-swahili
         | 
| 159 | 
            +
                "af": "",  # ylacombe/wav2vec2-common_voice-af-demo
         | 
| 160 | 
            +
                "bs": "",
         | 
| 161 | 
            +
                "la": "",
         | 
| 162 | 
            +
                "my": "",
         | 
| 163 | 
            +
                "no": "NbAiLab/wav2vec2-xlsr-300m-norwegian",
         | 
| 164 | 
            +
                "zh-TW": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn",
         | 
| 165 | 
            +
                "as": "",
         | 
| 166 | 
            +
                "eu": "", # cahya/wav2vec2-large-xlsr-basque # verify
         | 
| 167 | 
            +
                "ha": "infinitejoy/wav2vec2-large-xls-r-300m-hausa",
         | 
| 168 | 
            +
                "ht": "",
         | 
| 169 | 
            +
                "hy": "infinitejoy/wav2vec2-large-xls-r-300m-armenian", # no (.)
         | 
| 170 | 
            +
                "lo": "",
         | 
| 171 | 
            +
                "mg": "",
         | 
| 172 | 
            +
                "mn": "tugstugi/wav2vec2-large-xlsr-53-mongolian",
         | 
| 173 | 
            +
                "mt": "carlosdanielhernandezmena/wav2vec2-large-xlsr-53-maltese-64h",
         | 
| 174 | 
            +
                "pa": "kingabzpro/wav2vec2-large-xlsr-53-punjabi",
         | 
| 175 | 
            +
                "ps": "aamirhs/wav2vec2-large-xls-r-300m-pashto-colab",
         | 
| 176 | 
            +
                "sl": "anton-l/wav2vec2-large-xlsr-53-slovenian",
         | 
| 177 | 
            +
                "sn": "",
         | 
| 178 | 
            +
                "so": "",
         | 
| 179 | 
            +
                "tg": "",
         | 
| 180 | 
            +
                "tk": "",  # Ragav/wav2vec2-tk
         | 
| 181 | 
            +
                "tt": "anton-l/wav2vec2-large-xlsr-53-tatar",
         | 
| 182 | 
            +
                "uz": "",  # Mekhriddin/wav2vec2-large-xls-r-300m-uzbek-colab
         | 
| 183 | 
            +
                "yo": "ogbi/wav2vec2-large-mms-1b-yoruba-test",
         | 
| 184 | 
            +
            }
         | 
| 185 | 
            +
             | 
| 186 | 
            +
             | 
| 187 | 
            +
            def fix_code_language(translate_to, syntax="google"):
         | 
| 188 | 
            +
                if syntax == "google":
         | 
| 189 | 
            +
                    # google-translator, gTTS
         | 
| 190 | 
            +
                    replace_lang_code = {"zh": "zh-CN", "he": "iw", "zh-cn": "zh-CN"}
         | 
| 191 | 
            +
                elif syntax == "coqui":
         | 
| 192 | 
            +
                    # coqui-xtts
         | 
| 193 | 
            +
                    replace_lang_code = {"zh": "zh-cn", "zh-CN": "zh-cn", "zh-TW": "zh-cn"}
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                new_code_lang = replace_lang_code.get(translate_to, translate_to)
         | 
| 196 | 
            +
                logger.debug(f"Fix code {translate_to} -> {new_code_lang}")
         | 
| 197 | 
            +
                return new_code_lang
         | 
| 198 | 
            +
             | 
| 199 | 
            +
             | 
| 200 | 
            +
            BARK_VOICES_LIST = {
         | 
| 201 | 
            +
                "de_speaker_0-Male BARK": "v2/de_speaker_0",
         | 
| 202 | 
            +
                "de_speaker_1-Male BARK": "v2/de_speaker_1",
         | 
| 203 | 
            +
                "de_speaker_2-Male BARK": "v2/de_speaker_2",
         | 
| 204 | 
            +
                "de_speaker_3-Female BARK": "v2/de_speaker_3",
         | 
| 205 | 
            +
                "de_speaker_4-Male BARK": "v2/de_speaker_4",
         | 
| 206 | 
            +
                "de_speaker_5-Male BARK": "v2/de_speaker_5",
         | 
| 207 | 
            +
                "de_speaker_6-Male BARK": "v2/de_speaker_6",
         | 
| 208 | 
            +
                "de_speaker_7-Male BARK": "v2/de_speaker_7",
         | 
| 209 | 
            +
                "de_speaker_8-Female BARK": "v2/de_speaker_8",
         | 
| 210 | 
            +
                "de_speaker_9-Male BARK": "v2/de_speaker_9",
         | 
| 211 | 
            +
                "en_speaker_0-Male BARK": "v2/en_speaker_0",
         | 
| 212 | 
            +
                "en_speaker_1-Male BARK": "v2/en_speaker_1",
         | 
| 213 | 
            +
                "en_speaker_2-Male BARK": "v2/en_speaker_2",
         | 
| 214 | 
            +
                "en_speaker_3-Male BARK": "v2/en_speaker_3",
         | 
| 215 | 
            +
                "en_speaker_4-Male BARK": "v2/en_speaker_4",
         | 
| 216 | 
            +
                "en_speaker_5-Male BARK": "v2/en_speaker_5",
         | 
| 217 | 
            +
                "en_speaker_6-Male BARK": "v2/en_speaker_6",
         | 
| 218 | 
            +
                "en_speaker_7-Male BARK": "v2/en_speaker_7",
         | 
| 219 | 
            +
                "en_speaker_8-Male BARK": "v2/en_speaker_8",
         | 
| 220 | 
            +
                "en_speaker_9-Female BARK": "v2/en_speaker_9",
         | 
| 221 | 
            +
                "es_speaker_0-Male BARK": "v2/es_speaker_0",
         | 
| 222 | 
            +
                "es_speaker_1-Male BARK": "v2/es_speaker_1",
         | 
| 223 | 
            +
                "es_speaker_2-Male BARK": "v2/es_speaker_2",
         | 
| 224 | 
            +
                "es_speaker_3-Male BARK": "v2/es_speaker_3",
         | 
| 225 | 
            +
                "es_speaker_4-Male BARK": "v2/es_speaker_4",
         | 
| 226 | 
            +
                "es_speaker_5-Male BARK": "v2/es_speaker_5",
         | 
| 227 | 
            +
                "es_speaker_6-Male BARK": "v2/es_speaker_6",
         | 
| 228 | 
            +
                "es_speaker_7-Male BARK": "v2/es_speaker_7",
         | 
| 229 | 
            +
                "es_speaker_8-Female BARK": "v2/es_speaker_8",
         | 
| 230 | 
            +
                "es_speaker_9-Female BARK": "v2/es_speaker_9",
         | 
| 231 | 
            +
                "fr_speaker_0-Male BARK": "v2/fr_speaker_0",
         | 
| 232 | 
            +
                "fr_speaker_1-Female BARK": "v2/fr_speaker_1",
         | 
| 233 | 
            +
                "fr_speaker_2-Female BARK": "v2/fr_speaker_2",
         | 
| 234 | 
            +
                "fr_speaker_3-Male BARK": "v2/fr_speaker_3",
         | 
| 235 | 
            +
                "fr_speaker_4-Male BARK": "v2/fr_speaker_4",
         | 
| 236 | 
            +
                "fr_speaker_5-Female BARK": "v2/fr_speaker_5",
         | 
| 237 | 
            +
                "fr_speaker_6-Male BARK": "v2/fr_speaker_6",
         | 
| 238 | 
            +
                "fr_speaker_7-Male BARK": "v2/fr_speaker_7",
         | 
| 239 | 
            +
                "fr_speaker_8-Male BARK": "v2/fr_speaker_8",
         | 
| 240 | 
            +
                "fr_speaker_9-Male BARK": "v2/fr_speaker_9",
         | 
| 241 | 
            +
                "hi_speaker_0-Female BARK": "v2/hi_speaker_0",
         | 
| 242 | 
            +
                "hi_speaker_1-Female BARK": "v2/hi_speaker_1",
         | 
| 243 | 
            +
                "hi_speaker_2-Male BARK": "v2/hi_speaker_2",
         | 
| 244 | 
            +
                "hi_speaker_3-Female BARK": "v2/hi_speaker_3",
         | 
| 245 | 
            +
                "hi_speaker_4-Female BARK": "v2/hi_speaker_4",
         | 
| 246 | 
            +
                "hi_speaker_5-Male BARK": "v2/hi_speaker_5",
         | 
| 247 | 
            +
                "hi_speaker_6-Male BARK": "v2/hi_speaker_6",
         | 
| 248 | 
            +
                "hi_speaker_7-Male BARK": "v2/hi_speaker_7",
         | 
| 249 | 
            +
                "hi_speaker_8-Male BARK": "v2/hi_speaker_8",
         | 
| 250 | 
            +
                "hi_speaker_9-Female BARK": "v2/hi_speaker_9",
         | 
| 251 | 
            +
                "it_speaker_0-Male BARK": "v2/it_speaker_0",
         | 
| 252 | 
            +
                "it_speaker_1-Male BARK": "v2/it_speaker_1",
         | 
| 253 | 
            +
                "it_speaker_2-Female BARK": "v2/it_speaker_2",
         | 
| 254 | 
            +
                "it_speaker_3-Male BARK": "v2/it_speaker_3",
         | 
| 255 | 
            +
                "it_speaker_4-Male BARK": "v2/it_speaker_4",
         | 
| 256 | 
            +
                "it_speaker_5-Male BARK": "v2/it_speaker_5",
         | 
| 257 | 
            +
                "it_speaker_6-Male BARK": "v2/it_speaker_6",
         | 
| 258 | 
            +
                "it_speaker_7-Female BARK": "v2/it_speaker_7",
         | 
| 259 | 
            +
                "it_speaker_8-Male BARK": "v2/it_speaker_8",
         | 
| 260 | 
            +
                "it_speaker_9-Female BARK": "v2/it_speaker_9",
         | 
| 261 | 
            +
                "ja_speaker_0-Female BARK": "v2/ja_speaker_0",
         | 
| 262 | 
            +
                "ja_speaker_1-Female BARK": "v2/ja_speaker_1",
         | 
| 263 | 
            +
                "ja_speaker_2-Male BARK": "v2/ja_speaker_2",
         | 
| 264 | 
            +
                "ja_speaker_3-Female BARK": "v2/ja_speaker_3",
         | 
| 265 | 
            +
                "ja_speaker_4-Female BARK": "v2/ja_speaker_4",
         | 
| 266 | 
            +
                "ja_speaker_5-Female BARK": "v2/ja_speaker_5",
         | 
| 267 | 
            +
                "ja_speaker_6-Male BARK": "v2/ja_speaker_6",
         | 
| 268 | 
            +
                "ja_speaker_7-Female BARK": "v2/ja_speaker_7",
         | 
| 269 | 
            +
                "ja_speaker_8-Female BARK": "v2/ja_speaker_8",
         | 
| 270 | 
            +
                "ja_speaker_9-Female BARK": "v2/ja_speaker_9",
         | 
| 271 | 
            +
                "ko_speaker_0-Female BARK": "v2/ko_speaker_0",
         | 
| 272 | 
            +
                "ko_speaker_1-Male BARK": "v2/ko_speaker_1",
         | 
| 273 | 
            +
                "ko_speaker_2-Male BARK": "v2/ko_speaker_2",
         | 
| 274 | 
            +
                "ko_speaker_3-Male BARK": "v2/ko_speaker_3",
         | 
| 275 | 
            +
                "ko_speaker_4-Male BARK": "v2/ko_speaker_4",
         | 
| 276 | 
            +
                "ko_speaker_5-Male BARK": "v2/ko_speaker_5",
         | 
| 277 | 
            +
                "ko_speaker_6-Male BARK": "v2/ko_speaker_6",
         | 
| 278 | 
            +
                "ko_speaker_7-Male BARK": "v2/ko_speaker_7",
         | 
| 279 | 
            +
                "ko_speaker_8-Male BARK": "v2/ko_speaker_8",
         | 
| 280 | 
            +
                "ko_speaker_9-Male BARK": "v2/ko_speaker_9",
         | 
| 281 | 
            +
                "pl_speaker_0-Male BARK": "v2/pl_speaker_0",
         | 
| 282 | 
            +
                "pl_speaker_1-Male BARK": "v2/pl_speaker_1",
         | 
| 283 | 
            +
                "pl_speaker_2-Male BARK": "v2/pl_speaker_2",
         | 
| 284 | 
            +
                "pl_speaker_3-Male BARK": "v2/pl_speaker_3",
         | 
| 285 | 
            +
                "pl_speaker_4-Female BARK": "v2/pl_speaker_4",
         | 
| 286 | 
            +
                "pl_speaker_5-Male BARK": "v2/pl_speaker_5",
         | 
| 287 | 
            +
                "pl_speaker_6-Female BARK": "v2/pl_speaker_6",
         | 
| 288 | 
            +
                "pl_speaker_7-Male BARK": "v2/pl_speaker_7",
         | 
| 289 | 
            +
                "pl_speaker_8-Male BARK": "v2/pl_speaker_8",
         | 
| 290 | 
            +
                "pl_speaker_9-Female BARK": "v2/pl_speaker_9",
         | 
| 291 | 
            +
                "pt_speaker_0-Male BARK": "v2/pt_speaker_0",
         | 
| 292 | 
            +
                "pt_speaker_1-Male BARK": "v2/pt_speaker_1",
         | 
| 293 | 
            +
                "pt_speaker_2-Male BARK": "v2/pt_speaker_2",
         | 
| 294 | 
            +
                "pt_speaker_3-Male BARK": "v2/pt_speaker_3",
         | 
| 295 | 
            +
                "pt_speaker_4-Male BARK": "v2/pt_speaker_4",
         | 
| 296 | 
            +
                "pt_speaker_5-Male BARK": "v2/pt_speaker_5",
         | 
| 297 | 
            +
                "pt_speaker_6-Male BARK": "v2/pt_speaker_6",
         | 
| 298 | 
            +
                "pt_speaker_7-Male BARK": "v2/pt_speaker_7",
         | 
| 299 | 
            +
                "pt_speaker_8-Male BARK": "v2/pt_speaker_8",
         | 
| 300 | 
            +
                "pt_speaker_9-Male BARK": "v2/pt_speaker_9",
         | 
| 301 | 
            +
                "ru_speaker_0-Male BARK": "v2/ru_speaker_0",
         | 
| 302 | 
            +
                "ru_speaker_1-Male BARK": "v2/ru_speaker_1",
         | 
| 303 | 
            +
                "ru_speaker_2-Male BARK": "v2/ru_speaker_2",
         | 
| 304 | 
            +
                "ru_speaker_3-Male BARK": "v2/ru_speaker_3",
         | 
| 305 | 
            +
                "ru_speaker_4-Male BARK": "v2/ru_speaker_4",
         | 
| 306 | 
            +
                "ru_speaker_5-Female BARK": "v2/ru_speaker_5",
         | 
| 307 | 
            +
                "ru_speaker_6-Female BARK": "v2/ru_speaker_6",
         | 
| 308 | 
            +
                "ru_speaker_7-Male BARK": "v2/ru_speaker_7",
         | 
| 309 | 
            +
                "ru_speaker_8-Male BARK": "v2/ru_speaker_8",
         | 
| 310 | 
            +
                "ru_speaker_9-Female BARK": "v2/ru_speaker_9",
         | 
| 311 | 
            +
                "tr_speaker_0-Male BARK": "v2/tr_speaker_0",
         | 
| 312 | 
            +
                "tr_speaker_1-Male BARK": "v2/tr_speaker_1",
         | 
| 313 | 
            +
                "tr_speaker_2-Male BARK": "v2/tr_speaker_2",
         | 
| 314 | 
            +
                "tr_speaker_3-Male BARK": "v2/tr_speaker_3",
         | 
| 315 | 
            +
                "tr_speaker_4-Female BARK": "v2/tr_speaker_4",
         | 
| 316 | 
            +
                "tr_speaker_5-Female BARK": "v2/tr_speaker_5",
         | 
| 317 | 
            +
                "tr_speaker_6-Male BARK": "v2/tr_speaker_6",
         | 
| 318 | 
            +
                "tr_speaker_7-Male BARK": "v2/tr_speaker_7",
         | 
| 319 | 
            +
                "tr_speaker_8-Male BARK": "v2/tr_speaker_8",
         | 
| 320 | 
            +
                "tr_speaker_9-Male BARK": "v2/tr_speaker_9",
         | 
| 321 | 
            +
                "zh_speaker_0-Male BARK": "v2/zh_speaker_0",
         | 
| 322 | 
            +
                "zh_speaker_1-Male BARK": "v2/zh_speaker_1",
         | 
| 323 | 
            +
                "zh_speaker_2-Male BARK": "v2/zh_speaker_2",
         | 
| 324 | 
            +
                "zh_speaker_3-Male BARK": "v2/zh_speaker_3",
         | 
| 325 | 
            +
                "zh_speaker_4-Female BARK": "v2/zh_speaker_4",
         | 
| 326 | 
            +
                "zh_speaker_5-Male BARK": "v2/zh_speaker_5",
         | 
| 327 | 
            +
                "zh_speaker_6-Female BARK": "v2/zh_speaker_6",
         | 
| 328 | 
            +
                "zh_speaker_7-Female BARK": "v2/zh_speaker_7",
         | 
| 329 | 
            +
                "zh_speaker_8-Male BARK": "v2/zh_speaker_8",
         | 
| 330 | 
            +
                "zh_speaker_9-Female BARK": "v2/zh_speaker_9",
         | 
| 331 | 
            +
            }
         | 
| 332 | 
            +
             | 
| 333 | 
            +
            VITS_VOICES_LIST = {
         | 
| 334 | 
            +
                "ar-facebook-mms VITS": "facebook/mms-tts-ara",
         | 
| 335 | 
            +
                # 'zh-facebook-mms VITS': 'facebook/mms-tts-cmn',
         | 
| 336 | 
            +
                "zh_Hakka-facebook-mms VITS": "facebook/mms-tts-hak",
         | 
| 337 | 
            +
                "zh_MinNan-facebook-mms VITS": "facebook/mms-tts-nan",
         | 
| 338 | 
            +
                # 'cs-facebook-mms VITS': 'facebook/mms-tts-ces',
         | 
| 339 | 
            +
                # 'da-facebook-mms VITS': 'facebook/mms-tts-dan',
         | 
| 340 | 
            +
                "nl-facebook-mms VITS": "facebook/mms-tts-nld",
         | 
| 341 | 
            +
                "en-facebook-mms VITS": "facebook/mms-tts-eng",
         | 
| 342 | 
            +
                "fi-facebook-mms VITS": "facebook/mms-tts-fin",
         | 
| 343 | 
            +
                "fr-facebook-mms VITS": "facebook/mms-tts-fra",
         | 
| 344 | 
            +
                "de-facebook-mms VITS": "facebook/mms-tts-deu",
         | 
| 345 | 
            +
                "el-facebook-mms VITS": "facebook/mms-tts-ell",
         | 
| 346 | 
            +
                "el_Ancient-facebook-mms VITS": "facebook/mms-tts-grc",
         | 
| 347 | 
            +
                "he-facebook-mms VITS": "facebook/mms-tts-heb",
         | 
| 348 | 
            +
                "hu-facebook-mms VITS": "facebook/mms-tts-hun",
         | 
| 349 | 
            +
                # 'it-facebook-mms VITS': 'facebook/mms-tts-ita',
         | 
| 350 | 
            +
                # 'ja-facebook-mms VITS': 'facebook/mms-tts-jpn',
         | 
| 351 | 
            +
                "ko-facebook-mms VITS": "facebook/mms-tts-kor",
         | 
| 352 | 
            +
                "fa-facebook-mms VITS": "facebook/mms-tts-fas",
         | 
| 353 | 
            +
                "pl-facebook-mms VITS": "facebook/mms-tts-pol",
         | 
| 354 | 
            +
                "pt-facebook-mms VITS": "facebook/mms-tts-por",
         | 
| 355 | 
            +
                "ru-facebook-mms VITS": "facebook/mms-tts-rus",
         | 
| 356 | 
            +
                "es-facebook-mms VITS": "facebook/mms-tts-spa",
         | 
| 357 | 
            +
                "tr-facebook-mms VITS": "facebook/mms-tts-tur",
         | 
| 358 | 
            +
                "uk-facebook-mms VITS": "facebook/mms-tts-ukr",
         | 
| 359 | 
            +
                "ur_arabic-facebook-mms VITS": "facebook/mms-tts-urd-script_arabic",
         | 
| 360 | 
            +
                "ur_devanagari-facebook-mms VITS": "facebook/mms-tts-urd-script_devanagari",
         | 
| 361 | 
            +
                "ur_latin-facebook-mms VITS": "facebook/mms-tts-urd-script_latin",
         | 
| 362 | 
            +
                "vi-facebook-mms VITS": "facebook/mms-tts-vie",
         | 
| 363 | 
            +
                "hi-facebook-mms VITS": "facebook/mms-tts-hin",
         | 
| 364 | 
            +
                "hi_Fiji-facebook-mms VITS": "facebook/mms-tts-hif",
         | 
| 365 | 
            +
                "id-facebook-mms VITS": "facebook/mms-tts-ind",
         | 
| 366 | 
            +
                "bn-facebook-mms VITS": "facebook/mms-tts-ben",
         | 
| 367 | 
            +
                "te-facebook-mms VITS": "facebook/mms-tts-tel",
         | 
| 368 | 
            +
                "mr-facebook-mms VITS": "facebook/mms-tts-mar",
         | 
| 369 | 
            +
                "ta-facebook-mms VITS": "facebook/mms-tts-tam",
         | 
| 370 | 
            +
                "jw-facebook-mms VITS": "facebook/mms-tts-jav",
         | 
| 371 | 
            +
                "jw_Suriname-facebook-mms VITS": "facebook/mms-tts-jvn",
         | 
| 372 | 
            +
                "ca-facebook-mms VITS": "facebook/mms-tts-cat",
         | 
| 373 | 
            +
                "ne-facebook-mms VITS": "facebook/mms-tts-nep",
         | 
| 374 | 
            +
                "th-facebook-mms VITS": "facebook/mms-tts-tha",
         | 
| 375 | 
            +
                "th_Northern-facebook-mms VITS": "facebook/mms-tts-nod",
         | 
| 376 | 
            +
                "sv-facebook-mms VITS": "facebook/mms-tts-swe",
         | 
| 377 | 
            +
                "am-facebook-mms VITS": "facebook/mms-tts-amh",
         | 
| 378 | 
            +
                "cy-facebook-mms VITS": "facebook/mms-tts-cym",
         | 
| 379 | 
            +
                # "et-facebook-mms VITS": "facebook/mms-tts-est",
         | 
| 380 | 
            +
                # "ht-facebook-mms VITS": "facebook/mms-tts-hrv",
         | 
| 381 | 
            +
                "is-facebook-mms VITS": "facebook/mms-tts-isl",
         | 
| 382 | 
            +
                "km-facebook-mms VITS": "facebook/mms-tts-khm",
         | 
| 383 | 
            +
                "km_Northern-facebook-mms VITS": "facebook/mms-tts-kxm",
         | 
| 384 | 
            +
                # "sk-facebook-mms VITS": "facebook/mms-tts-slk",
         | 
| 385 | 
            +
                "sq_Northern-facebook-mms VITS": "facebook/mms-tts-sqi",
         | 
| 386 | 
            +
                "az_South-facebook-mms VITS": "facebook/mms-tts-azb",
         | 
| 387 | 
            +
                "az_North_script_cyrillic-facebook-mms VITS": "facebook/mms-tts-azj-script_cyrillic",
         | 
| 388 | 
            +
                "az_North_script_latin-facebook-mms VITS": "facebook/mms-tts-azj-script_latin",
         | 
| 389 | 
            +
                "bg-facebook-mms VITS": "facebook/mms-tts-bul",
         | 
| 390 | 
            +
                # "gl-facebook-mms VITS": "facebook/mms-tts-glg",
         | 
| 391 | 
            +
                "gu-facebook-mms VITS": "facebook/mms-tts-guj",
         | 
| 392 | 
            +
                "kk-facebook-mms VITS": "facebook/mms-tts-kaz",
         | 
| 393 | 
            +
                "kn-facebook-mms VITS": "facebook/mms-tts-kan",
         | 
| 394 | 
            +
                # "lt-facebook-mms VITS": "facebook/mms-tts-lit",
         | 
| 395 | 
            +
                "lv-facebook-mms VITS": "facebook/mms-tts-lav",
         | 
| 396 | 
            +
                # "mk-facebook-mms VITS": "facebook/mms-tts-mkd",
         | 
| 397 | 
            +
                "ml-facebook-mms VITS": "facebook/mms-tts-mal",
         | 
| 398 | 
            +
                "ms-facebook-mms VITS": "facebook/mms-tts-zlm",
         | 
| 399 | 
            +
                "ms_Central-facebook-mms VITS": "facebook/mms-tts-pse",
         | 
| 400 | 
            +
                "ms_Manado-facebook-mms VITS": "facebook/mms-tts-xmm",
         | 
| 401 | 
            +
                "ro-facebook-mms VITS": "facebook/mms-tts-ron",
         | 
| 402 | 
            +
                # "si-facebook-mms VITS": "facebook/mms-tts-sin",
         | 
| 403 | 
            +
                "sw-facebook-mms VITS": "facebook/mms-tts-swh",
         | 
| 404 | 
            +
                # "af-facebook-mms VITS": "facebook/mms-tts-afr",
         | 
| 405 | 
            +
                # "bs-facebook-mms VITS": "facebook/mms-tts-bos",
         | 
| 406 | 
            +
                "la-facebook-mms VITS": "facebook/mms-tts-lat",
         | 
| 407 | 
            +
                "my-facebook-mms VITS": "facebook/mms-tts-mya",
         | 
| 408 | 
            +
                # "no_Bokmål-facebook-mms VITS": "thomasht86/mms-tts-nob",  # verify
         | 
| 409 | 
            +
                "as-facebook-mms VITS": "facebook/mms-tts-asm",
         | 
| 410 | 
            +
                "as_Nagamese-facebook-mms VITS": "facebook/mms-tts-nag",
         | 
| 411 | 
            +
                "eu-facebook-mms VITS": "facebook/mms-tts-eus",
         | 
| 412 | 
            +
                "ha-facebook-mms VITS": "facebook/mms-tts-hau",
         | 
| 413 | 
            +
                "ht-facebook-mms VITS": "facebook/mms-tts-hat",
         | 
| 414 | 
            +
                "hy_Western-facebook-mms VITS": "facebook/mms-tts-hyw",
         | 
| 415 | 
            +
                "lo-facebook-mms VITS": "facebook/mms-tts-lao",
         | 
| 416 | 
            +
                "mg-facebook-mms VITS": "facebook/mms-tts-mlg",
         | 
| 417 | 
            +
                "mn-facebook-mms VITS": "facebook/mms-tts-mon",
         | 
| 418 | 
            +
                # "mt-facebook-mms VITS": "facebook/mms-tts-mlt",
         | 
| 419 | 
            +
                "pa_Eastern-facebook-mms VITS": "facebook/mms-tts-pan",
         | 
| 420 | 
            +
                # "pa_Western-facebook-mms VITS": "facebook/mms-tts-pnb",
         | 
| 421 | 
            +
                # "ps-facebook-mms VITS": "facebook/mms-tts-pus",
         | 
| 422 | 
            +
                # "sl-facebook-mms VITS": "facebook/mms-tts-slv",
         | 
| 423 | 
            +
                "sn-facebook-mms VITS": "facebook/mms-tts-sna",
         | 
| 424 | 
            +
                "so-facebook-mms VITS": "facebook/mms-tts-son",
         | 
| 425 | 
            +
                "tg-facebook-mms VITS": "facebook/mms-tts-tgk",
         | 
| 426 | 
            +
                "tk_script_arabic-facebook-mms VITS": "facebook/mms-tts-tuk-script_arabic",
         | 
| 427 | 
            +
                "tk_script_latin-facebook-mms VITS": "facebook/mms-tts-tuk-script_latin",
         | 
| 428 | 
            +
                "tt-facebook-mms VITS": "facebook/mms-tts-tat",
         | 
| 429 | 
            +
                "tt_Crimean-facebook-mms VITS": "facebook/mms-tts-crh",
         | 
| 430 | 
            +
                "uz_script_cyrillic-facebook-mms VITS": "facebook/mms-tts-uzb-script_cyrillic",
         | 
| 431 | 
            +
                "yo-facebook-mms VITS": "facebook/mms-tts-yor",
         | 
| 432 | 
            +
                "ay-facebook-mms VITS": "facebook/mms-tts-ayr",
         | 
| 433 | 
            +
                "bm-facebook-mms VITS": "facebook/mms-tts-bam",
         | 
| 434 | 
            +
                "ceb-facebook-mms VITS": "facebook/mms-tts-ceb",
         | 
| 435 | 
            +
                "ny-facebook-mms VITS": "facebook/mms-tts-nya",
         | 
| 436 | 
            +
                "dv-facebook-mms VITS": "facebook/mms-tts-div",
         | 
| 437 | 
            +
                "doi-facebook-mms VITS": "facebook/mms-tts-dgo",
         | 
| 438 | 
            +
                "ee-facebook-mms VITS": "facebook/mms-tts-ewe",
         | 
| 439 | 
            +
                "gn-facebook-mms VITS": "facebook/mms-tts-grn",
         | 
| 440 | 
            +
                "ilo-facebook-mms VITS": "facebook/mms-tts-ilo",
         | 
| 441 | 
            +
                "rw-facebook-mms VITS": "facebook/mms-tts-kin",
         | 
| 442 | 
            +
                "kri-facebook-mms VITS": "facebook/mms-tts-kri",
         | 
| 443 | 
            +
                "ku_script_arabic-facebook-mms VITS": "facebook/mms-tts-kmr-script_arabic",
         | 
| 444 | 
            +
                "ku_script_cyrillic-facebook-mms VITS": "facebook/mms-tts-kmr-script_cyrillic",
         | 
| 445 | 
            +
                "ku_script_latin-facebook-mms VITS": "facebook/mms-tts-kmr-script_latin",
         | 
| 446 | 
            +
                "ckb-facebook-mms VITS": "razhan/mms-tts-ckb",  # Verify w
         | 
| 447 | 
            +
                "ky-facebook-mms VITS": "facebook/mms-tts-kir",
         | 
| 448 | 
            +
                "lg-facebook-mms VITS": "facebook/mms-tts-lug",
         | 
| 449 | 
            +
                "mai-facebook-mms VITS": "facebook/mms-tts-mai",
         | 
| 450 | 
            +
                "or-facebook-mms VITS": "facebook/mms-tts-ory",
         | 
| 451 | 
            +
                "om-facebook-mms VITS": "facebook/mms-tts-orm",
         | 
| 452 | 
            +
                "qu_Huallaga-facebook-mms VITS": "facebook/mms-tts-qub",
         | 
| 453 | 
            +
                "qu_Lambayeque-facebook-mms VITS": "facebook/mms-tts-quf",
         | 
| 454 | 
            +
                "qu_South_Bolivian-facebook-mms VITS": "facebook/mms-tts-quh",
         | 
| 455 | 
            +
                "qu_North_Bolivian-facebook-mms VITS": "facebook/mms-tts-qul",
         | 
| 456 | 
            +
                "qu_Tena_Lowland-facebook-mms VITS": "facebook/mms-tts-quw",
         | 
| 457 | 
            +
                "qu_Ayacucho-facebook-mms VITS": "facebook/mms-tts-quy",
         | 
| 458 | 
            +
                "qu_Cusco-facebook-mms VITS": "facebook/mms-tts-quz",
         | 
| 459 | 
            +
                "qu_Cajamarca-facebook-mms VITS": "facebook/mms-tts-qvc",
         | 
| 460 | 
            +
                "qu_Eastern_Apurímac-facebook-mms VITS": "facebook/mms-tts-qve",
         | 
| 461 | 
            +
                "qu_Huamalíes_Dos_de_Mayo_Huánuco-facebook-mms VITS": "facebook/mms-tts-qvh",
         | 
| 462 | 
            +
                "qu_Margos_Yarowilca_Lauricocha-facebook-mms VITS": "facebook/mms-tts-qvm",
         | 
| 463 | 
            +
                "qu_North_Junín-facebook-mms VITS": "facebook/mms-tts-qvn",
         | 
| 464 | 
            +
                "qu_Napo-facebook-mms VITS": "facebook/mms-tts-qvo",
         | 
| 465 | 
            +
                "qu_San_Martín-facebook-mms VITS": "facebook/mms-tts-qvs",
         | 
| 466 | 
            +
                "qu_Huaylla_Wanca-facebook-mms VITS": "facebook/mms-tts-qvw",
         | 
| 467 | 
            +
                "qu_Northern_Pastaza-facebook-mms VITS": "facebook/mms-tts-qvz",
         | 
| 468 | 
            +
                "qu_Huaylas_Ancash-facebook-mms VITS": "facebook/mms-tts-qwh",
         | 
| 469 | 
            +
                "qu_Panao-facebook-mms VITS": "facebook/mms-tts-qxh",
         | 
| 470 | 
            +
                "qu_Salasaca_Highland-facebook-mms VITS": "facebook/mms-tts-qxl",
         | 
| 471 | 
            +
                "qu_Northern_Conchucos_Ancash-facebook-mms VITS": "facebook/mms-tts-qxn",
         | 
| 472 | 
            +
                "qu_Southern_Conchucos-facebook-mms VITS": "facebook/mms-tts-qxo",
         | 
| 473 | 
            +
                "qu_Cañar_Highland-facebook-mms VITS": "facebook/mms-tts-qxr",
         | 
| 474 | 
            +
                "sm-facebook-mms VITS": "facebook/mms-tts-smo",
         | 
| 475 | 
            +
                "ti-facebook-mms VITS": "facebook/mms-tts-tir",
         | 
| 476 | 
            +
                "ts-facebook-mms VITS": "facebook/mms-tts-tso",
         | 
| 477 | 
            +
                "ak-facebook-mms VITS": "facebook/mms-tts-aka",
         | 
| 478 | 
            +
                "ug_script_arabic-facebook-mms VITS": "facebook/mms-tts-uig-script_arabic",
         | 
| 479 | 
            +
                "ug_script_cyrillic-facebook-mms VITS": "facebook/mms-tts-uig-script_cyrillic",
         | 
| 480 | 
            +
            }
         | 
| 481 | 
            +
             | 
| 482 | 
            +
            OPENAI_TTS_CODES = [
         | 
| 483 | 
            +
                "af", "ar", "hy", "az", "be", "bs", "bg", "ca", "zh", "hr", "cs", "da",
         | 
| 484 | 
            +
                "nl", "en", "et", "fi", "fr", "gl", "de", "el", "he", "hi", "hu", "is",
         | 
| 485 | 
            +
                "id", "it", "ja", "kn", "kk", "ko", "lv", "lt", "mk", "ms", "mr", "mi",
         | 
| 486 | 
            +
                "ne", "no", "fa", "pl", "pt", "ro", "ru", "sr", "sk", "sl", "es", "sw",
         | 
| 487 | 
            +
                "sv", "tl", "ta", "th", "tr", "uk", "ur", "vi", "cy", "zh-TW"
         | 
| 488 | 
            +
            ]
         | 
| 489 | 
            +
             | 
| 490 | 
            +
            OPENAI_TTS_MODELS = [
         | 
| 491 | 
            +
                ">alloy OpenAI-TTS",
         | 
| 492 | 
            +
                ">echo OpenAI-TTS",
         | 
| 493 | 
            +
                ">fable OpenAI-TTS",
         | 
| 494 | 
            +
                ">onyx OpenAI-TTS",
         | 
| 495 | 
            +
                ">nova OpenAI-TTS",
         | 
| 496 | 
            +
                ">shimmer OpenAI-TTS",
         | 
| 497 | 
            +
                ">alloy HD OpenAI-TTS",
         | 
| 498 | 
            +
                ">echo HD OpenAI-TTS",
         | 
| 499 | 
            +
                ">fable HD OpenAI-TTS",
         | 
| 500 | 
            +
                ">onyx HD OpenAI-TTS",
         | 
| 501 | 
            +
                ">nova HD OpenAI-TTS",
         | 
| 502 | 
            +
                ">shimmer HD OpenAI-TTS"
         | 
| 503 | 
            +
            ]
         | 
| 504 | 
            +
             | 
| 505 | 
            +
            LANGUAGE_CODE_IN_THREE_LETTERS = {
         | 
| 506 | 
            +
                "Automatic detection": "aut",
         | 
| 507 | 
            +
                "ar": "ara",
         | 
| 508 | 
            +
                "zh": "chi",
         | 
| 509 | 
            +
                "cs": "cze",
         | 
| 510 | 
            +
                "da": "dan",
         | 
| 511 | 
            +
                "nl": "dut",
         | 
| 512 | 
            +
                "en": "eng",
         | 
| 513 | 
            +
                "fi": "fin",
         | 
| 514 | 
            +
                "fr": "fre",
         | 
| 515 | 
            +
                "de": "ger",
         | 
| 516 | 
            +
                "el": "gre",
         | 
| 517 | 
            +
                "he": "heb",
         | 
| 518 | 
            +
                "hu": "hun",
         | 
| 519 | 
            +
                "it": "ita",
         | 
| 520 | 
            +
                "ja": "jpn",
         | 
| 521 | 
            +
                "ko": "kor",
         | 
| 522 | 
            +
                "fa": "per",
         | 
| 523 | 
            +
                "pl": "pol",
         | 
| 524 | 
            +
                "pt": "por",
         | 
| 525 | 
            +
                "ru": "rus",
         | 
| 526 | 
            +
                "es": "spa",
         | 
| 527 | 
            +
                "tr": "tur",
         | 
| 528 | 
            +
                "uk": "ukr",
         | 
| 529 | 
            +
                "ur": "urd",
         | 
| 530 | 
            +
                "vi": "vie",
         | 
| 531 | 
            +
                "hi": "hin",
         | 
| 532 | 
            +
                "id": "ind",
         | 
| 533 | 
            +
                "bn": "ben",
         | 
| 534 | 
            +
                "te": "tel",
         | 
| 535 | 
            +
                "mr": "mar",
         | 
| 536 | 
            +
                "ta": "tam",
         | 
| 537 | 
            +
                "jw": "jav",
         | 
| 538 | 
            +
                "ca": "cat",
         | 
| 539 | 
            +
                "ne": "nep",
         | 
| 540 | 
            +
                "th": "tha",
         | 
| 541 | 
            +
                "sv": "swe",
         | 
| 542 | 
            +
                "am": "amh",
         | 
| 543 | 
            +
                "cy": "cym",
         | 
| 544 | 
            +
                "et": "est",
         | 
| 545 | 
            +
                "hr": "hrv",
         | 
| 546 | 
            +
                "is": "isl",
         | 
| 547 | 
            +
                "km": "khm",
         | 
| 548 | 
            +
                "sk": "slk",
         | 
| 549 | 
            +
                "sq": "sqi",
         | 
| 550 | 
            +
                "sr": "srp",
         | 
| 551 | 
            +
            }
         | 
    	
        soni_translate/languages_gui.py
    ADDED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        soni_translate/logging_setup.py
    ADDED
    
    | @@ -0,0 +1,68 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import logging
         | 
| 2 | 
            +
            import sys
         | 
| 3 | 
            +
            import warnings
         | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            def configure_logging_libs(debug=False):
         | 
| 8 | 
            +
                warnings.filterwarnings(
         | 
| 9 | 
            +
                  action="ignore", category=UserWarning, module="pyannote"
         | 
| 10 | 
            +
                )
         | 
| 11 | 
            +
                modules = [
         | 
| 12 | 
            +
                  "numba", "httpx", "markdown_it", "speechbrain", "fairseq", "pyannote",
         | 
| 13 | 
            +
                  "faiss",
         | 
| 14 | 
            +
                  "pytorch_lightning.utilities.migration.utils",
         | 
| 15 | 
            +
                  "pytorch_lightning.utilities.migration",
         | 
| 16 | 
            +
                  "pytorch_lightning",
         | 
| 17 | 
            +
                  "lightning",
         | 
| 18 | 
            +
                  "lightning.pytorch.utilities.migration.utils",
         | 
| 19 | 
            +
                ]
         | 
| 20 | 
            +
                try:
         | 
| 21 | 
            +
                    for module in modules:
         | 
| 22 | 
            +
                        logging.getLogger(module).setLevel(logging.WARNING)
         | 
| 23 | 
            +
                    os.environ['TF_CPP_MIN_LOG_LEVEL'] = "3" if not debug else "1"
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                    # fix verbose pyannote audio
         | 
| 26 | 
            +
                    def fix_verbose_pyannote(*args, what=""):
         | 
| 27 | 
            +
                        pass
         | 
| 28 | 
            +
                    import pyannote.audio.core.model # noqa
         | 
| 29 | 
            +
                    pyannote.audio.core.model.check_version = fix_verbose_pyannote
         | 
| 30 | 
            +
                except Exception as error:
         | 
| 31 | 
            +
                    logger.error(str(error))
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            def setup_logger(name_log):
         | 
| 35 | 
            +
                logger = logging.getLogger(name_log)
         | 
| 36 | 
            +
                logger.setLevel(logging.INFO)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                _default_handler = logging.StreamHandler()  # Set sys.stderr as stream.
         | 
| 39 | 
            +
                _default_handler.flush = sys.stderr.flush
         | 
| 40 | 
            +
                logger.addHandler(_default_handler)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                logger.propagate = False
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                handlers = logger.handlers
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                for handler in handlers:
         | 
| 47 | 
            +
                    formatter = logging.Formatter("[%(levelname)s] >> %(message)s")
         | 
| 48 | 
            +
                    handler.setFormatter(formatter)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                # logger.handlers
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                return logger
         | 
| 53 | 
            +
             | 
| 54 | 
            +
             | 
| 55 | 
            +
            logger = setup_logger("sonitranslate")
         | 
| 56 | 
            +
            logger.setLevel(logging.INFO)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
             | 
| 59 | 
            +
            def set_logging_level(verbosity_level):
         | 
| 60 | 
            +
                logging_level_mapping = {
         | 
| 61 | 
            +
                    "debug": logging.DEBUG,
         | 
| 62 | 
            +
                    "info": logging.INFO,
         | 
| 63 | 
            +
                    "warning": logging.WARNING,
         | 
| 64 | 
            +
                    "error": logging.ERROR,
         | 
| 65 | 
            +
                    "critical": logging.CRITICAL,
         | 
| 66 | 
            +
                }
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                logger.setLevel(logging_level_mapping.get(verbosity_level, logging.INFO))
         | 
    	
        soni_translate/mdx_net.py
    ADDED
    
    | @@ -0,0 +1,582 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gc
         | 
| 2 | 
            +
            import hashlib
         | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            import queue
         | 
| 5 | 
            +
            import threading
         | 
| 6 | 
            +
            import json
         | 
| 7 | 
            +
            import shlex
         | 
| 8 | 
            +
            import sys
         | 
| 9 | 
            +
            import subprocess
         | 
| 10 | 
            +
            import librosa
         | 
| 11 | 
            +
            import numpy as np
         | 
| 12 | 
            +
            import soundfile as sf
         | 
| 13 | 
            +
            import torch
         | 
| 14 | 
            +
            from tqdm import tqdm
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            try:
         | 
| 17 | 
            +
                from .utils import (
         | 
| 18 | 
            +
                    remove_directory_contents,
         | 
| 19 | 
            +
                    create_directories,
         | 
| 20 | 
            +
                )
         | 
| 21 | 
            +
            except:  # noqa
         | 
| 22 | 
            +
                from utils import (
         | 
| 23 | 
            +
                    remove_directory_contents,
         | 
| 24 | 
            +
                    create_directories,
         | 
| 25 | 
            +
                )
         | 
| 26 | 
            +
            from .logging_setup import logger
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            try:
         | 
| 29 | 
            +
                import onnxruntime as ort
         | 
| 30 | 
            +
            except Exception as error:
         | 
| 31 | 
            +
                logger.error(str(error))
         | 
| 32 | 
            +
            # import warnings
         | 
| 33 | 
            +
            # warnings.filterwarnings("ignore")
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            stem_naming = {
         | 
| 36 | 
            +
                "Vocals": "Instrumental",
         | 
| 37 | 
            +
                "Other": "Instruments",
         | 
| 38 | 
            +
                "Instrumental": "Vocals",
         | 
| 39 | 
            +
                "Drums": "Drumless",
         | 
| 40 | 
            +
                "Bass": "Bassless",
         | 
| 41 | 
            +
            }
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            class MDXModel:
         | 
| 45 | 
            +
                def __init__(
         | 
| 46 | 
            +
                    self,
         | 
| 47 | 
            +
                    device,
         | 
| 48 | 
            +
                    dim_f,
         | 
| 49 | 
            +
                    dim_t,
         | 
| 50 | 
            +
                    n_fft,
         | 
| 51 | 
            +
                    hop=1024,
         | 
| 52 | 
            +
                    stem_name=None,
         | 
| 53 | 
            +
                    compensation=1.000,
         | 
| 54 | 
            +
                ):
         | 
| 55 | 
            +
                    self.dim_f = dim_f
         | 
| 56 | 
            +
                    self.dim_t = dim_t
         | 
| 57 | 
            +
                    self.dim_c = 4
         | 
| 58 | 
            +
                    self.n_fft = n_fft
         | 
| 59 | 
            +
                    self.hop = hop
         | 
| 60 | 
            +
                    self.stem_name = stem_name
         | 
| 61 | 
            +
                    self.compensation = compensation
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    self.n_bins = self.n_fft // 2 + 1
         | 
| 64 | 
            +
                    self.chunk_size = hop * (self.dim_t - 1)
         | 
| 65 | 
            +
                    self.window = torch.hann_window(
         | 
| 66 | 
            +
                        window_length=self.n_fft, periodic=True
         | 
| 67 | 
            +
                    ).to(device)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    out_c = self.dim_c
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    self.freq_pad = torch.zeros(
         | 
| 72 | 
            +
                        [1, out_c, self.n_bins - self.dim_f, self.dim_t]
         | 
| 73 | 
            +
                    ).to(device)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                def stft(self, x):
         | 
| 76 | 
            +
                    x = x.reshape([-1, self.chunk_size])
         | 
| 77 | 
            +
                    x = torch.stft(
         | 
| 78 | 
            +
                        x,
         | 
| 79 | 
            +
                        n_fft=self.n_fft,
         | 
| 80 | 
            +
                        hop_length=self.hop,
         | 
| 81 | 
            +
                        window=self.window,
         | 
| 82 | 
            +
                        center=True,
         | 
| 83 | 
            +
                        return_complex=True,
         | 
| 84 | 
            +
                    )
         | 
| 85 | 
            +
                    x = torch.view_as_real(x)
         | 
| 86 | 
            +
                    x = x.permute([0, 3, 1, 2])
         | 
| 87 | 
            +
                    x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape(
         | 
| 88 | 
            +
                        [-1, 4, self.n_bins, self.dim_t]
         | 
| 89 | 
            +
                    )
         | 
| 90 | 
            +
                    return x[:, :, : self.dim_f]
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                def istft(self, x, freq_pad=None):
         | 
| 93 | 
            +
                    freq_pad = (
         | 
| 94 | 
            +
                        self.freq_pad.repeat([x.shape[0], 1, 1, 1])
         | 
| 95 | 
            +
                        if freq_pad is None
         | 
| 96 | 
            +
                        else freq_pad
         | 
| 97 | 
            +
                    )
         | 
| 98 | 
            +
                    x = torch.cat([x, freq_pad], -2)
         | 
| 99 | 
            +
                    # c = 4*2 if self.target_name=='*' else 2
         | 
| 100 | 
            +
                    x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape(
         | 
| 101 | 
            +
                        [-1, 2, self.n_bins, self.dim_t]
         | 
| 102 | 
            +
                    )
         | 
| 103 | 
            +
                    x = x.permute([0, 2, 3, 1])
         | 
| 104 | 
            +
                    x = x.contiguous()
         | 
| 105 | 
            +
                    x = torch.view_as_complex(x)
         | 
| 106 | 
            +
                    x = torch.istft(
         | 
| 107 | 
            +
                        x,
         | 
| 108 | 
            +
                        n_fft=self.n_fft,
         | 
| 109 | 
            +
                        hop_length=self.hop,
         | 
| 110 | 
            +
                        window=self.window,
         | 
| 111 | 
            +
                        center=True,
         | 
| 112 | 
            +
                    )
         | 
| 113 | 
            +
                    return x.reshape([-1, 2, self.chunk_size])
         | 
| 114 | 
            +
             | 
| 115 | 
            +
             | 
| 116 | 
            +
            class MDX:
         | 
| 117 | 
            +
                DEFAULT_SR = 44100
         | 
| 118 | 
            +
                # Unit: seconds
         | 
| 119 | 
            +
                DEFAULT_CHUNK_SIZE = 0 * DEFAULT_SR
         | 
| 120 | 
            +
                DEFAULT_MARGIN_SIZE = 1 * DEFAULT_SR
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                def __init__(
         | 
| 123 | 
            +
                    self, model_path: str, params: MDXModel, processor=0
         | 
| 124 | 
            +
                ):
         | 
| 125 | 
            +
                    # Set the device and the provider (CPU or CUDA)
         | 
| 126 | 
            +
                    self.device = (
         | 
| 127 | 
            +
                        torch.device(f"cuda:{processor}")
         | 
| 128 | 
            +
                        if processor >= 0
         | 
| 129 | 
            +
                        else torch.device("cpu")
         | 
| 130 | 
            +
                    )
         | 
| 131 | 
            +
                    self.provider = (
         | 
| 132 | 
            +
                        ["CUDAExecutionProvider"]
         | 
| 133 | 
            +
                        if processor >= 0
         | 
| 134 | 
            +
                        else ["CPUExecutionProvider"]
         | 
| 135 | 
            +
                    )
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    self.model = params
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    # Load the ONNX model using ONNX Runtime
         | 
| 140 | 
            +
                    self.ort = ort.InferenceSession(model_path, providers=self.provider)
         | 
| 141 | 
            +
                    # Preload the model for faster performance
         | 
| 142 | 
            +
                    self.ort.run(
         | 
| 143 | 
            +
                        None,
         | 
| 144 | 
            +
                        {"input": torch.rand(1, 4, params.dim_f, params.dim_t).numpy()},
         | 
| 145 | 
            +
                    )
         | 
| 146 | 
            +
                    self.process = lambda spec: self.ort.run(
         | 
| 147 | 
            +
                        None, {"input": spec.cpu().numpy()}
         | 
| 148 | 
            +
                    )[0]
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    self.prog = None
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                @staticmethod
         | 
| 153 | 
            +
                def get_hash(model_path):
         | 
| 154 | 
            +
                    try:
         | 
| 155 | 
            +
                        with open(model_path, "rb") as f:
         | 
| 156 | 
            +
                            f.seek(-10000 * 1024, 2)
         | 
| 157 | 
            +
                            model_hash = hashlib.md5(f.read()).hexdigest()
         | 
| 158 | 
            +
                    except: # noqa
         | 
| 159 | 
            +
                        model_hash = hashlib.md5(open(model_path, "rb").read()).hexdigest()
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    return model_hash
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                @staticmethod
         | 
| 164 | 
            +
                def segment(
         | 
| 165 | 
            +
                    wave,
         | 
| 166 | 
            +
                    combine=True,
         | 
| 167 | 
            +
                    chunk_size=DEFAULT_CHUNK_SIZE,
         | 
| 168 | 
            +
                    margin_size=DEFAULT_MARGIN_SIZE,
         | 
| 169 | 
            +
                ):
         | 
| 170 | 
            +
                    """
         | 
| 171 | 
            +
                    Segment or join segmented wave array
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                    Args:
         | 
| 174 | 
            +
                        wave: (np.array) Wave array to be segmented or joined
         | 
| 175 | 
            +
                        combine: (bool) If True, combines segmented wave array.
         | 
| 176 | 
            +
                            If False, segments wave array.
         | 
| 177 | 
            +
                        chunk_size: (int) Size of each segment (in samples)
         | 
| 178 | 
            +
                        margin_size: (int) Size of margin between segments (in samples)
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                    Returns:
         | 
| 181 | 
            +
                        numpy array: Segmented or joined wave array
         | 
| 182 | 
            +
                    """
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    if combine:
         | 
| 185 | 
            +
                        # Initializing as None instead of [] for later numpy array concatenation
         | 
| 186 | 
            +
                        processed_wave = None
         | 
| 187 | 
            +
                        for segment_count, segment in enumerate(wave):
         | 
| 188 | 
            +
                            start = 0 if segment_count == 0 else margin_size
         | 
| 189 | 
            +
                            end = None if segment_count == len(wave) - 1 else -margin_size
         | 
| 190 | 
            +
                            if margin_size == 0:
         | 
| 191 | 
            +
                                end = None
         | 
| 192 | 
            +
                            if processed_wave is None:  # Create array for first segment
         | 
| 193 | 
            +
                                processed_wave = segment[:, start:end]
         | 
| 194 | 
            +
                            else:  # Concatenate to existing array for subsequent segments
         | 
| 195 | 
            +
                                processed_wave = np.concatenate(
         | 
| 196 | 
            +
                                    (processed_wave, segment[:, start:end]), axis=-1
         | 
| 197 | 
            +
                                )
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    else:
         | 
| 200 | 
            +
                        processed_wave = []
         | 
| 201 | 
            +
                        sample_count = wave.shape[-1]
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                        if chunk_size <= 0 or chunk_size > sample_count:
         | 
| 204 | 
            +
                            chunk_size = sample_count
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                        if margin_size > chunk_size:
         | 
| 207 | 
            +
                            margin_size = chunk_size
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                        for segment_count, skip in enumerate(
         | 
| 210 | 
            +
                            range(0, sample_count, chunk_size)
         | 
| 211 | 
            +
                        ):
         | 
| 212 | 
            +
                            margin = 0 if segment_count == 0 else margin_size
         | 
| 213 | 
            +
                            end = min(skip + chunk_size + margin_size, sample_count)
         | 
| 214 | 
            +
                            start = skip - margin
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                            cut = wave[:, start:end].copy()
         | 
| 217 | 
            +
                            processed_wave.append(cut)
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                            if end == sample_count:
         | 
| 220 | 
            +
                                break
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                    return processed_wave
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                def pad_wave(self, wave):
         | 
| 225 | 
            +
                    """
         | 
| 226 | 
            +
                    Pad the wave array to match the required chunk size
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                    Args:
         | 
| 229 | 
            +
                        wave: (np.array) Wave array to be padded
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                    Returns:
         | 
| 232 | 
            +
                        tuple: (padded_wave, pad, trim)
         | 
| 233 | 
            +
                            - padded_wave: Padded wave array
         | 
| 234 | 
            +
                            - pad: Number of samples that were padded
         | 
| 235 | 
            +
                            - trim: Number of samples that were trimmed
         | 
| 236 | 
            +
                    """
         | 
| 237 | 
            +
                    n_sample = wave.shape[1]
         | 
| 238 | 
            +
                    trim = self.model.n_fft // 2
         | 
| 239 | 
            +
                    gen_size = self.model.chunk_size - 2 * trim
         | 
| 240 | 
            +
                    pad = gen_size - n_sample % gen_size
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                    # Padded wave
         | 
| 243 | 
            +
                    wave_p = np.concatenate(
         | 
| 244 | 
            +
                        (
         | 
| 245 | 
            +
                            np.zeros((2, trim)),
         | 
| 246 | 
            +
                            wave,
         | 
| 247 | 
            +
                            np.zeros((2, pad)),
         | 
| 248 | 
            +
                            np.zeros((2, trim)),
         | 
| 249 | 
            +
                        ),
         | 
| 250 | 
            +
                        1,
         | 
| 251 | 
            +
                    )
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                    mix_waves = []
         | 
| 254 | 
            +
                    for i in range(0, n_sample + pad, gen_size):
         | 
| 255 | 
            +
                        waves = np.array(wave_p[:, i:i + self.model.chunk_size])
         | 
| 256 | 
            +
                        mix_waves.append(waves)
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                    mix_waves = torch.tensor(mix_waves, dtype=torch.float32).to(
         | 
| 259 | 
            +
                        self.device
         | 
| 260 | 
            +
                    )
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                    return mix_waves, pad, trim
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                def _process_wave(self, mix_waves, trim, pad, q: queue.Queue, _id: int):
         | 
| 265 | 
            +
                    """
         | 
| 266 | 
            +
                    Process each wave segment in a multi-threaded environment
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                    Args:
         | 
| 269 | 
            +
                        mix_waves: (torch.Tensor) Wave segments to be processed
         | 
| 270 | 
            +
                        trim: (int) Number of samples trimmed during padding
         | 
| 271 | 
            +
                        pad: (int) Number of samples padded during padding
         | 
| 272 | 
            +
                        q: (queue.Queue) Queue to hold the processed wave segments
         | 
| 273 | 
            +
                        _id: (int) Identifier of the processed wave segment
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                    Returns:
         | 
| 276 | 
            +
                        numpy array: Processed wave segment
         | 
| 277 | 
            +
                    """
         | 
| 278 | 
            +
                    mix_waves = mix_waves.split(1)
         | 
| 279 | 
            +
                    with torch.no_grad():
         | 
| 280 | 
            +
                        pw = []
         | 
| 281 | 
            +
                        for mix_wave in mix_waves:
         | 
| 282 | 
            +
                            self.prog.update()
         | 
| 283 | 
            +
                            spec = self.model.stft(mix_wave)
         | 
| 284 | 
            +
                            processed_spec = torch.tensor(self.process(spec))
         | 
| 285 | 
            +
                            processed_wav = self.model.istft(
         | 
| 286 | 
            +
                                processed_spec.to(self.device)
         | 
| 287 | 
            +
                            )
         | 
| 288 | 
            +
                            processed_wav = (
         | 
| 289 | 
            +
                                processed_wav[:, :, trim:-trim]
         | 
| 290 | 
            +
                                .transpose(0, 1)
         | 
| 291 | 
            +
                                .reshape(2, -1)
         | 
| 292 | 
            +
                                .cpu()
         | 
| 293 | 
            +
                                .numpy()
         | 
| 294 | 
            +
                            )
         | 
| 295 | 
            +
                            pw.append(processed_wav)
         | 
| 296 | 
            +
                    processed_signal = np.concatenate(pw, axis=-1)[:, :-pad]
         | 
| 297 | 
            +
                    q.put({_id: processed_signal})
         | 
| 298 | 
            +
                    return processed_signal
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                def process_wave(self, wave: np.array, mt_threads=1):
         | 
| 301 | 
            +
                    """
         | 
| 302 | 
            +
                    Process the wave array in a multi-threaded environment
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                    Args:
         | 
| 305 | 
            +
                        wave: (np.array) Wave array to be processed
         | 
| 306 | 
            +
                        mt_threads: (int) Number of threads to be used for processing
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                    Returns:
         | 
| 309 | 
            +
                        numpy array: Processed wave array
         | 
| 310 | 
            +
                    """
         | 
| 311 | 
            +
                    self.prog = tqdm(total=0)
         | 
| 312 | 
            +
                    chunk = wave.shape[-1] // mt_threads
         | 
| 313 | 
            +
                    waves = self.segment(wave, False, chunk)
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                    # Create a queue to hold the processed wave segments
         | 
| 316 | 
            +
                    q = queue.Queue()
         | 
| 317 | 
            +
                    threads = []
         | 
| 318 | 
            +
                    for c, batch in enumerate(waves):
         | 
| 319 | 
            +
                        mix_waves, pad, trim = self.pad_wave(batch)
         | 
| 320 | 
            +
                        self.prog.total = len(mix_waves) * mt_threads
         | 
| 321 | 
            +
                        thread = threading.Thread(
         | 
| 322 | 
            +
                            target=self._process_wave, args=(mix_waves, trim, pad, q, c)
         | 
| 323 | 
            +
                        )
         | 
| 324 | 
            +
                        thread.start()
         | 
| 325 | 
            +
                        threads.append(thread)
         | 
| 326 | 
            +
                    for thread in threads:
         | 
| 327 | 
            +
                        thread.join()
         | 
| 328 | 
            +
                    self.prog.close()
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                    processed_batches = []
         | 
| 331 | 
            +
                    while not q.empty():
         | 
| 332 | 
            +
                        processed_batches.append(q.get())
         | 
| 333 | 
            +
                    processed_batches = [
         | 
| 334 | 
            +
                        list(wave.values())[0]
         | 
| 335 | 
            +
                        for wave in sorted(
         | 
| 336 | 
            +
                            processed_batches, key=lambda d: list(d.keys())[0]
         | 
| 337 | 
            +
                        )
         | 
| 338 | 
            +
                    ]
         | 
| 339 | 
            +
                    assert len(processed_batches) == len(
         | 
| 340 | 
            +
                        waves
         | 
| 341 | 
            +
                    ), "Incomplete processed batches, please reduce batch size!"
         | 
| 342 | 
            +
                    return self.segment(processed_batches, True, chunk)
         | 
| 343 | 
            +
             | 
| 344 | 
            +
             | 
| 345 | 
            +
            def run_mdx(
         | 
| 346 | 
            +
                model_params,
         | 
| 347 | 
            +
                output_dir,
         | 
| 348 | 
            +
                model_path,
         | 
| 349 | 
            +
                filename,
         | 
| 350 | 
            +
                exclude_main=False,
         | 
| 351 | 
            +
                exclude_inversion=False,
         | 
| 352 | 
            +
                suffix=None,
         | 
| 353 | 
            +
                invert_suffix=None,
         | 
| 354 | 
            +
                denoise=False,
         | 
| 355 | 
            +
                keep_orig=True,
         | 
| 356 | 
            +
                m_threads=2,
         | 
| 357 | 
            +
                device_base="cuda",
         | 
| 358 | 
            +
            ):
         | 
| 359 | 
            +
                if device_base == "cuda":
         | 
| 360 | 
            +
                    device = torch.device("cuda:0")
         | 
| 361 | 
            +
                    processor_num = 0
         | 
| 362 | 
            +
                    device_properties = torch.cuda.get_device_properties(device)
         | 
| 363 | 
            +
                    vram_gb = device_properties.total_memory / 1024**3
         | 
| 364 | 
            +
                    m_threads = 1 if vram_gb < 8 else 2
         | 
| 365 | 
            +
                else:
         | 
| 366 | 
            +
                    device = torch.device("cpu")
         | 
| 367 | 
            +
                    processor_num = -1
         | 
| 368 | 
            +
                    m_threads = 1
         | 
| 369 | 
            +
             | 
| 370 | 
            +
                model_hash = MDX.get_hash(model_path)
         | 
| 371 | 
            +
                mp = model_params.get(model_hash)
         | 
| 372 | 
            +
                model = MDXModel(
         | 
| 373 | 
            +
                    device,
         | 
| 374 | 
            +
                    dim_f=mp["mdx_dim_f_set"],
         | 
| 375 | 
            +
                    dim_t=2 ** mp["mdx_dim_t_set"],
         | 
| 376 | 
            +
                    n_fft=mp["mdx_n_fft_scale_set"],
         | 
| 377 | 
            +
                    stem_name=mp["primary_stem"],
         | 
| 378 | 
            +
                    compensation=mp["compensate"],
         | 
| 379 | 
            +
                )
         | 
| 380 | 
            +
             | 
| 381 | 
            +
                mdx_sess = MDX(model_path, model, processor=processor_num)
         | 
| 382 | 
            +
                wave, sr = librosa.load(filename, mono=False, sr=44100)
         | 
| 383 | 
            +
                # normalizing input wave gives better output
         | 
| 384 | 
            +
                peak = max(np.max(wave), abs(np.min(wave)))
         | 
| 385 | 
            +
                wave /= peak
         | 
| 386 | 
            +
                if denoise:
         | 
| 387 | 
            +
                    wave_processed = -(mdx_sess.process_wave(-wave, m_threads)) + (
         | 
| 388 | 
            +
                        mdx_sess.process_wave(wave, m_threads)
         | 
| 389 | 
            +
                    )
         | 
| 390 | 
            +
                    wave_processed *= 0.5
         | 
| 391 | 
            +
                else:
         | 
| 392 | 
            +
                    wave_processed = mdx_sess.process_wave(wave, m_threads)
         | 
| 393 | 
            +
                # return to previous peak
         | 
| 394 | 
            +
                wave_processed *= peak
         | 
| 395 | 
            +
                stem_name = model.stem_name if suffix is None else suffix
         | 
| 396 | 
            +
             | 
| 397 | 
            +
                main_filepath = None
         | 
| 398 | 
            +
                if not exclude_main:
         | 
| 399 | 
            +
                    main_filepath = os.path.join(
         | 
| 400 | 
            +
                        output_dir,
         | 
| 401 | 
            +
                        f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav",
         | 
| 402 | 
            +
                    )
         | 
| 403 | 
            +
                    sf.write(main_filepath, wave_processed.T, sr)
         | 
| 404 | 
            +
             | 
| 405 | 
            +
                invert_filepath = None
         | 
| 406 | 
            +
                if not exclude_inversion:
         | 
| 407 | 
            +
                    diff_stem_name = (
         | 
| 408 | 
            +
                        stem_naming.get(stem_name)
         | 
| 409 | 
            +
                        if invert_suffix is None
         | 
| 410 | 
            +
                        else invert_suffix
         | 
| 411 | 
            +
                    )
         | 
| 412 | 
            +
                    stem_name = (
         | 
| 413 | 
            +
                        f"{stem_name}_diff" if diff_stem_name is None else diff_stem_name
         | 
| 414 | 
            +
                    )
         | 
| 415 | 
            +
                    invert_filepath = os.path.join(
         | 
| 416 | 
            +
                        output_dir,
         | 
| 417 | 
            +
                        f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav",
         | 
| 418 | 
            +
                    )
         | 
| 419 | 
            +
                    sf.write(
         | 
| 420 | 
            +
                        invert_filepath,
         | 
| 421 | 
            +
                        (-wave_processed.T * model.compensation) + wave.T,
         | 
| 422 | 
            +
                        sr,
         | 
| 423 | 
            +
                    )
         | 
| 424 | 
            +
             | 
| 425 | 
            +
                if not keep_orig:
         | 
| 426 | 
            +
                    os.remove(filename)
         | 
| 427 | 
            +
             | 
| 428 | 
            +
                del mdx_sess, wave_processed, wave
         | 
| 429 | 
            +
                gc.collect()
         | 
| 430 | 
            +
                torch.cuda.empty_cache()
         | 
| 431 | 
            +
                return main_filepath, invert_filepath
         | 
| 432 | 
            +
             | 
| 433 | 
            +
             | 
| 434 | 
            +
            MDX_DOWNLOAD_LINK = "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/"
         | 
| 435 | 
            +
            UVR_MODELS = [
         | 
| 436 | 
            +
                "UVR-MDX-NET-Voc_FT.onnx",
         | 
| 437 | 
            +
                "UVR_MDXNET_KARA_2.onnx",
         | 
| 438 | 
            +
                "Reverb_HQ_By_FoxJoy.onnx",
         | 
| 439 | 
            +
                "UVR-MDX-NET-Inst_HQ_4.onnx",
         | 
| 440 | 
            +
            ]
         | 
| 441 | 
            +
            BASE_DIR = "."  # os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
         | 
| 442 | 
            +
            mdxnet_models_dir = os.path.join(BASE_DIR, "mdx_models")
         | 
| 443 | 
            +
            output_dir = os.path.join(BASE_DIR, "clean_song_output")
         | 
| 444 | 
            +
             | 
| 445 | 
            +
             | 
| 446 | 
            +
            def convert_to_stereo_and_wav(audio_path):
         | 
| 447 | 
            +
                wave, sr = librosa.load(audio_path, mono=False, sr=44100)
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                # check if mono
         | 
| 450 | 
            +
                if type(wave[0]) != np.ndarray or audio_path[-4:].lower() != ".wav": # noqa
         | 
| 451 | 
            +
                    stereo_path = f"{os.path.splitext(audio_path)[0]}_stereo.wav"
         | 
| 452 | 
            +
                    stereo_path = os.path.join(output_dir, stereo_path)
         | 
| 453 | 
            +
             | 
| 454 | 
            +
                    command = shlex.split(
         | 
| 455 | 
            +
                        f'ffmpeg -y -loglevel error -i "{audio_path}" -ac 2 -f wav "{stereo_path}"'
         | 
| 456 | 
            +
                    )
         | 
| 457 | 
            +
                    sub_params = {
         | 
| 458 | 
            +
                        "stdout": subprocess.PIPE,
         | 
| 459 | 
            +
                        "stderr": subprocess.PIPE,
         | 
| 460 | 
            +
                        "creationflags": subprocess.CREATE_NO_WINDOW
         | 
| 461 | 
            +
                        if sys.platform == "win32"
         | 
| 462 | 
            +
                        else 0,
         | 
| 463 | 
            +
                    }
         | 
| 464 | 
            +
                    process_wav = subprocess.Popen(command, **sub_params)
         | 
| 465 | 
            +
                    output, errors = process_wav.communicate()
         | 
| 466 | 
            +
                    if process_wav.returncode != 0 or not os.path.exists(stereo_path):
         | 
| 467 | 
            +
                        raise Exception("Error processing audio to stereo wav")
         | 
| 468 | 
            +
             | 
| 469 | 
            +
                    return stereo_path
         | 
| 470 | 
            +
                else:
         | 
| 471 | 
            +
                    return audio_path
         | 
| 472 | 
            +
             | 
| 473 | 
            +
             | 
| 474 | 
            +
            def process_uvr_task(
         | 
| 475 | 
            +
                orig_song_path: str = "aud_test.mp3",
         | 
| 476 | 
            +
                main_vocals: bool = False,
         | 
| 477 | 
            +
                dereverb: bool = True,
         | 
| 478 | 
            +
                song_id: str = "mdx",  # folder output name
         | 
| 479 | 
            +
                only_voiceless: bool = False,
         | 
| 480 | 
            +
                remove_files_output_dir: bool = False,
         | 
| 481 | 
            +
            ):
         | 
| 482 | 
            +
                if os.environ.get("SONITR_DEVICE") == "cpu":
         | 
| 483 | 
            +
                    device_base = "cpu"
         | 
| 484 | 
            +
                else:
         | 
| 485 | 
            +
                    device_base = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 486 | 
            +
             | 
| 487 | 
            +
                if remove_files_output_dir:
         | 
| 488 | 
            +
                    remove_directory_contents(output_dir)
         | 
| 489 | 
            +
             | 
| 490 | 
            +
                with open(os.path.join(mdxnet_models_dir, "data.json")) as infile:
         | 
| 491 | 
            +
                    mdx_model_params = json.load(infile)
         | 
| 492 | 
            +
             | 
| 493 | 
            +
                song_output_dir = os.path.join(output_dir, song_id)
         | 
| 494 | 
            +
                create_directories(song_output_dir)
         | 
| 495 | 
            +
                orig_song_path = convert_to_stereo_and_wav(orig_song_path)
         | 
| 496 | 
            +
             | 
| 497 | 
            +
                logger.debug(f"onnxruntime device >> {ort.get_device()}")
         | 
| 498 | 
            +
             | 
| 499 | 
            +
                if only_voiceless:
         | 
| 500 | 
            +
                    logger.info("Voiceless Track Separation...")
         | 
| 501 | 
            +
                    return run_mdx(
         | 
| 502 | 
            +
                        mdx_model_params,
         | 
| 503 | 
            +
                        song_output_dir,
         | 
| 504 | 
            +
                        os.path.join(mdxnet_models_dir, "UVR-MDX-NET-Inst_HQ_4.onnx"),
         | 
| 505 | 
            +
                        orig_song_path,
         | 
| 506 | 
            +
                        suffix="Voiceless",
         | 
| 507 | 
            +
                        denoise=False,
         | 
| 508 | 
            +
                        keep_orig=True,
         | 
| 509 | 
            +
                        exclude_inversion=True,
         | 
| 510 | 
            +
                        device_base=device_base,
         | 
| 511 | 
            +
                    )
         | 
| 512 | 
            +
             | 
| 513 | 
            +
                logger.info("Vocal Track Isolation and Voiceless Track Separation...")
         | 
| 514 | 
            +
                vocals_path, instrumentals_path = run_mdx(
         | 
| 515 | 
            +
                    mdx_model_params,
         | 
| 516 | 
            +
                    song_output_dir,
         | 
| 517 | 
            +
                    os.path.join(mdxnet_models_dir, "UVR-MDX-NET-Voc_FT.onnx"),
         | 
| 518 | 
            +
                    orig_song_path,
         | 
| 519 | 
            +
                    denoise=True,
         | 
| 520 | 
            +
                    keep_orig=True,
         | 
| 521 | 
            +
                    device_base=device_base,
         | 
| 522 | 
            +
                )
         | 
| 523 | 
            +
             | 
| 524 | 
            +
                if main_vocals:
         | 
| 525 | 
            +
                    logger.info("Main Voice Separation from Supporting Vocals...")
         | 
| 526 | 
            +
                    backup_vocals_path, main_vocals_path = run_mdx(
         | 
| 527 | 
            +
                        mdx_model_params,
         | 
| 528 | 
            +
                        song_output_dir,
         | 
| 529 | 
            +
                        os.path.join(mdxnet_models_dir, "UVR_MDXNET_KARA_2.onnx"),
         | 
| 530 | 
            +
                        vocals_path,
         | 
| 531 | 
            +
                        suffix="Backup",
         | 
| 532 | 
            +
                        invert_suffix="Main",
         | 
| 533 | 
            +
                        denoise=True,
         | 
| 534 | 
            +
                        device_base=device_base,
         | 
| 535 | 
            +
                    )
         | 
| 536 | 
            +
                else:
         | 
| 537 | 
            +
                    backup_vocals_path, main_vocals_path = None, vocals_path
         | 
| 538 | 
            +
             | 
| 539 | 
            +
                if dereverb:
         | 
| 540 | 
            +
                    logger.info("Vocal Clarity Enhancement through De-Reverberation...")
         | 
| 541 | 
            +
                    _, vocals_dereverb_path = run_mdx(
         | 
| 542 | 
            +
                        mdx_model_params,
         | 
| 543 | 
            +
                        song_output_dir,
         | 
| 544 | 
            +
                        os.path.join(mdxnet_models_dir, "Reverb_HQ_By_FoxJoy.onnx"),
         | 
| 545 | 
            +
                        main_vocals_path,
         | 
| 546 | 
            +
                        invert_suffix="DeReverb",
         | 
| 547 | 
            +
                        exclude_main=True,
         | 
| 548 | 
            +
                        denoise=True,
         | 
| 549 | 
            +
                        device_base=device_base,
         | 
| 550 | 
            +
                    )
         | 
| 551 | 
            +
                else:
         | 
| 552 | 
            +
                    vocals_dereverb_path = main_vocals_path
         | 
| 553 | 
            +
             | 
| 554 | 
            +
                return (
         | 
| 555 | 
            +
                    vocals_path,
         | 
| 556 | 
            +
                    instrumentals_path,
         | 
| 557 | 
            +
                    backup_vocals_path,
         | 
| 558 | 
            +
                    main_vocals_path,
         | 
| 559 | 
            +
                    vocals_dereverb_path,
         | 
| 560 | 
            +
                )
         | 
| 561 | 
            +
             | 
| 562 | 
            +
             | 
| 563 | 
            +
            if __name__ == "__main__":
         | 
| 564 | 
            +
                from utils import download_manager
         | 
| 565 | 
            +
             | 
| 566 | 
            +
                for id_model in UVR_MODELS:
         | 
| 567 | 
            +
                    download_manager(
         | 
| 568 | 
            +
                        os.path.join(MDX_DOWNLOAD_LINK, id_model), mdxnet_models_dir
         | 
| 569 | 
            +
                    )
         | 
| 570 | 
            +
                (
         | 
| 571 | 
            +
                    vocals_path_,
         | 
| 572 | 
            +
                    instrumentals_path_,
         | 
| 573 | 
            +
                    backup_vocals_path_,
         | 
| 574 | 
            +
                    main_vocals_path_,
         | 
| 575 | 
            +
                    vocals_dereverb_path_,
         | 
| 576 | 
            +
                ) = process_uvr_task(
         | 
| 577 | 
            +
                    orig_song_path="aud.mp3",
         | 
| 578 | 
            +
                    main_vocals=True,
         | 
| 579 | 
            +
                    dereverb=True,
         | 
| 580 | 
            +
                    song_id="mdx",
         | 
| 581 | 
            +
                    remove_files_output_dir=True,
         | 
| 582 | 
            +
                )
         | 
    	
        soni_translate/postprocessor.py
    ADDED
    
    | @@ -0,0 +1,229 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .utils import remove_files, run_command
         | 
| 2 | 
            +
            from .text_multiformat_processor import get_subtitle
         | 
| 3 | 
            +
            from .logging_setup import logger
         | 
| 4 | 
            +
            import unicodedata
         | 
| 5 | 
            +
            import shutil
         | 
| 6 | 
            +
            import copy
         | 
| 7 | 
            +
            import os
         | 
| 8 | 
            +
            import re
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            OUTPUT_TYPE_OPTIONS = [
         | 
| 11 | 
            +
                "video (mp4)",
         | 
| 12 | 
            +
                "video (mkv)",
         | 
| 13 | 
            +
                "audio (mp3)",
         | 
| 14 | 
            +
                "audio (ogg)",
         | 
| 15 | 
            +
                "audio (wav)",
         | 
| 16 | 
            +
                "subtitle",
         | 
| 17 | 
            +
                "subtitle [by speaker]",
         | 
| 18 | 
            +
                "video [subtitled] (mp4)",
         | 
| 19 | 
            +
                "video [subtitled] (mkv)",
         | 
| 20 | 
            +
                "audio [original vocal sound]",
         | 
| 21 | 
            +
                "audio [original background sound]",
         | 
| 22 | 
            +
                "audio [original vocal and background sound]",
         | 
| 23 | 
            +
                "audio [original vocal-dereverb sound]",
         | 
| 24 | 
            +
                "audio [original vocal-dereverb and background sound]",
         | 
| 25 | 
            +
                "raw media",
         | 
| 26 | 
            +
            ]
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            DOCS_OUTPUT_TYPE_OPTIONS = [
         | 
| 29 | 
            +
                "videobook (mp4)",
         | 
| 30 | 
            +
                "videobook (mkv)",
         | 
| 31 | 
            +
                "audiobook (wav)",
         | 
| 32 | 
            +
                "audiobook (mp3)",
         | 
| 33 | 
            +
                "audiobook (ogg)",
         | 
| 34 | 
            +
                "book (txt)",
         | 
| 35 | 
            +
            ]  # Add DOCX and etc.
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            def get_no_ext_filename(file_path):
         | 
| 39 | 
            +
                file_name_with_extension = os.path.basename(rf"{file_path}")
         | 
| 40 | 
            +
                filename_without_extension, _ = os.path.splitext(file_name_with_extension)
         | 
| 41 | 
            +
                return filename_without_extension
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            def get_video_info(link):
         | 
| 45 | 
            +
                aux_name = f"video_url_{link}"
         | 
| 46 | 
            +
                params_dlp = {"quiet": True, "no_warnings": True, "noplaylist": True}
         | 
| 47 | 
            +
                try:
         | 
| 48 | 
            +
                    from yt_dlp import YoutubeDL
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    with YoutubeDL(params_dlp) as ydl:
         | 
| 51 | 
            +
                        if link.startswith(("www.youtube.com/", "m.youtube.com/")):
         | 
| 52 | 
            +
                            link = "https://" + link
         | 
| 53 | 
            +
                        info_dict = ydl.extract_info(link, download=False, process=False)
         | 
| 54 | 
            +
                        video_id = info_dict.get("id", aux_name)
         | 
| 55 | 
            +
                        video_title = info_dict.get("title", video_id)
         | 
| 56 | 
            +
                        if "youtube.com" in link and "&list=" in link:
         | 
| 57 | 
            +
                            video_title = ydl.extract_info(
         | 
| 58 | 
            +
                                "https://m.youtube.com/watch?v="+video_id,
         | 
| 59 | 
            +
                                download=False,
         | 
| 60 | 
            +
                                process=False
         | 
| 61 | 
            +
                            ).get("title", video_title)
         | 
| 62 | 
            +
                except Exception as error:
         | 
| 63 | 
            +
                    logger.error(str(error))
         | 
| 64 | 
            +
                    video_title, video_id = aux_name, "NO_ID"
         | 
| 65 | 
            +
                return video_title, video_id
         | 
| 66 | 
            +
             | 
| 67 | 
            +
             | 
| 68 | 
            +
            def sanitize_file_name(file_name):
         | 
| 69 | 
            +
                # Normalize the string to NFKD form to separate combined
         | 
| 70 | 
            +
                # characters into base characters and diacritics
         | 
| 71 | 
            +
                normalized_name = unicodedata.normalize("NFKD", file_name)
         | 
| 72 | 
            +
                # Replace any non-ASCII characters or special symbols with an underscore
         | 
| 73 | 
            +
                sanitized_name = re.sub(r"[^\w\s.-]", "_", normalized_name)
         | 
| 74 | 
            +
                return sanitized_name
         | 
| 75 | 
            +
             | 
| 76 | 
            +
             | 
| 77 | 
            +
            def get_output_file(
         | 
| 78 | 
            +
                    original_file,
         | 
| 79 | 
            +
                    new_file_name,
         | 
| 80 | 
            +
                    soft_subtitles,
         | 
| 81 | 
            +
                    output_directory="",
         | 
| 82 | 
            +
            ):
         | 
| 83 | 
            +
                directory_base = "."  # default directory
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                if output_directory and os.path.isdir(output_directory):
         | 
| 86 | 
            +
                    new_file_path = os.path.join(output_directory, new_file_name)
         | 
| 87 | 
            +
                else:
         | 
| 88 | 
            +
                    new_file_path = os.path.join(directory_base, "outputs", new_file_name)
         | 
| 89 | 
            +
                remove_files(new_file_path)
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                cm = None
         | 
| 92 | 
            +
                if soft_subtitles and original_file.endswith(".mp4"):
         | 
| 93 | 
            +
                    if new_file_path.endswith(".mp4"):
         | 
| 94 | 
            +
                        cm = f'ffmpeg -y -i "{original_file}" -i sub_tra.srt -i sub_ori.srt -map 0:v -map 0:a -map 1 -map 2 -c:v copy -c:a copy -c:s mov_text "{new_file_path}"'
         | 
| 95 | 
            +
                    else:
         | 
| 96 | 
            +
                        cm = f'ffmpeg -y -i "{original_file}" -i sub_tra.srt -i sub_ori.srt -map 0:v -map 0:a -map 1 -map 2 -c:v copy -c:a copy -c:s srt -movflags use_metadata_tags -map_metadata 0 "{new_file_path}"'
         | 
| 97 | 
            +
                elif new_file_path.endswith(".mkv"):
         | 
| 98 | 
            +
                    cm = f'ffmpeg -i "{original_file}" -c:v copy -c:a copy "{new_file_path}"'
         | 
| 99 | 
            +
                elif new_file_path.endswith(".wav") and not original_file.endswith(".wav"):
         | 
| 100 | 
            +
                    cm = f'ffmpeg -y -i "{original_file}" -acodec pcm_s16le -ar 44100 -ac 2 "{new_file_path}"'
         | 
| 101 | 
            +
                elif new_file_path.endswith(".ogg"):
         | 
| 102 | 
            +
                    cm = f'ffmpeg -i "{original_file}" -c:a libvorbis "{new_file_path}"'
         | 
| 103 | 
            +
                elif new_file_path.endswith(".mp3") and not original_file.endswith(".mp3"):
         | 
| 104 | 
            +
                    cm = f'ffmpeg -y -i "{original_file}" -codec:a libmp3lame -qscale:a 2 "{new_file_path}"'
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                if cm:
         | 
| 107 | 
            +
                    try:
         | 
| 108 | 
            +
                        run_command(cm)
         | 
| 109 | 
            +
                    except Exception as error:
         | 
| 110 | 
            +
                        logger.error(str(error))
         | 
| 111 | 
            +
                        remove_files(new_file_path)
         | 
| 112 | 
            +
                        shutil.copy2(original_file, new_file_path)
         | 
| 113 | 
            +
                else:
         | 
| 114 | 
            +
                    shutil.copy2(original_file, new_file_path)
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                return os.path.abspath(new_file_path)
         | 
| 117 | 
            +
             | 
| 118 | 
            +
             | 
| 119 | 
            +
            def media_out(
         | 
| 120 | 
            +
                media_file,
         | 
| 121 | 
            +
                lang_code,
         | 
| 122 | 
            +
                media_out_name="",
         | 
| 123 | 
            +
                extension="mp4",
         | 
| 124 | 
            +
                file_obj="video_dub.mp4",
         | 
| 125 | 
            +
                soft_subtitles=False,
         | 
| 126 | 
            +
                subtitle_files="disable",
         | 
| 127 | 
            +
            ):
         | 
| 128 | 
            +
                if not media_out_name:
         | 
| 129 | 
            +
                    if os.path.exists(media_file):
         | 
| 130 | 
            +
                        base_name = get_no_ext_filename(media_file)
         | 
| 131 | 
            +
                    else:
         | 
| 132 | 
            +
                        base_name, _ = get_video_info(media_file)
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                    media_out_name = f"{base_name}__{lang_code}"
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                f_name = f"{sanitize_file_name(media_out_name)}.{extension}"
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                if subtitle_files != "disable":
         | 
| 139 | 
            +
                    final_media = [get_output_file(file_obj, f_name, soft_subtitles)]
         | 
| 140 | 
            +
                    name_tra = f"{sanitize_file_name(media_out_name)}.{subtitle_files}"
         | 
| 141 | 
            +
                    name_ori = f"{sanitize_file_name(base_name)}.{subtitle_files}"
         | 
| 142 | 
            +
                    tgt_subs = f"sub_tra.{subtitle_files}"
         | 
| 143 | 
            +
                    ori_subs = f"sub_ori.{subtitle_files}"
         | 
| 144 | 
            +
                    final_subtitles = [
         | 
| 145 | 
            +
                        get_output_file(tgt_subs, name_tra, False),
         | 
| 146 | 
            +
                        get_output_file(ori_subs, name_ori, False)
         | 
| 147 | 
            +
                    ]
         | 
| 148 | 
            +
                    return final_media + final_subtitles
         | 
| 149 | 
            +
                else:
         | 
| 150 | 
            +
                    return get_output_file(file_obj, f_name, soft_subtitles)
         | 
| 151 | 
            +
             | 
| 152 | 
            +
             | 
| 153 | 
            +
            def get_subtitle_speaker(media_file, result, language, extension, base_name):
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                segments_base = copy.deepcopy(result)
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                # Sub segments by speaker
         | 
| 158 | 
            +
                segments_by_speaker = {}
         | 
| 159 | 
            +
                for segment in segments_base["segments"]:
         | 
| 160 | 
            +
                    if segment["speaker"] not in segments_by_speaker.keys():
         | 
| 161 | 
            +
                        segments_by_speaker[segment["speaker"]] = [segment]
         | 
| 162 | 
            +
                    else:
         | 
| 163 | 
            +
                        segments_by_speaker[segment["speaker"]].append(segment)
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                if not base_name:
         | 
| 166 | 
            +
                    if os.path.exists(media_file):
         | 
| 167 | 
            +
                        base_name = get_no_ext_filename(media_file)
         | 
| 168 | 
            +
                    else:
         | 
| 169 | 
            +
                        base_name, _ = get_video_info(media_file)
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                files_subs = []
         | 
| 172 | 
            +
                for name_sk, segments in segments_by_speaker.items():
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                    subtitle_speaker = get_subtitle(
         | 
| 175 | 
            +
                        language,
         | 
| 176 | 
            +
                        {"segments": segments},
         | 
| 177 | 
            +
                        extension,
         | 
| 178 | 
            +
                        filename=name_sk,
         | 
| 179 | 
            +
                    )
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                    media_out_name = f"{base_name}_{language}_{name_sk}"
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    output = media_out(
         | 
| 184 | 
            +
                        media_file,  # no need
         | 
| 185 | 
            +
                        language,
         | 
| 186 | 
            +
                        media_out_name,
         | 
| 187 | 
            +
                        extension,
         | 
| 188 | 
            +
                        file_obj=subtitle_speaker,
         | 
| 189 | 
            +
                    )
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                    files_subs.append(output)
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                return files_subs
         | 
| 194 | 
            +
             | 
| 195 | 
            +
             | 
| 196 | 
            +
            def sound_separate(media_file, task_uvr):
         | 
| 197 | 
            +
                from .mdx_net import process_uvr_task
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                outputs = []
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                if "vocal" in task_uvr:
         | 
| 202 | 
            +
                    try:
         | 
| 203 | 
            +
                        _, _, _, _, vocal_audio = process_uvr_task(
         | 
| 204 | 
            +
                            orig_song_path=media_file,
         | 
| 205 | 
            +
                            main_vocals=False,
         | 
| 206 | 
            +
                            dereverb=True if "dereverb" in task_uvr else False,
         | 
| 207 | 
            +
                            remove_files_output_dir=True,
         | 
| 208 | 
            +
                        )
         | 
| 209 | 
            +
                        outputs.append(vocal_audio)
         | 
| 210 | 
            +
                    except Exception as error:
         | 
| 211 | 
            +
                        logger.error(str(error))
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                if "background" in task_uvr:
         | 
| 214 | 
            +
                    try:
         | 
| 215 | 
            +
                        background_audio, _ = process_uvr_task(
         | 
| 216 | 
            +
                            orig_song_path=media_file,
         | 
| 217 | 
            +
                            song_id="voiceless",
         | 
| 218 | 
            +
                            only_voiceless=True,
         | 
| 219 | 
            +
                            remove_files_output_dir=False if "vocal" in task_uvr else True,
         | 
| 220 | 
            +
                        )
         | 
| 221 | 
            +
                        # copy_files(background_audio, ".")
         | 
| 222 | 
            +
                        outputs.append(background_audio)
         | 
| 223 | 
            +
                    except Exception as error:
         | 
| 224 | 
            +
                        logger.error(str(error))
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                if not outputs:
         | 
| 227 | 
            +
                    raise Exception("Error in uvr process")
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                return outputs
         | 
    	
        soni_translate/preprocessor.py
    ADDED
    
    | @@ -0,0 +1,308 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .utils import remove_files
         | 
| 2 | 
            +
            import os, shutil, subprocess, time, shlex, sys # noqa
         | 
| 3 | 
            +
            from .logging_setup import logger
         | 
| 4 | 
            +
            import json
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            ERROR_INCORRECT_CODEC_PARAMETERS = [
         | 
| 7 | 
            +
                "prores",  # mov
         | 
| 8 | 
            +
                "ffv1",  # mkv
         | 
| 9 | 
            +
                "msmpeg4v3",  # avi
         | 
| 10 | 
            +
                "wmv2",  # wmv
         | 
| 11 | 
            +
                "theora",  # ogv
         | 
| 12 | 
            +
            ]  # fix final merge
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            TESTED_CODECS = [
         | 
| 15 | 
            +
                "h264",  # mp4
         | 
| 16 | 
            +
                "h265",  # mp4
         | 
| 17 | 
            +
                "vp9",  # webm
         | 
| 18 | 
            +
                "mpeg4",  # mp4
         | 
| 19 | 
            +
                "mpeg2video",  # mpg
         | 
| 20 | 
            +
                "mjpeg",  # avi
         | 
| 21 | 
            +
            ]
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            class OperationFailedError(Exception):
         | 
| 25 | 
            +
                def __init__(self, message="The operation did not complete successfully."):
         | 
| 26 | 
            +
                    self.message = message
         | 
| 27 | 
            +
                    super().__init__(self.message)
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            def get_video_codec(video_file):
         | 
| 31 | 
            +
                command_base = rf'ffprobe -v error -select_streams v:0 -show_entries stream=codec_name -of json "{video_file}"'
         | 
| 32 | 
            +
                command = shlex.split(command_base)
         | 
| 33 | 
            +
                try:
         | 
| 34 | 
            +
                    process = subprocess.Popen(
         | 
| 35 | 
            +
                        command,
         | 
| 36 | 
            +
                        stdout=subprocess.PIPE,
         | 
| 37 | 
            +
                        creationflags=subprocess.CREATE_NO_WINDOW if sys.platform == "win32" else 0,
         | 
| 38 | 
            +
                    )
         | 
| 39 | 
            +
                    output, _ = process.communicate()
         | 
| 40 | 
            +
                    codec_info = json.loads(output.decode('utf-8'))
         | 
| 41 | 
            +
                    codec_name = codec_info['streams'][0]['codec_name']
         | 
| 42 | 
            +
                    return codec_name
         | 
| 43 | 
            +
                except Exception as error:
         | 
| 44 | 
            +
                    logger.debug(str(error))
         | 
| 45 | 
            +
                    return None
         | 
| 46 | 
            +
             | 
| 47 | 
            +
             | 
| 48 | 
            +
            def audio_preprocessor(preview, base_audio, audio_wav, use_cuda=False):
         | 
| 49 | 
            +
                base_audio = base_audio.strip()
         | 
| 50 | 
            +
                previous_files_to_remove = [audio_wav]
         | 
| 51 | 
            +
                remove_files(previous_files_to_remove)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                if preview:
         | 
| 54 | 
            +
                    logger.warning(
         | 
| 55 | 
            +
                        "Creating a preview video of 10 seconds, to disable "
         | 
| 56 | 
            +
                        "this option, go to advanced settings and turn off preview."
         | 
| 57 | 
            +
                    )
         | 
| 58 | 
            +
                    wav_ = f'ffmpeg -y -i "{base_audio}" -ss 00:00:20 -t 00:00:10 -vn -acodec pcm_s16le -ar 44100 -ac 2 audio.wav'
         | 
| 59 | 
            +
                else:
         | 
| 60 | 
            +
                    wav_ = f'ffmpeg -y -i "{base_audio}" -vn -acodec pcm_s16le -ar 44100 -ac 2 audio.wav'
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                # Run cmd process
         | 
| 63 | 
            +
                sub_params = {
         | 
| 64 | 
            +
                    "stdout": subprocess.PIPE,
         | 
| 65 | 
            +
                    "stderr": subprocess.PIPE,
         | 
| 66 | 
            +
                    "creationflags": subprocess.CREATE_NO_WINDOW
         | 
| 67 | 
            +
                    if sys.platform == "win32"
         | 
| 68 | 
            +
                    else 0,
         | 
| 69 | 
            +
                }
         | 
| 70 | 
            +
                wav_ = shlex.split(wav_)
         | 
| 71 | 
            +
                result_convert_audio = subprocess.Popen(wav_, **sub_params)
         | 
| 72 | 
            +
                output, errors = result_convert_audio.communicate()
         | 
| 73 | 
            +
                time.sleep(1)
         | 
| 74 | 
            +
                if result_convert_audio.returncode in [1, 2] or not os.path.exists(
         | 
| 75 | 
            +
                    audio_wav
         | 
| 76 | 
            +
                ):
         | 
| 77 | 
            +
                    raise OperationFailedError(f"Error can't create the audio file:\n{errors.decode('utf-8')}")
         | 
| 78 | 
            +
             | 
| 79 | 
            +
             | 
| 80 | 
            +
            def audio_video_preprocessor(
         | 
| 81 | 
            +
                preview, video, OutputFile, audio_wav, use_cuda=False
         | 
| 82 | 
            +
            ):
         | 
| 83 | 
            +
                video = video.strip()
         | 
| 84 | 
            +
                previous_files_to_remove = [OutputFile, "audio.webm", audio_wav]
         | 
| 85 | 
            +
                remove_files(previous_files_to_remove)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                if os.path.exists(video):
         | 
| 88 | 
            +
                    if preview:
         | 
| 89 | 
            +
                        logger.warning(
         | 
| 90 | 
            +
                            "Creating a preview video of 10 seconds, "
         | 
| 91 | 
            +
                            "to disable this option, go to advanced "
         | 
| 92 | 
            +
                            "settings and turn off preview."
         | 
| 93 | 
            +
                        )
         | 
| 94 | 
            +
                        mp4_ = f'ffmpeg -y -i "{video}" -ss 00:00:20 -t 00:00:10 -c:v libx264 -c:a aac -strict experimental Video.mp4'
         | 
| 95 | 
            +
                    else:
         | 
| 96 | 
            +
                        video_codec = get_video_codec(video)
         | 
| 97 | 
            +
                        if not video_codec:
         | 
| 98 | 
            +
                            logger.debug("No video codec found in video")
         | 
| 99 | 
            +
                        else:
         | 
| 100 | 
            +
                            logger.info(f"Video codec: {video_codec}")
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                        # Check if the file ends with ".mp4" extension or is valid codec
         | 
| 103 | 
            +
                        if video.endswith(".mp4") or video_codec in TESTED_CODECS:
         | 
| 104 | 
            +
                            destination_path = os.path.join(os.getcwd(), "Video.mp4")
         | 
| 105 | 
            +
                            shutil.copy(video, destination_path)
         | 
| 106 | 
            +
                            time.sleep(0.5)
         | 
| 107 | 
            +
                            if os.path.exists(OutputFile):
         | 
| 108 | 
            +
                                mp4_ = "ffmpeg -h"
         | 
| 109 | 
            +
                            else:
         | 
| 110 | 
            +
                                mp4_ = f'ffmpeg -y -i "{video}" -c copy Video.mp4'
         | 
| 111 | 
            +
                        else:
         | 
| 112 | 
            +
                            logger.warning(
         | 
| 113 | 
            +
                                "File does not have the '.mp4' extension  or a "
         | 
| 114 | 
            +
                                "supported codec. Converting video to mp4 (codec: h264)."
         | 
| 115 | 
            +
                            )
         | 
| 116 | 
            +
                            mp4_ = f'ffmpeg -y -i "{video}" -c:v libx264 -c:a aac -strict experimental Video.mp4'
         | 
| 117 | 
            +
                else:
         | 
| 118 | 
            +
                    if preview:
         | 
| 119 | 
            +
                        logger.warning(
         | 
| 120 | 
            +
                            "Creating a preview from the link, 10 seconds "
         | 
| 121 | 
            +
                            "to disable this option, go to advanced "
         | 
| 122 | 
            +
                            "settings and turn off preview."
         | 
| 123 | 
            +
                        )
         | 
| 124 | 
            +
                        # https://github.com/yt-dlp/yt-dlp/issues/2220
         | 
| 125 | 
            +
                        mp4_ = f'yt-dlp -f "mp4" --downloader ffmpeg --downloader-args "ffmpeg_i: -ss 00:00:20 -t 00:00:10" --force-overwrites --max-downloads 1 --no-warnings --no-playlist --no-abort-on-error --ignore-no-formats-error --restrict-filenames -o {OutputFile} {video}'
         | 
| 126 | 
            +
                        wav_ = "ffmpeg -y -i Video.mp4 -vn -acodec pcm_s16le -ar 44100 -ac 2 audio.wav"
         | 
| 127 | 
            +
                    else:
         | 
| 128 | 
            +
                        mp4_ = f'yt-dlp -f "mp4" --force-overwrites --max-downloads 1 --no-warnings --no-playlist --no-abort-on-error --ignore-no-formats-error --restrict-filenames -o {OutputFile} {video}'
         | 
| 129 | 
            +
                        wav_ = f"python -m yt_dlp --output {audio_wav} --force-overwrites --max-downloads 1 --no-warnings --no-playlist --no-abort-on-error --ignore-no-formats-error --extract-audio --audio-format wav {video}"
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                # Run cmd process
         | 
| 132 | 
            +
                mp4_ = shlex.split(mp4_)
         | 
| 133 | 
            +
                sub_params = {
         | 
| 134 | 
            +
                    "stdout": subprocess.PIPE,
         | 
| 135 | 
            +
                    "stderr": subprocess.PIPE,
         | 
| 136 | 
            +
                    "creationflags": subprocess.CREATE_NO_WINDOW
         | 
| 137 | 
            +
                    if sys.platform == "win32"
         | 
| 138 | 
            +
                    else 0,
         | 
| 139 | 
            +
                }
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                if os.path.exists(video):
         | 
| 142 | 
            +
                    logger.info("Process video...")
         | 
| 143 | 
            +
                    result_convert_video = subprocess.Popen(mp4_, **sub_params)
         | 
| 144 | 
            +
                    # result_convert_video.wait()
         | 
| 145 | 
            +
                    output, errors = result_convert_video.communicate()
         | 
| 146 | 
            +
                    time.sleep(1)
         | 
| 147 | 
            +
                    if result_convert_video.returncode in [1, 2] or not os.path.exists(
         | 
| 148 | 
            +
                        OutputFile
         | 
| 149 | 
            +
                    ):
         | 
| 150 | 
            +
                        raise OperationFailedError(f"Error processing video:\n{errors.decode('utf-8')}")
         | 
| 151 | 
            +
                    logger.info("Process audio...")
         | 
| 152 | 
            +
                    wav_ = "ffmpeg -y -i Video.mp4 -vn -acodec pcm_s16le -ar 44100 -ac 2 audio.wav"
         | 
| 153 | 
            +
                    wav_ = shlex.split(wav_)
         | 
| 154 | 
            +
                    result_convert_audio = subprocess.Popen(wav_, **sub_params)
         | 
| 155 | 
            +
                    output, errors = result_convert_audio.communicate()
         | 
| 156 | 
            +
                    time.sleep(1)
         | 
| 157 | 
            +
                    if result_convert_audio.returncode in [1, 2] or not os.path.exists(
         | 
| 158 | 
            +
                        audio_wav
         | 
| 159 | 
            +
                    ):
         | 
| 160 | 
            +
                        raise OperationFailedError(f"Error can't create the audio file:\n{errors.decode('utf-8')}")
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                else:
         | 
| 163 | 
            +
                    wav_ = shlex.split(wav_)
         | 
| 164 | 
            +
                    if preview:
         | 
| 165 | 
            +
                        result_convert_video = subprocess.Popen(mp4_, **sub_params)
         | 
| 166 | 
            +
                        output, errors = result_convert_video.communicate()
         | 
| 167 | 
            +
                        time.sleep(0.5)
         | 
| 168 | 
            +
                        result_convert_audio = subprocess.Popen(wav_, **sub_params)
         | 
| 169 | 
            +
                        output, errors = result_convert_audio.communicate()
         | 
| 170 | 
            +
                        time.sleep(0.5)
         | 
| 171 | 
            +
                        if result_convert_audio.returncode in [1, 2] or not os.path.exists(
         | 
| 172 | 
            +
                            audio_wav
         | 
| 173 | 
            +
                        ):
         | 
| 174 | 
            +
                            raise OperationFailedError(
         | 
| 175 | 
            +
                                f"Error can't create the preview file:\n{errors.decode('utf-8')}"
         | 
| 176 | 
            +
                            )
         | 
| 177 | 
            +
                    else:
         | 
| 178 | 
            +
                        logger.info("Process audio...")
         | 
| 179 | 
            +
                        result_convert_audio = subprocess.Popen(wav_, **sub_params)
         | 
| 180 | 
            +
                        output, errors = result_convert_audio.communicate()
         | 
| 181 | 
            +
                        time.sleep(1)
         | 
| 182 | 
            +
                        if result_convert_audio.returncode in [1, 2] or not os.path.exists(
         | 
| 183 | 
            +
                            audio_wav
         | 
| 184 | 
            +
                        ):
         | 
| 185 | 
            +
                            raise OperationFailedError(f"Error can't download the audio:\n{errors.decode('utf-8')}")
         | 
| 186 | 
            +
                        logger.info("Process video...")
         | 
| 187 | 
            +
                        result_convert_video = subprocess.Popen(mp4_, **sub_params)
         | 
| 188 | 
            +
                        output, errors = result_convert_video.communicate()
         | 
| 189 | 
            +
                        time.sleep(1)
         | 
| 190 | 
            +
                        if result_convert_video.returncode in [1, 2] or not os.path.exists(
         | 
| 191 | 
            +
                            OutputFile
         | 
| 192 | 
            +
                        ):
         | 
| 193 | 
            +
                            raise OperationFailedError(f"Error can't download the video:\n{errors.decode('utf-8')}")
         | 
| 194 | 
            +
             | 
| 195 | 
            +
             | 
| 196 | 
            +
            def old_audio_video_preprocessor(preview, video, OutputFile, audio_wav):
         | 
| 197 | 
            +
                previous_files_to_remove = [OutputFile, "audio.webm", audio_wav]
         | 
| 198 | 
            +
                remove_files(previous_files_to_remove)
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                if os.path.exists(video):
         | 
| 201 | 
            +
                    if preview:
         | 
| 202 | 
            +
                        logger.warning(
         | 
| 203 | 
            +
                            "Creating a preview video of 10 seconds, "
         | 
| 204 | 
            +
                            "to disable this option, go to advanced "
         | 
| 205 | 
            +
                            "settings and turn off preview."
         | 
| 206 | 
            +
                        )
         | 
| 207 | 
            +
                        command = f'ffmpeg -y -i "{video}" -ss 00:00:20 -t 00:00:10 -c:v libx264 -c:a aac -strict experimental Video.mp4'
         | 
| 208 | 
            +
                        result_convert_video = subprocess.run(
         | 
| 209 | 
            +
                            command, capture_output=True, text=True, shell=True
         | 
| 210 | 
            +
                        )
         | 
| 211 | 
            +
                    else:
         | 
| 212 | 
            +
                        # Check if the file ends with ".mp4" extension
         | 
| 213 | 
            +
                        if video.endswith(".mp4"):
         | 
| 214 | 
            +
                            destination_path = os.path.join(os.getcwd(), "Video.mp4")
         | 
| 215 | 
            +
                            shutil.copy(video, destination_path)
         | 
| 216 | 
            +
                            result_convert_video = {}
         | 
| 217 | 
            +
                            result_convert_video = subprocess.run(
         | 
| 218 | 
            +
                                "echo Video copied",
         | 
| 219 | 
            +
                                capture_output=True,
         | 
| 220 | 
            +
                                text=True,
         | 
| 221 | 
            +
                                shell=True,
         | 
| 222 | 
            +
                            )
         | 
| 223 | 
            +
                        else:
         | 
| 224 | 
            +
                            logger.warning(
         | 
| 225 | 
            +
                                "File does not have the '.mp4' extension. Converting video."
         | 
| 226 | 
            +
                            )
         | 
| 227 | 
            +
                            command = f'ffmpeg -y -i "{video}" -c:v libx264 -c:a aac -strict experimental Video.mp4'
         | 
| 228 | 
            +
                            result_convert_video = subprocess.run(
         | 
| 229 | 
            +
                                command, capture_output=True, text=True, shell=True
         | 
| 230 | 
            +
                            )
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                    if result_convert_video.returncode in [1, 2]:
         | 
| 233 | 
            +
                        raise OperationFailedError("Error can't convert the video")
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                    for i in range(120):
         | 
| 236 | 
            +
                        time.sleep(1)
         | 
| 237 | 
            +
                        logger.info("Process video...")
         | 
| 238 | 
            +
                        if os.path.exists(OutputFile):
         | 
| 239 | 
            +
                            time.sleep(1)
         | 
| 240 | 
            +
                            command = "ffmpeg -y -i Video.mp4 -vn -acodec pcm_s16le -ar 44100 -ac 2 audio.wav"
         | 
| 241 | 
            +
                            result_convert_audio = subprocess.run(
         | 
| 242 | 
            +
                                command, capture_output=True, text=True, shell=True
         | 
| 243 | 
            +
                            )
         | 
| 244 | 
            +
                            time.sleep(1)
         | 
| 245 | 
            +
                            break
         | 
| 246 | 
            +
                        if i == 119:
         | 
| 247 | 
            +
                            # if not os.path.exists(OutputFile):
         | 
| 248 | 
            +
                            raise OperationFailedError("Error processing video")
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                    if result_convert_audio.returncode in [1, 2]:
         | 
| 251 | 
            +
                        raise OperationFailedError(
         | 
| 252 | 
            +
                            f"Error can't create the audio file: {result_convert_audio.stderr}"
         | 
| 253 | 
            +
                        )
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                    for i in range(120):
         | 
| 256 | 
            +
                        time.sleep(1)
         | 
| 257 | 
            +
                        logger.info("Process audio...")
         | 
| 258 | 
            +
                        if os.path.exists(audio_wav):
         | 
| 259 | 
            +
                            break
         | 
| 260 | 
            +
                        if i == 119:
         | 
| 261 | 
            +
                            raise OperationFailedError("Error can't create the audio file")
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                else:
         | 
| 264 | 
            +
                    video = video.strip()
         | 
| 265 | 
            +
                    if preview:
         | 
| 266 | 
            +
                        logger.warning(
         | 
| 267 | 
            +
                            "Creating a preview from the link, 10 "
         | 
| 268 | 
            +
                            "seconds to disable this option, go to "
         | 
| 269 | 
            +
                            "advanced settings and turn off preview."
         | 
| 270 | 
            +
                        )
         | 
| 271 | 
            +
                        # https://github.com/yt-dlp/yt-dlp/issues/2220
         | 
| 272 | 
            +
                        mp4_ = f'yt-dlp -f "mp4" --downloader ffmpeg --downloader-args "ffmpeg_i: -ss 00:00:20 -t 00:00:10" --force-overwrites --max-downloads 1 --no-warnings --no-abort-on-error --ignore-no-formats-error --restrict-filenames -o {OutputFile} {video}'
         | 
| 273 | 
            +
                        wav_ = "ffmpeg -y -i Video.mp4 -vn -acodec pcm_s16le -ar 44100 -ac 2 audio.wav"
         | 
| 274 | 
            +
                        result_convert_video = subprocess.run(
         | 
| 275 | 
            +
                            mp4_, capture_output=True, text=True, shell=True
         | 
| 276 | 
            +
                        )
         | 
| 277 | 
            +
                        result_convert_audio = subprocess.run(
         | 
| 278 | 
            +
                            wav_, capture_output=True, text=True, shell=True
         | 
| 279 | 
            +
                        )
         | 
| 280 | 
            +
                        if result_convert_audio.returncode in [1, 2]:
         | 
| 281 | 
            +
                            raise OperationFailedError("Error can't download a preview")
         | 
| 282 | 
            +
                    else:
         | 
| 283 | 
            +
                        mp4_ = f'yt-dlp -f "mp4" --force-overwrites --max-downloads 1 --no-warnings --no-abort-on-error --ignore-no-formats-error --restrict-filenames -o {OutputFile} {video}'
         | 
| 284 | 
            +
                        wav_ = f"python -m yt_dlp --output {audio_wav} --force-overwrites --max-downloads 1 --no-warnings --no-abort-on-error --ignore-no-formats-error --extract-audio --audio-format wav {video}"
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                        result_convert_audio = subprocess.run(
         | 
| 287 | 
            +
                            wav_, capture_output=True, text=True, shell=True
         | 
| 288 | 
            +
                        )
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                        if result_convert_audio.returncode in [1, 2]:
         | 
| 291 | 
            +
                            raise OperationFailedError("Error can't download the audio")
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                        for i in range(120):
         | 
| 294 | 
            +
                            time.sleep(1)
         | 
| 295 | 
            +
                            logger.info("Process audio...")
         | 
| 296 | 
            +
                            if os.path.exists(audio_wav) and not os.path.exists(
         | 
| 297 | 
            +
                                "audio.webm"
         | 
| 298 | 
            +
                            ):
         | 
| 299 | 
            +
                                time.sleep(1)
         | 
| 300 | 
            +
                                result_convert_video = subprocess.run(
         | 
| 301 | 
            +
                                    mp4_, capture_output=True, text=True, shell=True
         | 
| 302 | 
            +
                                )
         | 
| 303 | 
            +
                                break
         | 
| 304 | 
            +
                            if i == 119:
         | 
| 305 | 
            +
                                raise OperationFailedError("Error downloading the audio")
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                        if result_convert_video.returncode in [1, 2]:
         | 
| 308 | 
            +
                            raise OperationFailedError("Error can't download the video")
         | 
    	
        soni_translate/speech_segmentation.py
    ADDED
    
    | @@ -0,0 +1,499 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from whisperx.alignment import (
         | 
| 2 | 
            +
                DEFAULT_ALIGN_MODELS_TORCH as DAMT,
         | 
| 3 | 
            +
                DEFAULT_ALIGN_MODELS_HF as DAMHF,
         | 
| 4 | 
            +
            )
         | 
| 5 | 
            +
            from whisperx.utils import TO_LANGUAGE_CODE
         | 
| 6 | 
            +
            import whisperx
         | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            import gc
         | 
| 9 | 
            +
            import os
         | 
| 10 | 
            +
            import soundfile as sf
         | 
| 11 | 
            +
            from IPython.utils import capture # noqa
         | 
| 12 | 
            +
            from .language_configuration import EXTRA_ALIGN, INVERTED_LANGUAGES
         | 
| 13 | 
            +
            from .logging_setup import logger
         | 
| 14 | 
            +
            from .postprocessor import sanitize_file_name
         | 
| 15 | 
            +
            from .utils import remove_directory_contents, run_command
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            # ZERO GPU CONFIG
         | 
| 18 | 
            +
            import spaces
         | 
| 19 | 
            +
            import copy
         | 
| 20 | 
            +
            import random
         | 
| 21 | 
            +
            import time
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            def random_sleep():
         | 
| 24 | 
            +
                if os.environ.get("ZERO_GPU") == "TRUE":
         | 
| 25 | 
            +
                    print("Random sleep")
         | 
| 26 | 
            +
                    sleep_time = round(random.uniform(7.2, 9.9), 1)
         | 
| 27 | 
            +
                    time.sleep(sleep_time)
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            @spaces.GPU(duration=120)
         | 
| 31 | 
            +
            def load_and_transcribe_audio(asr_model, audio, compute_type, language, asr_options, batch_size, segment_duration_limit):
         | 
| 32 | 
            +
                # Load model
         | 
| 33 | 
            +
                model = whisperx.load_model(
         | 
| 34 | 
            +
                    asr_model,
         | 
| 35 | 
            +
                    os.environ.get("SONITR_DEVICE") if os.environ.get("ZERO_GPU") != "TRUE" else "cuda",
         | 
| 36 | 
            +
                    compute_type=compute_type,
         | 
| 37 | 
            +
                    language=language,
         | 
| 38 | 
            +
                    asr_options=asr_options,
         | 
| 39 | 
            +
                )
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                # Transcribe audio
         | 
| 42 | 
            +
                result = model.transcribe(
         | 
| 43 | 
            +
                    audio,
         | 
| 44 | 
            +
                    batch_size=batch_size,
         | 
| 45 | 
            +
                    chunk_size=segment_duration_limit,
         | 
| 46 | 
            +
                    print_progress=True,
         | 
| 47 | 
            +
                )
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                del model
         | 
| 50 | 
            +
                gc.collect()
         | 
| 51 | 
            +
                torch.cuda.empty_cache()  # noqa
         | 
| 52 | 
            +
                
         | 
| 53 | 
            +
                return result
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            def load_align_and_align_segments(result, audio, DAMHF):
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                # Load alignment model
         | 
| 58 | 
            +
                model_a, metadata = whisperx.load_align_model(
         | 
| 59 | 
            +
                    language_code=result["language"],
         | 
| 60 | 
            +
                    device=os.environ.get("SONITR_DEVICE") if os.environ.get("ZERO_GPU") != "TRUE" else "cuda",
         | 
| 61 | 
            +
                    model_name=None
         | 
| 62 | 
            +
                    if result["language"] in DAMHF.keys()
         | 
| 63 | 
            +
                    else EXTRA_ALIGN[result["language"]],
         | 
| 64 | 
            +
                )
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                # Align segments
         | 
| 67 | 
            +
                alignment_result = whisperx.align(
         | 
| 68 | 
            +
                    result["segments"],
         | 
| 69 | 
            +
                    model_a,
         | 
| 70 | 
            +
                    metadata,
         | 
| 71 | 
            +
                    audio,
         | 
| 72 | 
            +
                    os.environ.get("SONITR_DEVICE") if os.environ.get("ZERO_GPU") != "TRUE" else "cuda",
         | 
| 73 | 
            +
                    return_char_alignments=True,
         | 
| 74 | 
            +
                    print_progress=False,
         | 
| 75 | 
            +
                )
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                # Clean up
         | 
| 78 | 
            +
                del model_a
         | 
| 79 | 
            +
                gc.collect()
         | 
| 80 | 
            +
                torch.cuda.empty_cache()  # noqa
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                return alignment_result
         | 
| 83 | 
            +
             | 
| 84 | 
            +
            @spaces.GPU(duration=120)
         | 
| 85 | 
            +
            def diarize_audio(diarize_model, audio_wav, min_speakers, max_speakers):
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                if os.environ.get("ZERO_GPU") == "TRUE":
         | 
| 88 | 
            +
                    diarize_model.model.to(torch.device("cuda"))
         | 
| 89 | 
            +
                diarize_segments = diarize_model(
         | 
| 90 | 
            +
                    audio_wav, 
         | 
| 91 | 
            +
                    min_speakers=min_speakers, 
         | 
| 92 | 
            +
                    max_speakers=max_speakers
         | 
| 93 | 
            +
                )
         | 
| 94 | 
            +
                return diarize_segments
         | 
| 95 | 
            +
             | 
| 96 | 
            +
            # ZERO GPU CONFIG
         | 
| 97 | 
            +
             | 
| 98 | 
            +
            ASR_MODEL_OPTIONS = [
         | 
| 99 | 
            +
                "tiny",
         | 
| 100 | 
            +
                "base",
         | 
| 101 | 
            +
                "small",
         | 
| 102 | 
            +
                "medium",
         | 
| 103 | 
            +
                "large",
         | 
| 104 | 
            +
                "large-v1",
         | 
| 105 | 
            +
                "large-v2",
         | 
| 106 | 
            +
                "large-v3",
         | 
| 107 | 
            +
                "distil-large-v2",
         | 
| 108 | 
            +
                "Systran/faster-distil-whisper-large-v3",
         | 
| 109 | 
            +
                "tiny.en",
         | 
| 110 | 
            +
                "base.en",
         | 
| 111 | 
            +
                "small.en",
         | 
| 112 | 
            +
                "medium.en",
         | 
| 113 | 
            +
                "distil-small.en",
         | 
| 114 | 
            +
                "distil-medium.en",
         | 
| 115 | 
            +
                "OpenAI_API_Whisper",
         | 
| 116 | 
            +
            ]
         | 
| 117 | 
            +
             | 
| 118 | 
            +
            COMPUTE_TYPE_GPU = [
         | 
| 119 | 
            +
                "default",
         | 
| 120 | 
            +
                "auto",
         | 
| 121 | 
            +
                "int8",
         | 
| 122 | 
            +
                "int8_float32",
         | 
| 123 | 
            +
                "int8_float16",
         | 
| 124 | 
            +
                "int8_bfloat16",
         | 
| 125 | 
            +
                "float16",
         | 
| 126 | 
            +
                "bfloat16",
         | 
| 127 | 
            +
                "float32"
         | 
| 128 | 
            +
            ]
         | 
| 129 | 
            +
             | 
| 130 | 
            +
            COMPUTE_TYPE_CPU = [
         | 
| 131 | 
            +
                "default",
         | 
| 132 | 
            +
                "auto",
         | 
| 133 | 
            +
                "int8",
         | 
| 134 | 
            +
                "int8_float32",
         | 
| 135 | 
            +
                "int16",
         | 
| 136 | 
            +
                "float32",
         | 
| 137 | 
            +
            ]
         | 
| 138 | 
            +
             | 
| 139 | 
            +
            WHISPER_MODELS_PATH = './WHISPER_MODELS'
         | 
| 140 | 
            +
             | 
| 141 | 
            +
             | 
| 142 | 
            +
            def openai_api_whisper(
         | 
| 143 | 
            +
                input_audio_file,
         | 
| 144 | 
            +
                source_lang=None,
         | 
| 145 | 
            +
                chunk_duration=1800
         | 
| 146 | 
            +
            ):
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                info = sf.info(input_audio_file)
         | 
| 149 | 
            +
                duration = info.duration
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                output_directory = "./whisper_api_audio_parts"
         | 
| 152 | 
            +
                os.makedirs(output_directory, exist_ok=True)
         | 
| 153 | 
            +
                remove_directory_contents(output_directory)
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                if duration > chunk_duration:
         | 
| 156 | 
            +
                    # Split the audio file into smaller chunks with 30-minute duration
         | 
| 157 | 
            +
                    cm = f'ffmpeg -i "{input_audio_file}" -f segment -segment_time {chunk_duration} -c:a libvorbis "{output_directory}/output%03d.ogg"'
         | 
| 158 | 
            +
                    run_command(cm)
         | 
| 159 | 
            +
                    # Get list of generated chunk files
         | 
| 160 | 
            +
                    chunk_files = sorted(
         | 
| 161 | 
            +
                        [f"{output_directory}/{f}" for f in os.listdir(output_directory) if f.endswith('.ogg')]
         | 
| 162 | 
            +
                    )
         | 
| 163 | 
            +
                else:
         | 
| 164 | 
            +
                    one_file = f"{output_directory}/output000.ogg"
         | 
| 165 | 
            +
                    cm = f'ffmpeg -i "{input_audio_file}" -c:a libvorbis {one_file}'
         | 
| 166 | 
            +
                    run_command(cm)
         | 
| 167 | 
            +
                    chunk_files = [one_file]
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                # Transcript
         | 
| 170 | 
            +
                segments = []
         | 
| 171 | 
            +
                language = source_lang if source_lang else None
         | 
| 172 | 
            +
                for i, chunk in enumerate(chunk_files):
         | 
| 173 | 
            +
                    from openai import OpenAI
         | 
| 174 | 
            +
                    client = OpenAI()
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    audio_file = open(chunk, "rb")
         | 
| 177 | 
            +
                    transcription = client.audio.transcriptions.create(
         | 
| 178 | 
            +
                      model="whisper-1",
         | 
| 179 | 
            +
                      file=audio_file,
         | 
| 180 | 
            +
                      language=language,
         | 
| 181 | 
            +
                      response_format="verbose_json",
         | 
| 182 | 
            +
                      timestamp_granularities=["segment"],
         | 
| 183 | 
            +
                    )
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                    try:
         | 
| 186 | 
            +
                        transcript_dict = transcription.model_dump()
         | 
| 187 | 
            +
                    except: # noqa
         | 
| 188 | 
            +
                        transcript_dict = transcription.to_dict()
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                    if language is None:
         | 
| 191 | 
            +
                        logger.info(f'Language detected: {transcript_dict["language"]}')
         | 
| 192 | 
            +
                        language = TO_LANGUAGE_CODE[transcript_dict["language"]]
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                    chunk_time = chunk_duration * (i)
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                    for seg in transcript_dict["segments"]:
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                        if "start" in seg.keys():
         | 
| 199 | 
            +
                            segments.append(
         | 
| 200 | 
            +
                                {
         | 
| 201 | 
            +
                                    "text": seg["text"],
         | 
| 202 | 
            +
                                    "start": seg["start"] + chunk_time,
         | 
| 203 | 
            +
                                    "end": seg["end"] + chunk_time,
         | 
| 204 | 
            +
                                }
         | 
| 205 | 
            +
                            )
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                audio = whisperx.load_audio(input_audio_file)
         | 
| 208 | 
            +
                result = {"segments": segments, "language": language}
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                return audio, result
         | 
| 211 | 
            +
             | 
| 212 | 
            +
             | 
| 213 | 
            +
            def find_whisper_models():
         | 
| 214 | 
            +
                path = WHISPER_MODELS_PATH
         | 
| 215 | 
            +
                folders = []
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                if os.path.exists(path):
         | 
| 218 | 
            +
                    for folder in os.listdir(path):
         | 
| 219 | 
            +
                        folder_path = os.path.join(path, folder)
         | 
| 220 | 
            +
                        if (
         | 
| 221 | 
            +
                            os.path.isdir(folder_path)
         | 
| 222 | 
            +
                            and 'model.bin' in os.listdir(folder_path)
         | 
| 223 | 
            +
                        ):
         | 
| 224 | 
            +
                            folders.append(folder)
         | 
| 225 | 
            +
                return folders
         | 
| 226 | 
            +
             | 
| 227 | 
            +
            def transcribe_speech(
         | 
| 228 | 
            +
                audio_wav,
         | 
| 229 | 
            +
                asr_model,
         | 
| 230 | 
            +
                compute_type,
         | 
| 231 | 
            +
                batch_size,
         | 
| 232 | 
            +
                SOURCE_LANGUAGE,
         | 
| 233 | 
            +
                literalize_numbers=True,
         | 
| 234 | 
            +
                segment_duration_limit=15,
         | 
| 235 | 
            +
            ):
         | 
| 236 | 
            +
                """
         | 
| 237 | 
            +
                Transcribe speech using a whisper model.
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                Parameters:
         | 
| 240 | 
            +
                - audio_wav (str): Path to the audio file in WAV format.
         | 
| 241 | 
            +
                - asr_model (str): The whisper model to be loaded.
         | 
| 242 | 
            +
                - compute_type (str): Type of compute to be used (e.g., 'int8', 'float16').
         | 
| 243 | 
            +
                - batch_size (int): Batch size for transcription.
         | 
| 244 | 
            +
                - SOURCE_LANGUAGE (str): Source language for transcription.
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                Returns:
         | 
| 247 | 
            +
                - Tuple containing:
         | 
| 248 | 
            +
                    - audio: Loaded audio file.
         | 
| 249 | 
            +
                    - result: Transcription result as a dictionary.
         | 
| 250 | 
            +
                """
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                if asr_model == "OpenAI_API_Whisper":
         | 
| 253 | 
            +
                    if literalize_numbers:
         | 
| 254 | 
            +
                        logger.info(
         | 
| 255 | 
            +
                            "OpenAI's API Whisper does not support "
         | 
| 256 | 
            +
                            "the literalization of numbers."
         | 
| 257 | 
            +
                        )
         | 
| 258 | 
            +
                    return openai_api_whisper(audio_wav, SOURCE_LANGUAGE)
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                # https://github.com/openai/whisper/discussions/277
         | 
| 261 | 
            +
                prompt = "以下是普通话的句子。" if SOURCE_LANGUAGE == "zh" else None
         | 
| 262 | 
            +
                SOURCE_LANGUAGE = (
         | 
| 263 | 
            +
                    SOURCE_LANGUAGE if SOURCE_LANGUAGE != "zh-TW" else "zh"
         | 
| 264 | 
            +
                )
         | 
| 265 | 
            +
                asr_options = {
         | 
| 266 | 
            +
                    "initial_prompt": prompt,
         | 
| 267 | 
            +
                    "suppress_numerals": literalize_numbers
         | 
| 268 | 
            +
                }
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                if asr_model not in ASR_MODEL_OPTIONS:
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                    base_dir = WHISPER_MODELS_PATH
         | 
| 273 | 
            +
                    if not os.path.exists(base_dir):
         | 
| 274 | 
            +
                        os.makedirs(base_dir)
         | 
| 275 | 
            +
                    model_dir = os.path.join(base_dir, sanitize_file_name(asr_model))
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                    if not os.path.exists(model_dir):
         | 
| 278 | 
            +
                        from ctranslate2.converters import TransformersConverter
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                        quantization = "float32"
         | 
| 281 | 
            +
                        # Download new model
         | 
| 282 | 
            +
                        try:
         | 
| 283 | 
            +
                            converter = TransformersConverter(
         | 
| 284 | 
            +
                                asr_model,
         | 
| 285 | 
            +
                                low_cpu_mem_usage=True,
         | 
| 286 | 
            +
                                copy_files=[
         | 
| 287 | 
            +
                                    "tokenizer_config.json", "preprocessor_config.json"
         | 
| 288 | 
            +
                                ]
         | 
| 289 | 
            +
                            )
         | 
| 290 | 
            +
                            converter.convert(
         | 
| 291 | 
            +
                                model_dir,
         | 
| 292 | 
            +
                                quantization=quantization,
         | 
| 293 | 
            +
                                force=False
         | 
| 294 | 
            +
                            )
         | 
| 295 | 
            +
                        except Exception as error:
         | 
| 296 | 
            +
                            if "File tokenizer_config.json does not exist" in str(error):
         | 
| 297 | 
            +
                                converter._copy_files = [
         | 
| 298 | 
            +
                                    "tokenizer.json", "preprocessor_config.json"
         | 
| 299 | 
            +
                                ]
         | 
| 300 | 
            +
                                converter.convert(
         | 
| 301 | 
            +
                                    model_dir,
         | 
| 302 | 
            +
                                    quantization=quantization,
         | 
| 303 | 
            +
                                    force=True
         | 
| 304 | 
            +
                                )
         | 
| 305 | 
            +
                            else:
         | 
| 306 | 
            +
                                raise error
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                    asr_model = model_dir
         | 
| 309 | 
            +
                    logger.info(f"ASR Model: {str(model_dir)}")
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                audio = whisperx.load_audio(audio_wav)
         | 
| 312 | 
            +
                
         | 
| 313 | 
            +
                result = load_and_transcribe_audio(
         | 
| 314 | 
            +
                    asr_model, audio, compute_type, SOURCE_LANGUAGE, asr_options, batch_size, segment_duration_limit
         | 
| 315 | 
            +
                )
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                if result["language"] == "zh" and not prompt:
         | 
| 318 | 
            +
                    result["language"] = "zh-TW"
         | 
| 319 | 
            +
                    logger.info("Chinese - Traditional (zh-TW)")
         | 
| 320 | 
            +
             | 
| 321 | 
            +
             | 
| 322 | 
            +
                return audio, result
         | 
| 323 | 
            +
             | 
| 324 | 
            +
             | 
| 325 | 
            +
            def align_speech(audio, result):
         | 
| 326 | 
            +
                """
         | 
| 327 | 
            +
                Aligns speech segments based on the provided audio and result metadata.
         | 
| 328 | 
            +
             | 
| 329 | 
            +
                Parameters:
         | 
| 330 | 
            +
                - audio (array): The audio data in a suitable format for alignment.
         | 
| 331 | 
            +
                - result (dict): Metadata containing information about the segments
         | 
| 332 | 
            +
                     and language.
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                Returns:
         | 
| 335 | 
            +
                - result (dict): Updated metadata after aligning the segments with
         | 
| 336 | 
            +
                    the audio. This includes character-level alignments if
         | 
| 337 | 
            +
                    'return_char_alignments' is set to True.
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                Notes:
         | 
| 340 | 
            +
                - This function uses language-specific models to align speech segments.
         | 
| 341 | 
            +
                - It performs language compatibility checks and selects the
         | 
| 342 | 
            +
                    appropriate alignment model.
         | 
| 343 | 
            +
                - Cleans up memory by releasing resources after alignment.
         | 
| 344 | 
            +
                """
         | 
| 345 | 
            +
                DAMHF.update(DAMT)  # lang align
         | 
| 346 | 
            +
                if (
         | 
| 347 | 
            +
                    not result["language"] in DAMHF.keys()
         | 
| 348 | 
            +
                    and not result["language"] in EXTRA_ALIGN.keys()
         | 
| 349 | 
            +
                ):
         | 
| 350 | 
            +
                    logger.warning(
         | 
| 351 | 
            +
                        "Automatic detection: Source language not compatible with align"
         | 
| 352 | 
            +
                    )
         | 
| 353 | 
            +
                    raise ValueError(
         | 
| 354 | 
            +
                        f"Detected language {result['language']}  incompatible, "
         | 
| 355 | 
            +
                        "you can select the source language to avoid this error."
         | 
| 356 | 
            +
                    )
         | 
| 357 | 
            +
                if (
         | 
| 358 | 
            +
                    result["language"] in EXTRA_ALIGN.keys()
         | 
| 359 | 
            +
                    and EXTRA_ALIGN[result["language"]] == ""
         | 
| 360 | 
            +
                ):
         | 
| 361 | 
            +
                    lang_name = (
         | 
| 362 | 
            +
                        INVERTED_LANGUAGES[result["language"]]
         | 
| 363 | 
            +
                        if result["language"] in INVERTED_LANGUAGES.keys()
         | 
| 364 | 
            +
                        else result["language"]
         | 
| 365 | 
            +
                    )
         | 
| 366 | 
            +
                    logger.warning(
         | 
| 367 | 
            +
                        "No compatible wav2vec2 model found "
         | 
| 368 | 
            +
                        f"for the language '{lang_name}', skipping alignment."
         | 
| 369 | 
            +
                    )
         | 
| 370 | 
            +
                    return result
         | 
| 371 | 
            +
             | 
| 372 | 
            +
                random_sleep()
         | 
| 373 | 
            +
                result = load_align_and_align_segments(result, audio, DAMHF)
         | 
| 374 | 
            +
             | 
| 375 | 
            +
                return result
         | 
| 376 | 
            +
             | 
| 377 | 
            +
             | 
| 378 | 
            +
            diarization_models = {
         | 
| 379 | 
            +
                "pyannote_3.1": "pyannote/speaker-diarization-3.1",
         | 
| 380 | 
            +
                "pyannote_2.1": "pyannote/[email protected]",
         | 
| 381 | 
            +
                "disable": "",
         | 
| 382 | 
            +
            }
         | 
| 383 | 
            +
             | 
| 384 | 
            +
             | 
| 385 | 
            +
            def reencode_speakers(result):
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                if result["segments"][0]["speaker"] == "SPEAKER_00":
         | 
| 388 | 
            +
                    return result
         | 
| 389 | 
            +
             | 
| 390 | 
            +
                speaker_mapping = {}
         | 
| 391 | 
            +
                counter = 0
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                logger.debug("Reencode speakers")
         | 
| 394 | 
            +
             | 
| 395 | 
            +
                for segment in result["segments"]:
         | 
| 396 | 
            +
                    old_speaker = segment["speaker"]
         | 
| 397 | 
            +
                    if old_speaker not in speaker_mapping:
         | 
| 398 | 
            +
                        speaker_mapping[old_speaker] = f"SPEAKER_{counter:02d}"
         | 
| 399 | 
            +
                        counter += 1
         | 
| 400 | 
            +
                    segment["speaker"] = speaker_mapping[old_speaker]
         | 
| 401 | 
            +
             | 
| 402 | 
            +
                return result
         | 
| 403 | 
            +
             | 
| 404 | 
            +
             | 
| 405 | 
            +
            def diarize_speech(
         | 
| 406 | 
            +
                audio_wav,
         | 
| 407 | 
            +
                result,
         | 
| 408 | 
            +
                min_speakers,
         | 
| 409 | 
            +
                max_speakers,
         | 
| 410 | 
            +
                YOUR_HF_TOKEN,
         | 
| 411 | 
            +
                model_name="pyannote/[email protected]",
         | 
| 412 | 
            +
            ):
         | 
| 413 | 
            +
                """
         | 
| 414 | 
            +
                Performs speaker diarization on speech segments.
         | 
| 415 | 
            +
             | 
| 416 | 
            +
                Parameters:
         | 
| 417 | 
            +
                - audio_wav (array): Audio data in WAV format to perform speaker
         | 
| 418 | 
            +
                    diarization.
         | 
| 419 | 
            +
                - result (dict): Metadata containing information about speech segments
         | 
| 420 | 
            +
                    and alignments.
         | 
| 421 | 
            +
                - min_speakers (int): Minimum number of speakers expected in the audio.
         | 
| 422 | 
            +
                - max_speakers (int): Maximum number of speakers expected in the audio.
         | 
| 423 | 
            +
                - YOUR_HF_TOKEN (str): Your Hugging Face API token for model
         | 
| 424 | 
            +
                    authentication.
         | 
| 425 | 
            +
                - model_name (str): Name of the speaker diarization model to be used
         | 
| 426 | 
            +
                    (default: "pyannote/[email protected]").
         | 
| 427 | 
            +
             | 
| 428 | 
            +
                Returns:
         | 
| 429 | 
            +
                - result_diarize (dict): Updated metadata after assigning speaker
         | 
| 430 | 
            +
                    labels to segments.
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                Notes:
         | 
| 433 | 
            +
                - This function utilizes a speaker diarization model to label speaker
         | 
| 434 | 
            +
                    segments in the audio.
         | 
| 435 | 
            +
                - It assigns speakers to word-level segments based on diarization results.
         | 
| 436 | 
            +
                - Cleans up memory by releasing resources after diarization.
         | 
| 437 | 
            +
                - If only one speaker is specified, each segment is automatically assigned
         | 
| 438 | 
            +
                    as the first speaker, eliminating the need for diarization inference.
         | 
| 439 | 
            +
                """
         | 
| 440 | 
            +
             | 
| 441 | 
            +
                if max(min_speakers, max_speakers) > 1 and model_name:
         | 
| 442 | 
            +
                    try:
         | 
| 443 | 
            +
             | 
| 444 | 
            +
                        diarize_model = whisperx.DiarizationPipeline(
         | 
| 445 | 
            +
                            model_name=model_name,
         | 
| 446 | 
            +
                            use_auth_token=YOUR_HF_TOKEN,
         | 
| 447 | 
            +
                            device=os.environ.get("SONITR_DEVICE"),
         | 
| 448 | 
            +
                        )
         | 
| 449 | 
            +
             | 
| 450 | 
            +
                    except Exception as error:
         | 
| 451 | 
            +
                        error_str = str(error)
         | 
| 452 | 
            +
                        gc.collect()
         | 
| 453 | 
            +
                        torch.cuda.empty_cache()  # noqa
         | 
| 454 | 
            +
                        if "'NoneType' object has no attribute 'to'" in error_str:
         | 
| 455 | 
            +
                            if model_name == diarization_models["pyannote_2.1"]:
         | 
| 456 | 
            +
                                raise ValueError(
         | 
| 457 | 
            +
                                    "Accept the license agreement for using Pyannote 2.1."
         | 
| 458 | 
            +
                                    " You need to have an account on Hugging Face and "
         | 
| 459 | 
            +
                                    "accept the license to use the models: "
         | 
| 460 | 
            +
                                    "https://huggingface.co/pyannote/speaker-diarization "
         | 
| 461 | 
            +
                                    "and https://huggingface.co/pyannote/segmentation "
         | 
| 462 | 
            +
                                    "Get your KEY TOKEN here: "
         | 
| 463 | 
            +
                                    "https://hf.co/settings/tokens "
         | 
| 464 | 
            +
                                )
         | 
| 465 | 
            +
                            elif model_name == diarization_models["pyannote_3.1"]:
         | 
| 466 | 
            +
                                raise ValueError(
         | 
| 467 | 
            +
                                    "New Licence Pyannote 3.1: You need to have an account"
         | 
| 468 | 
            +
                                    " on Hugging Face and accept the license to use the "
         | 
| 469 | 
            +
                                    "models: https://huggingface.co/pyannote/speaker-diarization-3.1 " # noqa
         | 
| 470 | 
            +
                                    "and https://huggingface.co/pyannote/segmentation-3.0 "
         | 
| 471 | 
            +
                                )
         | 
| 472 | 
            +
                        else:
         | 
| 473 | 
            +
                            raise error
         | 
| 474 | 
            +
             | 
| 475 | 
            +
                    random_sleep()
         | 
| 476 | 
            +
                    diarize_segments = diarize_audio(diarize_model, audio_wav, min_speakers, max_speakers)
         | 
| 477 | 
            +
             | 
| 478 | 
            +
                    result_diarize = whisperx.assign_word_speakers(
         | 
| 479 | 
            +
                        diarize_segments, result
         | 
| 480 | 
            +
                    )
         | 
| 481 | 
            +
             | 
| 482 | 
            +
                    for segment in result_diarize["segments"]:
         | 
| 483 | 
            +
                        if "speaker" not in segment:
         | 
| 484 | 
            +
                            segment["speaker"] = "SPEAKER_00"
         | 
| 485 | 
            +
                            logger.warning(
         | 
| 486 | 
            +
                                f"No speaker detected in {segment['start']}. First TTS "
         | 
| 487 | 
            +
                                f"will be used for the segment text: {segment['text']} "
         | 
| 488 | 
            +
                            )
         | 
| 489 | 
            +
             | 
| 490 | 
            +
                    del diarize_model
         | 
| 491 | 
            +
                    gc.collect()
         | 
| 492 | 
            +
                    torch.cuda.empty_cache()  # noqa
         | 
| 493 | 
            +
                else:
         | 
| 494 | 
            +
                    result_diarize = result
         | 
| 495 | 
            +
                    result_diarize["segments"] = [
         | 
| 496 | 
            +
                        {**item, "speaker": "SPEAKER_00"}
         | 
| 497 | 
            +
                        for item in result_diarize["segments"]
         | 
| 498 | 
            +
                    ]
         | 
| 499 | 
            +
                return reencode_speakers(result_diarize)
         | 
    	
        soni_translate/text_multiformat_processor.py
    ADDED
    
    | @@ -0,0 +1,987 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .logging_setup import logger
         | 
| 2 | 
            +
            from whisperx.utils import get_writer
         | 
| 3 | 
            +
            from .utils import remove_files, run_command, remove_directory_contents
         | 
| 4 | 
            +
            from typing import List
         | 
| 5 | 
            +
            import srt
         | 
| 6 | 
            +
            import re
         | 
| 7 | 
            +
            import os
         | 
| 8 | 
            +
            import copy
         | 
| 9 | 
            +
            import string
         | 
| 10 | 
            +
            import soundfile as sf
         | 
| 11 | 
            +
            from PIL import Image, ImageOps, ImageDraw, ImageFont
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            punctuation_list = list(
         | 
| 14 | 
            +
                string.punctuation + "¡¿«»„”“”‚‘’「」『』《》()【】〈〉〔〕〖〗〘〙〚〛⸤⸥⸨⸩"
         | 
| 15 | 
            +
            )
         | 
| 16 | 
            +
            symbol_list = punctuation_list + ["", "..", "..."]
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            def extract_from_srt(file_path):
         | 
| 20 | 
            +
                with open(file_path, "r", encoding="utf-8") as file:
         | 
| 21 | 
            +
                    srt_content = file.read()
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                subtitle_generator = srt.parse(srt_content)
         | 
| 24 | 
            +
                srt_content_list = list(subtitle_generator)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                return srt_content_list
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            def clean_text(text):
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                # Remove content within square brackets
         | 
| 32 | 
            +
                text = re.sub(r'\[.*?\]', '', text)
         | 
| 33 | 
            +
                # Add pattern to remove content within <comment> tags
         | 
| 34 | 
            +
                text = re.sub(r'<comment>.*?</comment>', '', text)
         | 
| 35 | 
            +
                # Remove HTML tags
         | 
| 36 | 
            +
                text = re.sub(r'<.*?>', '', text)
         | 
| 37 | 
            +
                # Remove "♫" and "♪" content
         | 
| 38 | 
            +
                text = re.sub(r'♫.*?♫', '', text)
         | 
| 39 | 
            +
                text = re.sub(r'♪.*?♪', '', text)
         | 
| 40 | 
            +
                # Replace newline characters with an empty string
         | 
| 41 | 
            +
                text = text.replace("\n", ". ")
         | 
| 42 | 
            +
                # Remove double quotation marks
         | 
| 43 | 
            +
                text = text.replace('"', '')
         | 
| 44 | 
            +
                # Collapse multiple spaces and replace with a single space
         | 
| 45 | 
            +
                text = re.sub(r"\s+", " ", text)
         | 
| 46 | 
            +
                # Normalize spaces around periods
         | 
| 47 | 
            +
                text = re.sub(r"[\s\.]+(?=\s)", ". ", text)
         | 
| 48 | 
            +
                # Check if there are ♫ or ♪ symbols present
         | 
| 49 | 
            +
                if '♫' in text or '♪' in text:
         | 
| 50 | 
            +
                    return ""
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                text = text.strip()
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                # Valid text
         | 
| 55 | 
            +
                return text if text not in symbol_list else ""
         | 
| 56 | 
            +
             | 
| 57 | 
            +
             | 
| 58 | 
            +
            def srt_file_to_segments(file_path, speaker=False):
         | 
| 59 | 
            +
                try:
         | 
| 60 | 
            +
                    srt_content_list = extract_from_srt(file_path)
         | 
| 61 | 
            +
                except Exception as error:
         | 
| 62 | 
            +
                    logger.error(str(error))
         | 
| 63 | 
            +
                    fixed_file = "fixed_sub.srt"
         | 
| 64 | 
            +
                    remove_files(fixed_file)
         | 
| 65 | 
            +
                    fix_sub = f'ffmpeg -i "{file_path}" "{fixed_file}" -y'
         | 
| 66 | 
            +
                    run_command(fix_sub)
         | 
| 67 | 
            +
                    srt_content_list = extract_from_srt(fixed_file)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                segments = []
         | 
| 70 | 
            +
                for segment in srt_content_list:
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    text = clean_text(str(segment.content))
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    if text:
         | 
| 75 | 
            +
                        segments.append(
         | 
| 76 | 
            +
                            {
         | 
| 77 | 
            +
                                "text": text,
         | 
| 78 | 
            +
                                "start": float(segment.start.total_seconds()),
         | 
| 79 | 
            +
                                "end": float(segment.end.total_seconds()),
         | 
| 80 | 
            +
                            }
         | 
| 81 | 
            +
                        )
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                if not segments:
         | 
| 84 | 
            +
                    raise Exception("No data found in srt subtitle file")
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                if speaker:
         | 
| 87 | 
            +
                    segments = [{**seg, "speaker": "SPEAKER_00"} for seg in segments]
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                return {"segments": segments}
         | 
| 90 | 
            +
             | 
| 91 | 
            +
             | 
| 92 | 
            +
            # documents
         | 
| 93 | 
            +
             | 
| 94 | 
            +
             | 
| 95 | 
            +
            def dehyphenate(lines: List[str], line_no: int) -> List[str]:
         | 
| 96 | 
            +
                next_line = lines[line_no + 1]
         | 
| 97 | 
            +
                word_suffix = next_line.split(" ")[0]
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                lines[line_no] = lines[line_no][:-1] + word_suffix
         | 
| 100 | 
            +
                lines[line_no + 1] = lines[line_no + 1][len(word_suffix):]
         | 
| 101 | 
            +
                return lines
         | 
| 102 | 
            +
             | 
| 103 | 
            +
             | 
| 104 | 
            +
            def remove_hyphens(text: str) -> str:
         | 
| 105 | 
            +
                """
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                This fails for:
         | 
| 108 | 
            +
                * Natural dashes: well-known, self-replication, use-cases, non-semantic,
         | 
| 109 | 
            +
                                  Post-processing, Window-wise, viewpoint-dependent
         | 
| 110 | 
            +
                * Trailing math operands: 2 - 4
         | 
| 111 | 
            +
                * Names: Lopez-Ferreras, VGG-19, CIFAR-100
         | 
| 112 | 
            +
                """
         | 
| 113 | 
            +
                lines = [line.rstrip() for line in text.split("\n")]
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                # Find dashes
         | 
| 116 | 
            +
                line_numbers = []
         | 
| 117 | 
            +
                for line_no, line in enumerate(lines[:-1]):
         | 
| 118 | 
            +
                    if line.endswith("-"):
         | 
| 119 | 
            +
                        line_numbers.append(line_no)
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                # Replace
         | 
| 122 | 
            +
                for line_no in line_numbers:
         | 
| 123 | 
            +
                    lines = dehyphenate(lines, line_no)
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                return "\n".join(lines)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
             | 
| 128 | 
            +
            def pdf_to_txt(pdf_file, start_page, end_page):
         | 
| 129 | 
            +
                from pypdf import PdfReader
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                with open(pdf_file, "rb") as file:
         | 
| 132 | 
            +
                    reader = PdfReader(file)
         | 
| 133 | 
            +
                    logger.debug(f"Total pages: {reader.get_num_pages()}")
         | 
| 134 | 
            +
                    text = ""
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    start_page_idx = max((start_page-1), 0)
         | 
| 137 | 
            +
                    end_page_inx = min((end_page), (reader.get_num_pages()))
         | 
| 138 | 
            +
                    document_pages = reader.pages[start_page_idx:end_page_inx]
         | 
| 139 | 
            +
                    logger.info(
         | 
| 140 | 
            +
                        f"Selected pages from {start_page_idx} to {end_page_inx}: "
         | 
| 141 | 
            +
                        f"{len(document_pages)}"
         | 
| 142 | 
            +
                    )
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    for page in document_pages:
         | 
| 145 | 
            +
                        text += remove_hyphens(page.extract_text())
         | 
| 146 | 
            +
                return text
         | 
| 147 | 
            +
             | 
| 148 | 
            +
             | 
| 149 | 
            +
            def docx_to_txt(docx_file):
         | 
| 150 | 
            +
                # https://github.com/AlJohri/docx2pdf update
         | 
| 151 | 
            +
                from docx import Document
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                doc = Document(docx_file)
         | 
| 154 | 
            +
                text = ""
         | 
| 155 | 
            +
                for paragraph in doc.paragraphs:
         | 
| 156 | 
            +
                    text += paragraph.text + "\n"
         | 
| 157 | 
            +
                return text
         | 
| 158 | 
            +
             | 
| 159 | 
            +
             | 
| 160 | 
            +
            def replace_multiple_elements(text, replacements):
         | 
| 161 | 
            +
                pattern = re.compile("|".join(map(re.escape, replacements.keys())))
         | 
| 162 | 
            +
                replaced_text = pattern.sub(
         | 
| 163 | 
            +
                    lambda match: replacements[match.group(0)], text
         | 
| 164 | 
            +
                )
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                # Remove multiple spaces
         | 
| 167 | 
            +
                replaced_text = re.sub(r"\s+", " ", replaced_text)
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                return replaced_text
         | 
| 170 | 
            +
             | 
| 171 | 
            +
             | 
| 172 | 
            +
            def document_preprocessor(file_path, is_string, start_page, end_page):
         | 
| 173 | 
            +
                if not is_string:
         | 
| 174 | 
            +
                    file_ext = os.path.splitext(file_path)[1].lower()
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                if is_string:
         | 
| 177 | 
            +
                    text = file_path
         | 
| 178 | 
            +
                elif file_ext == ".pdf":
         | 
| 179 | 
            +
                    text = pdf_to_txt(file_path, start_page, end_page)
         | 
| 180 | 
            +
                elif file_ext == ".docx":
         | 
| 181 | 
            +
                    text = docx_to_txt(file_path)
         | 
| 182 | 
            +
                elif file_ext == ".txt":
         | 
| 183 | 
            +
                    with open(
         | 
| 184 | 
            +
                        file_path, "r", encoding='utf-8', errors='replace'
         | 
| 185 | 
            +
                    ) as file:
         | 
| 186 | 
            +
                        text = file.read()
         | 
| 187 | 
            +
                else:
         | 
| 188 | 
            +
                    raise Exception("Unsupported file format")
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                # Add space to break segments more easily later
         | 
| 191 | 
            +
                replacements = {
         | 
| 192 | 
            +
                    "、": "、 ",
         | 
| 193 | 
            +
                    "。": "。 ",
         | 
| 194 | 
            +
                    # "\n": " ",
         | 
| 195 | 
            +
                }
         | 
| 196 | 
            +
                text = replace_multiple_elements(text, replacements)
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                # Save text to a .txt file
         | 
| 199 | 
            +
                # file_name = os.path.splitext(os.path.basename(file_path))[0]
         | 
| 200 | 
            +
                txt_file_path = "./text_preprocessor.txt"
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                with open(
         | 
| 203 | 
            +
                    txt_file_path, "w", encoding='utf-8', errors='replace'
         | 
| 204 | 
            +
                ) as txt_file:
         | 
| 205 | 
            +
                    txt_file.write(text)
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                return txt_file_path, text
         | 
| 208 | 
            +
             | 
| 209 | 
            +
             | 
| 210 | 
            +
            def split_text_into_chunks(text, chunk_size):
         | 
| 211 | 
            +
                words = re.findall(r"\b\w+\b", text)
         | 
| 212 | 
            +
                chunks = []
         | 
| 213 | 
            +
                current_chunk = ""
         | 
| 214 | 
            +
                for word in words:
         | 
| 215 | 
            +
                    if (
         | 
| 216 | 
            +
                        len(current_chunk) + len(word) + 1 <= chunk_size
         | 
| 217 | 
            +
                    ):  # Adding 1 for the space between words
         | 
| 218 | 
            +
                        if current_chunk:
         | 
| 219 | 
            +
                            current_chunk += " "
         | 
| 220 | 
            +
                        current_chunk += word
         | 
| 221 | 
            +
                    else:
         | 
| 222 | 
            +
                        chunks.append(current_chunk)
         | 
| 223 | 
            +
                        current_chunk = word
         | 
| 224 | 
            +
                if current_chunk:
         | 
| 225 | 
            +
                    chunks.append(current_chunk)
         | 
| 226 | 
            +
                return chunks
         | 
| 227 | 
            +
             | 
| 228 | 
            +
             | 
| 229 | 
            +
            def determine_chunk_size(file_name):
         | 
| 230 | 
            +
                patterns = {
         | 
| 231 | 
            +
                    re.compile(r".*-(Male|Female)$"): 1024,  # by character
         | 
| 232 | 
            +
                    re.compile(r".* BARK$"): 100,  # t 64 256
         | 
| 233 | 
            +
                    re.compile(r".* VITS$"): 500,
         | 
| 234 | 
            +
                    re.compile(
         | 
| 235 | 
            +
                        r".+\.(wav|mp3|ogg|m4a)$"
         | 
| 236 | 
            +
                    ): 150,  # t 250 400 api automatic split
         | 
| 237 | 
            +
                    re.compile(r".* VITS-onnx$"): 250,  # automatic sentence split
         | 
| 238 | 
            +
                    re.compile(r".* OpenAI-TTS$"): 1024  # max charaters 4096
         | 
| 239 | 
            +
                }
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                for pattern, chunk_size in patterns.items():
         | 
| 242 | 
            +
                    if pattern.match(file_name):
         | 
| 243 | 
            +
                        return chunk_size
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                # Default chunk size if the file doesn't match any pattern; max 1800
         | 
| 246 | 
            +
                return 100
         | 
| 247 | 
            +
             | 
| 248 | 
            +
             | 
| 249 | 
            +
            def plain_text_to_segments(result_text=None, chunk_size=None):
         | 
| 250 | 
            +
                if not chunk_size:
         | 
| 251 | 
            +
                    chunk_size = 100
         | 
| 252 | 
            +
                text_chunks = split_text_into_chunks(result_text, chunk_size)
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                segments_chunks = []
         | 
| 255 | 
            +
                for num, chunk in enumerate(text_chunks):
         | 
| 256 | 
            +
                    chunk_dict = {
         | 
| 257 | 
            +
                        "text": chunk,
         | 
| 258 | 
            +
                        "start": (1.0 + num),
         | 
| 259 | 
            +
                        "end": (2.0 + num),
         | 
| 260 | 
            +
                        "speaker": "SPEAKER_00",
         | 
| 261 | 
            +
                    }
         | 
| 262 | 
            +
                    segments_chunks.append(chunk_dict)
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                result_diarize = {"segments": segments_chunks}
         | 
| 265 | 
            +
             | 
| 266 | 
            +
                return result_diarize
         | 
| 267 | 
            +
             | 
| 268 | 
            +
             | 
| 269 | 
            +
            def segments_to_plain_text(result_diarize):
         | 
| 270 | 
            +
                complete_text = ""
         | 
| 271 | 
            +
                for seg in result_diarize["segments"]:
         | 
| 272 | 
            +
                    complete_text += seg["text"] + " "  # issue
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                # Save text to a .txt file
         | 
| 275 | 
            +
                # file_name = os.path.splitext(os.path.basename(file_path))[0]
         | 
| 276 | 
            +
                txt_file_path = "./text_translation.txt"
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                with open(
         | 
| 279 | 
            +
                    txt_file_path, "w", encoding='utf-8', errors='replace'
         | 
| 280 | 
            +
                ) as txt_file:
         | 
| 281 | 
            +
                    txt_file.write(complete_text)
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                return txt_file_path, complete_text
         | 
| 284 | 
            +
             | 
| 285 | 
            +
             | 
| 286 | 
            +
            # doc to video
         | 
| 287 | 
            +
             | 
| 288 | 
            +
            COLORS = {
         | 
| 289 | 
            +
                "black": (0, 0, 0),
         | 
| 290 | 
            +
                "white": (255, 255, 255),
         | 
| 291 | 
            +
                "red": (255, 0, 0),
         | 
| 292 | 
            +
                "green": (0, 255, 0),
         | 
| 293 | 
            +
                "blue": (0, 0, 255),
         | 
| 294 | 
            +
                "yellow": (255, 255, 0),
         | 
| 295 | 
            +
                "light_gray": (200, 200, 200),
         | 
| 296 | 
            +
                "light_blue": (173, 216, 230),
         | 
| 297 | 
            +
                "light_green": (144, 238, 144),
         | 
| 298 | 
            +
                "light_yellow": (255, 255, 224),
         | 
| 299 | 
            +
                "light_pink": (255, 182, 193),
         | 
| 300 | 
            +
                "lavender": (230, 230, 250),
         | 
| 301 | 
            +
                "peach": (255, 218, 185),
         | 
| 302 | 
            +
                "light_cyan": (224, 255, 255),
         | 
| 303 | 
            +
                "light_salmon": (255, 160, 122),
         | 
| 304 | 
            +
                "light_green_yellow": (173, 255, 47),
         | 
| 305 | 
            +
            }
         | 
| 306 | 
            +
             | 
| 307 | 
            +
            BORDER_COLORS = ["dynamic"] + list(COLORS.keys())
         | 
| 308 | 
            +
             | 
| 309 | 
            +
             | 
| 310 | 
            +
            def calculate_average_color(img):
         | 
| 311 | 
            +
                # Resize the image to a small size for faster processing
         | 
| 312 | 
            +
                img_small = img.resize((50, 50))
         | 
| 313 | 
            +
                # Calculate the average color
         | 
| 314 | 
            +
                average_color = img_small.convert("RGB").resize((1, 1)).getpixel((0, 0))
         | 
| 315 | 
            +
                return average_color
         | 
| 316 | 
            +
             | 
| 317 | 
            +
             | 
| 318 | 
            +
            def add_border_to_image(
         | 
| 319 | 
            +
                image_path,
         | 
| 320 | 
            +
                target_width,
         | 
| 321 | 
            +
                target_height,
         | 
| 322 | 
            +
                border_color=None
         | 
| 323 | 
            +
            ):
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                img = Image.open(image_path)
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                # Calculate the width and height for the new image with borders
         | 
| 328 | 
            +
                original_width, original_height = img.size
         | 
| 329 | 
            +
                original_aspect_ratio = original_width / original_height
         | 
| 330 | 
            +
                target_aspect_ratio = target_width / target_height
         | 
| 331 | 
            +
             | 
| 332 | 
            +
                # Resize the image to fit the target resolution retaining aspect ratio
         | 
| 333 | 
            +
                if original_aspect_ratio > target_aspect_ratio:
         | 
| 334 | 
            +
                    # Image is wider, calculate new height
         | 
| 335 | 
            +
                    new_height = int(target_width / original_aspect_ratio)
         | 
| 336 | 
            +
                    resized_img = img.resize((target_width, new_height))
         | 
| 337 | 
            +
                else:
         | 
| 338 | 
            +
                    # Image is taller, calculate new width
         | 
| 339 | 
            +
                    new_width = int(target_height * original_aspect_ratio)
         | 
| 340 | 
            +
                    resized_img = img.resize((new_width, target_height))
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                # Calculate padding for borders
         | 
| 343 | 
            +
                padding = (0, 0, 0, 0)
         | 
| 344 | 
            +
                if resized_img.size[0] != target_width or resized_img.size[1] != target_height:
         | 
| 345 | 
            +
                    if original_aspect_ratio > target_aspect_ratio:
         | 
| 346 | 
            +
                        # Add borders vertically
         | 
| 347 | 
            +
                        padding = (0, (target_height - resized_img.size[1]) // 2, 0, (target_height - resized_img.size[1]) // 2)
         | 
| 348 | 
            +
                    else:
         | 
| 349 | 
            +
                        # Add borders horizontally
         | 
| 350 | 
            +
                        padding = ((target_width - resized_img.size[0]) // 2, 0, (target_width - resized_img.size[0]) // 2, 0)
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                # Add borders with specified color
         | 
| 353 | 
            +
                if not border_color or border_color == "dynamic":
         | 
| 354 | 
            +
                    border_color = calculate_average_color(resized_img)
         | 
| 355 | 
            +
                else:
         | 
| 356 | 
            +
                    border_color = COLORS.get(border_color, (0, 0, 0))
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                bordered_img = ImageOps.expand(resized_img, padding, fill=border_color)
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                bordered_img.save(image_path)
         | 
| 361 | 
            +
             | 
| 362 | 
            +
                return image_path
         | 
| 363 | 
            +
             | 
| 364 | 
            +
             | 
| 365 | 
            +
            def resize_and_position_subimage(
         | 
| 366 | 
            +
                subimage,
         | 
| 367 | 
            +
                max_width,
         | 
| 368 | 
            +
                max_height,
         | 
| 369 | 
            +
                subimage_position,
         | 
| 370 | 
            +
                main_width,
         | 
| 371 | 
            +
                main_height
         | 
| 372 | 
            +
            ):
         | 
| 373 | 
            +
                subimage_width, subimage_height = subimage.size
         | 
| 374 | 
            +
             | 
| 375 | 
            +
                # Resize subimage if it exceeds maximum dimensions
         | 
| 376 | 
            +
                if subimage_width > max_width or subimage_height > max_height:
         | 
| 377 | 
            +
                    # Calculate scaling factor
         | 
| 378 | 
            +
                    width_scale = max_width / subimage_width
         | 
| 379 | 
            +
                    height_scale = max_height / subimage_height
         | 
| 380 | 
            +
                    scale = min(width_scale, height_scale)
         | 
| 381 | 
            +
             | 
| 382 | 
            +
                    # Resize subimage
         | 
| 383 | 
            +
                    subimage = subimage.resize(
         | 
| 384 | 
            +
                        (int(subimage_width * scale), int(subimage_height * scale))
         | 
| 385 | 
            +
                    )
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                # Calculate position to place the subimage
         | 
| 388 | 
            +
                if subimage_position == "top-left":
         | 
| 389 | 
            +
                    subimage_x = 0
         | 
| 390 | 
            +
                    subimage_y = 0
         | 
| 391 | 
            +
                elif subimage_position == "top-right":
         | 
| 392 | 
            +
                    subimage_x = main_width - subimage.width
         | 
| 393 | 
            +
                    subimage_y = 0
         | 
| 394 | 
            +
                elif subimage_position == "bottom-left":
         | 
| 395 | 
            +
                    subimage_x = 0
         | 
| 396 | 
            +
                    subimage_y = main_height - subimage.height
         | 
| 397 | 
            +
                elif subimage_position == "bottom-right":
         | 
| 398 | 
            +
                    subimage_x = main_width - subimage.width
         | 
| 399 | 
            +
                    subimage_y = main_height - subimage.height
         | 
| 400 | 
            +
                else:
         | 
| 401 | 
            +
                    raise ValueError(
         | 
| 402 | 
            +
                        "Invalid subimage_position. Choose from 'top-left', 'top-right',"
         | 
| 403 | 
            +
                        " 'bottom-left', or 'bottom-right'."
         | 
| 404 | 
            +
                    )
         | 
| 405 | 
            +
             | 
| 406 | 
            +
                return subimage, subimage_x, subimage_y
         | 
| 407 | 
            +
             | 
| 408 | 
            +
             | 
| 409 | 
            +
            def create_image_with_text_and_subimages(
         | 
| 410 | 
            +
                text,
         | 
| 411 | 
            +
                subimages,
         | 
| 412 | 
            +
                width,
         | 
| 413 | 
            +
                height,
         | 
| 414 | 
            +
                text_color,
         | 
| 415 | 
            +
                background_color,
         | 
| 416 | 
            +
                output_file
         | 
| 417 | 
            +
            ):
         | 
| 418 | 
            +
                # Create an image with the specified resolution and background color
         | 
| 419 | 
            +
                image = Image.new('RGB', (width, height), color=background_color)
         | 
| 420 | 
            +
             | 
| 421 | 
            +
                # Initialize ImageDraw object
         | 
| 422 | 
            +
                draw = ImageDraw.Draw(image)
         | 
| 423 | 
            +
             | 
| 424 | 
            +
                # Load a font
         | 
| 425 | 
            +
                font = ImageFont.load_default()  # You can specify your font file here
         | 
| 426 | 
            +
             | 
| 427 | 
            +
                # Calculate text size and position
         | 
| 428 | 
            +
                text_bbox = draw.textbbox((0, 0), text, font=font)
         | 
| 429 | 
            +
                text_width = text_bbox[2] - text_bbox[0]
         | 
| 430 | 
            +
                text_height = text_bbox[3] - text_bbox[1]
         | 
| 431 | 
            +
                text_x = (width - text_width) / 2
         | 
| 432 | 
            +
                text_y = (height - text_height) / 2
         | 
| 433 | 
            +
             | 
| 434 | 
            +
                # Draw text on the image
         | 
| 435 | 
            +
                draw.text((text_x, text_y), text, fill=text_color, font=font)
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                # Paste subimages onto the main image
         | 
| 438 | 
            +
                for subimage_path, subimage_position in subimages:
         | 
| 439 | 
            +
                    # Open the subimage
         | 
| 440 | 
            +
                    subimage = Image.open(subimage_path)
         | 
| 441 | 
            +
             | 
| 442 | 
            +
                    # Convert subimage to RGBA mode if it doesn't have an alpha channel
         | 
| 443 | 
            +
                    if subimage.mode != 'RGBA':
         | 
| 444 | 
            +
                        subimage = subimage.convert('RGBA')
         | 
| 445 | 
            +
             | 
| 446 | 
            +
                    # Resize and position the subimage
         | 
| 447 | 
            +
                    subimage, subimage_x, subimage_y = resize_and_position_subimage(
         | 
| 448 | 
            +
                        subimage, width / 4, height / 4, subimage_position, width, height
         | 
| 449 | 
            +
                    )
         | 
| 450 | 
            +
             | 
| 451 | 
            +
                    # Paste the subimage onto the main image
         | 
| 452 | 
            +
                    image.paste(subimage, (int(subimage_x), int(subimage_y)), subimage)
         | 
| 453 | 
            +
             | 
| 454 | 
            +
                image.save(output_file)
         | 
| 455 | 
            +
             | 
| 456 | 
            +
                return output_file
         | 
| 457 | 
            +
             | 
| 458 | 
            +
             | 
| 459 | 
            +
            def doc_to_txtximg_pages(
         | 
| 460 | 
            +
                document,
         | 
| 461 | 
            +
                width,
         | 
| 462 | 
            +
                height,
         | 
| 463 | 
            +
                start_page,
         | 
| 464 | 
            +
                end_page,
         | 
| 465 | 
            +
                bcolor
         | 
| 466 | 
            +
            ):
         | 
| 467 | 
            +
                from pypdf import PdfReader
         | 
| 468 | 
            +
             | 
| 469 | 
            +
                images_folder = "pdf_images/"
         | 
| 470 | 
            +
                os.makedirs(images_folder, exist_ok=True)
         | 
| 471 | 
            +
                remove_directory_contents(images_folder)
         | 
| 472 | 
            +
             | 
| 473 | 
            +
                # First image
         | 
| 474 | 
            +
                text_image = os.path.basename(document)[:-4]
         | 
| 475 | 
            +
                subimages = [("./assets/logo.jpeg", "top-left")]
         | 
| 476 | 
            +
                text_color = (255, 255, 255) if bcolor == "black" else (0, 0, 0)  # w|b
         | 
| 477 | 
            +
                background_color = COLORS.get(bcolor, (255, 255, 255))  # dynamic white
         | 
| 478 | 
            +
                first_image = "pdf_images/0000_00_aaa.png"
         | 
| 479 | 
            +
             | 
| 480 | 
            +
                create_image_with_text_and_subimages(
         | 
| 481 | 
            +
                    text_image,
         | 
| 482 | 
            +
                    subimages,
         | 
| 483 | 
            +
                    width,
         | 
| 484 | 
            +
                    height,
         | 
| 485 | 
            +
                    text_color,
         | 
| 486 | 
            +
                    background_color,
         | 
| 487 | 
            +
                    first_image
         | 
| 488 | 
            +
                )
         | 
| 489 | 
            +
             | 
| 490 | 
            +
                reader = PdfReader(document)
         | 
| 491 | 
            +
                logger.debug(f"Total pages: {reader.get_num_pages()}")
         | 
| 492 | 
            +
             | 
| 493 | 
            +
                start_page_idx = max((start_page-1), 0)
         | 
| 494 | 
            +
                end_page_inx = min((end_page), (reader.get_num_pages()))
         | 
| 495 | 
            +
                document_pages = reader.pages[start_page_idx:end_page_inx]
         | 
| 496 | 
            +
             | 
| 497 | 
            +
                logger.info(
         | 
| 498 | 
            +
                    f"Selected pages from {start_page_idx} to {end_page_inx}: "
         | 
| 499 | 
            +
                    f"{len(document_pages)}"
         | 
| 500 | 
            +
                )
         | 
| 501 | 
            +
             | 
| 502 | 
            +
                data_doc = {}
         | 
| 503 | 
            +
                for i, page in enumerate(document_pages):
         | 
| 504 | 
            +
             | 
| 505 | 
            +
                    count = 0
         | 
| 506 | 
            +
                    images = []
         | 
| 507 | 
            +
                    for image_file_object in page.images:
         | 
| 508 | 
            +
                        img_name = f"{images_folder}{i:04d}_{count:02d}_{image_file_object.name}"
         | 
| 509 | 
            +
                        images.append(img_name)
         | 
| 510 | 
            +
                        with open(img_name, "wb") as fp:
         | 
| 511 | 
            +
                            fp.write(image_file_object.data)
         | 
| 512 | 
            +
                            count += 1
         | 
| 513 | 
            +
                        img_name = add_border_to_image(img_name, width, height, bcolor)
         | 
| 514 | 
            +
             | 
| 515 | 
            +
                    data_doc[i] = {
         | 
| 516 | 
            +
                        "text": remove_hyphens(page.extract_text()),
         | 
| 517 | 
            +
                        "images": images
         | 
| 518 | 
            +
                    }
         | 
| 519 | 
            +
             | 
| 520 | 
            +
                return data_doc
         | 
| 521 | 
            +
             | 
| 522 | 
            +
             | 
| 523 | 
            +
            def page_data_to_segments(result_text=None, chunk_size=None):
         | 
| 524 | 
            +
             | 
| 525 | 
            +
                if not chunk_size:
         | 
| 526 | 
            +
                    chunk_size = 100
         | 
| 527 | 
            +
             | 
| 528 | 
            +
                segments_chunks = []
         | 
| 529 | 
            +
                time_global = 0
         | 
| 530 | 
            +
                for page, result_data in result_text.items():
         | 
| 531 | 
            +
                    # result_image = result_data["images"]
         | 
| 532 | 
            +
                    result_text = result_data["text"]
         | 
| 533 | 
            +
                    text_chunks = split_text_into_chunks(result_text, chunk_size)
         | 
| 534 | 
            +
                    if not text_chunks:
         | 
| 535 | 
            +
                        text_chunks = [" "]
         | 
| 536 | 
            +
             | 
| 537 | 
            +
                    for chunk in text_chunks:
         | 
| 538 | 
            +
                        chunk_dict = {
         | 
| 539 | 
            +
                            "text": chunk,
         | 
| 540 | 
            +
                            "start": (1.0 + time_global),
         | 
| 541 | 
            +
                            "end": (2.0 + time_global),
         | 
| 542 | 
            +
                            "speaker": "SPEAKER_00",
         | 
| 543 | 
            +
                            "page": page,
         | 
| 544 | 
            +
                        }
         | 
| 545 | 
            +
                        segments_chunks.append(chunk_dict)
         | 
| 546 | 
            +
                        time_global += 1
         | 
| 547 | 
            +
             | 
| 548 | 
            +
                result_diarize = {"segments": segments_chunks}
         | 
| 549 | 
            +
             | 
| 550 | 
            +
                return result_diarize
         | 
| 551 | 
            +
             | 
| 552 | 
            +
             | 
| 553 | 
            +
            def update_page_data(result_diarize, doc_data):
         | 
| 554 | 
            +
                complete_text = ""
         | 
| 555 | 
            +
                current_page = result_diarize["segments"][0]["page"]
         | 
| 556 | 
            +
                text_page = ""
         | 
| 557 | 
            +
             | 
| 558 | 
            +
                for seg in result_diarize["segments"]:
         | 
| 559 | 
            +
                    text = seg["text"] + " "  # issue
         | 
| 560 | 
            +
                    complete_text += text
         | 
| 561 | 
            +
             | 
| 562 | 
            +
                    page = seg["page"]
         | 
| 563 | 
            +
             | 
| 564 | 
            +
                    if page == current_page:
         | 
| 565 | 
            +
                        text_page += text
         | 
| 566 | 
            +
                    else:
         | 
| 567 | 
            +
                        doc_data[current_page]["text"] = text_page
         | 
| 568 | 
            +
             | 
| 569 | 
            +
                        # Next
         | 
| 570 | 
            +
                        text_page = text
         | 
| 571 | 
            +
                        current_page = page
         | 
| 572 | 
            +
             | 
| 573 | 
            +
                if doc_data[current_page]["text"] != text_page:
         | 
| 574 | 
            +
                    doc_data[current_page]["text"] = text_page
         | 
| 575 | 
            +
             | 
| 576 | 
            +
                return doc_data
         | 
| 577 | 
            +
             | 
| 578 | 
            +
             | 
| 579 | 
            +
            def fix_timestamps_docs(result_diarize, audio_files):
         | 
| 580 | 
            +
                current_start = 0.0
         | 
| 581 | 
            +
             | 
| 582 | 
            +
                for seg, audio in zip(result_diarize["segments"], audio_files):
         | 
| 583 | 
            +
                    duration = round(sf.info(audio).duration, 2)
         | 
| 584 | 
            +
             | 
| 585 | 
            +
                    seg["start"] = current_start
         | 
| 586 | 
            +
                    current_start += duration
         | 
| 587 | 
            +
                    seg["end"] = current_start
         | 
| 588 | 
            +
             | 
| 589 | 
            +
                return result_diarize
         | 
| 590 | 
            +
             | 
| 591 | 
            +
             | 
| 592 | 
            +
            def create_video_from_images(
         | 
| 593 | 
            +
                doc_data,
         | 
| 594 | 
            +
                result_diarize
         | 
| 595 | 
            +
            ):
         | 
| 596 | 
            +
             | 
| 597 | 
            +
                # First image path
         | 
| 598 | 
            +
                first_image = "pdf_images/0000_00_aaa.png"
         | 
| 599 | 
            +
             | 
| 600 | 
            +
                # Time segments and images
         | 
| 601 | 
            +
                max_pages_idx = len(doc_data) - 1
         | 
| 602 | 
            +
                current_page = result_diarize["segments"][0]["page"]
         | 
| 603 | 
            +
                duration_page = 0.0
         | 
| 604 | 
            +
                last_image = None
         | 
| 605 | 
            +
             | 
| 606 | 
            +
                for seg in result_diarize["segments"]:
         | 
| 607 | 
            +
                    start = seg["start"]
         | 
| 608 | 
            +
                    end = seg["end"]
         | 
| 609 | 
            +
                    duration_seg = end - start
         | 
| 610 | 
            +
             | 
| 611 | 
            +
                    page = seg["page"]
         | 
| 612 | 
            +
             | 
| 613 | 
            +
                    if page == current_page:
         | 
| 614 | 
            +
                        duration_page += duration_seg
         | 
| 615 | 
            +
                    else:
         | 
| 616 | 
            +
             | 
| 617 | 
            +
                        images = doc_data[current_page]["images"]
         | 
| 618 | 
            +
             | 
| 619 | 
            +
                        if first_image:
         | 
| 620 | 
            +
                            images = [first_image] + images
         | 
| 621 | 
            +
                            first_image = None
         | 
| 622 | 
            +
                        if not doc_data[min(max_pages_idx, (current_page+1))]["text"].strip():
         | 
| 623 | 
            +
                            images = images + doc_data[min(max_pages_idx, (current_page+1))]["images"]
         | 
| 624 | 
            +
                        if not images and last_image:
         | 
| 625 | 
            +
                            images = [last_image]
         | 
| 626 | 
            +
             | 
| 627 | 
            +
                        # Calculate images duration
         | 
| 628 | 
            +
                        time_duration_per_image = round((duration_page / len(images)), 2)
         | 
| 629 | 
            +
                        doc_data[current_page]["time_per_image"] = time_duration_per_image
         | 
| 630 | 
            +
             | 
| 631 | 
            +
                        # Next values
         | 
| 632 | 
            +
                        doc_data[current_page]["images"] = images
         | 
| 633 | 
            +
                        last_image = images[-1]
         | 
| 634 | 
            +
                        duration_page = duration_seg
         | 
| 635 | 
            +
                        current_page = page
         | 
| 636 | 
            +
             | 
| 637 | 
            +
                if "time_per_image" not in doc_data[current_page].keys():
         | 
| 638 | 
            +
                    images = doc_data[current_page]["images"]
         | 
| 639 | 
            +
                    if first_image:
         | 
| 640 | 
            +
                        images = [first_image] + images
         | 
| 641 | 
            +
                    if not images:
         | 
| 642 | 
            +
                        images = [last_image]
         | 
| 643 | 
            +
                    time_duration_per_image = round((duration_page / len(images)), 2)
         | 
| 644 | 
            +
                    doc_data[current_page]["time_per_image"] = time_duration_per_image
         | 
| 645 | 
            +
             | 
| 646 | 
            +
                # Timestamped image video.
         | 
| 647 | 
            +
                with open("list.txt", "w") as file:
         | 
| 648 | 
            +
             | 
| 649 | 
            +
                    for i, page in enumerate(doc_data.values()):
         | 
| 650 | 
            +
             | 
| 651 | 
            +
                        duration = page["time_per_image"]
         | 
| 652 | 
            +
                        for img in page["images"]:
         | 
| 653 | 
            +
                            if i == len(doc_data) - 1 and img == page["images"][-1]:  # Check if it's the last item
         | 
| 654 | 
            +
                                file.write(f"file {img}\n")
         | 
| 655 | 
            +
                                file.write(f"outpoint {duration}")
         | 
| 656 | 
            +
                            else:
         | 
| 657 | 
            +
                                file.write(f"file {img}\n")
         | 
| 658 | 
            +
                                file.write(f"outpoint {duration}\n")
         | 
| 659 | 
            +
             | 
| 660 | 
            +
                out_video = "video_from_images.mp4"
         | 
| 661 | 
            +
                remove_files(out_video)
         | 
| 662 | 
            +
             | 
| 663 | 
            +
                cm = f"ffmpeg -y -f concat -i list.txt -c:v libx264 -preset veryfast -crf 18 -pix_fmt yuv420p {out_video}"
         | 
| 664 | 
            +
                cm_alt = f"ffmpeg -f concat -i list.txt -c:v libx264 -r 30 -pix_fmt yuv420p -y {out_video}"
         | 
| 665 | 
            +
                try:
         | 
| 666 | 
            +
                    run_command(cm)
         | 
| 667 | 
            +
                except Exception as error:
         | 
| 668 | 
            +
                    logger.error(str(error))
         | 
| 669 | 
            +
                    remove_files(out_video)
         | 
| 670 | 
            +
                    run_command(cm_alt)
         | 
| 671 | 
            +
             | 
| 672 | 
            +
                return out_video
         | 
| 673 | 
            +
             | 
| 674 | 
            +
             | 
| 675 | 
            +
            def merge_video_and_audio(video_doc, final_wav_file):
         | 
| 676 | 
            +
             | 
| 677 | 
            +
                fixed_audio = "fixed_audio.mp3"
         | 
| 678 | 
            +
                remove_files(fixed_audio)
         | 
| 679 | 
            +
                cm = f"ffmpeg -i {final_wav_file} -c:a libmp3lame {fixed_audio}"
         | 
| 680 | 
            +
                run_command(cm)
         | 
| 681 | 
            +
             | 
| 682 | 
            +
                vid_out = "video_book.mp4"
         | 
| 683 | 
            +
                remove_files(vid_out)
         | 
| 684 | 
            +
                cm = f"ffmpeg -i {video_doc} -i {fixed_audio} -c:v copy -c:a copy -map 0:v -map 1:a -shortest {vid_out}"
         | 
| 685 | 
            +
                run_command(cm)
         | 
| 686 | 
            +
             | 
| 687 | 
            +
                return vid_out
         | 
| 688 | 
            +
             | 
| 689 | 
            +
             | 
| 690 | 
            +
            # subtitles
         | 
| 691 | 
            +
             | 
| 692 | 
            +
             | 
| 693 | 
            +
            def get_subtitle(
         | 
| 694 | 
            +
                language,
         | 
| 695 | 
            +
                segments_data,
         | 
| 696 | 
            +
                extension,
         | 
| 697 | 
            +
                filename=None,
         | 
| 698 | 
            +
                highlight_words=False,
         | 
| 699 | 
            +
            ):
         | 
| 700 | 
            +
                if not filename:
         | 
| 701 | 
            +
                    filename = "task_subtitle"
         | 
| 702 | 
            +
             | 
| 703 | 
            +
                is_ass_extension = False
         | 
| 704 | 
            +
                if extension == "ass":
         | 
| 705 | 
            +
                    is_ass_extension = True
         | 
| 706 | 
            +
                    extension = "srt"
         | 
| 707 | 
            +
             | 
| 708 | 
            +
                sub_file = filename + "." + extension
         | 
| 709 | 
            +
                support_name = filename + ".mp3"
         | 
| 710 | 
            +
                remove_files(sub_file)
         | 
| 711 | 
            +
             | 
| 712 | 
            +
                writer = get_writer(extension, output_dir=".")
         | 
| 713 | 
            +
                word_options = {
         | 
| 714 | 
            +
                    "highlight_words": highlight_words,
         | 
| 715 | 
            +
                    "max_line_count": None,
         | 
| 716 | 
            +
                    "max_line_width": None,
         | 
| 717 | 
            +
                }
         | 
| 718 | 
            +
             | 
| 719 | 
            +
                # Get data subs
         | 
| 720 | 
            +
                subtitle_data = copy.deepcopy(segments_data)
         | 
| 721 | 
            +
                subtitle_data["language"] = (
         | 
| 722 | 
            +
                    "ja" if language in ["ja", "zh", "zh-TW"] else language
         | 
| 723 | 
            +
                )
         | 
| 724 | 
            +
             | 
| 725 | 
            +
                # Clean
         | 
| 726 | 
            +
                if not highlight_words:
         | 
| 727 | 
            +
                    subtitle_data.pop("word_segments", None)
         | 
| 728 | 
            +
                    for segment in subtitle_data["segments"]:
         | 
| 729 | 
            +
                        for key in ["speaker", "chars", "words"]:
         | 
| 730 | 
            +
                            segment.pop(key, None)
         | 
| 731 | 
            +
             | 
| 732 | 
            +
                writer(
         | 
| 733 | 
            +
                    subtitle_data,
         | 
| 734 | 
            +
                    support_name,
         | 
| 735 | 
            +
                    word_options,
         | 
| 736 | 
            +
                )
         | 
| 737 | 
            +
             | 
| 738 | 
            +
                if is_ass_extension:
         | 
| 739 | 
            +
                    temp_name = filename + ".ass"
         | 
| 740 | 
            +
                    remove_files(temp_name)
         | 
| 741 | 
            +
                    convert_sub = f'ffmpeg -i "{sub_file}" "{temp_name}" -y'
         | 
| 742 | 
            +
                    run_command(convert_sub)
         | 
| 743 | 
            +
                    sub_file = temp_name
         | 
| 744 | 
            +
             | 
| 745 | 
            +
                return sub_file
         | 
| 746 | 
            +
             | 
| 747 | 
            +
             | 
| 748 | 
            +
            def process_subtitles(
         | 
| 749 | 
            +
                deep_copied_result,
         | 
| 750 | 
            +
                align_language,
         | 
| 751 | 
            +
                result_diarize,
         | 
| 752 | 
            +
                output_format_subtitle,
         | 
| 753 | 
            +
                TRANSLATE_AUDIO_TO,
         | 
| 754 | 
            +
            ):
         | 
| 755 | 
            +
                name_ori = "sub_ori."
         | 
| 756 | 
            +
                name_tra = "sub_tra."
         | 
| 757 | 
            +
                remove_files(
         | 
| 758 | 
            +
                    [name_ori + output_format_subtitle, name_tra + output_format_subtitle]
         | 
| 759 | 
            +
                )
         | 
| 760 | 
            +
             | 
| 761 | 
            +
                writer = get_writer(output_format_subtitle, output_dir=".")
         | 
| 762 | 
            +
                word_options = {
         | 
| 763 | 
            +
                    "highlight_words": False,
         | 
| 764 | 
            +
                    "max_line_count": None,
         | 
| 765 | 
            +
                    "max_line_width": None,
         | 
| 766 | 
            +
                }
         | 
| 767 | 
            +
             | 
| 768 | 
            +
                # original lang
         | 
| 769 | 
            +
                subs_copy_result = copy.deepcopy(deep_copied_result)
         | 
| 770 | 
            +
                subs_copy_result["language"] = (
         | 
| 771 | 
            +
                    "zh" if align_language == "zh-TW" else align_language
         | 
| 772 | 
            +
                )
         | 
| 773 | 
            +
                for segment in subs_copy_result["segments"]:
         | 
| 774 | 
            +
                    segment.pop("speaker", None)
         | 
| 775 | 
            +
             | 
| 776 | 
            +
                try:
         | 
| 777 | 
            +
                    writer(
         | 
| 778 | 
            +
                        subs_copy_result,
         | 
| 779 | 
            +
                        name_ori[:-1] + ".mp3",
         | 
| 780 | 
            +
                        word_options,
         | 
| 781 | 
            +
                    )
         | 
| 782 | 
            +
                except Exception as error:
         | 
| 783 | 
            +
                    logger.error(str(error))
         | 
| 784 | 
            +
                    if str(error) == "list indices must be integers or slices, not str":
         | 
| 785 | 
            +
                        logger.error(
         | 
| 786 | 
            +
                            "Related to poor word segmentation"
         | 
| 787 | 
            +
                            " in segments after alignment."
         | 
| 788 | 
            +
                        )
         | 
| 789 | 
            +
                    subs_copy_result["segments"][0].pop("words")
         | 
| 790 | 
            +
                    writer(
         | 
| 791 | 
            +
                        subs_copy_result,
         | 
| 792 | 
            +
                        name_ori[:-1] + ".mp3",
         | 
| 793 | 
            +
                        word_options,
         | 
| 794 | 
            +
                    )
         | 
| 795 | 
            +
             | 
| 796 | 
            +
                # translated lang
         | 
| 797 | 
            +
                subs_tra_copy_result = copy.deepcopy(result_diarize)
         | 
| 798 | 
            +
                subs_tra_copy_result["language"] = (
         | 
| 799 | 
            +
                    "ja" if TRANSLATE_AUDIO_TO in ["ja", "zh", "zh-TW"] else align_language
         | 
| 800 | 
            +
                )
         | 
| 801 | 
            +
                subs_tra_copy_result.pop("word_segments", None)
         | 
| 802 | 
            +
                for segment in subs_tra_copy_result["segments"]:
         | 
| 803 | 
            +
                    for key in ["speaker", "chars", "words"]:
         | 
| 804 | 
            +
                        segment.pop(key, None)
         | 
| 805 | 
            +
             | 
| 806 | 
            +
                writer(
         | 
| 807 | 
            +
                    subs_tra_copy_result,
         | 
| 808 | 
            +
                    name_tra[:-1] + ".mp3",
         | 
| 809 | 
            +
                    word_options,
         | 
| 810 | 
            +
                )
         | 
| 811 | 
            +
             | 
| 812 | 
            +
                return name_tra + output_format_subtitle
         | 
| 813 | 
            +
             | 
| 814 | 
            +
             | 
| 815 | 
            +
            def linguistic_level_segments(
         | 
| 816 | 
            +
                result_base,
         | 
| 817 | 
            +
                linguistic_unit="word",  # word or char
         | 
| 818 | 
            +
            ):
         | 
| 819 | 
            +
                linguistic_unit = linguistic_unit[:4]
         | 
| 820 | 
            +
                linguistic_unit_key = linguistic_unit + "s"
         | 
| 821 | 
            +
                result = copy.deepcopy(result_base)
         | 
| 822 | 
            +
             | 
| 823 | 
            +
                if linguistic_unit_key not in result["segments"][0].keys():
         | 
| 824 | 
            +
                    raise ValueError("No alignment detected, can't process")
         | 
| 825 | 
            +
             | 
| 826 | 
            +
                segments_by_unit = []
         | 
| 827 | 
            +
                for segment in result["segments"]:
         | 
| 828 | 
            +
                    segment_units = segment[linguistic_unit_key]
         | 
| 829 | 
            +
                    # segment_speaker = segment.get("speaker", "SPEAKER_00")
         | 
| 830 | 
            +
             | 
| 831 | 
            +
                    for unit in segment_units:
         | 
| 832 | 
            +
             | 
| 833 | 
            +
                        text = unit[linguistic_unit]
         | 
| 834 | 
            +
             | 
| 835 | 
            +
                        if "start" in unit.keys():
         | 
| 836 | 
            +
                            segments_by_unit.append(
         | 
| 837 | 
            +
                                {
         | 
| 838 | 
            +
                                    "start": unit["start"],
         | 
| 839 | 
            +
                                    "end": unit["end"],
         | 
| 840 | 
            +
                                    "text": text,
         | 
| 841 | 
            +
                                    # "speaker": segment_speaker,
         | 
| 842 | 
            +
                                }
         | 
| 843 | 
            +
                                )
         | 
| 844 | 
            +
                        elif not segments_by_unit:
         | 
| 845 | 
            +
                            pass
         | 
| 846 | 
            +
                        else:
         | 
| 847 | 
            +
                            segments_by_unit[-1]["text"] += text
         | 
| 848 | 
            +
             | 
| 849 | 
            +
                return {"segments": segments_by_unit}
         | 
| 850 | 
            +
             | 
| 851 | 
            +
             | 
| 852 | 
            +
            def break_aling_segments(
         | 
| 853 | 
            +
                result: dict,
         | 
| 854 | 
            +
                break_characters: str = "",  # ":|,|.|"
         | 
| 855 | 
            +
            ):
         | 
| 856 | 
            +
                result_align = copy.deepcopy(result)
         | 
| 857 | 
            +
             | 
| 858 | 
            +
                break_characters_list = break_characters.split("|")
         | 
| 859 | 
            +
                break_characters_list = [i for i in break_characters_list if i != '']
         | 
| 860 | 
            +
             | 
| 861 | 
            +
                if not break_characters_list:
         | 
| 862 | 
            +
                    logger.info("No valid break characters were specified.")
         | 
| 863 | 
            +
                    return result
         | 
| 864 | 
            +
             | 
| 865 | 
            +
                logger.info(f"Redivide text segments by: {str(break_characters_list)}")
         | 
| 866 | 
            +
             | 
| 867 | 
            +
                # create new with filters
         | 
| 868 | 
            +
                normal = []
         | 
| 869 | 
            +
             | 
| 870 | 
            +
                def process_chars(chars, letter_new_start, num, text):
         | 
| 871 | 
            +
                    start_key, end_key = "start", "end"
         | 
| 872 | 
            +
                    start_value = end_value = None
         | 
| 873 | 
            +
             | 
| 874 | 
            +
                    for char in chars:
         | 
| 875 | 
            +
                        if start_key in char:
         | 
| 876 | 
            +
                            start_value = char[start_key]
         | 
| 877 | 
            +
                            break
         | 
| 878 | 
            +
             | 
| 879 | 
            +
                    for char in reversed(chars):
         | 
| 880 | 
            +
                        if end_key in char:
         | 
| 881 | 
            +
                            end_value = char[end_key]
         | 
| 882 | 
            +
                            break
         | 
| 883 | 
            +
             | 
| 884 | 
            +
                    if not start_value or not end_value:
         | 
| 885 | 
            +
                        raise Exception(
         | 
| 886 | 
            +
                            f"Unable to obtain a valid timestamp for chars: {str(chars)}"
         | 
| 887 | 
            +
                        )
         | 
| 888 | 
            +
             | 
| 889 | 
            +
                    return {
         | 
| 890 | 
            +
                        "start": start_value,
         | 
| 891 | 
            +
                        "end": end_value,
         | 
| 892 | 
            +
                        "text": text,
         | 
| 893 | 
            +
                        "words": chars,
         | 
| 894 | 
            +
                    }
         | 
| 895 | 
            +
             | 
| 896 | 
            +
                for i, segment in enumerate(result_align['segments']):
         | 
| 897 | 
            +
             | 
| 898 | 
            +
                    logger.debug(f"- Process segment: {i}, text: {segment['text']}")
         | 
| 899 | 
            +
                    # start = segment['start']
         | 
| 900 | 
            +
                    letter_new_start = 0
         | 
| 901 | 
            +
                    for num, char in enumerate(segment['chars']):
         | 
| 902 | 
            +
             | 
| 903 | 
            +
                        if char["char"] is None:
         | 
| 904 | 
            +
                            continue
         | 
| 905 | 
            +
             | 
| 906 | 
            +
                        # if "start" in char:
         | 
| 907 | 
            +
                        #     start = char["start"]
         | 
| 908 | 
            +
             | 
| 909 | 
            +
                        # if "end" in char:
         | 
| 910 | 
            +
                        #     end = char["end"]
         | 
| 911 | 
            +
             | 
| 912 | 
            +
                        # Break by character
         | 
| 913 | 
            +
                        if char['char'] in break_characters_list:
         | 
| 914 | 
            +
             | 
| 915 | 
            +
                            text = segment['text'][letter_new_start:num+1]
         | 
| 916 | 
            +
             | 
| 917 | 
            +
                            logger.debug(
         | 
| 918 | 
            +
                                f"Break in: {char['char']}, position: {num}, text: {text}"
         | 
| 919 | 
            +
                            )
         | 
| 920 | 
            +
             | 
| 921 | 
            +
                            chars = segment['chars'][letter_new_start:num+1]
         | 
| 922 | 
            +
             | 
| 923 | 
            +
                            if not text:
         | 
| 924 | 
            +
                                logger.debug("No text")
         | 
| 925 | 
            +
                                continue
         | 
| 926 | 
            +
             | 
| 927 | 
            +
                            if num == 0 and not text.strip():
         | 
| 928 | 
            +
                                logger.debug("blank space in start")
         | 
| 929 | 
            +
                                continue
         | 
| 930 | 
            +
             | 
| 931 | 
            +
                            if len(text) == 1:
         | 
| 932 | 
            +
                                logger.debug(f"Short char append, num: {num}")
         | 
| 933 | 
            +
                                normal[-1]["text"] += text
         | 
| 934 | 
            +
                                normal[-1]["words"].append(chars)
         | 
| 935 | 
            +
                                continue
         | 
| 936 | 
            +
             | 
| 937 | 
            +
                            # logger.debug(chars)
         | 
| 938 | 
            +
                            normal_dict = process_chars(chars, letter_new_start, num, text)
         | 
| 939 | 
            +
             | 
| 940 | 
            +
                            letter_new_start = num+1
         | 
| 941 | 
            +
             | 
| 942 | 
            +
                            normal.append(normal_dict)
         | 
| 943 | 
            +
             | 
| 944 | 
            +
                        # If we reach the end of the segment, add the last part of chars.
         | 
| 945 | 
            +
                        if num == len(segment["chars"]) - 1:
         | 
| 946 | 
            +
             | 
| 947 | 
            +
                            text = segment['text'][letter_new_start:num+1]
         | 
| 948 | 
            +
             | 
| 949 | 
            +
                            # If remain text len is not default len text
         | 
| 950 | 
            +
                            if num not in [len(text)-1, len(text)] and text:
         | 
| 951 | 
            +
                                logger.debug(f'Remaining text: {text}')
         | 
| 952 | 
            +
             | 
| 953 | 
            +
                            if not text:
         | 
| 954 | 
            +
                                logger.debug("No remaining text.")
         | 
| 955 | 
            +
                                continue
         | 
| 956 | 
            +
             | 
| 957 | 
            +
                            if len(text) == 1:
         | 
| 958 | 
            +
                                logger.debug(f"Short char append, num: {num}")
         | 
| 959 | 
            +
                                normal[-1]["text"] += text
         | 
| 960 | 
            +
                                normal[-1]["words"].append(chars)
         | 
| 961 | 
            +
                                continue
         | 
| 962 | 
            +
             | 
| 963 | 
            +
                            chars = segment['chars'][letter_new_start:num+1]
         | 
| 964 | 
            +
             | 
| 965 | 
            +
                            normal_dict = process_chars(chars, letter_new_start, num, text)
         | 
| 966 | 
            +
             | 
| 967 | 
            +
                            letter_new_start = num+1
         | 
| 968 | 
            +
             | 
| 969 | 
            +
                            normal.append(normal_dict)
         | 
| 970 | 
            +
             | 
| 971 | 
            +
                # Rename char to word
         | 
| 972 | 
            +
                for item in normal:
         | 
| 973 | 
            +
                    words_list = item['words']
         | 
| 974 | 
            +
                    for word_item in words_list:
         | 
| 975 | 
            +
                        if 'char' in word_item:
         | 
| 976 | 
            +
                            word_item['word'] = word_item.pop('char')
         | 
| 977 | 
            +
             | 
| 978 | 
            +
                # Convert to dict default
         | 
| 979 | 
            +
                break_segments = {"segments": normal}
         | 
| 980 | 
            +
             | 
| 981 | 
            +
                msg_count = (
         | 
| 982 | 
            +
                    f"Segment count before: {len(result['segments'])}, "
         | 
| 983 | 
            +
                    f"after: {len(break_segments['segments'])}."
         | 
| 984 | 
            +
                )
         | 
| 985 | 
            +
                logger.info(msg_count)
         | 
| 986 | 
            +
             | 
| 987 | 
            +
                return break_segments
         | 
    	
        soni_translate/text_to_speech.py
    ADDED
    
    | @@ -0,0 +1,1574 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from gtts import gTTS
         | 
| 2 | 
            +
            import edge_tts, asyncio, json, glob # noqa
         | 
| 3 | 
            +
            from tqdm import tqdm
         | 
| 4 | 
            +
            import librosa, os, re, torch, gc, subprocess # noqa
         | 
| 5 | 
            +
            from .language_configuration import (
         | 
| 6 | 
            +
                fix_code_language,
         | 
| 7 | 
            +
                BARK_VOICES_LIST,
         | 
| 8 | 
            +
                VITS_VOICES_LIST,
         | 
| 9 | 
            +
            )
         | 
| 10 | 
            +
            from .utils import (
         | 
| 11 | 
            +
                download_manager,
         | 
| 12 | 
            +
                create_directories,
         | 
| 13 | 
            +
                copy_files,
         | 
| 14 | 
            +
                rename_file,
         | 
| 15 | 
            +
                remove_directory_contents,
         | 
| 16 | 
            +
                remove_files,
         | 
| 17 | 
            +
                run_command,
         | 
| 18 | 
            +
            )
         | 
| 19 | 
            +
            import numpy as np
         | 
| 20 | 
            +
            from typing import Any, Dict
         | 
| 21 | 
            +
            from pathlib import Path
         | 
| 22 | 
            +
            import soundfile as sf
         | 
| 23 | 
            +
            import platform
         | 
| 24 | 
            +
            import logging
         | 
| 25 | 
            +
            import traceback
         | 
| 26 | 
            +
            from .logging_setup import logger
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            class TTS_OperationError(Exception):
         | 
| 30 | 
            +
                def __init__(self, message="The operation did not complete successfully."):
         | 
| 31 | 
            +
                    self.message = message
         | 
| 32 | 
            +
                    super().__init__(self.message)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            def verify_saved_file_and_size(filename):
         | 
| 36 | 
            +
                if not os.path.exists(filename):
         | 
| 37 | 
            +
                    raise TTS_OperationError(f"File '{filename}' was not saved.")
         | 
| 38 | 
            +
                if os.path.getsize(filename) == 0:
         | 
| 39 | 
            +
                    raise TTS_OperationError(
         | 
| 40 | 
            +
                        f"File '{filename}' has a zero size. "
         | 
| 41 | 
            +
                        "Related to incorrect TTS for the target language"
         | 
| 42 | 
            +
                    )
         | 
| 43 | 
            +
             | 
| 44 | 
            +
             | 
| 45 | 
            +
            def error_handling_in_tts(error, segment, TRANSLATE_AUDIO_TO, filename):
         | 
| 46 | 
            +
                traceback.print_exc()
         | 
| 47 | 
            +
                logger.error(f"Error: {str(error)}")
         | 
| 48 | 
            +
                try:
         | 
| 49 | 
            +
                    from tempfile import TemporaryFile
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    tts = gTTS(segment["text"], lang=fix_code_language(TRANSLATE_AUDIO_TO))
         | 
| 52 | 
            +
                    # tts.save(filename)
         | 
| 53 | 
            +
                    f = TemporaryFile()
         | 
| 54 | 
            +
                    tts.write_to_fp(f)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                    # Reset the file pointer to the beginning of the file
         | 
| 57 | 
            +
                    f.seek(0)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    # Read audio data from the TemporaryFile using soundfile
         | 
| 60 | 
            +
                    audio_data, samplerate = sf.read(f)
         | 
| 61 | 
            +
                    f.close()  # Close the TemporaryFile
         | 
| 62 | 
            +
                    sf.write(
         | 
| 63 | 
            +
                        filename, audio_data, samplerate, format="ogg", subtype="vorbis"
         | 
| 64 | 
            +
                    )
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    logger.warning(
         | 
| 67 | 
            +
                        'TTS auxiliary will be utilized '
         | 
| 68 | 
            +
                        f'rather than TTS: {segment["tts_name"]}'
         | 
| 69 | 
            +
                    )
         | 
| 70 | 
            +
                    verify_saved_file_and_size(filename)
         | 
| 71 | 
            +
                except Exception as error:
         | 
| 72 | 
            +
                    logger.critical(f"Error: {str(error)}")
         | 
| 73 | 
            +
                    sample_rate_aux = 22050
         | 
| 74 | 
            +
                    duration = float(segment["end"]) - float(segment["start"])
         | 
| 75 | 
            +
                    data = np.zeros(int(sample_rate_aux * duration)).astype(np.float32)
         | 
| 76 | 
            +
                    sf.write(
         | 
| 77 | 
            +
                        filename, data, sample_rate_aux, format="ogg", subtype="vorbis"
         | 
| 78 | 
            +
                    )
         | 
| 79 | 
            +
                    logger.error("Audio will be replaced -> [silent audio].")
         | 
| 80 | 
            +
                    verify_saved_file_and_size(filename)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
             | 
| 83 | 
            +
            def pad_array(array, sr):
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                if isinstance(array, list):
         | 
| 86 | 
            +
                    array = np.array(array)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                if not array.shape[0]:
         | 
| 89 | 
            +
                    raise ValueError("The generated audio does not contain any data")
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                valid_indices = np.where(np.abs(array) > 0.001)[0]
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                if len(valid_indices) == 0:
         | 
| 94 | 
            +
                    logger.debug(f"No valid indices: {array}")
         | 
| 95 | 
            +
                    return array
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                try:
         | 
| 98 | 
            +
                    pad_indice = int(0.1 * sr)
         | 
| 99 | 
            +
                    start_pad = max(0, valid_indices[0] - pad_indice)
         | 
| 100 | 
            +
                    end_pad = min(len(array), valid_indices[-1] + 1 + pad_indice)
         | 
| 101 | 
            +
                    padded_array = array[start_pad:end_pad]
         | 
| 102 | 
            +
                    return padded_array
         | 
| 103 | 
            +
                except Exception as error:
         | 
| 104 | 
            +
                    logger.error(str(error))
         | 
| 105 | 
            +
                    return array
         | 
| 106 | 
            +
             | 
| 107 | 
            +
             | 
| 108 | 
            +
            # =====================================
         | 
| 109 | 
            +
            # EDGE TTS
         | 
| 110 | 
            +
            # =====================================
         | 
| 111 | 
            +
             | 
| 112 | 
            +
             | 
| 113 | 
            +
            def edge_tts_voices_list():
         | 
| 114 | 
            +
                try:
         | 
| 115 | 
            +
                    completed_process = subprocess.run(
         | 
| 116 | 
            +
                        ["edge-tts", "--list-voices"], capture_output=True, text=True
         | 
| 117 | 
            +
                    )
         | 
| 118 | 
            +
                    lines = completed_process.stdout.strip().split("\n")
         | 
| 119 | 
            +
                except Exception as error:
         | 
| 120 | 
            +
                    logger.debug(str(error))
         | 
| 121 | 
            +
                    lines = []
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                voices = []
         | 
| 124 | 
            +
                for line in lines:
         | 
| 125 | 
            +
                    if line.startswith("Name: "):
         | 
| 126 | 
            +
                        voice_entry = {}
         | 
| 127 | 
            +
                        voice_entry["Name"] = line.split(": ")[1]
         | 
| 128 | 
            +
                    elif line.startswith("Gender: "):
         | 
| 129 | 
            +
                        voice_entry["Gender"] = line.split(": ")[1]
         | 
| 130 | 
            +
                        voices.append(voice_entry)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                formatted_voices = [
         | 
| 133 | 
            +
                    f"{entry['Name']}-{entry['Gender']}" for entry in voices
         | 
| 134 | 
            +
                ]
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                if not formatted_voices:
         | 
| 137 | 
            +
                    logger.warning(
         | 
| 138 | 
            +
                        "The list of Edge TTS voices could not be obtained, "
         | 
| 139 | 
            +
                        "switching to an alternative method"
         | 
| 140 | 
            +
                    )
         | 
| 141 | 
            +
                    tts_voice_list = asyncio.new_event_loop().run_until_complete(
         | 
| 142 | 
            +
                        edge_tts.list_voices()
         | 
| 143 | 
            +
                    )
         | 
| 144 | 
            +
                    formatted_voices = sorted(
         | 
| 145 | 
            +
                        [f"{v['ShortName']}-{v['Gender']}" for v in tts_voice_list]
         | 
| 146 | 
            +
                    )
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                if not formatted_voices:
         | 
| 149 | 
            +
                    logger.error("Can't get EDGE TTS - list voices")
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                return formatted_voices
         | 
| 152 | 
            +
             | 
| 153 | 
            +
             | 
| 154 | 
            +
            def segments_egde_tts(filtered_edge_segments, TRANSLATE_AUDIO_TO, is_gui):
         | 
| 155 | 
            +
                for segment in tqdm(filtered_edge_segments["segments"]):
         | 
| 156 | 
            +
                    speaker = segment["speaker"] # noqa
         | 
| 157 | 
            +
                    text = segment["text"]
         | 
| 158 | 
            +
                    start = segment["start"]
         | 
| 159 | 
            +
                    tts_name = segment["tts_name"]
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    # make the tts audio
         | 
| 162 | 
            +
                    filename = f"audio/{start}.ogg"
         | 
| 163 | 
            +
                    temp_file = filename[:-3] + "mp3"
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                    logger.info(f"{text} >> {filename}")
         | 
| 166 | 
            +
                    try:
         | 
| 167 | 
            +
                        if is_gui:
         | 
| 168 | 
            +
                            asyncio.run(
         | 
| 169 | 
            +
                                edge_tts.Communicate(
         | 
| 170 | 
            +
                                    text, "-".join(tts_name.split("-")[:-1])
         | 
| 171 | 
            +
                                ).save(temp_file)
         | 
| 172 | 
            +
                            )
         | 
| 173 | 
            +
                        else:
         | 
| 174 | 
            +
                            # nest_asyncio.apply() if not is_gui else None
         | 
| 175 | 
            +
                            command = f'edge-tts -t "{text}" -v "{tts_name.replace("-Male", "").replace("-Female", "")}" --write-media "{temp_file}"'
         | 
| 176 | 
            +
                            run_command(command)
         | 
| 177 | 
            +
                        verify_saved_file_and_size(temp_file)
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                        data, sample_rate = sf.read(temp_file)
         | 
| 180 | 
            +
                        data = pad_array(data, sample_rate)
         | 
| 181 | 
            +
                        # os.remove(temp_file)
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                        # Save file
         | 
| 184 | 
            +
                        sf.write(
         | 
| 185 | 
            +
                            file=filename,
         | 
| 186 | 
            +
                            samplerate=sample_rate,
         | 
| 187 | 
            +
                            data=data,
         | 
| 188 | 
            +
                            format="ogg",
         | 
| 189 | 
            +
                            subtype="vorbis",
         | 
| 190 | 
            +
                        )
         | 
| 191 | 
            +
                        verify_saved_file_and_size(filename)
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                    except Exception as error:
         | 
| 194 | 
            +
                        error_handling_in_tts(error, segment, TRANSLATE_AUDIO_TO, filename)
         | 
| 195 | 
            +
             | 
| 196 | 
            +
             | 
| 197 | 
            +
            # =====================================
         | 
| 198 | 
            +
            # BARK TTS
         | 
| 199 | 
            +
            # =====================================
         | 
| 200 | 
            +
             | 
| 201 | 
            +
             | 
| 202 | 
            +
            def segments_bark_tts(
         | 
| 203 | 
            +
                filtered_bark_segments, TRANSLATE_AUDIO_TO, model_id_bark="suno/bark-small"
         | 
| 204 | 
            +
            ):
         | 
| 205 | 
            +
                from transformers import AutoProcessor, BarkModel
         | 
| 206 | 
            +
                from optimum.bettertransformer import BetterTransformer
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                device = os.environ.get("SONITR_DEVICE")
         | 
| 209 | 
            +
                torch_dtype_env = torch.float16 if device == "cuda" else torch.float32
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                # load model bark
         | 
| 212 | 
            +
                model = BarkModel.from_pretrained(
         | 
| 213 | 
            +
                    model_id_bark, torch_dtype=torch_dtype_env
         | 
| 214 | 
            +
                ).to(device)
         | 
| 215 | 
            +
                model = model.to(device)
         | 
| 216 | 
            +
                processor = AutoProcessor.from_pretrained(
         | 
| 217 | 
            +
                    model_id_bark, return_tensors="pt"
         | 
| 218 | 
            +
                )  # , padding=True
         | 
| 219 | 
            +
                if device == "cuda":
         | 
| 220 | 
            +
                    # convert to bettertransformer
         | 
| 221 | 
            +
                    model = BetterTransformer.transform(model, keep_original_model=False)
         | 
| 222 | 
            +
                    # enable CPU offload
         | 
| 223 | 
            +
                    # model.enable_cpu_offload()
         | 
| 224 | 
            +
                sampling_rate = model.generation_config.sample_rate
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                # filtered_segments = filtered_bark_segments['segments']
         | 
| 227 | 
            +
                # Sorting the segments by 'tts_name'
         | 
| 228 | 
            +
                # sorted_segments = sorted(filtered_segments, key=lambda x: x['tts_name'])
         | 
| 229 | 
            +
                # logger.debug(sorted_segments)
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                for segment in tqdm(filtered_bark_segments["segments"]):
         | 
| 232 | 
            +
                    speaker = segment["speaker"] # noqa
         | 
| 233 | 
            +
                    text = segment["text"]
         | 
| 234 | 
            +
                    start = segment["start"]
         | 
| 235 | 
            +
                    tts_name = segment["tts_name"]
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                    inputs = processor(text, voice_preset=BARK_VOICES_LIST[tts_name]).to(
         | 
| 238 | 
            +
                        device
         | 
| 239 | 
            +
                    )
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                    # make the tts audio
         | 
| 242 | 
            +
                    filename = f"audio/{start}.ogg"
         | 
| 243 | 
            +
                    logger.info(f"{text} >> {filename}")
         | 
| 244 | 
            +
                    try:
         | 
| 245 | 
            +
                        # Infer
         | 
| 246 | 
            +
                        with torch.inference_mode():
         | 
| 247 | 
            +
                            speech_output = model.generate(
         | 
| 248 | 
            +
                                **inputs,
         | 
| 249 | 
            +
                                do_sample=True,
         | 
| 250 | 
            +
                                fine_temperature=0.4,
         | 
| 251 | 
            +
                                coarse_temperature=0.8,
         | 
| 252 | 
            +
                                pad_token_id=processor.tokenizer.pad_token_id,
         | 
| 253 | 
            +
                            )
         | 
| 254 | 
            +
                        # Save file
         | 
| 255 | 
            +
                        data_tts = pad_array(
         | 
| 256 | 
            +
                            speech_output.cpu().numpy().squeeze().astype(np.float32),
         | 
| 257 | 
            +
                            sampling_rate,
         | 
| 258 | 
            +
                        )
         | 
| 259 | 
            +
                        sf.write(
         | 
| 260 | 
            +
                            file=filename,
         | 
| 261 | 
            +
                            samplerate=sampling_rate,
         | 
| 262 | 
            +
                            data=data_tts,
         | 
| 263 | 
            +
                            format="ogg",
         | 
| 264 | 
            +
                            subtype="vorbis",
         | 
| 265 | 
            +
                        )
         | 
| 266 | 
            +
                        verify_saved_file_and_size(filename)
         | 
| 267 | 
            +
                    except Exception as error:
         | 
| 268 | 
            +
                        error_handling_in_tts(error, segment, TRANSLATE_AUDIO_TO, filename)
         | 
| 269 | 
            +
                    gc.collect()
         | 
| 270 | 
            +
                    torch.cuda.empty_cache()
         | 
| 271 | 
            +
                try:
         | 
| 272 | 
            +
                    del processor
         | 
| 273 | 
            +
                    del model
         | 
| 274 | 
            +
                    gc.collect()
         | 
| 275 | 
            +
                    torch.cuda.empty_cache()
         | 
| 276 | 
            +
                except Exception as error:
         | 
| 277 | 
            +
                    logger.error(str(error))
         | 
| 278 | 
            +
                    gc.collect()
         | 
| 279 | 
            +
                    torch.cuda.empty_cache()
         | 
| 280 | 
            +
             | 
| 281 | 
            +
             | 
| 282 | 
            +
            # =====================================
         | 
| 283 | 
            +
            # VITS TTS
         | 
| 284 | 
            +
            # =====================================
         | 
| 285 | 
            +
             | 
| 286 | 
            +
             | 
| 287 | 
            +
            def uromanize(input_string):
         | 
| 288 | 
            +
                """Convert non-Roman strings to Roman using the `uroman` perl package."""
         | 
| 289 | 
            +
                # script_path = os.path.join(uroman_path, "bin", "uroman.pl")
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                if not os.path.exists("./uroman"):
         | 
| 292 | 
            +
                    logger.info(
         | 
| 293 | 
            +
                        "Clonning repository uroman https://github.com/isi-nlp/uroman.git"
         | 
| 294 | 
            +
                        " for romanize the text"
         | 
| 295 | 
            +
                    )
         | 
| 296 | 
            +
                    process = subprocess.Popen(
         | 
| 297 | 
            +
                        ["git", "clone", "https://github.com/isi-nlp/uroman.git"],
         | 
| 298 | 
            +
                        stdout=subprocess.PIPE,
         | 
| 299 | 
            +
                        stderr=subprocess.PIPE,
         | 
| 300 | 
            +
                    )
         | 
| 301 | 
            +
                    stdout, stderr = process.communicate()
         | 
| 302 | 
            +
                script_path = os.path.join("./uroman", "bin", "uroman.pl")
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                command = ["perl", script_path]
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                process = subprocess.Popen(
         | 
| 307 | 
            +
                    command,
         | 
| 308 | 
            +
                    stdin=subprocess.PIPE,
         | 
| 309 | 
            +
                    stdout=subprocess.PIPE,
         | 
| 310 | 
            +
                    stderr=subprocess.PIPE,
         | 
| 311 | 
            +
                )
         | 
| 312 | 
            +
                # Execute the perl command
         | 
| 313 | 
            +
                stdout, stderr = process.communicate(input=input_string.encode())
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                if process.returncode != 0:
         | 
| 316 | 
            +
                    raise ValueError(f"Error {process.returncode}: {stderr.decode()}")
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                # Return the output as a string and skip the new-line character at the end
         | 
| 319 | 
            +
                return stdout.decode()[:-1]
         | 
| 320 | 
            +
             | 
| 321 | 
            +
             | 
| 322 | 
            +
            def segments_vits_tts(filtered_vits_segments, TRANSLATE_AUDIO_TO):
         | 
| 323 | 
            +
                from transformers import VitsModel, AutoTokenizer
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                filtered_segments = filtered_vits_segments["segments"]
         | 
| 326 | 
            +
                # Sorting the segments by 'tts_name'
         | 
| 327 | 
            +
                sorted_segments = sorted(filtered_segments, key=lambda x: x["tts_name"])
         | 
| 328 | 
            +
                logger.debug(sorted_segments)
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                model_name_key = None
         | 
| 331 | 
            +
                for segment in tqdm(sorted_segments):
         | 
| 332 | 
            +
                    speaker = segment["speaker"] # noqa
         | 
| 333 | 
            +
                    text = segment["text"]
         | 
| 334 | 
            +
                    start = segment["start"]
         | 
| 335 | 
            +
                    tts_name = segment["tts_name"]
         | 
| 336 | 
            +
             | 
| 337 | 
            +
                    if tts_name != model_name_key:
         | 
| 338 | 
            +
                        model_name_key = tts_name
         | 
| 339 | 
            +
                        model = VitsModel.from_pretrained(VITS_VOICES_LIST[tts_name])
         | 
| 340 | 
            +
                        tokenizer = AutoTokenizer.from_pretrained(
         | 
| 341 | 
            +
                            VITS_VOICES_LIST[tts_name]
         | 
| 342 | 
            +
                        )
         | 
| 343 | 
            +
                        sampling_rate = model.config.sampling_rate
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                    if tokenizer.is_uroman:
         | 
| 346 | 
            +
                        romanize_text = uromanize(text)
         | 
| 347 | 
            +
                        logger.debug(f"Romanize text: {romanize_text}")
         | 
| 348 | 
            +
                        inputs = tokenizer(romanize_text, return_tensors="pt")
         | 
| 349 | 
            +
                    else:
         | 
| 350 | 
            +
                        inputs = tokenizer(text, return_tensors="pt")
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                    # make the tts audio
         | 
| 353 | 
            +
                    filename = f"audio/{start}.ogg"
         | 
| 354 | 
            +
                    logger.info(f"{text} >> {filename}")
         | 
| 355 | 
            +
                    try:
         | 
| 356 | 
            +
                        # Infer
         | 
| 357 | 
            +
                        with torch.no_grad():
         | 
| 358 | 
            +
                            speech_output = model(**inputs).waveform
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                        data_tts = pad_array(
         | 
| 361 | 
            +
                            speech_output.cpu().numpy().squeeze().astype(np.float32),
         | 
| 362 | 
            +
                            sampling_rate,
         | 
| 363 | 
            +
                        )
         | 
| 364 | 
            +
                        # Save file
         | 
| 365 | 
            +
                        sf.write(
         | 
| 366 | 
            +
                            file=filename,
         | 
| 367 | 
            +
                            samplerate=sampling_rate,
         | 
| 368 | 
            +
                            data=data_tts,
         | 
| 369 | 
            +
                            format="ogg",
         | 
| 370 | 
            +
                            subtype="vorbis",
         | 
| 371 | 
            +
                        )
         | 
| 372 | 
            +
                        verify_saved_file_and_size(filename)
         | 
| 373 | 
            +
                    except Exception as error:
         | 
| 374 | 
            +
                        error_handling_in_tts(error, segment, TRANSLATE_AUDIO_TO, filename)
         | 
| 375 | 
            +
                    gc.collect()
         | 
| 376 | 
            +
                    torch.cuda.empty_cache()
         | 
| 377 | 
            +
                try:
         | 
| 378 | 
            +
                    del tokenizer
         | 
| 379 | 
            +
                    del model
         | 
| 380 | 
            +
                    gc.collect()
         | 
| 381 | 
            +
                    torch.cuda.empty_cache()
         | 
| 382 | 
            +
                except Exception as error:
         | 
| 383 | 
            +
                    logger.error(str(error))
         | 
| 384 | 
            +
                    gc.collect()
         | 
| 385 | 
            +
                    torch.cuda.empty_cache()
         | 
| 386 | 
            +
             | 
| 387 | 
            +
             | 
| 388 | 
            +
            # =====================================
         | 
| 389 | 
            +
            # Coqui XTTS
         | 
| 390 | 
            +
            # =====================================
         | 
| 391 | 
            +
             | 
| 392 | 
            +
             | 
| 393 | 
            +
            def coqui_xtts_voices_list():
         | 
| 394 | 
            +
                main_folder = "_XTTS_"
         | 
| 395 | 
            +
                pattern_coqui = re.compile(r".+\.(wav|mp3|ogg|m4a)$")
         | 
| 396 | 
            +
                pattern_automatic_speaker = re.compile(r"AUTOMATIC_SPEAKER_\d+\.wav$")
         | 
| 397 | 
            +
             | 
| 398 | 
            +
                # List only files in the directory matching the pattern but not matching
         | 
| 399 | 
            +
                # AUTOMATIC_SPEAKER_00.wav, AUTOMATIC_SPEAKER_01.wav, etc.
         | 
| 400 | 
            +
                wav_voices = [
         | 
| 401 | 
            +
                    "_XTTS_/" + f
         | 
| 402 | 
            +
                    for f in os.listdir(main_folder)
         | 
| 403 | 
            +
                    if os.path.isfile(os.path.join(main_folder, f))
         | 
| 404 | 
            +
                    and pattern_coqui.match(f)
         | 
| 405 | 
            +
                    and not pattern_automatic_speaker.match(f)
         | 
| 406 | 
            +
                ]
         | 
| 407 | 
            +
             | 
| 408 | 
            +
                return ["_XTTS_/AUTOMATIC.wav"] + wav_voices
         | 
| 409 | 
            +
             | 
| 410 | 
            +
             | 
| 411 | 
            +
            def seconds_to_hhmmss_ms(seconds):
         | 
| 412 | 
            +
                hours = seconds // 3600
         | 
| 413 | 
            +
                minutes = (seconds % 3600) // 60
         | 
| 414 | 
            +
                seconds = seconds % 60
         | 
| 415 | 
            +
                milliseconds = int((seconds - int(seconds)) * 1000)
         | 
| 416 | 
            +
                return "%02d:%02d:%02d.%03d" % (hours, minutes, int(seconds), milliseconds)
         | 
| 417 | 
            +
             | 
| 418 | 
            +
             | 
| 419 | 
            +
            def audio_trimming(audio_path, destination, start, end):
         | 
| 420 | 
            +
                if isinstance(start, (int, float)):
         | 
| 421 | 
            +
                    start = seconds_to_hhmmss_ms(start)
         | 
| 422 | 
            +
                if isinstance(end, (int, float)):
         | 
| 423 | 
            +
                    end = seconds_to_hhmmss_ms(end)
         | 
| 424 | 
            +
             | 
| 425 | 
            +
                if destination:
         | 
| 426 | 
            +
                    file_directory = destination
         | 
| 427 | 
            +
                else:
         | 
| 428 | 
            +
                    file_directory = os.path.dirname(audio_path)
         | 
| 429 | 
            +
             | 
| 430 | 
            +
                file_name = os.path.splitext(os.path.basename(audio_path))[0]
         | 
| 431 | 
            +
                file_ = f"{file_name}_trim.wav"
         | 
| 432 | 
            +
                # file_ = f'{os.path.splitext(audio_path)[0]}_trim.wav'
         | 
| 433 | 
            +
                output_path = os.path.join(file_directory, file_)
         | 
| 434 | 
            +
             | 
| 435 | 
            +
                # -t (duration from -ss) | -to (time stop) | -af silenceremove=1:0:-50dB (remove silence)
         | 
| 436 | 
            +
                command = f'ffmpeg -y -loglevel error -i "{audio_path}" -ss {start} -to {end} -acodec pcm_s16le -f wav "{output_path}"'
         | 
| 437 | 
            +
                run_command(command)
         | 
| 438 | 
            +
             | 
| 439 | 
            +
                return output_path
         | 
| 440 | 
            +
             | 
| 441 | 
            +
             | 
| 442 | 
            +
            def convert_to_xtts_good_sample(audio_path: str = "", destination: str = ""):
         | 
| 443 | 
            +
                if destination:
         | 
| 444 | 
            +
                    file_directory = destination
         | 
| 445 | 
            +
                else:
         | 
| 446 | 
            +
                    file_directory = os.path.dirname(audio_path)
         | 
| 447 | 
            +
             | 
| 448 | 
            +
                file_name = os.path.splitext(os.path.basename(audio_path))[0]
         | 
| 449 | 
            +
                file_ = f"{file_name}_good_sample.wav"
         | 
| 450 | 
            +
                # file_ = f'{os.path.splitext(audio_path)[0]}_good_sample.wav'
         | 
| 451 | 
            +
                mono_path = os.path.join(file_directory, file_)  # get root
         | 
| 452 | 
            +
             | 
| 453 | 
            +
                command = f'ffmpeg -y -loglevel error -i "{audio_path}" -ac 1 -ar 22050 -sample_fmt s16 -f wav "{mono_path}"'
         | 
| 454 | 
            +
                run_command(command)
         | 
| 455 | 
            +
             | 
| 456 | 
            +
                return mono_path
         | 
| 457 | 
            +
             | 
| 458 | 
            +
             | 
| 459 | 
            +
            def sanitize_file_name(file_name):
         | 
| 460 | 
            +
                import unicodedata
         | 
| 461 | 
            +
             | 
| 462 | 
            +
                # Normalize the string to NFKD form to separate combined characters into
         | 
| 463 | 
            +
                # base characters and diacritics
         | 
| 464 | 
            +
                normalized_name = unicodedata.normalize("NFKD", file_name)
         | 
| 465 | 
            +
                # Replace any non-ASCII characters or special symbols with an underscore
         | 
| 466 | 
            +
                sanitized_name = re.sub(r"[^\w\s.-]", "_", normalized_name)
         | 
| 467 | 
            +
                return sanitized_name
         | 
| 468 | 
            +
             | 
| 469 | 
            +
             | 
| 470 | 
            +
            def create_wav_file_vc(
         | 
| 471 | 
            +
                sample_name="",  # name final file
         | 
| 472 | 
            +
                audio_wav="",  # path
         | 
| 473 | 
            +
                start=None,  # trim start
         | 
| 474 | 
            +
                end=None,  # trim end
         | 
| 475 | 
            +
                output_final_path="_XTTS_",
         | 
| 476 | 
            +
                get_vocals_dereverb=True,
         | 
| 477 | 
            +
            ):
         | 
| 478 | 
            +
                sample_name = sample_name if sample_name else "default_name"
         | 
| 479 | 
            +
                sample_name = sanitize_file_name(sample_name)
         | 
| 480 | 
            +
                audio_wav = audio_wav if isinstance(audio_wav, str) else audio_wav.name
         | 
| 481 | 
            +
             | 
| 482 | 
            +
                BASE_DIR = (
         | 
| 483 | 
            +
                    "."  # os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
         | 
| 484 | 
            +
                )
         | 
| 485 | 
            +
             | 
| 486 | 
            +
                output_dir = os.path.join(BASE_DIR, "clean_song_output")  # remove content
         | 
| 487 | 
            +
                # remove_directory_contents(output_dir)
         | 
| 488 | 
            +
             | 
| 489 | 
            +
                if start or end:
         | 
| 490 | 
            +
                    # Cut file
         | 
| 491 | 
            +
                    audio_segment = audio_trimming(audio_wav, output_dir, start, end)
         | 
| 492 | 
            +
                else:
         | 
| 493 | 
            +
                    # Complete file
         | 
| 494 | 
            +
                    audio_segment = audio_wav
         | 
| 495 | 
            +
             | 
| 496 | 
            +
                from .mdx_net import process_uvr_task
         | 
| 497 | 
            +
             | 
| 498 | 
            +
                try:
         | 
| 499 | 
            +
                    _, _, _, _, audio_segment = process_uvr_task(
         | 
| 500 | 
            +
                        orig_song_path=audio_segment,
         | 
| 501 | 
            +
                        main_vocals=True,
         | 
| 502 | 
            +
                        dereverb=get_vocals_dereverb,
         | 
| 503 | 
            +
                    )
         | 
| 504 | 
            +
                except Exception as error:
         | 
| 505 | 
            +
                    logger.error(str(error))
         | 
| 506 | 
            +
             | 
| 507 | 
            +
                sample = convert_to_xtts_good_sample(audio_segment)
         | 
| 508 | 
            +
             | 
| 509 | 
            +
                sample_name = f"{sample_name}.wav"
         | 
| 510 | 
            +
                sample_rename = rename_file(sample, sample_name)
         | 
| 511 | 
            +
             | 
| 512 | 
            +
                copy_files(sample_rename, output_final_path)
         | 
| 513 | 
            +
             | 
| 514 | 
            +
                final_sample = os.path.join(output_final_path, sample_name)
         | 
| 515 | 
            +
                if os.path.exists(final_sample):
         | 
| 516 | 
            +
                    logger.info(final_sample)
         | 
| 517 | 
            +
                    return final_sample
         | 
| 518 | 
            +
                else:
         | 
| 519 | 
            +
                    raise Exception(f"Error wav: {final_sample}")
         | 
| 520 | 
            +
             | 
| 521 | 
            +
             | 
| 522 | 
            +
            def create_new_files_for_vc(
         | 
| 523 | 
            +
                speakers_coqui,
         | 
| 524 | 
            +
                segments_base,
         | 
| 525 | 
            +
                dereverb_automatic=True
         | 
| 526 | 
            +
            ):
         | 
| 527 | 
            +
                # before function delete automatic delete_previous_automatic
         | 
| 528 | 
            +
                output_dir = os.path.join(".", "clean_song_output")  # remove content
         | 
| 529 | 
            +
                remove_directory_contents(output_dir)
         | 
| 530 | 
            +
             | 
| 531 | 
            +
                for speaker in speakers_coqui:
         | 
| 532 | 
            +
                    filtered_speaker = [
         | 
| 533 | 
            +
                        segment
         | 
| 534 | 
            +
                        for segment in segments_base
         | 
| 535 | 
            +
                        if segment["speaker"] == speaker
         | 
| 536 | 
            +
                    ]
         | 
| 537 | 
            +
                    if len(filtered_speaker) > 4:
         | 
| 538 | 
            +
                        filtered_speaker = filtered_speaker[1:]
         | 
| 539 | 
            +
                    if filtered_speaker[0]["tts_name"] == "_XTTS_/AUTOMATIC.wav":
         | 
| 540 | 
            +
                        name_automatic_wav = f"AUTOMATIC_{speaker}"
         | 
| 541 | 
            +
                        if os.path.exists(f"_XTTS_/{name_automatic_wav}.wav"):
         | 
| 542 | 
            +
                            logger.info(f"WAV automatic {speaker} exists")
         | 
| 543 | 
            +
                            # path_wav = path_automatic_wav
         | 
| 544 | 
            +
                            pass
         | 
| 545 | 
            +
                        else:
         | 
| 546 | 
            +
                            # create wav
         | 
| 547 | 
            +
                            wav_ok = False
         | 
| 548 | 
            +
                            for seg in filtered_speaker:
         | 
| 549 | 
            +
                                duration = float(seg["end"]) - float(seg["start"])
         | 
| 550 | 
            +
                                if duration > 7.0 and duration < 12.0:
         | 
| 551 | 
            +
                                    logger.info(
         | 
| 552 | 
            +
                                        f'Processing segment: {seg["start"]}, {seg["end"]}, {seg["speaker"]}, {duration}, {seg["text"]}'
         | 
| 553 | 
            +
                                    )
         | 
| 554 | 
            +
                                    create_wav_file_vc(
         | 
| 555 | 
            +
                                        sample_name=name_automatic_wav,
         | 
| 556 | 
            +
                                        audio_wav="audio.wav",
         | 
| 557 | 
            +
                                        start=(float(seg["start"]) + 1.0),
         | 
| 558 | 
            +
                                        end=(float(seg["end"]) - 1.0),
         | 
| 559 | 
            +
                                        get_vocals_dereverb=dereverb_automatic,
         | 
| 560 | 
            +
                                    )
         | 
| 561 | 
            +
                                    wav_ok = True
         | 
| 562 | 
            +
                                    break
         | 
| 563 | 
            +
             | 
| 564 | 
            +
                            if not wav_ok:
         | 
| 565 | 
            +
                                logger.info("Taking the first segment")
         | 
| 566 | 
            +
                                seg = filtered_speaker[0]
         | 
| 567 | 
            +
                                logger.info(
         | 
| 568 | 
            +
                                    f'Processing segment: {seg["start"]}, {seg["end"]}, {seg["speaker"]}, {seg["text"]}'
         | 
| 569 | 
            +
                                )
         | 
| 570 | 
            +
                                max_duration = float(seg["end"]) - float(seg["start"])
         | 
| 571 | 
            +
                                max_duration = max(2.0, min(max_duration, 9.0))
         | 
| 572 | 
            +
             | 
| 573 | 
            +
                                create_wav_file_vc(
         | 
| 574 | 
            +
                                    sample_name=name_automatic_wav,
         | 
| 575 | 
            +
                                    audio_wav="audio.wav",
         | 
| 576 | 
            +
                                    start=(float(seg["start"])),
         | 
| 577 | 
            +
                                    end=(float(seg["start"]) + max_duration),
         | 
| 578 | 
            +
                                    get_vocals_dereverb=dereverb_automatic,
         | 
| 579 | 
            +
                                )
         | 
| 580 | 
            +
             | 
| 581 | 
            +
             | 
| 582 | 
            +
            def segments_coqui_tts(
         | 
| 583 | 
            +
                filtered_coqui_segments,
         | 
| 584 | 
            +
                TRANSLATE_AUDIO_TO,
         | 
| 585 | 
            +
                model_id_coqui="tts_models/multilingual/multi-dataset/xtts_v2",
         | 
| 586 | 
            +
                speakers_coqui=None,
         | 
| 587 | 
            +
                delete_previous_automatic=True,
         | 
| 588 | 
            +
                dereverb_automatic=True,
         | 
| 589 | 
            +
                emotion=None,
         | 
| 590 | 
            +
            ):
         | 
| 591 | 
            +
                """XTTS
         | 
| 592 | 
            +
                Install:
         | 
| 593 | 
            +
                pip install -q TTS==0.21.1
         | 
| 594 | 
            +
                pip install -q numpy==1.23.5
         | 
| 595 | 
            +
             | 
| 596 | 
            +
                Notes:
         | 
| 597 | 
            +
                - tts_name is the wav|mp3|ogg|m4a file for VC
         | 
| 598 | 
            +
                """
         | 
| 599 | 
            +
                from TTS.api import TTS
         | 
| 600 | 
            +
             | 
| 601 | 
            +
                TRANSLATE_AUDIO_TO = fix_code_language(TRANSLATE_AUDIO_TO, syntax="coqui")
         | 
| 602 | 
            +
                supported_lang_coqui = [
         | 
| 603 | 
            +
                    "zh-cn",
         | 
| 604 | 
            +
                    "en",
         | 
| 605 | 
            +
                    "fr",
         | 
| 606 | 
            +
                    "de",
         | 
| 607 | 
            +
                    "it",
         | 
| 608 | 
            +
                    "pt",
         | 
| 609 | 
            +
                    "pl",
         | 
| 610 | 
            +
                    "tr",
         | 
| 611 | 
            +
                    "ru",
         | 
| 612 | 
            +
                    "nl",
         | 
| 613 | 
            +
                    "cs",
         | 
| 614 | 
            +
                    "ar",
         | 
| 615 | 
            +
                    "es",
         | 
| 616 | 
            +
                    "hu",
         | 
| 617 | 
            +
                    "ko",
         | 
| 618 | 
            +
                    "ja",
         | 
| 619 | 
            +
                ]
         | 
| 620 | 
            +
                if TRANSLATE_AUDIO_TO not in supported_lang_coqui:
         | 
| 621 | 
            +
                    raise TTS_OperationError(
         | 
| 622 | 
            +
                        f"'{TRANSLATE_AUDIO_TO}' is not a supported language for Coqui XTTS"
         | 
| 623 | 
            +
                    )
         | 
| 624 | 
            +
                # Emotion and speed can only be used with Coqui Studio models. discontinued
         | 
| 625 | 
            +
                # emotions = ["Neutral", "Happy", "Sad", "Angry", "Dull"]
         | 
| 626 | 
            +
             | 
| 627 | 
            +
                if delete_previous_automatic:
         | 
| 628 | 
            +
                    for spk in speakers_coqui:
         | 
| 629 | 
            +
                        remove_files(f"_XTTS_/AUTOMATIC_{spk}.wav")
         | 
| 630 | 
            +
             | 
| 631 | 
            +
                directory_audios_vc = "_XTTS_"
         | 
| 632 | 
            +
                create_directories(directory_audios_vc)
         | 
| 633 | 
            +
                create_new_files_for_vc(
         | 
| 634 | 
            +
                    speakers_coqui,
         | 
| 635 | 
            +
                    filtered_coqui_segments["segments"],
         | 
| 636 | 
            +
                    dereverb_automatic,
         | 
| 637 | 
            +
                )
         | 
| 638 | 
            +
             | 
| 639 | 
            +
                # Init TTS
         | 
| 640 | 
            +
                device = os.environ.get("SONITR_DEVICE")
         | 
| 641 | 
            +
                model = TTS(model_id_coqui).to(device)
         | 
| 642 | 
            +
                sampling_rate = 24000
         | 
| 643 | 
            +
             | 
| 644 | 
            +
                # filtered_segments = filtered_coqui_segments['segments']
         | 
| 645 | 
            +
                # Sorting the segments by 'tts_name'
         | 
| 646 | 
            +
                # sorted_segments = sorted(filtered_segments, key=lambda x: x['tts_name'])
         | 
| 647 | 
            +
                # logger.debug(sorted_segments)
         | 
| 648 | 
            +
             | 
| 649 | 
            +
                for segment in tqdm(filtered_coqui_segments["segments"]):
         | 
| 650 | 
            +
                    speaker = segment["speaker"]
         | 
| 651 | 
            +
                    text = segment["text"]
         | 
| 652 | 
            +
                    start = segment["start"]
         | 
| 653 | 
            +
                    tts_name = segment["tts_name"]
         | 
| 654 | 
            +
                    if tts_name == "_XTTS_/AUTOMATIC.wav":
         | 
| 655 | 
            +
                        tts_name = f"_XTTS_/AUTOMATIC_{speaker}.wav"
         | 
| 656 | 
            +
             | 
| 657 | 
            +
                    # make the tts audio
         | 
| 658 | 
            +
                    filename = f"audio/{start}.ogg"
         | 
| 659 | 
            +
                    logger.info(f"{text} >> {filename}")
         | 
| 660 | 
            +
                    try:
         | 
| 661 | 
            +
                        # Infer
         | 
| 662 | 
            +
                        wav = model.tts(
         | 
| 663 | 
            +
                            text=text, speaker_wav=tts_name, language=TRANSLATE_AUDIO_TO
         | 
| 664 | 
            +
                        )
         | 
| 665 | 
            +
                        data_tts = pad_array(
         | 
| 666 | 
            +
                            wav,
         | 
| 667 | 
            +
                            sampling_rate,
         | 
| 668 | 
            +
                        )
         | 
| 669 | 
            +
                        # Save file
         | 
| 670 | 
            +
                        sf.write(
         | 
| 671 | 
            +
                            file=filename,
         | 
| 672 | 
            +
                            samplerate=sampling_rate,
         | 
| 673 | 
            +
                            data=data_tts,
         | 
| 674 | 
            +
                            format="ogg",
         | 
| 675 | 
            +
                            subtype="vorbis",
         | 
| 676 | 
            +
                        )
         | 
| 677 | 
            +
                        verify_saved_file_and_size(filename)
         | 
| 678 | 
            +
                    except Exception as error:
         | 
| 679 | 
            +
                        error_handling_in_tts(error, segment, TRANSLATE_AUDIO_TO, filename)
         | 
| 680 | 
            +
                    gc.collect()
         | 
| 681 | 
            +
                    torch.cuda.empty_cache()
         | 
| 682 | 
            +
                try:
         | 
| 683 | 
            +
                    del model
         | 
| 684 | 
            +
                    gc.collect()
         | 
| 685 | 
            +
                    torch.cuda.empty_cache()
         | 
| 686 | 
            +
                except Exception as error:
         | 
| 687 | 
            +
                    logger.error(str(error))
         | 
| 688 | 
            +
                    gc.collect()
         | 
| 689 | 
            +
                    torch.cuda.empty_cache()
         | 
| 690 | 
            +
             | 
| 691 | 
            +
             | 
| 692 | 
            +
            # =====================================
         | 
| 693 | 
            +
            # PIPER TTS
         | 
| 694 | 
            +
            # =====================================
         | 
| 695 | 
            +
             | 
| 696 | 
            +
             | 
| 697 | 
            +
            def piper_tts_voices_list():
         | 
| 698 | 
            +
                file_path = download_manager(
         | 
| 699 | 
            +
                    url="https://huggingface.co/rhasspy/piper-voices/resolve/main/voices.json",
         | 
| 700 | 
            +
                    path="./PIPER_MODELS",
         | 
| 701 | 
            +
                )
         | 
| 702 | 
            +
             | 
| 703 | 
            +
                with open(file_path, "r", encoding="utf8") as file:
         | 
| 704 | 
            +
                    data = json.load(file)
         | 
| 705 | 
            +
                piper_id_models = [key + " VITS-onnx" for key in data.keys()]
         | 
| 706 | 
            +
             | 
| 707 | 
            +
                return piper_id_models
         | 
| 708 | 
            +
             | 
| 709 | 
            +
             | 
| 710 | 
            +
            def replace_text_in_json(file_path, key_to_replace, new_text, condition=None):
         | 
| 711 | 
            +
                # Read the JSON file
         | 
| 712 | 
            +
                with open(file_path, "r", encoding="utf-8") as file:
         | 
| 713 | 
            +
                    data = json.load(file)
         | 
| 714 | 
            +
             | 
| 715 | 
            +
                # Modify the specified key's value with the new text
         | 
| 716 | 
            +
                if key_to_replace in data:
         | 
| 717 | 
            +
                    if condition:
         | 
| 718 | 
            +
                        value_condition = condition
         | 
| 719 | 
            +
                    else:
         | 
| 720 | 
            +
                        value_condition = data[key_to_replace]
         | 
| 721 | 
            +
             | 
| 722 | 
            +
                    if data[key_to_replace] == value_condition:
         | 
| 723 | 
            +
                        data[key_to_replace] = new_text
         | 
| 724 | 
            +
             | 
| 725 | 
            +
                # Write the modified content back to the JSON file
         | 
| 726 | 
            +
                with open(file_path, "w") as file:
         | 
| 727 | 
            +
                    json.dump(
         | 
| 728 | 
            +
                        data, file, indent=2
         | 
| 729 | 
            +
                    )  # Write the modified data back to the file with indentation for readability
         | 
| 730 | 
            +
             | 
| 731 | 
            +
             | 
| 732 | 
            +
            def load_piper_model(
         | 
| 733 | 
            +
                model: str,
         | 
| 734 | 
            +
                data_dir: list,
         | 
| 735 | 
            +
                download_dir: str = "",
         | 
| 736 | 
            +
                update_voices: bool = False,
         | 
| 737 | 
            +
            ):
         | 
| 738 | 
            +
                from piper import PiperVoice
         | 
| 739 | 
            +
                from piper.download import ensure_voice_exists, find_voice, get_voices
         | 
| 740 | 
            +
             | 
| 741 | 
            +
                try:
         | 
| 742 | 
            +
                    import onnxruntime as rt
         | 
| 743 | 
            +
             | 
| 744 | 
            +
                    if rt.get_device() == "GPU" and os.environ.get("SONITR_DEVICE") == "cuda":
         | 
| 745 | 
            +
                        logger.debug("onnxruntime device > GPU")
         | 
| 746 | 
            +
                        cuda = True
         | 
| 747 | 
            +
                    else:
         | 
| 748 | 
            +
                        logger.info(
         | 
| 749 | 
            +
                            "onnxruntime device > CPU"
         | 
| 750 | 
            +
                        )  # try pip install onnxruntime-gpu
         | 
| 751 | 
            +
                        cuda = False
         | 
| 752 | 
            +
                except Exception as error:
         | 
| 753 | 
            +
                    raise TTS_OperationError(f"onnxruntime error: {str(error)}")
         | 
| 754 | 
            +
             | 
| 755 | 
            +
                # Disable CUDA in Windows
         | 
| 756 | 
            +
                if platform.system() == "Windows":
         | 
| 757 | 
            +
                    logger.info("Employing CPU exclusivity with Piper TTS")
         | 
| 758 | 
            +
                    cuda = False
         | 
| 759 | 
            +
             | 
| 760 | 
            +
                if not download_dir:
         | 
| 761 | 
            +
                    # Download to first data directory by default
         | 
| 762 | 
            +
                    download_dir = data_dir[0]
         | 
| 763 | 
            +
                else:
         | 
| 764 | 
            +
                    data_dir = [os.path.join(data_dir[0], download_dir)]
         | 
| 765 | 
            +
             | 
| 766 | 
            +
                # Download voice if file doesn't exist
         | 
| 767 | 
            +
                model_path = Path(model)
         | 
| 768 | 
            +
                if not model_path.exists():
         | 
| 769 | 
            +
                    # Load voice info
         | 
| 770 | 
            +
                    voices_info = get_voices(download_dir, update_voices=update_voices)
         | 
| 771 | 
            +
             | 
| 772 | 
            +
                    # Resolve aliases for backwards compatibility with old voice names
         | 
| 773 | 
            +
                    aliases_info: Dict[str, Any] = {}
         | 
| 774 | 
            +
                    for voice_info in voices_info.values():
         | 
| 775 | 
            +
                        for voice_alias in voice_info.get("aliases", []):
         | 
| 776 | 
            +
                            aliases_info[voice_alias] = {"_is_alias": True, **voice_info}
         | 
| 777 | 
            +
             | 
| 778 | 
            +
                    voices_info.update(aliases_info)
         | 
| 779 | 
            +
                    ensure_voice_exists(model, data_dir, download_dir, voices_info)
         | 
| 780 | 
            +
                    model, config = find_voice(model, data_dir)
         | 
| 781 | 
            +
             | 
| 782 | 
            +
                    replace_text_in_json(
         | 
| 783 | 
            +
                        config, "phoneme_type", "espeak", "PhonemeType.ESPEAK"
         | 
| 784 | 
            +
                    )
         | 
| 785 | 
            +
             | 
| 786 | 
            +
                # Load voice
         | 
| 787 | 
            +
                voice = PiperVoice.load(model, config_path=config, use_cuda=cuda)
         | 
| 788 | 
            +
             | 
| 789 | 
            +
                return voice
         | 
| 790 | 
            +
             | 
| 791 | 
            +
             | 
| 792 | 
            +
            def synthesize_text_to_audio_np_array(voice, text, synthesize_args):
         | 
| 793 | 
            +
                audio_stream = voice.synthesize_stream_raw(text, **synthesize_args)
         | 
| 794 | 
            +
             | 
| 795 | 
            +
                # Collect the audio bytes into a single NumPy array
         | 
| 796 | 
            +
                audio_data = b""
         | 
| 797 | 
            +
                for audio_bytes in audio_stream:
         | 
| 798 | 
            +
                    audio_data += audio_bytes
         | 
| 799 | 
            +
             | 
| 800 | 
            +
                # Ensure correct data type and convert audio bytes to NumPy array
         | 
| 801 | 
            +
                audio_np = np.frombuffer(audio_data, dtype=np.int16)
         | 
| 802 | 
            +
                return audio_np
         | 
| 803 | 
            +
             | 
| 804 | 
            +
             | 
| 805 | 
            +
            def segments_vits_onnx_tts(filtered_onnx_vits_segments, TRANSLATE_AUDIO_TO):
         | 
| 806 | 
            +
                """
         | 
| 807 | 
            +
                Install:
         | 
| 808 | 
            +
                pip install -q piper-tts==1.2.0 onnxruntime-gpu # for cuda118
         | 
| 809 | 
            +
                """
         | 
| 810 | 
            +
             | 
| 811 | 
            +
                data_dir = [
         | 
| 812 | 
            +
                    str(Path.cwd())
         | 
| 813 | 
            +
                ]  # "Data directory to check for downloaded models (default: current directory)"
         | 
| 814 | 
            +
                download_dir = "PIPER_MODELS"
         | 
| 815 | 
            +
                # model_name = "en_US-lessac-medium" tts_name in a dict like VITS
         | 
| 816 | 
            +
                update_voices = True  # "Download latest voices.json during startup",
         | 
| 817 | 
            +
             | 
| 818 | 
            +
                synthesize_args = {
         | 
| 819 | 
            +
                    "speaker_id": None,
         | 
| 820 | 
            +
                    "length_scale": 1.0,
         | 
| 821 | 
            +
                    "noise_scale": 0.667,
         | 
| 822 | 
            +
                    "noise_w": 0.8,
         | 
| 823 | 
            +
                    "sentence_silence": 0.0,
         | 
| 824 | 
            +
                }
         | 
| 825 | 
            +
             | 
| 826 | 
            +
                filtered_segments = filtered_onnx_vits_segments["segments"]
         | 
| 827 | 
            +
                # Sorting the segments by 'tts_name'
         | 
| 828 | 
            +
                sorted_segments = sorted(filtered_segments, key=lambda x: x["tts_name"])
         | 
| 829 | 
            +
                logger.debug(sorted_segments)
         | 
| 830 | 
            +
             | 
| 831 | 
            +
                model_name_key = None
         | 
| 832 | 
            +
                for segment in tqdm(sorted_segments):
         | 
| 833 | 
            +
                    speaker = segment["speaker"] # noqa
         | 
| 834 | 
            +
                    text = segment["text"]
         | 
| 835 | 
            +
                    start = segment["start"]
         | 
| 836 | 
            +
                    tts_name = segment["tts_name"].replace(" VITS-onnx", "")
         | 
| 837 | 
            +
             | 
| 838 | 
            +
                    if tts_name != model_name_key:
         | 
| 839 | 
            +
                        model_name_key = tts_name
         | 
| 840 | 
            +
                        model = load_piper_model(
         | 
| 841 | 
            +
                            tts_name, data_dir, download_dir, update_voices
         | 
| 842 | 
            +
                        )
         | 
| 843 | 
            +
                        sampling_rate = model.config.sample_rate
         | 
| 844 | 
            +
             | 
| 845 | 
            +
                    # make the tts audio
         | 
| 846 | 
            +
                    filename = f"audio/{start}.ogg"
         | 
| 847 | 
            +
                    logger.info(f"{text} >> {filename}")
         | 
| 848 | 
            +
                    try:
         | 
| 849 | 
            +
                        # Infer
         | 
| 850 | 
            +
                        speech_output = synthesize_text_to_audio_np_array(
         | 
| 851 | 
            +
                            model, text, synthesize_args
         | 
| 852 | 
            +
                        )
         | 
| 853 | 
            +
                        data_tts = pad_array(
         | 
| 854 | 
            +
                            speech_output,  # .cpu().numpy().squeeze().astype(np.float32),
         | 
| 855 | 
            +
                            sampling_rate,
         | 
| 856 | 
            +
                        )
         | 
| 857 | 
            +
                        # Save file
         | 
| 858 | 
            +
                        sf.write(
         | 
| 859 | 
            +
                            file=filename,
         | 
| 860 | 
            +
                            samplerate=sampling_rate,
         | 
| 861 | 
            +
                            data=data_tts,
         | 
| 862 | 
            +
                            format="ogg",
         | 
| 863 | 
            +
                            subtype="vorbis",
         | 
| 864 | 
            +
                        )
         | 
| 865 | 
            +
                        verify_saved_file_and_size(filename)
         | 
| 866 | 
            +
                    except Exception as error:
         | 
| 867 | 
            +
                        error_handling_in_tts(error, segment, TRANSLATE_AUDIO_TO, filename)
         | 
| 868 | 
            +
                    gc.collect()
         | 
| 869 | 
            +
                    torch.cuda.empty_cache()
         | 
| 870 | 
            +
                try:
         | 
| 871 | 
            +
                    del model
         | 
| 872 | 
            +
                    gc.collect()
         | 
| 873 | 
            +
                    torch.cuda.empty_cache()
         | 
| 874 | 
            +
                except Exception as error:
         | 
| 875 | 
            +
                    logger.error(str(error))
         | 
| 876 | 
            +
                    gc.collect()
         | 
| 877 | 
            +
                    torch.cuda.empty_cache()
         | 
| 878 | 
            +
             | 
| 879 | 
            +
             | 
| 880 | 
            +
            # =====================================
         | 
| 881 | 
            +
            # CLOSEAI TTS
         | 
| 882 | 
            +
            # =====================================
         | 
| 883 | 
            +
             | 
| 884 | 
            +
             | 
| 885 | 
            +
            def segments_openai_tts(
         | 
| 886 | 
            +
                filtered_openai_tts_segments, TRANSLATE_AUDIO_TO
         | 
| 887 | 
            +
            ):
         | 
| 888 | 
            +
                from openai import OpenAI
         | 
| 889 | 
            +
             | 
| 890 | 
            +
                client = OpenAI()
         | 
| 891 | 
            +
                sampling_rate = 24000
         | 
| 892 | 
            +
             | 
| 893 | 
            +
                # filtered_segments = filtered_openai_tts_segments['segments']
         | 
| 894 | 
            +
                # Sorting the segments by 'tts_name'
         | 
| 895 | 
            +
                # sorted_segments = sorted(filtered_segments, key=lambda x: x['tts_name'])
         | 
| 896 | 
            +
             | 
| 897 | 
            +
                for segment in tqdm(filtered_openai_tts_segments["segments"]):
         | 
| 898 | 
            +
                    speaker = segment["speaker"] # noqa
         | 
| 899 | 
            +
                    text = segment["text"].strip()
         | 
| 900 | 
            +
                    start = segment["start"]
         | 
| 901 | 
            +
                    tts_name = segment["tts_name"]
         | 
| 902 | 
            +
             | 
| 903 | 
            +
                    # make the tts audio
         | 
| 904 | 
            +
                    filename = f"audio/{start}.ogg"
         | 
| 905 | 
            +
                    logger.info(f"{text} >> {filename}")
         | 
| 906 | 
            +
             | 
| 907 | 
            +
                    try:
         | 
| 908 | 
            +
                        # Request
         | 
| 909 | 
            +
                        response = client.audio.speech.create(
         | 
| 910 | 
            +
                            model="tts-1-hd" if "HD" in tts_name else "tts-1",
         | 
| 911 | 
            +
                            voice=tts_name.split()[0][1:],
         | 
| 912 | 
            +
                            response_format="wav",
         | 
| 913 | 
            +
                            input=text
         | 
| 914 | 
            +
                        )
         | 
| 915 | 
            +
             | 
| 916 | 
            +
                        audio_bytes = b''
         | 
| 917 | 
            +
                        for data in response.iter_bytes(chunk_size=4096):
         | 
| 918 | 
            +
                            audio_bytes += data
         | 
| 919 | 
            +
             | 
| 920 | 
            +
                        speech_output = np.frombuffer(audio_bytes, dtype=np.int16)
         | 
| 921 | 
            +
             | 
| 922 | 
            +
                        # Save file
         | 
| 923 | 
            +
                        data_tts = pad_array(
         | 
| 924 | 
            +
                            speech_output[240:],
         | 
| 925 | 
            +
                            sampling_rate,
         | 
| 926 | 
            +
                        )
         | 
| 927 | 
            +
             | 
| 928 | 
            +
                        sf.write(
         | 
| 929 | 
            +
                            file=filename,
         | 
| 930 | 
            +
                            samplerate=sampling_rate,
         | 
| 931 | 
            +
                            data=data_tts,
         | 
| 932 | 
            +
                            format="ogg",
         | 
| 933 | 
            +
                            subtype="vorbis",
         | 
| 934 | 
            +
                        )
         | 
| 935 | 
            +
                        verify_saved_file_and_size(filename)
         | 
| 936 | 
            +
             | 
| 937 | 
            +
                    except Exception as error:
         | 
| 938 | 
            +
                        error_handling_in_tts(error, segment, TRANSLATE_AUDIO_TO, filename)
         | 
| 939 | 
            +
             | 
| 940 | 
            +
             | 
| 941 | 
            +
            # =====================================
         | 
| 942 | 
            +
            # Select task TTS
         | 
| 943 | 
            +
            # =====================================
         | 
| 944 | 
            +
             | 
| 945 | 
            +
             | 
| 946 | 
            +
            def find_spkr(pattern, speaker_to_voice, segments):
         | 
| 947 | 
            +
                return [
         | 
| 948 | 
            +
                    speaker
         | 
| 949 | 
            +
                    for speaker, voice in speaker_to_voice.items()
         | 
| 950 | 
            +
                    if pattern.match(voice) and any(
         | 
| 951 | 
            +
                        segment["speaker"] == speaker for segment in segments
         | 
| 952 | 
            +
                    )
         | 
| 953 | 
            +
                ]
         | 
| 954 | 
            +
             | 
| 955 | 
            +
             | 
| 956 | 
            +
            def filter_by_speaker(speakers, segments):
         | 
| 957 | 
            +
                return {
         | 
| 958 | 
            +
                    "segments": [
         | 
| 959 | 
            +
                        segment
         | 
| 960 | 
            +
                        for segment in segments
         | 
| 961 | 
            +
                        if segment["speaker"] in speakers
         | 
| 962 | 
            +
                    ]
         | 
| 963 | 
            +
                }
         | 
| 964 | 
            +
             | 
| 965 | 
            +
             | 
| 966 | 
            +
            def audio_segmentation_to_voice(
         | 
| 967 | 
            +
                result_diarize,
         | 
| 968 | 
            +
                TRANSLATE_AUDIO_TO,
         | 
| 969 | 
            +
                is_gui,
         | 
| 970 | 
            +
                tts_voice00,
         | 
| 971 | 
            +
                tts_voice01="",
         | 
| 972 | 
            +
                tts_voice02="",
         | 
| 973 | 
            +
                tts_voice03="",
         | 
| 974 | 
            +
                tts_voice04="",
         | 
| 975 | 
            +
                tts_voice05="",
         | 
| 976 | 
            +
                tts_voice06="",
         | 
| 977 | 
            +
                tts_voice07="",
         | 
| 978 | 
            +
                tts_voice08="",
         | 
| 979 | 
            +
                tts_voice09="",
         | 
| 980 | 
            +
                tts_voice10="",
         | 
| 981 | 
            +
                tts_voice11="",
         | 
| 982 | 
            +
                dereverb_automatic=True,
         | 
| 983 | 
            +
                model_id_bark="suno/bark-small",
         | 
| 984 | 
            +
                model_id_coqui="tts_models/multilingual/multi-dataset/xtts_v2",
         | 
| 985 | 
            +
                delete_previous_automatic=True,
         | 
| 986 | 
            +
            ):
         | 
| 987 | 
            +
             | 
| 988 | 
            +
                remove_directory_contents("audio")
         | 
| 989 | 
            +
             | 
| 990 | 
            +
                # Mapping speakers to voice variables
         | 
| 991 | 
            +
                speaker_to_voice = {
         | 
| 992 | 
            +
                    "SPEAKER_00": tts_voice00,
         | 
| 993 | 
            +
                    "SPEAKER_01": tts_voice01,
         | 
| 994 | 
            +
                    "SPEAKER_02": tts_voice02,
         | 
| 995 | 
            +
                    "SPEAKER_03": tts_voice03,
         | 
| 996 | 
            +
                    "SPEAKER_04": tts_voice04,
         | 
| 997 | 
            +
                    "SPEAKER_05": tts_voice05,
         | 
| 998 | 
            +
                    "SPEAKER_06": tts_voice06,
         | 
| 999 | 
            +
                    "SPEAKER_07": tts_voice07,
         | 
| 1000 | 
            +
                    "SPEAKER_08": tts_voice08,
         | 
| 1001 | 
            +
                    "SPEAKER_09": tts_voice09,
         | 
| 1002 | 
            +
                    "SPEAKER_10": tts_voice10,
         | 
| 1003 | 
            +
                    "SPEAKER_11": tts_voice11,
         | 
| 1004 | 
            +
                }
         | 
| 1005 | 
            +
             | 
| 1006 | 
            +
                # Assign 'SPEAKER_00' to segments without a 'speaker' key
         | 
| 1007 | 
            +
                for segment in result_diarize["segments"]:
         | 
| 1008 | 
            +
                    if "speaker" not in segment:
         | 
| 1009 | 
            +
                        segment["speaker"] = "SPEAKER_00"
         | 
| 1010 | 
            +
                        logger.warning(
         | 
| 1011 | 
            +
                            "NO SPEAKER DETECT IN SEGMENT: First TTS will be used in the"
         | 
| 1012 | 
            +
                            f" segment time {segment['start'], segment['text']}"
         | 
| 1013 | 
            +
                        )
         | 
| 1014 | 
            +
                    # Assign the TTS name
         | 
| 1015 | 
            +
                    segment["tts_name"] = speaker_to_voice[segment["speaker"]]
         | 
| 1016 | 
            +
             | 
| 1017 | 
            +
                # Find TTS method
         | 
| 1018 | 
            +
                pattern_edge = re.compile(r".*-(Male|Female)$")
         | 
| 1019 | 
            +
                pattern_bark = re.compile(r".* BARK$")
         | 
| 1020 | 
            +
                pattern_vits = re.compile(r".* VITS$")
         | 
| 1021 | 
            +
                pattern_coqui = re.compile(r".+\.(wav|mp3|ogg|m4a)$")
         | 
| 1022 | 
            +
                pattern_vits_onnx = re.compile(r".* VITS-onnx$")
         | 
| 1023 | 
            +
                pattern_openai_tts = re.compile(r".* OpenAI-TTS$")
         | 
| 1024 | 
            +
             | 
| 1025 | 
            +
                all_segments = result_diarize["segments"]
         | 
| 1026 | 
            +
             | 
| 1027 | 
            +
                speakers_edge = find_spkr(pattern_edge, speaker_to_voice, all_segments)
         | 
| 1028 | 
            +
                speakers_bark = find_spkr(pattern_bark, speaker_to_voice, all_segments)
         | 
| 1029 | 
            +
                speakers_vits = find_spkr(pattern_vits, speaker_to_voice, all_segments)
         | 
| 1030 | 
            +
                speakers_coqui = find_spkr(pattern_coqui, speaker_to_voice, all_segments)
         | 
| 1031 | 
            +
                speakers_vits_onnx = find_spkr(
         | 
| 1032 | 
            +
                    pattern_vits_onnx, speaker_to_voice, all_segments
         | 
| 1033 | 
            +
                )
         | 
| 1034 | 
            +
                speakers_openai_tts = find_spkr(
         | 
| 1035 | 
            +
                    pattern_openai_tts, speaker_to_voice, all_segments
         | 
| 1036 | 
            +
                )
         | 
| 1037 | 
            +
             | 
| 1038 | 
            +
                # Filter method in segments
         | 
| 1039 | 
            +
                filtered_edge = filter_by_speaker(speakers_edge, all_segments)
         | 
| 1040 | 
            +
                filtered_bark = filter_by_speaker(speakers_bark, all_segments)
         | 
| 1041 | 
            +
                filtered_vits = filter_by_speaker(speakers_vits, all_segments)
         | 
| 1042 | 
            +
                filtered_coqui = filter_by_speaker(speakers_coqui, all_segments)
         | 
| 1043 | 
            +
                filtered_vits_onnx = filter_by_speaker(speakers_vits_onnx, all_segments)
         | 
| 1044 | 
            +
                filtered_openai_tts = filter_by_speaker(speakers_openai_tts, all_segments)
         | 
| 1045 | 
            +
             | 
| 1046 | 
            +
                # Infer
         | 
| 1047 | 
            +
                if filtered_edge["segments"]:
         | 
| 1048 | 
            +
                    logger.info(f"EDGE TTS: {speakers_edge}")
         | 
| 1049 | 
            +
                    segments_egde_tts(filtered_edge, TRANSLATE_AUDIO_TO, is_gui)  # mp3
         | 
| 1050 | 
            +
                if filtered_bark["segments"]:
         | 
| 1051 | 
            +
                    logger.info(f"BARK TTS: {speakers_bark}")
         | 
| 1052 | 
            +
                    segments_bark_tts(
         | 
| 1053 | 
            +
                        filtered_bark, TRANSLATE_AUDIO_TO, model_id_bark
         | 
| 1054 | 
            +
                    )  # wav
         | 
| 1055 | 
            +
                if filtered_vits["segments"]:
         | 
| 1056 | 
            +
                    logger.info(f"VITS TTS: {speakers_vits}")
         | 
| 1057 | 
            +
                    segments_vits_tts(filtered_vits, TRANSLATE_AUDIO_TO)  # wav
         | 
| 1058 | 
            +
                if filtered_coqui["segments"]:
         | 
| 1059 | 
            +
                    logger.info(f"Coqui TTS: {speakers_coqui}")
         | 
| 1060 | 
            +
                    segments_coqui_tts(
         | 
| 1061 | 
            +
                        filtered_coqui,
         | 
| 1062 | 
            +
                        TRANSLATE_AUDIO_TO,
         | 
| 1063 | 
            +
                        model_id_coqui,
         | 
| 1064 | 
            +
                        speakers_coqui,
         | 
| 1065 | 
            +
                        delete_previous_automatic,
         | 
| 1066 | 
            +
                        dereverb_automatic,
         | 
| 1067 | 
            +
                    )  # wav
         | 
| 1068 | 
            +
                if filtered_vits_onnx["segments"]:
         | 
| 1069 | 
            +
                    logger.info(f"PIPER TTS: {speakers_vits_onnx}")
         | 
| 1070 | 
            +
                    segments_vits_onnx_tts(filtered_vits_onnx, TRANSLATE_AUDIO_TO)  # wav
         | 
| 1071 | 
            +
                if filtered_openai_tts["segments"]:
         | 
| 1072 | 
            +
                    logger.info(f"OpenAI TTS: {speakers_openai_tts}")
         | 
| 1073 | 
            +
                    segments_openai_tts(filtered_openai_tts, TRANSLATE_AUDIO_TO)  # wav
         | 
| 1074 | 
            +
             | 
| 1075 | 
            +
                [result.pop("tts_name", None) for result in result_diarize["segments"]]
         | 
| 1076 | 
            +
                return [
         | 
| 1077 | 
            +
                    speakers_edge,
         | 
| 1078 | 
            +
                    speakers_bark,
         | 
| 1079 | 
            +
                    speakers_vits,
         | 
| 1080 | 
            +
                    speakers_coqui,
         | 
| 1081 | 
            +
                    speakers_vits_onnx,
         | 
| 1082 | 
            +
                    speakers_openai_tts
         | 
| 1083 | 
            +
                ]
         | 
| 1084 | 
            +
             | 
| 1085 | 
            +
             | 
| 1086 | 
            +
            def accelerate_segments(
         | 
| 1087 | 
            +
                result_diarize,
         | 
| 1088 | 
            +
                max_accelerate_audio,
         | 
| 1089 | 
            +
                valid_speakers,
         | 
| 1090 | 
            +
                acceleration_rate_regulation=False,
         | 
| 1091 | 
            +
                folder_output="audio2",
         | 
| 1092 | 
            +
            ):
         | 
| 1093 | 
            +
                logger.info("Apply acceleration")
         | 
| 1094 | 
            +
             | 
| 1095 | 
            +
                (
         | 
| 1096 | 
            +
                    speakers_edge,
         | 
| 1097 | 
            +
                    speakers_bark,
         | 
| 1098 | 
            +
                    speakers_vits,
         | 
| 1099 | 
            +
                    speakers_coqui,
         | 
| 1100 | 
            +
                    speakers_vits_onnx,
         | 
| 1101 | 
            +
                    speakers_openai_tts
         | 
| 1102 | 
            +
                ) = valid_speakers
         | 
| 1103 | 
            +
             | 
| 1104 | 
            +
                create_directories(f"{folder_output}/audio/")
         | 
| 1105 | 
            +
                remove_directory_contents(f"{folder_output}/audio/")
         | 
| 1106 | 
            +
             | 
| 1107 | 
            +
                audio_files = []
         | 
| 1108 | 
            +
                speakers_list = []
         | 
| 1109 | 
            +
             | 
| 1110 | 
            +
                max_count_segments_idx = len(result_diarize["segments"]) - 1
         | 
| 1111 | 
            +
             | 
| 1112 | 
            +
                for i, segment in tqdm(enumerate(result_diarize["segments"])):
         | 
| 1113 | 
            +
                    text = segment["text"] # noqa
         | 
| 1114 | 
            +
                    start = segment["start"]
         | 
| 1115 | 
            +
                    end = segment["end"]
         | 
| 1116 | 
            +
                    speaker = segment["speaker"]
         | 
| 1117 | 
            +
             | 
| 1118 | 
            +
                    # find name audio
         | 
| 1119 | 
            +
                    # if speaker in speakers_edge:
         | 
| 1120 | 
            +
                    filename = f"audio/{start}.ogg"
         | 
| 1121 | 
            +
                    # elif speaker in speakers_bark + speakers_vits + speakers_coqui + speakers_vits_onnx:
         | 
| 1122 | 
            +
                    #    filename = f"audio/{start}.wav" # wav
         | 
| 1123 | 
            +
             | 
| 1124 | 
            +
                    # duration
         | 
| 1125 | 
            +
                    duration_true = end - start
         | 
| 1126 | 
            +
                    duration_tts = librosa.get_duration(filename=filename)
         | 
| 1127 | 
            +
             | 
| 1128 | 
            +
                    # Accelerate percentage
         | 
| 1129 | 
            +
                    acc_percentage = duration_tts / duration_true
         | 
| 1130 | 
            +
             | 
| 1131 | 
            +
                    # Smoth
         | 
| 1132 | 
            +
                    if acceleration_rate_regulation and acc_percentage >= 1.3:
         | 
| 1133 | 
            +
                        try:
         | 
| 1134 | 
            +
                            next_segment = result_diarize["segments"][
         | 
| 1135 | 
            +
                                min(max_count_segments_idx, i + 1)
         | 
| 1136 | 
            +
                            ]
         | 
| 1137 | 
            +
                            next_start = next_segment["start"]
         | 
| 1138 | 
            +
                            next_speaker = next_segment["speaker"]
         | 
| 1139 | 
            +
                            duration_with_next_start = next_start - start
         | 
| 1140 | 
            +
             | 
| 1141 | 
            +
                            if duration_with_next_start > duration_true:
         | 
| 1142 | 
            +
                                extra_time = duration_with_next_start - duration_true
         | 
| 1143 | 
            +
             | 
| 1144 | 
            +
                                if speaker == next_speaker:
         | 
| 1145 | 
            +
                                    # half
         | 
| 1146 | 
            +
                                    smoth_duration = duration_true + (extra_time * 0.5)
         | 
| 1147 | 
            +
                                else:
         | 
| 1148 | 
            +
                                    # 7/10
         | 
| 1149 | 
            +
                                    smoth_duration = duration_true + (extra_time * 0.7)
         | 
| 1150 | 
            +
                                logger.debug(
         | 
| 1151 | 
            +
                                    f"Base acc: {acc_percentage}, "
         | 
| 1152 | 
            +
                                    f"smoth acc: {duration_tts / smoth_duration}"
         | 
| 1153 | 
            +
                                )
         | 
| 1154 | 
            +
                                acc_percentage = max(1.2, (duration_tts / smoth_duration))
         | 
| 1155 | 
            +
             | 
| 1156 | 
            +
                        except Exception as error:
         | 
| 1157 | 
            +
                            logger.error(str(error))
         | 
| 1158 | 
            +
             | 
| 1159 | 
            +
                    if acc_percentage > max_accelerate_audio:
         | 
| 1160 | 
            +
                        acc_percentage = max_accelerate_audio
         | 
| 1161 | 
            +
                    elif acc_percentage <= 1.15 and acc_percentage >= 0.8:
         | 
| 1162 | 
            +
                        acc_percentage = 1.0
         | 
| 1163 | 
            +
                    elif acc_percentage <= 0.79:
         | 
| 1164 | 
            +
                        acc_percentage = 0.8
         | 
| 1165 | 
            +
             | 
| 1166 | 
            +
                    # Round
         | 
| 1167 | 
            +
                    acc_percentage = round(acc_percentage + 0.0, 1)
         | 
| 1168 | 
            +
             | 
| 1169 | 
            +
                    # Format read if need
         | 
| 1170 | 
            +
                    if speaker in speakers_edge:
         | 
| 1171 | 
            +
                        info_enc = sf.info(filename).format
         | 
| 1172 | 
            +
                    else:
         | 
| 1173 | 
            +
                        info_enc = "OGG"
         | 
| 1174 | 
            +
             | 
| 1175 | 
            +
                    # Apply aceleration or opposite to the audio file in folder_output folder
         | 
| 1176 | 
            +
                    if acc_percentage == 1.0 and info_enc == "OGG":
         | 
| 1177 | 
            +
                        copy_files(filename, f"{folder_output}{os.sep}audio")
         | 
| 1178 | 
            +
                    else:
         | 
| 1179 | 
            +
                        os.system(
         | 
| 1180 | 
            +
                            f"ffmpeg -y -loglevel panic -i {filename} -filter:a atempo={acc_percentage} {folder_output}/{filename}"
         | 
| 1181 | 
            +
                        )
         | 
| 1182 | 
            +
             | 
| 1183 | 
            +
                    if logger.isEnabledFor(logging.DEBUG):
         | 
| 1184 | 
            +
                        duration_create = librosa.get_duration(
         | 
| 1185 | 
            +
                            filename=f"{folder_output}/{filename}"
         | 
| 1186 | 
            +
                        )
         | 
| 1187 | 
            +
                        logger.debug(
         | 
| 1188 | 
            +
                            f"acc_percen is {acc_percentage}, tts duration "
         | 
| 1189 | 
            +
                            f"is {duration_tts}, new duration is {duration_create}"
         | 
| 1190 | 
            +
                            f", for {filename}"
         | 
| 1191 | 
            +
                        )
         | 
| 1192 | 
            +
             | 
| 1193 | 
            +
                    audio_files.append(f"{folder_output}/{filename}")
         | 
| 1194 | 
            +
                    speaker = "TTS Speaker {:02d}".format(int(speaker[-2:]) + 1)
         | 
| 1195 | 
            +
                    speakers_list.append(speaker)
         | 
| 1196 | 
            +
             | 
| 1197 | 
            +
                return audio_files, speakers_list
         | 
| 1198 | 
            +
             | 
| 1199 | 
            +
             | 
| 1200 | 
            +
            # =====================================
         | 
| 1201 | 
            +
            # Tone color converter
         | 
| 1202 | 
            +
            # =====================================
         | 
| 1203 | 
            +
             | 
| 1204 | 
            +
             | 
| 1205 | 
            +
            def se_process_audio_segments(
         | 
| 1206 | 
            +
                source_seg, tone_color_converter, device, remove_previous_processed=True
         | 
| 1207 | 
            +
            ):
         | 
| 1208 | 
            +
                # list wav seg
         | 
| 1209 | 
            +
                source_audio_segs = glob.glob(f"{source_seg}/*.wav")
         | 
| 1210 | 
            +
                if not source_audio_segs:
         | 
| 1211 | 
            +
                    raise ValueError(
         | 
| 1212 | 
            +
                        f"No audio segments found in {str(source_audio_segs)}"
         | 
| 1213 | 
            +
                    )
         | 
| 1214 | 
            +
             | 
| 1215 | 
            +
                source_se_path = os.path.join(source_seg, "se.pth")
         | 
| 1216 | 
            +
             | 
| 1217 | 
            +
                # if exist not create wav
         | 
| 1218 | 
            +
                if os.path.isfile(source_se_path):
         | 
| 1219 | 
            +
                    se = torch.load(source_se_path).to(device)
         | 
| 1220 | 
            +
                    logger.debug(f"Previous created {source_se_path}")
         | 
| 1221 | 
            +
                else:
         | 
| 1222 | 
            +
                    se = tone_color_converter.extract_se(source_audio_segs, source_se_path)
         | 
| 1223 | 
            +
             | 
| 1224 | 
            +
                return se
         | 
| 1225 | 
            +
             | 
| 1226 | 
            +
             | 
| 1227 | 
            +
            def create_wav_vc(
         | 
| 1228 | 
            +
                valid_speakers,
         | 
| 1229 | 
            +
                segments_base,
         | 
| 1230 | 
            +
                audio_name,
         | 
| 1231 | 
            +
                max_segments=10,
         | 
| 1232 | 
            +
                target_dir="processed",
         | 
| 1233 | 
            +
                get_vocals_dereverb=False,
         | 
| 1234 | 
            +
            ):
         | 
| 1235 | 
            +
                # valid_speakers = list({item['speaker'] for item in segments_base})
         | 
| 1236 | 
            +
             | 
| 1237 | 
            +
                # Before function delete automatic delete_previous_automatic
         | 
| 1238 | 
            +
                output_dir = os.path.join(".", target_dir)  # remove content
         | 
| 1239 | 
            +
                # remove_directory_contents(output_dir)
         | 
| 1240 | 
            +
             | 
| 1241 | 
            +
                path_source_segments = []
         | 
| 1242 | 
            +
                path_target_segments = []
         | 
| 1243 | 
            +
                for speaker in valid_speakers:
         | 
| 1244 | 
            +
                    filtered_speaker = [
         | 
| 1245 | 
            +
                        segment
         | 
| 1246 | 
            +
                        for segment in segments_base
         | 
| 1247 | 
            +
                        if segment["speaker"] == speaker
         | 
| 1248 | 
            +
                    ]
         | 
| 1249 | 
            +
                    if len(filtered_speaker) > 4:
         | 
| 1250 | 
            +
                        filtered_speaker = filtered_speaker[1:]
         | 
| 1251 | 
            +
             | 
| 1252 | 
            +
                    dir_name_speaker = speaker + audio_name
         | 
| 1253 | 
            +
                    dir_name_speaker_tts = "tts" + speaker + audio_name
         | 
| 1254 | 
            +
                    dir_path_speaker = os.path.join(output_dir, dir_name_speaker)
         | 
| 1255 | 
            +
                    dir_path_speaker_tts = os.path.join(output_dir, dir_name_speaker_tts)
         | 
| 1256 | 
            +
                    create_directories([dir_path_speaker, dir_path_speaker_tts])
         | 
| 1257 | 
            +
             | 
| 1258 | 
            +
                    path_target_segments.append(dir_path_speaker)
         | 
| 1259 | 
            +
                    path_source_segments.append(dir_path_speaker_tts)
         | 
| 1260 | 
            +
             | 
| 1261 | 
            +
                    # create wav
         | 
| 1262 | 
            +
                    max_segments_count = 0
         | 
| 1263 | 
            +
                    for seg in filtered_speaker:
         | 
| 1264 | 
            +
                        duration = float(seg["end"]) - float(seg["start"])
         | 
| 1265 | 
            +
                        if duration > 3.0 and duration < 18.0:
         | 
| 1266 | 
            +
                            logger.info(
         | 
| 1267 | 
            +
                                f'Processing segment: {seg["start"]}, {seg["end"]}, {seg["speaker"]}, {duration}, {seg["text"]}'
         | 
| 1268 | 
            +
                            )
         | 
| 1269 | 
            +
                            name_new_wav = str(seg["start"])
         | 
| 1270 | 
            +
             | 
| 1271 | 
            +
                            check_segment_audio_target_file = os.path.join(
         | 
| 1272 | 
            +
                                dir_path_speaker, f"{name_new_wav}.wav"
         | 
| 1273 | 
            +
                            )
         | 
| 1274 | 
            +
             | 
| 1275 | 
            +
                            if os.path.exists(check_segment_audio_target_file):
         | 
| 1276 | 
            +
                                logger.debug(
         | 
| 1277 | 
            +
                                    "Segment vc source exists: "
         | 
| 1278 | 
            +
                                    f"{check_segment_audio_target_file}"
         | 
| 1279 | 
            +
                                )
         | 
| 1280 | 
            +
                                pass
         | 
| 1281 | 
            +
                            else:
         | 
| 1282 | 
            +
                                create_wav_file_vc(
         | 
| 1283 | 
            +
                                    sample_name=name_new_wav,
         | 
| 1284 | 
            +
                                    audio_wav="audio.wav",
         | 
| 1285 | 
            +
                                    start=(float(seg["start"]) + 1.0),
         | 
| 1286 | 
            +
                                    end=(float(seg["end"]) - 1.0),
         | 
| 1287 | 
            +
                                    output_final_path=dir_path_speaker,
         | 
| 1288 | 
            +
                                    get_vocals_dereverb=get_vocals_dereverb,
         | 
| 1289 | 
            +
                                )
         | 
| 1290 | 
            +
             | 
| 1291 | 
            +
                                file_name_tts = f"audio2/audio/{str(seg['start'])}.ogg"
         | 
| 1292 | 
            +
                                # copy_files(file_name_tts, os.path.join(output_dir, dir_name_speaker_tts)
         | 
| 1293 | 
            +
                                convert_to_xtts_good_sample(
         | 
| 1294 | 
            +
                                    file_name_tts, dir_path_speaker_tts
         | 
| 1295 | 
            +
                                )
         | 
| 1296 | 
            +
             | 
| 1297 | 
            +
                            max_segments_count += 1
         | 
| 1298 | 
            +
                            if max_segments_count == max_segments:
         | 
| 1299 | 
            +
                                break
         | 
| 1300 | 
            +
             | 
| 1301 | 
            +
                    if max_segments_count == 0:
         | 
| 1302 | 
            +
                        logger.info("Taking the first segment")
         | 
| 1303 | 
            +
                        seg = filtered_speaker[0]
         | 
| 1304 | 
            +
                        logger.info(
         | 
| 1305 | 
            +
                            f'Processing segment: {seg["start"]}, {seg["end"]}, {seg["speaker"]}, {seg["text"]}'
         | 
| 1306 | 
            +
                        )
         | 
| 1307 | 
            +
                        max_duration = float(seg["end"]) - float(seg["start"])
         | 
| 1308 | 
            +
                        max_duration = max(1.0, min(max_duration, 18.0))
         | 
| 1309 | 
            +
             | 
| 1310 | 
            +
                        name_new_wav = str(seg["start"])
         | 
| 1311 | 
            +
                        create_wav_file_vc(
         | 
| 1312 | 
            +
                            sample_name=name_new_wav,
         | 
| 1313 | 
            +
                            audio_wav="audio.wav",
         | 
| 1314 | 
            +
                            start=(float(seg["start"])),
         | 
| 1315 | 
            +
                            end=(float(seg["start"]) + max_duration),
         | 
| 1316 | 
            +
                            output_final_path=dir_path_speaker,
         | 
| 1317 | 
            +
                            get_vocals_dereverb=get_vocals_dereverb,
         | 
| 1318 | 
            +
                        )
         | 
| 1319 | 
            +
             | 
| 1320 | 
            +
                        file_name_tts = f"audio2/audio/{str(seg['start'])}.ogg"
         | 
| 1321 | 
            +
                        # copy_files(file_name_tts, os.path.join(output_dir, dir_name_speaker_tts)
         | 
| 1322 | 
            +
                        convert_to_xtts_good_sample(file_name_tts, dir_path_speaker_tts)
         | 
| 1323 | 
            +
             | 
| 1324 | 
            +
                logger.debug(f"Base: {str(path_source_segments)}")
         | 
| 1325 | 
            +
                logger.debug(f"Target: {str(path_target_segments)}")
         | 
| 1326 | 
            +
             | 
| 1327 | 
            +
                return path_source_segments, path_target_segments
         | 
| 1328 | 
            +
             | 
| 1329 | 
            +
             | 
| 1330 | 
            +
            def toneconverter_openvoice(
         | 
| 1331 | 
            +
                result_diarize,
         | 
| 1332 | 
            +
                preprocessor_max_segments,
         | 
| 1333 | 
            +
                remove_previous_process=True,
         | 
| 1334 | 
            +
                get_vocals_dereverb=False,
         | 
| 1335 | 
            +
                model="openvoice",
         | 
| 1336 | 
            +
            ):
         | 
| 1337 | 
            +
                audio_path = "audio.wav"
         | 
| 1338 | 
            +
                # se_path = "se.pth"
         | 
| 1339 | 
            +
                target_dir = "processed"
         | 
| 1340 | 
            +
                create_directories(target_dir)
         | 
| 1341 | 
            +
             | 
| 1342 | 
            +
                from openvoice import se_extractor
         | 
| 1343 | 
            +
                from openvoice.api import ToneColorConverter
         | 
| 1344 | 
            +
             | 
| 1345 | 
            +
                audio_name = f"{os.path.basename(audio_path).rsplit('.', 1)[0]}_{se_extractor.hash_numpy_array(audio_path)}"
         | 
| 1346 | 
            +
                # se_path = os.path.join(target_dir, audio_name, 'se.pth')
         | 
| 1347 | 
            +
             | 
| 1348 | 
            +
                # create wav seg original and target
         | 
| 1349 | 
            +
             | 
| 1350 | 
            +
                valid_speakers = list(
         | 
| 1351 | 
            +
                    {item["speaker"] for item in result_diarize["segments"]}
         | 
| 1352 | 
            +
                )
         | 
| 1353 | 
            +
             | 
| 1354 | 
            +
                logger.info("Openvoice preprocessor...")
         | 
| 1355 | 
            +
             | 
| 1356 | 
            +
                if remove_previous_process:
         | 
| 1357 | 
            +
                    remove_directory_contents(target_dir)
         | 
| 1358 | 
            +
             | 
| 1359 | 
            +
                path_source_segments, path_target_segments = create_wav_vc(
         | 
| 1360 | 
            +
                    valid_speakers,
         | 
| 1361 | 
            +
                    result_diarize["segments"],
         | 
| 1362 | 
            +
                    audio_name,
         | 
| 1363 | 
            +
                    max_segments=preprocessor_max_segments,
         | 
| 1364 | 
            +
                    get_vocals_dereverb=get_vocals_dereverb,
         | 
| 1365 | 
            +
                )
         | 
| 1366 | 
            +
             | 
| 1367 | 
            +
                logger.info("Openvoice loading model...")
         | 
| 1368 | 
            +
                model_path_openvoice = "./OPENVOICE_MODELS"
         | 
| 1369 | 
            +
                url_model_openvoice = "https://huggingface.co/myshell-ai/OpenVoice/resolve/main/checkpoints/converter"
         | 
| 1370 | 
            +
             | 
| 1371 | 
            +
                if "v2" in model:
         | 
| 1372 | 
            +
                    model_path = os.path.join(model_path_openvoice, "v2")
         | 
| 1373 | 
            +
                    url_model_openvoice = url_model_openvoice.replace(
         | 
| 1374 | 
            +
                        "OpenVoice", "OpenVoiceV2"
         | 
| 1375 | 
            +
                    ).replace("checkpoints/", "")
         | 
| 1376 | 
            +
                else:
         | 
| 1377 | 
            +
                    model_path = os.path.join(model_path_openvoice, "v1")
         | 
| 1378 | 
            +
                create_directories(model_path)
         | 
| 1379 | 
            +
             | 
| 1380 | 
            +
                config_url = f"{url_model_openvoice}/config.json"
         | 
| 1381 | 
            +
                checkpoint_url = f"{url_model_openvoice}/checkpoint.pth"
         | 
| 1382 | 
            +
             | 
| 1383 | 
            +
                config_path = download_manager(url=config_url, path=model_path)
         | 
| 1384 | 
            +
                checkpoint_path = download_manager(
         | 
| 1385 | 
            +
                    url=checkpoint_url, path=model_path
         | 
| 1386 | 
            +
                )
         | 
| 1387 | 
            +
             | 
| 1388 | 
            +
                device = os.environ.get("SONITR_DEVICE")
         | 
| 1389 | 
            +
                tone_color_converter = ToneColorConverter(config_path, device=device)
         | 
| 1390 | 
            +
                tone_color_converter.load_ckpt(checkpoint_path)
         | 
| 1391 | 
            +
             | 
| 1392 | 
            +
                logger.info("Openvoice tone color converter:")
         | 
| 1393 | 
            +
                global_progress_bar = tqdm(total=len(result_diarize["segments"]), desc="Progress")
         | 
| 1394 | 
            +
             | 
| 1395 | 
            +
                for source_seg, target_seg, speaker in zip(
         | 
| 1396 | 
            +
                    path_source_segments, path_target_segments, valid_speakers
         | 
| 1397 | 
            +
                ):
         | 
| 1398 | 
            +
                    # source_se_path = os.path.join(source_seg, 'se.pth')
         | 
| 1399 | 
            +
                    source_se = se_process_audio_segments(source_seg, tone_color_converter, device)
         | 
| 1400 | 
            +
                    # target_se_path = os.path.join(target_seg, 'se.pth')
         | 
| 1401 | 
            +
                    target_se = se_process_audio_segments(target_seg, tone_color_converter, device)
         | 
| 1402 | 
            +
             | 
| 1403 | 
            +
                    # Iterate throw segments
         | 
| 1404 | 
            +
                    encode_message = "@MyShell"
         | 
| 1405 | 
            +
                    filtered_speaker = [
         | 
| 1406 | 
            +
                        segment
         | 
| 1407 | 
            +
                        for segment in result_diarize["segments"]
         | 
| 1408 | 
            +
                        if segment["speaker"] == speaker
         | 
| 1409 | 
            +
                    ]
         | 
| 1410 | 
            +
                    for seg in filtered_speaker:
         | 
| 1411 | 
            +
                        src_path = (
         | 
| 1412 | 
            +
                            save_path
         | 
| 1413 | 
            +
                        ) = f"audio2/audio/{str(seg['start'])}.ogg"  # overwrite
         | 
| 1414 | 
            +
                        logger.debug(f"{src_path}")
         | 
| 1415 | 
            +
             | 
| 1416 | 
            +
                        tone_color_converter.convert(
         | 
| 1417 | 
            +
                            audio_src_path=src_path,
         | 
| 1418 | 
            +
                            src_se=source_se,
         | 
| 1419 | 
            +
                            tgt_se=target_se,
         | 
| 1420 | 
            +
                            output_path=save_path,
         | 
| 1421 | 
            +
                            message=encode_message,
         | 
| 1422 | 
            +
                        )
         | 
| 1423 | 
            +
             | 
| 1424 | 
            +
                        global_progress_bar.update(1)
         | 
| 1425 | 
            +
             | 
| 1426 | 
            +
                global_progress_bar.close()
         | 
| 1427 | 
            +
             | 
| 1428 | 
            +
                try:
         | 
| 1429 | 
            +
                    del tone_color_converter
         | 
| 1430 | 
            +
                    gc.collect()
         | 
| 1431 | 
            +
                    torch.cuda.empty_cache()
         | 
| 1432 | 
            +
                except Exception as error:
         | 
| 1433 | 
            +
                    logger.error(str(error))
         | 
| 1434 | 
            +
                    gc.collect()
         | 
| 1435 | 
            +
                    torch.cuda.empty_cache()
         | 
| 1436 | 
            +
             | 
| 1437 | 
            +
             | 
| 1438 | 
            +
            def toneconverter_freevc(
         | 
| 1439 | 
            +
                result_diarize,
         | 
| 1440 | 
            +
                remove_previous_process=True,
         | 
| 1441 | 
            +
                get_vocals_dereverb=False,
         | 
| 1442 | 
            +
            ):
         | 
| 1443 | 
            +
                audio_path = "audio.wav"
         | 
| 1444 | 
            +
                target_dir = "processed"
         | 
| 1445 | 
            +
                create_directories(target_dir)
         | 
| 1446 | 
            +
             | 
| 1447 | 
            +
                from openvoice import se_extractor
         | 
| 1448 | 
            +
             | 
| 1449 | 
            +
                audio_name = f"{os.path.basename(audio_path).rsplit('.', 1)[0]}_{se_extractor.hash_numpy_array(audio_path)}"
         | 
| 1450 | 
            +
             | 
| 1451 | 
            +
                # create wav seg; original is target and dubbing is source
         | 
| 1452 | 
            +
                valid_speakers = list(
         | 
| 1453 | 
            +
                    {item["speaker"] for item in result_diarize["segments"]}
         | 
| 1454 | 
            +
                )
         | 
| 1455 | 
            +
             | 
| 1456 | 
            +
                logger.info("FreeVC preprocessor...")
         | 
| 1457 | 
            +
             | 
| 1458 | 
            +
                if remove_previous_process:
         | 
| 1459 | 
            +
                    remove_directory_contents(target_dir)
         | 
| 1460 | 
            +
             | 
| 1461 | 
            +
                path_source_segments, path_target_segments = create_wav_vc(
         | 
| 1462 | 
            +
                    valid_speakers,
         | 
| 1463 | 
            +
                    result_diarize["segments"],
         | 
| 1464 | 
            +
                    audio_name,
         | 
| 1465 | 
            +
                    max_segments=1,
         | 
| 1466 | 
            +
                    get_vocals_dereverb=get_vocals_dereverb,
         | 
| 1467 | 
            +
                )
         | 
| 1468 | 
            +
             | 
| 1469 | 
            +
                logger.info("FreeVC loading model...")
         | 
| 1470 | 
            +
                device_id = os.environ.get("SONITR_DEVICE")
         | 
| 1471 | 
            +
                device = None if device_id == "cpu" else device_id
         | 
| 1472 | 
            +
                try:
         | 
| 1473 | 
            +
                    from TTS.api import TTS
         | 
| 1474 | 
            +
                    tts = TTS(
         | 
| 1475 | 
            +
                        model_name="voice_conversion_models/multilingual/vctk/freevc24",
         | 
| 1476 | 
            +
                        progress_bar=False
         | 
| 1477 | 
            +
                    ).to(device)
         | 
| 1478 | 
            +
                except Exception as error:
         | 
| 1479 | 
            +
                    logger.error(str(error))
         | 
| 1480 | 
            +
                    logger.error("Error loading the FreeVC model.")
         | 
| 1481 | 
            +
                    return
         | 
| 1482 | 
            +
             | 
| 1483 | 
            +
                logger.info("FreeVC process:")
         | 
| 1484 | 
            +
                global_progress_bar = tqdm(total=len(result_diarize["segments"]), desc="Progress")
         | 
| 1485 | 
            +
             | 
| 1486 | 
            +
                for source_seg, target_seg, speaker in zip(
         | 
| 1487 | 
            +
                    path_source_segments, path_target_segments, valid_speakers
         | 
| 1488 | 
            +
                ):
         | 
| 1489 | 
            +
             | 
| 1490 | 
            +
                    filtered_speaker = [
         | 
| 1491 | 
            +
                        segment
         | 
| 1492 | 
            +
                        for segment in result_diarize["segments"]
         | 
| 1493 | 
            +
                        if segment["speaker"] == speaker
         | 
| 1494 | 
            +
                    ]
         | 
| 1495 | 
            +
             | 
| 1496 | 
            +
                    files_and_directories = os.listdir(target_seg)
         | 
| 1497 | 
            +
                    wav_files = [file for file in files_and_directories if file.endswith(".wav")]
         | 
| 1498 | 
            +
                    original_wav_audio_segment = os.path.join(target_seg, wav_files[0])
         | 
| 1499 | 
            +
             | 
| 1500 | 
            +
                    for seg in filtered_speaker:
         | 
| 1501 | 
            +
             | 
| 1502 | 
            +
                        src_path = (
         | 
| 1503 | 
            +
                              save_path
         | 
| 1504 | 
            +
                          ) = f"audio2/audio/{str(seg['start'])}.ogg"  # overwrite
         | 
| 1505 | 
            +
                        logger.debug(f"{src_path} - {original_wav_audio_segment}")
         | 
| 1506 | 
            +
             | 
| 1507 | 
            +
                        wav = tts.voice_conversion(
         | 
| 1508 | 
            +
                            source_wav=src_path,
         | 
| 1509 | 
            +
                            target_wav=original_wav_audio_segment,
         | 
| 1510 | 
            +
                        )
         | 
| 1511 | 
            +
             | 
| 1512 | 
            +
                        sf.write(
         | 
| 1513 | 
            +
                            file=save_path,
         | 
| 1514 | 
            +
                            samplerate=tts.voice_converter.vc_config.audio.output_sample_rate,
         | 
| 1515 | 
            +
                            data=wav,
         | 
| 1516 | 
            +
                            format="ogg",
         | 
| 1517 | 
            +
                            subtype="vorbis",
         | 
| 1518 | 
            +
                        )
         | 
| 1519 | 
            +
             | 
| 1520 | 
            +
                        global_progress_bar.update(1)
         | 
| 1521 | 
            +
             | 
| 1522 | 
            +
                global_progress_bar.close()
         | 
| 1523 | 
            +
             | 
| 1524 | 
            +
                try:
         | 
| 1525 | 
            +
                    del tts
         | 
| 1526 | 
            +
                    gc.collect()
         | 
| 1527 | 
            +
                    torch.cuda.empty_cache()
         | 
| 1528 | 
            +
                except Exception as error:
         | 
| 1529 | 
            +
                    logger.error(str(error))
         | 
| 1530 | 
            +
                    gc.collect()
         | 
| 1531 | 
            +
                    torch.cuda.empty_cache()
         | 
| 1532 | 
            +
             | 
| 1533 | 
            +
             | 
| 1534 | 
            +
            def toneconverter(
         | 
| 1535 | 
            +
                result_diarize,
         | 
| 1536 | 
            +
                preprocessor_max_segments,
         | 
| 1537 | 
            +
                remove_previous_process=True,
         | 
| 1538 | 
            +
                get_vocals_dereverb=False,
         | 
| 1539 | 
            +
                method_vc="freevc"
         | 
| 1540 | 
            +
            ):
         | 
| 1541 | 
            +
             | 
| 1542 | 
            +
                if method_vc == "freevc":
         | 
| 1543 | 
            +
                    if preprocessor_max_segments > 1:
         | 
| 1544 | 
            +
                        logger.info("FreeVC only uses one segment.")
         | 
| 1545 | 
            +
                    return toneconverter_freevc(
         | 
| 1546 | 
            +
                                result_diarize,
         | 
| 1547 | 
            +
                                remove_previous_process=remove_previous_process,
         | 
| 1548 | 
            +
                                get_vocals_dereverb=get_vocals_dereverb,
         | 
| 1549 | 
            +
                            )
         | 
| 1550 | 
            +
                elif "openvoice" in method_vc:
         | 
| 1551 | 
            +
                    return toneconverter_openvoice(
         | 
| 1552 | 
            +
                                result_diarize,
         | 
| 1553 | 
            +
                                preprocessor_max_segments,
         | 
| 1554 | 
            +
                                remove_previous_process=remove_previous_process,
         | 
| 1555 | 
            +
                                get_vocals_dereverb=get_vocals_dereverb,
         | 
| 1556 | 
            +
                                model=method_vc,
         | 
| 1557 | 
            +
                            )
         | 
| 1558 | 
            +
             | 
| 1559 | 
            +
             | 
| 1560 | 
            +
            if __name__ == "__main__":
         | 
| 1561 | 
            +
                from segments import result_diarize
         | 
| 1562 | 
            +
             | 
| 1563 | 
            +
                audio_segmentation_to_voice(
         | 
| 1564 | 
            +
                    result_diarize,
         | 
| 1565 | 
            +
                    TRANSLATE_AUDIO_TO="en",
         | 
| 1566 | 
            +
                    max_accelerate_audio=2.1,
         | 
| 1567 | 
            +
                    is_gui=True,
         | 
| 1568 | 
            +
                    tts_voice00="en-facebook-mms VITS",
         | 
| 1569 | 
            +
                    tts_voice01="en-CA-ClaraNeural-Female",
         | 
| 1570 | 
            +
                    tts_voice02="en-GB-ThomasNeural-Male",
         | 
| 1571 | 
            +
                    tts_voice03="en-GB-SoniaNeural-Female",
         | 
| 1572 | 
            +
                    tts_voice04="en-NZ-MitchellNeural-Male",
         | 
| 1573 | 
            +
                    tts_voice05="en-GB-MaisieNeural-Female",
         | 
| 1574 | 
            +
                )
         | 
    	
        soni_translate/translate_segments.py
    ADDED
    
    | @@ -0,0 +1,457 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from tqdm import tqdm
         | 
| 2 | 
            +
            from deep_translator import GoogleTranslator
         | 
| 3 | 
            +
            from itertools import chain
         | 
| 4 | 
            +
            import copy
         | 
| 5 | 
            +
            from .language_configuration import fix_code_language, INVERTED_LANGUAGES
         | 
| 6 | 
            +
            from .logging_setup import logger
         | 
| 7 | 
            +
            import re
         | 
| 8 | 
            +
            import json
         | 
| 9 | 
            +
            import time
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            TRANSLATION_PROCESS_OPTIONS = [
         | 
| 12 | 
            +
                "google_translator_batch",
         | 
| 13 | 
            +
                "google_translator",
         | 
| 14 | 
            +
                "gpt-3.5-turbo-0125_batch",
         | 
| 15 | 
            +
                "gpt-3.5-turbo-0125",
         | 
| 16 | 
            +
                "gpt-4-turbo-preview_batch",
         | 
| 17 | 
            +
                "gpt-4-turbo-preview",
         | 
| 18 | 
            +
                "disable_translation",
         | 
| 19 | 
            +
            ]
         | 
| 20 | 
            +
            DOCS_TRANSLATION_PROCESS_OPTIONS = [
         | 
| 21 | 
            +
                "google_translator",
         | 
| 22 | 
            +
                "gpt-3.5-turbo-0125",
         | 
| 23 | 
            +
                "gpt-4-turbo-preview",
         | 
| 24 | 
            +
                "disable_translation",
         | 
| 25 | 
            +
            ]
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            def translate_iterative(segments, target, source=None):
         | 
| 29 | 
            +
                """
         | 
| 30 | 
            +
                Translate text segments individually to the specified language.
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                Parameters:
         | 
| 33 | 
            +
                - segments (list): A list of dictionaries with 'text' as a key for
         | 
| 34 | 
            +
                    segment text.
         | 
| 35 | 
            +
                - target (str): Target language code.
         | 
| 36 | 
            +
                - source (str, optional): Source language code. Defaults to None.
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                Returns:
         | 
| 39 | 
            +
                - list: Translated text segments in the target language.
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                Notes:
         | 
| 42 | 
            +
                - Translates each segment using Google Translate.
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                Example:
         | 
| 45 | 
            +
                segments = [{'text': 'first segment.'}, {'text': 'second segment.'}]
         | 
| 46 | 
            +
                translated_segments = translate_iterative(segments, 'es')
         | 
| 47 | 
            +
                """
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                segments_ = copy.deepcopy(segments)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                if (
         | 
| 52 | 
            +
                    not source
         | 
| 53 | 
            +
                ):
         | 
| 54 | 
            +
                    logger.debug("No source language")
         | 
| 55 | 
            +
                    source = "auto"
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                translator = GoogleTranslator(source=source, target=target)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                for line in tqdm(range(len(segments_))):
         | 
| 60 | 
            +
                    text = segments_[line]["text"]
         | 
| 61 | 
            +
                    translated_line = translator.translate(text.strip())
         | 
| 62 | 
            +
                    segments_[line]["text"] = translated_line
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                return segments_
         | 
| 65 | 
            +
             | 
| 66 | 
            +
             | 
| 67 | 
            +
            def verify_translate(
         | 
| 68 | 
            +
                segments,
         | 
| 69 | 
            +
                segments_copy,
         | 
| 70 | 
            +
                translated_lines,
         | 
| 71 | 
            +
                target,
         | 
| 72 | 
            +
                source
         | 
| 73 | 
            +
            ):
         | 
| 74 | 
            +
                """
         | 
| 75 | 
            +
                Verify integrity and translate segments if lengths match, otherwise
         | 
| 76 | 
            +
                switch to iterative translation.
         | 
| 77 | 
            +
                """
         | 
| 78 | 
            +
                if len(segments) == len(translated_lines):
         | 
| 79 | 
            +
                    for line in range(len(segments_copy)):
         | 
| 80 | 
            +
                        logger.debug(
         | 
| 81 | 
            +
                            f"{segments_copy[line]['text']} >> "
         | 
| 82 | 
            +
                            f"{translated_lines[line].strip()}"
         | 
| 83 | 
            +
                        )
         | 
| 84 | 
            +
                        segments_copy[line]["text"] = translated_lines[
         | 
| 85 | 
            +
                            line].replace("\t", "").replace("\n", "").strip()
         | 
| 86 | 
            +
                    return segments_copy
         | 
| 87 | 
            +
                else:
         | 
| 88 | 
            +
                    logger.error(
         | 
| 89 | 
            +
                        "The translation failed, switching to google_translate iterative. "
         | 
| 90 | 
            +
                        f"{len(segments), len(translated_lines)}"
         | 
| 91 | 
            +
                    )
         | 
| 92 | 
            +
                    return translate_iterative(segments, target, source)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
             | 
| 95 | 
            +
            def translate_batch(segments, target, chunk_size=2000, source=None):
         | 
| 96 | 
            +
                """
         | 
| 97 | 
            +
                Translate a batch of text segments into the specified language in chunks,
         | 
| 98 | 
            +
                    respecting the character limit.
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                Parameters:
         | 
| 101 | 
            +
                - segments (list): List of dictionaries with 'text' as a key for segment
         | 
| 102 | 
            +
                    text.
         | 
| 103 | 
            +
                - target (str): Target language code.
         | 
| 104 | 
            +
                - chunk_size (int, optional): Maximum character limit for each translation
         | 
| 105 | 
            +
                    chunk (default is 2000; max 5000).
         | 
| 106 | 
            +
                - source (str, optional): Source language code. Defaults to None.
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                Returns:
         | 
| 109 | 
            +
                - list: Translated text segments in the target language.
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                Notes:
         | 
| 112 | 
            +
                - Splits input segments into chunks respecting the character limit for
         | 
| 113 | 
            +
                    translation.
         | 
| 114 | 
            +
                - Translates the chunks using Google Translate.
         | 
| 115 | 
            +
                - If chunked translation fails, switches to iterative translation using
         | 
| 116 | 
            +
                    `translate_iterative()`.
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                Example:
         | 
| 119 | 
            +
                segments = [{'text': 'first segment.'}, {'text': 'second segment.'}]
         | 
| 120 | 
            +
                translated = translate_batch(segments, 'es', chunk_size=4000, source='en')
         | 
| 121 | 
            +
                """
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                segments_copy = copy.deepcopy(segments)
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                if (
         | 
| 126 | 
            +
                    not source
         | 
| 127 | 
            +
                ):
         | 
| 128 | 
            +
                    logger.debug("No source language")
         | 
| 129 | 
            +
                    source = "auto"
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                # Get text
         | 
| 132 | 
            +
                text_lines = []
         | 
| 133 | 
            +
                for line in range(len(segments_copy)):
         | 
| 134 | 
            +
                    text = segments_copy[line]["text"].strip()
         | 
| 135 | 
            +
                    text_lines.append(text)
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                # chunk limit
         | 
| 138 | 
            +
                text_merge = []
         | 
| 139 | 
            +
                actual_chunk = ""
         | 
| 140 | 
            +
                global_text_list = []
         | 
| 141 | 
            +
                actual_text_list = []
         | 
| 142 | 
            +
                for one_line in text_lines:
         | 
| 143 | 
            +
                    one_line = " " if not one_line else one_line
         | 
| 144 | 
            +
                    if (len(actual_chunk) + len(one_line)) <= chunk_size:
         | 
| 145 | 
            +
                        if actual_chunk:
         | 
| 146 | 
            +
                            actual_chunk += " ||||| "
         | 
| 147 | 
            +
                        actual_chunk += one_line
         | 
| 148 | 
            +
                        actual_text_list.append(one_line)
         | 
| 149 | 
            +
                    else:
         | 
| 150 | 
            +
                        text_merge.append(actual_chunk)
         | 
| 151 | 
            +
                        actual_chunk = one_line
         | 
| 152 | 
            +
                        global_text_list.append(actual_text_list)
         | 
| 153 | 
            +
                        actual_text_list = [one_line]
         | 
| 154 | 
            +
                if actual_chunk:
         | 
| 155 | 
            +
                    text_merge.append(actual_chunk)
         | 
| 156 | 
            +
                    global_text_list.append(actual_text_list)
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                # translate chunks
         | 
| 159 | 
            +
                progress_bar = tqdm(total=len(segments), desc="Translating")
         | 
| 160 | 
            +
                translator = GoogleTranslator(source=source, target=target)
         | 
| 161 | 
            +
                split_list = []
         | 
| 162 | 
            +
                try:
         | 
| 163 | 
            +
                    for text, text_iterable in zip(text_merge, global_text_list):
         | 
| 164 | 
            +
                        translated_line = translator.translate(text.strip())
         | 
| 165 | 
            +
                        split_text = translated_line.split("|||||")
         | 
| 166 | 
            +
                        if len(split_text) == len(text_iterable):
         | 
| 167 | 
            +
                            progress_bar.update(len(split_text))
         | 
| 168 | 
            +
                        else:
         | 
| 169 | 
            +
                            logger.debug(
         | 
| 170 | 
            +
                                "Chunk fixing iteratively. Len chunk: "
         | 
| 171 | 
            +
                                f"{len(split_text)}, expected: {len(text_iterable)}"
         | 
| 172 | 
            +
                            )
         | 
| 173 | 
            +
                            split_text = []
         | 
| 174 | 
            +
                            for txt_iter in text_iterable:
         | 
| 175 | 
            +
                                translated_txt = translator.translate(txt_iter.strip())
         | 
| 176 | 
            +
                                split_text.append(translated_txt)
         | 
| 177 | 
            +
                                progress_bar.update(1)
         | 
| 178 | 
            +
                        split_list.append(split_text)
         | 
| 179 | 
            +
                    progress_bar.close()
         | 
| 180 | 
            +
                except Exception as error:
         | 
| 181 | 
            +
                    progress_bar.close()
         | 
| 182 | 
            +
                    logger.error(str(error))
         | 
| 183 | 
            +
                    logger.warning(
         | 
| 184 | 
            +
                        "The translation in chunks failed, switching to iterative."
         | 
| 185 | 
            +
                        " Related: too many request"
         | 
| 186 | 
            +
                    )  # use proxy or less chunk size
         | 
| 187 | 
            +
                    return translate_iterative(segments, target, source)
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                # un chunk
         | 
| 190 | 
            +
                translated_lines = list(chain.from_iterable(split_list))
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                return verify_translate(
         | 
| 193 | 
            +
                    segments, segments_copy, translated_lines, target, source
         | 
| 194 | 
            +
                )
         | 
| 195 | 
            +
             | 
| 196 | 
            +
             | 
| 197 | 
            +
            def call_gpt_translate(
         | 
| 198 | 
            +
                client,
         | 
| 199 | 
            +
                model,
         | 
| 200 | 
            +
                system_prompt,
         | 
| 201 | 
            +
                user_prompt,
         | 
| 202 | 
            +
                original_text=None,
         | 
| 203 | 
            +
                batch_lines=None,
         | 
| 204 | 
            +
            ):
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                # https://platform.openai.com/docs/guides/text-generation/json-mode
         | 
| 207 | 
            +
                response = client.chat.completions.create(
         | 
| 208 | 
            +
                    model=model,
         | 
| 209 | 
            +
                    response_format={"type": "json_object"},
         | 
| 210 | 
            +
                    messages=[
         | 
| 211 | 
            +
                      {"role": "system", "content": system_prompt},
         | 
| 212 | 
            +
                      {"role": "user", "content": user_prompt}
         | 
| 213 | 
            +
                    ]
         | 
| 214 | 
            +
                )
         | 
| 215 | 
            +
                result = response.choices[0].message.content
         | 
| 216 | 
            +
                logger.debug(f"Result: {str(result)}")
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                try:
         | 
| 219 | 
            +
                    translation = json.loads(result)
         | 
| 220 | 
            +
                except Exception as error:
         | 
| 221 | 
            +
                    match_result = re.search(r'\{.*?\}', result)
         | 
| 222 | 
            +
                    if match_result:
         | 
| 223 | 
            +
                        logger.error(str(error))
         | 
| 224 | 
            +
                        json_str = match_result.group(0)
         | 
| 225 | 
            +
                        translation = json.loads(json_str)
         | 
| 226 | 
            +
                    else:
         | 
| 227 | 
            +
                        raise error
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                # Get valid data
         | 
| 230 | 
            +
                if batch_lines:
         | 
| 231 | 
            +
                    for conversation in translation.values():
         | 
| 232 | 
            +
                        if isinstance(conversation, dict):
         | 
| 233 | 
            +
                            conversation = list(conversation.values())[0]
         | 
| 234 | 
            +
                        if (
         | 
| 235 | 
            +
                            list(
         | 
| 236 | 
            +
                                original_text["conversation"][0].values()
         | 
| 237 | 
            +
                            )[0].strip() ==
         | 
| 238 | 
            +
                            list(conversation[0].values())[0].strip()
         | 
| 239 | 
            +
                        ):
         | 
| 240 | 
            +
                            continue
         | 
| 241 | 
            +
                        if len(conversation) == batch_lines:
         | 
| 242 | 
            +
                            break
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                    fix_conversation_length = []
         | 
| 245 | 
            +
                    for line in conversation:
         | 
| 246 | 
            +
                        for speaker_code, text_tr in line.items():
         | 
| 247 | 
            +
                            fix_conversation_length.append({speaker_code: text_tr})
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                    logger.debug(f"Data batch: {str(fix_conversation_length)}")
         | 
| 250 | 
            +
                    logger.debug(
         | 
| 251 | 
            +
                        f"Lines Received: {len(fix_conversation_length)},"
         | 
| 252 | 
            +
                        f" expected: {batch_lines}"
         | 
| 253 | 
            +
                    )
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                    return fix_conversation_length
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                else:
         | 
| 258 | 
            +
                    if isinstance(translation, dict):
         | 
| 259 | 
            +
                        translation = list(translation.values())[0]
         | 
| 260 | 
            +
                    if isinstance(translation, list):
         | 
| 261 | 
            +
                        translation = translation[0]
         | 
| 262 | 
            +
                    if isinstance(translation, set):
         | 
| 263 | 
            +
                        translation = list(translation)[0]
         | 
| 264 | 
            +
                    if not isinstance(translation, str):
         | 
| 265 | 
            +
                        raise ValueError(f"No valid response received: {str(translation)}")
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                    return translation
         | 
| 268 | 
            +
             | 
| 269 | 
            +
             | 
| 270 | 
            +
            def gpt_sequential(segments, model, target, source=None):
         | 
| 271 | 
            +
                from openai import OpenAI
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                translated_segments = copy.deepcopy(segments)
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                client = OpenAI()
         | 
| 276 | 
            +
                progress_bar = tqdm(total=len(segments), desc="Translating")
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                lang_tg = re.sub(r'\([^)]*\)', '', INVERTED_LANGUAGES[target]).strip()
         | 
| 279 | 
            +
                lang_sc = ""
         | 
| 280 | 
            +
                if source:
         | 
| 281 | 
            +
                    lang_sc = re.sub(r'\([^)]*\)', '', INVERTED_LANGUAGES[source]).strip()
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                fixed_target = fix_code_language(target)
         | 
| 284 | 
            +
                fixed_source = fix_code_language(source) if source else "auto"
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                system_prompt = "Machine translation designed to output the translated_text JSON."
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                for i, line in enumerate(translated_segments):
         | 
| 289 | 
            +
                    text = line["text"].strip()
         | 
| 290 | 
            +
                    start = line["start"]
         | 
| 291 | 
            +
                    user_prompt = f"Translate the following {lang_sc} text into {lang_tg}, write the fully translated text and nothing more:\n{text}"
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                    time.sleep(0.5)
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                    try:
         | 
| 296 | 
            +
                        translated_text = call_gpt_translate(
         | 
| 297 | 
            +
                            client,
         | 
| 298 | 
            +
                            model,
         | 
| 299 | 
            +
                            system_prompt,
         | 
| 300 | 
            +
                            user_prompt,
         | 
| 301 | 
            +
                        )
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                    except Exception as error:
         | 
| 304 | 
            +
                        logger.error(
         | 
| 305 | 
            +
                            f"{str(error)} >> The text of segment {start} "
         | 
| 306 | 
            +
                            "is being corrected with Google Translate"
         | 
| 307 | 
            +
                        )
         | 
| 308 | 
            +
                        translator = GoogleTranslator(
         | 
| 309 | 
            +
                            source=fixed_source, target=fixed_target
         | 
| 310 | 
            +
                        )
         | 
| 311 | 
            +
                        translated_text = translator.translate(text.strip())
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                    translated_segments[i]["text"] = translated_text.strip()
         | 
| 314 | 
            +
                    progress_bar.update(1)
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                progress_bar.close()
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                return translated_segments
         | 
| 319 | 
            +
             | 
| 320 | 
            +
             | 
| 321 | 
            +
            def gpt_batch(segments, model, target, token_batch_limit=900, source=None):
         | 
| 322 | 
            +
                from openai import OpenAI
         | 
| 323 | 
            +
                import tiktoken
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                token_batch_limit = max(100, (token_batch_limit - 40) // 2)
         | 
| 326 | 
            +
                progress_bar = tqdm(total=len(segments), desc="Translating")
         | 
| 327 | 
            +
                segments_copy = copy.deepcopy(segments)
         | 
| 328 | 
            +
                encoding = tiktoken.get_encoding("cl100k_base")
         | 
| 329 | 
            +
                client = OpenAI()
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                lang_tg = re.sub(r'\([^)]*\)', '', INVERTED_LANGUAGES[target]).strip()
         | 
| 332 | 
            +
                lang_sc = ""
         | 
| 333 | 
            +
                if source:
         | 
| 334 | 
            +
                    lang_sc = re.sub(r'\([^)]*\)', '', INVERTED_LANGUAGES[source]).strip()
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                fixed_target = fix_code_language(target)
         | 
| 337 | 
            +
                fixed_source = fix_code_language(source) if source else "auto"
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                name_speaker = "ABCDEFGHIJKL"
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                translated_lines = []
         | 
| 342 | 
            +
                text_data_dict = []
         | 
| 343 | 
            +
                num_tokens = 0
         | 
| 344 | 
            +
                count_sk = {char: 0 for char in "ABCDEFGHIJKL"}
         | 
| 345 | 
            +
             | 
| 346 | 
            +
                for i, line in enumerate(segments_copy):
         | 
| 347 | 
            +
                    text = line["text"]
         | 
| 348 | 
            +
                    speaker = line["speaker"]
         | 
| 349 | 
            +
                    last_start = line["start"]
         | 
| 350 | 
            +
                    # text_data_dict.append({str(int(speaker[-1])+1): text})
         | 
| 351 | 
            +
                    index_sk = int(speaker[-2:])
         | 
| 352 | 
            +
                    character_sk = name_speaker[index_sk]
         | 
| 353 | 
            +
                    count_sk[character_sk] += 1
         | 
| 354 | 
            +
                    code_sk = character_sk+str(count_sk[character_sk])
         | 
| 355 | 
            +
                    text_data_dict.append({code_sk: text})
         | 
| 356 | 
            +
                    num_tokens += len(encoding.encode(text)) + 7
         | 
| 357 | 
            +
                    if num_tokens >= token_batch_limit or i == len(segments_copy)-1:
         | 
| 358 | 
            +
                        try:
         | 
| 359 | 
            +
                            batch_lines = len(text_data_dict)
         | 
| 360 | 
            +
                            batch_conversation = {"conversation": copy.deepcopy(text_data_dict)}
         | 
| 361 | 
            +
                            # Reset vars
         | 
| 362 | 
            +
                            num_tokens = 0
         | 
| 363 | 
            +
                            text_data_dict = []
         | 
| 364 | 
            +
                            count_sk = {char: 0 for char in "ABCDEFGHIJKL"}
         | 
| 365 | 
            +
                            # Process translation
         | 
| 366 | 
            +
                            # https://arxiv.org/pdf/2309.03409.pdf
         | 
| 367 | 
            +
                            system_prompt = f"Machine translation designed to output the translated_conversation key JSON containing a list of {batch_lines} items."
         | 
| 368 | 
            +
                            user_prompt = f"Translate each of the following text values in conversation{' from' if lang_sc else ''} {lang_sc} to {lang_tg}:\n{batch_conversation}"
         | 
| 369 | 
            +
                            logger.debug(f"Prompt: {str(user_prompt)}")
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                            conversation = call_gpt_translate(
         | 
| 372 | 
            +
                                client,
         | 
| 373 | 
            +
                                model,
         | 
| 374 | 
            +
                                system_prompt,
         | 
| 375 | 
            +
                                user_prompt,
         | 
| 376 | 
            +
                                original_text=batch_conversation,
         | 
| 377 | 
            +
                                batch_lines=batch_lines,
         | 
| 378 | 
            +
                            )
         | 
| 379 | 
            +
             | 
| 380 | 
            +
                            if len(conversation) < batch_lines:
         | 
| 381 | 
            +
                                raise ValueError(
         | 
| 382 | 
            +
                                    "Incomplete result received. Batch lines: "
         | 
| 383 | 
            +
                                    f"{len(conversation)}, expected: {batch_lines}"
         | 
| 384 | 
            +
                                )
         | 
| 385 | 
            +
             | 
| 386 | 
            +
                            for i, translated_text in enumerate(conversation):
         | 
| 387 | 
            +
                                if i+1 > batch_lines:
         | 
| 388 | 
            +
                                    break
         | 
| 389 | 
            +
                                translated_lines.append(list(translated_text.values())[0])
         | 
| 390 | 
            +
             | 
| 391 | 
            +
                            progress_bar.update(batch_lines)
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                        except Exception as error:
         | 
| 394 | 
            +
                            logger.error(str(error))
         | 
| 395 | 
            +
             | 
| 396 | 
            +
                            first_start = segments_copy[max(0, i-(batch_lines-1))]["start"]
         | 
| 397 | 
            +
                            logger.warning(
         | 
| 398 | 
            +
                                f"The batch from {first_start} to {last_start} "
         | 
| 399 | 
            +
                                "failed, is being corrected with Google Translate"
         | 
| 400 | 
            +
                            )
         | 
| 401 | 
            +
             | 
| 402 | 
            +
                            translator = GoogleTranslator(
         | 
| 403 | 
            +
                                source=fixed_source,
         | 
| 404 | 
            +
                                target=fixed_target
         | 
| 405 | 
            +
                            )
         | 
| 406 | 
            +
             | 
| 407 | 
            +
                            for txt_source in batch_conversation["conversation"]:
         | 
| 408 | 
            +
                                translated_txt = translator.translate(
         | 
| 409 | 
            +
                                    list(txt_source.values())[0].strip()
         | 
| 410 | 
            +
                                )
         | 
| 411 | 
            +
                                translated_lines.append(translated_txt.strip())
         | 
| 412 | 
            +
                                progress_bar.update(1)
         | 
| 413 | 
            +
             | 
| 414 | 
            +
                progress_bar.close()
         | 
| 415 | 
            +
             | 
| 416 | 
            +
                return verify_translate(
         | 
| 417 | 
            +
                    segments, segments_copy, translated_lines, fixed_target, fixed_source
         | 
| 418 | 
            +
                )
         | 
| 419 | 
            +
             | 
| 420 | 
            +
             | 
| 421 | 
            +
            def translate_text(
         | 
| 422 | 
            +
                segments,
         | 
| 423 | 
            +
                target,
         | 
| 424 | 
            +
                translation_process="google_translator_batch",
         | 
| 425 | 
            +
                chunk_size=4500,
         | 
| 426 | 
            +
                source=None,
         | 
| 427 | 
            +
                token_batch_limit=1000,
         | 
| 428 | 
            +
            ):
         | 
| 429 | 
            +
                """Translates text segments using a specified process."""
         | 
| 430 | 
            +
                match translation_process:
         | 
| 431 | 
            +
                    case "google_translator_batch":
         | 
| 432 | 
            +
                        return translate_batch(
         | 
| 433 | 
            +
                            segments,
         | 
| 434 | 
            +
                            fix_code_language(target),
         | 
| 435 | 
            +
                            chunk_size,
         | 
| 436 | 
            +
                            fix_code_language(source)
         | 
| 437 | 
            +
                        )
         | 
| 438 | 
            +
                    case "google_translator":
         | 
| 439 | 
            +
                        return translate_iterative(
         | 
| 440 | 
            +
                            segments,
         | 
| 441 | 
            +
                            fix_code_language(target),
         | 
| 442 | 
            +
                            fix_code_language(source)
         | 
| 443 | 
            +
                        )
         | 
| 444 | 
            +
                    case model if model in ["gpt-3.5-turbo-0125", "gpt-4-turbo-preview"]:
         | 
| 445 | 
            +
                        return gpt_sequential(segments, model, target, source)
         | 
| 446 | 
            +
                    case model if model in ["gpt-3.5-turbo-0125_batch", "gpt-4-turbo-preview_batch",]:
         | 
| 447 | 
            +
                        return gpt_batch(
         | 
| 448 | 
            +
                            segments,
         | 
| 449 | 
            +
                            translation_process.replace("_batch", ""),
         | 
| 450 | 
            +
                            target,
         | 
| 451 | 
            +
                            token_batch_limit,
         | 
| 452 | 
            +
                            source
         | 
| 453 | 
            +
                        )
         | 
| 454 | 
            +
                    case "disable_translation":
         | 
| 455 | 
            +
                        return segments
         | 
| 456 | 
            +
                    case _:
         | 
| 457 | 
            +
                        raise ValueError("No valid translation process")
         | 
    	
        soni_translate/utils.py
    ADDED
    
    | @@ -0,0 +1,487 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os, zipfile, rarfile, shutil, subprocess, shlex, sys # noqa
         | 
| 2 | 
            +
            from .logging_setup import logger
         | 
| 3 | 
            +
            from urllib.parse import urlparse
         | 
| 4 | 
            +
            from IPython.utils import capture
         | 
| 5 | 
            +
            import re
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            VIDEO_EXTENSIONS = [
         | 
| 8 | 
            +
                ".mp4",
         | 
| 9 | 
            +
                ".avi",
         | 
| 10 | 
            +
                ".mov",
         | 
| 11 | 
            +
                ".mkv",
         | 
| 12 | 
            +
                ".wmv",
         | 
| 13 | 
            +
                ".flv",
         | 
| 14 | 
            +
                ".webm",
         | 
| 15 | 
            +
                ".m4v",
         | 
| 16 | 
            +
                ".mpeg",
         | 
| 17 | 
            +
                ".mpg",
         | 
| 18 | 
            +
                ".3gp"
         | 
| 19 | 
            +
            ]
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            AUDIO_EXTENSIONS = [
         | 
| 22 | 
            +
                ".mp3",
         | 
| 23 | 
            +
                ".wav",
         | 
| 24 | 
            +
                ".aiff",
         | 
| 25 | 
            +
                ".aif",
         | 
| 26 | 
            +
                ".flac",
         | 
| 27 | 
            +
                ".aac",
         | 
| 28 | 
            +
                ".ogg",
         | 
| 29 | 
            +
                ".wma",
         | 
| 30 | 
            +
                ".m4a",
         | 
| 31 | 
            +
                ".alac",
         | 
| 32 | 
            +
                ".pcm",
         | 
| 33 | 
            +
                ".opus",
         | 
| 34 | 
            +
                ".ape",
         | 
| 35 | 
            +
                ".amr",
         | 
| 36 | 
            +
                ".ac3",
         | 
| 37 | 
            +
                ".vox",
         | 
| 38 | 
            +
                ".caf"
         | 
| 39 | 
            +
            ]
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            SUBTITLE_EXTENSIONS = [
         | 
| 42 | 
            +
                ".srt",
         | 
| 43 | 
            +
                ".vtt",
         | 
| 44 | 
            +
                ".ass"
         | 
| 45 | 
            +
            ]
         | 
| 46 | 
            +
             | 
| 47 | 
            +
             | 
| 48 | 
            +
            def run_command(command):
         | 
| 49 | 
            +
                logger.debug(command)
         | 
| 50 | 
            +
                if isinstance(command, str):
         | 
| 51 | 
            +
                    command = shlex.split(command)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                sub_params = {
         | 
| 54 | 
            +
                    "stdout": subprocess.PIPE,
         | 
| 55 | 
            +
                    "stderr": subprocess.PIPE,
         | 
| 56 | 
            +
                    "creationflags": subprocess.CREATE_NO_WINDOW
         | 
| 57 | 
            +
                    if sys.platform == "win32"
         | 
| 58 | 
            +
                    else 0,
         | 
| 59 | 
            +
                }
         | 
| 60 | 
            +
                process_command = subprocess.Popen(command, **sub_params)
         | 
| 61 | 
            +
                output, errors = process_command.communicate()
         | 
| 62 | 
            +
                if (
         | 
| 63 | 
            +
                    process_command.returncode != 0
         | 
| 64 | 
            +
                ):  # or not os.path.exists(mono_path) or os.path.getsize(mono_path) == 0:
         | 
| 65 | 
            +
                    logger.error("Error comnand")
         | 
| 66 | 
            +
                    raise Exception(errors.decode())
         | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
            def print_tree_directory(root_dir, indent=""):
         | 
| 70 | 
            +
                if not os.path.exists(root_dir):
         | 
| 71 | 
            +
                    logger.error(f"{indent} Invalid directory or file: {root_dir}")
         | 
| 72 | 
            +
                    return
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                items = os.listdir(root_dir)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                for index, item in enumerate(sorted(items)):
         | 
| 77 | 
            +
                    item_path = os.path.join(root_dir, item)
         | 
| 78 | 
            +
                    is_last_item = index == len(items) - 1
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    if os.path.isfile(item_path) and item_path.endswith(".zip"):
         | 
| 81 | 
            +
                        with zipfile.ZipFile(item_path, "r") as zip_file:
         | 
| 82 | 
            +
                            print(
         | 
| 83 | 
            +
                                f"{indent}{'└──' if is_last_item else '├──'} {item} (zip file)"
         | 
| 84 | 
            +
                            )
         | 
| 85 | 
            +
                            zip_contents = zip_file.namelist()
         | 
| 86 | 
            +
                            for zip_item in sorted(zip_contents):
         | 
| 87 | 
            +
                                print(
         | 
| 88 | 
            +
                                    f"{indent}{'    ' if is_last_item else '│   '}{zip_item}"
         | 
| 89 | 
            +
                                )
         | 
| 90 | 
            +
                    else:
         | 
| 91 | 
            +
                        print(f"{indent}{'└──' if is_last_item else '├──'} {item}")
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                        if os.path.isdir(item_path):
         | 
| 94 | 
            +
                            new_indent = indent + ("    " if is_last_item else "│   ")
         | 
| 95 | 
            +
                            print_tree_directory(item_path, new_indent)
         | 
| 96 | 
            +
             | 
| 97 | 
            +
             | 
| 98 | 
            +
            def upload_model_list():
         | 
| 99 | 
            +
                weight_root = "weights"
         | 
| 100 | 
            +
                models = []
         | 
| 101 | 
            +
                for name in os.listdir(weight_root):
         | 
| 102 | 
            +
                    if name.endswith(".pth"):
         | 
| 103 | 
            +
                        models.append("weights/" + name)
         | 
| 104 | 
            +
                if models:
         | 
| 105 | 
            +
                    logger.debug(models)
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                index_root = "logs"
         | 
| 108 | 
            +
                index_paths = [None]
         | 
| 109 | 
            +
                for name in os.listdir(index_root):
         | 
| 110 | 
            +
                    if name.endswith(".index"):
         | 
| 111 | 
            +
                        index_paths.append("logs/" + name)
         | 
| 112 | 
            +
                if index_paths:
         | 
| 113 | 
            +
                    logger.debug(index_paths)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                return models, index_paths
         | 
| 116 | 
            +
             | 
| 117 | 
            +
             | 
| 118 | 
            +
            def manual_download(url, dst):
         | 
| 119 | 
            +
                if "drive.google" in url:
         | 
| 120 | 
            +
                    logger.info("Drive url")
         | 
| 121 | 
            +
                    if "folders" in url:
         | 
| 122 | 
            +
                        logger.info("folder")
         | 
| 123 | 
            +
                        os.system(f'gdown --folder "{url}" -O {dst} --fuzzy -c')
         | 
| 124 | 
            +
                    else:
         | 
| 125 | 
            +
                        logger.info("single")
         | 
| 126 | 
            +
                        os.system(f'gdown "{url}" -O {dst} --fuzzy -c')
         | 
| 127 | 
            +
                elif "huggingface" in url:
         | 
| 128 | 
            +
                    logger.info("HuggingFace url")
         | 
| 129 | 
            +
                    if "/blob/" in url or "/resolve/" in url:
         | 
| 130 | 
            +
                        if "/blob/" in url:
         | 
| 131 | 
            +
                            url = url.replace("/blob/", "/resolve/")
         | 
| 132 | 
            +
                        download_manager(url=url, path=dst, overwrite=True, progress=True)
         | 
| 133 | 
            +
                    else:
         | 
| 134 | 
            +
                        os.system(f"git clone {url} {dst+'repo/'}")
         | 
| 135 | 
            +
                elif "http" in url:
         | 
| 136 | 
            +
                    logger.info("URL")
         | 
| 137 | 
            +
                    download_manager(url=url, path=dst, overwrite=True, progress=True)
         | 
| 138 | 
            +
                elif os.path.exists(url):
         | 
| 139 | 
            +
                    logger.info("Path")
         | 
| 140 | 
            +
                    copy_files(url, dst)
         | 
| 141 | 
            +
                else:
         | 
| 142 | 
            +
                    logger.error(f"No valid URL: {url}")
         | 
| 143 | 
            +
             | 
| 144 | 
            +
             | 
| 145 | 
            +
            def download_list(text_downloads):
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                if os.environ.get("ZERO_GPU") == "TRUE":
         | 
| 148 | 
            +
                    raise RuntimeError("This option is disabled in this demo.")
         | 
| 149 | 
            +
                
         | 
| 150 | 
            +
                try:
         | 
| 151 | 
            +
                    urls = [elem.strip() for elem in text_downloads.split(",")]
         | 
| 152 | 
            +
                except Exception as error:
         | 
| 153 | 
            +
                    raise ValueError(f"No valid URL. {str(error)}")
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                create_directories(["downloads", "logs", "weights"])
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                path_download = "downloads/"
         | 
| 158 | 
            +
                for url in urls:
         | 
| 159 | 
            +
                    manual_download(url, path_download)
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                # Tree
         | 
| 162 | 
            +
                print("####################################")
         | 
| 163 | 
            +
                print_tree_directory("downloads", indent="")
         | 
| 164 | 
            +
                print("####################################")
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                # Place files
         | 
| 167 | 
            +
                select_zip_and_rar_files("downloads/")
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                models, _ = upload_model_list()
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                # hf space models files delete
         | 
| 172 | 
            +
                remove_directory_contents("downloads/repo")
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                return f"Downloaded = {models}"
         | 
| 175 | 
            +
             | 
| 176 | 
            +
             | 
| 177 | 
            +
            def select_zip_and_rar_files(directory_path="downloads/"):
         | 
| 178 | 
            +
                # filter
         | 
| 179 | 
            +
                zip_files = []
         | 
| 180 | 
            +
                rar_files = []
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                for file_name in os.listdir(directory_path):
         | 
| 183 | 
            +
                    if file_name.endswith(".zip"):
         | 
| 184 | 
            +
                        zip_files.append(file_name)
         | 
| 185 | 
            +
                    elif file_name.endswith(".rar"):
         | 
| 186 | 
            +
                        rar_files.append(file_name)
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                # extract
         | 
| 189 | 
            +
                for file_name in zip_files:
         | 
| 190 | 
            +
                    file_path = os.path.join(directory_path, file_name)
         | 
| 191 | 
            +
                    with zipfile.ZipFile(file_path, "r") as zip_ref:
         | 
| 192 | 
            +
                        zip_ref.extractall(directory_path)
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                for file_name in rar_files:
         | 
| 195 | 
            +
                    file_path = os.path.join(directory_path, file_name)
         | 
| 196 | 
            +
                    with rarfile.RarFile(file_path, "r") as rar_ref:
         | 
| 197 | 
            +
                        rar_ref.extractall(directory_path)
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                # set in path
         | 
| 200 | 
            +
                def move_files_with_extension(src_dir, extension, destination_dir):
         | 
| 201 | 
            +
                    for root, _, files in os.walk(src_dir):
         | 
| 202 | 
            +
                        for file_name in files:
         | 
| 203 | 
            +
                            if file_name.endswith(extension):
         | 
| 204 | 
            +
                                source_file = os.path.join(root, file_name)
         | 
| 205 | 
            +
                                destination = os.path.join(destination_dir, file_name)
         | 
| 206 | 
            +
                                shutil.move(source_file, destination)
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                move_files_with_extension(directory_path, ".index", "logs/")
         | 
| 209 | 
            +
                move_files_with_extension(directory_path, ".pth", "weights/")
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                return "Download complete"
         | 
| 212 | 
            +
             | 
| 213 | 
            +
             | 
| 214 | 
            +
            def is_file_with_extensions(string_path, extensions):
         | 
| 215 | 
            +
                return any(string_path.lower().endswith(ext) for ext in extensions)
         | 
| 216 | 
            +
             | 
| 217 | 
            +
             | 
| 218 | 
            +
            def is_video_file(string_path):
         | 
| 219 | 
            +
                return is_file_with_extensions(string_path, VIDEO_EXTENSIONS)
         | 
| 220 | 
            +
             | 
| 221 | 
            +
             | 
| 222 | 
            +
            def is_audio_file(string_path):
         | 
| 223 | 
            +
                return is_file_with_extensions(string_path, AUDIO_EXTENSIONS)
         | 
| 224 | 
            +
             | 
| 225 | 
            +
             | 
| 226 | 
            +
            def is_subtitle_file(string_path):
         | 
| 227 | 
            +
                return is_file_with_extensions(string_path, SUBTITLE_EXTENSIONS)
         | 
| 228 | 
            +
             | 
| 229 | 
            +
             | 
| 230 | 
            +
            def get_directory_files(directory):
         | 
| 231 | 
            +
                audio_files = []
         | 
| 232 | 
            +
                video_files = []
         | 
| 233 | 
            +
                sub_files = []
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                for item in os.listdir(directory):
         | 
| 236 | 
            +
                    item_path = os.path.join(directory, item)
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                    if os.path.isfile(item_path):
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                        if is_audio_file(item_path):
         | 
| 241 | 
            +
                            audio_files.append(item_path)
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                        elif is_video_file(item_path):
         | 
| 244 | 
            +
                            video_files.append(item_path)
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                        elif is_subtitle_file(item_path):
         | 
| 247 | 
            +
                            sub_files.append(item_path)
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                logger.info(
         | 
| 250 | 
            +
                    f"Files in path ({directory}): "
         | 
| 251 | 
            +
                    f"{str(audio_files + video_files + sub_files)}"
         | 
| 252 | 
            +
                )
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                return audio_files, video_files, sub_files
         | 
| 255 | 
            +
             | 
| 256 | 
            +
             | 
| 257 | 
            +
            def get_valid_files(paths):
         | 
| 258 | 
            +
                valid_paths = []
         | 
| 259 | 
            +
                for path in paths:
         | 
| 260 | 
            +
                    if os.path.isdir(path):
         | 
| 261 | 
            +
                        audio_files, video_files, sub_files = get_directory_files(path)
         | 
| 262 | 
            +
                        valid_paths.extend(audio_files)
         | 
| 263 | 
            +
                        valid_paths.extend(video_files)
         | 
| 264 | 
            +
                        valid_paths.extend(sub_files)
         | 
| 265 | 
            +
                    else:
         | 
| 266 | 
            +
                        valid_paths.append(path)
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                return valid_paths
         | 
| 269 | 
            +
             | 
| 270 | 
            +
             | 
| 271 | 
            +
            def extract_video_links(link):
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                params_dlp = {"quiet": False, "no_warnings": True, "noplaylist": False}
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                try:
         | 
| 276 | 
            +
                    from yt_dlp import YoutubeDL
         | 
| 277 | 
            +
                    with capture.capture_output() as cap:
         | 
| 278 | 
            +
                        with YoutubeDL(params_dlp) as ydl:
         | 
| 279 | 
            +
                            info_dict = ydl.extract_info( # noqa
         | 
| 280 | 
            +
                                link, download=False, process=True
         | 
| 281 | 
            +
                            )
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                    urls = re.findall(r'\[youtube\] Extracting URL: (.*?)\n', cap.stdout)
         | 
| 284 | 
            +
                    logger.info(f"List of videos in ({link}): {str(urls)}")
         | 
| 285 | 
            +
                    del cap
         | 
| 286 | 
            +
                except Exception as error:
         | 
| 287 | 
            +
                    logger.error(f"{link} >> {str(error)}")
         | 
| 288 | 
            +
                    urls = [link]
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                return urls
         | 
| 291 | 
            +
             | 
| 292 | 
            +
             | 
| 293 | 
            +
            def get_link_list(urls):
         | 
| 294 | 
            +
                valid_links = []
         | 
| 295 | 
            +
                for url_video in urls:
         | 
| 296 | 
            +
                    if "youtube.com" in url_video and "/watch?v=" not in url_video:
         | 
| 297 | 
            +
                        url_links = extract_video_links(url_video)
         | 
| 298 | 
            +
                        valid_links.extend(url_links)
         | 
| 299 | 
            +
                    else:
         | 
| 300 | 
            +
                        valid_links.append(url_video)
         | 
| 301 | 
            +
                return valid_links
         | 
| 302 | 
            +
             | 
| 303 | 
            +
            # =====================================
         | 
| 304 | 
            +
            # Download Manager
         | 
| 305 | 
            +
            # =====================================
         | 
| 306 | 
            +
             | 
| 307 | 
            +
             | 
| 308 | 
            +
            def load_file_from_url(
         | 
| 309 | 
            +
                url: str,
         | 
| 310 | 
            +
                model_dir: str,
         | 
| 311 | 
            +
                file_name: str | None = None,
         | 
| 312 | 
            +
                overwrite: bool = False,
         | 
| 313 | 
            +
                progress: bool = True,
         | 
| 314 | 
            +
            ) -> str:
         | 
| 315 | 
            +
                """Download a file from `url` into `model_dir`,
         | 
| 316 | 
            +
                using the file present if possible.
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                Returns the path to the downloaded file.
         | 
| 319 | 
            +
                """
         | 
| 320 | 
            +
                os.makedirs(model_dir, exist_ok=True)
         | 
| 321 | 
            +
                if not file_name:
         | 
| 322 | 
            +
                    parts = urlparse(url)
         | 
| 323 | 
            +
                    file_name = os.path.basename(parts.path)
         | 
| 324 | 
            +
                cached_file = os.path.abspath(os.path.join(model_dir, file_name))
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                # Overwrite
         | 
| 327 | 
            +
                if os.path.exists(cached_file):
         | 
| 328 | 
            +
                    if overwrite or os.path.getsize(cached_file) == 0:
         | 
| 329 | 
            +
                        remove_files(cached_file)
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                # Download
         | 
| 332 | 
            +
                if not os.path.exists(cached_file):
         | 
| 333 | 
            +
                    logger.info(f'Downloading: "{url}" to {cached_file}\n')
         | 
| 334 | 
            +
                    from torch.hub import download_url_to_file
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                    download_url_to_file(url, cached_file, progress=progress)
         | 
| 337 | 
            +
                else:
         | 
| 338 | 
            +
                    logger.debug(cached_file)
         | 
| 339 | 
            +
             | 
| 340 | 
            +
                return cached_file
         | 
| 341 | 
            +
             | 
| 342 | 
            +
             | 
| 343 | 
            +
            def friendly_name(file: str):
         | 
| 344 | 
            +
                if file.startswith("http"):
         | 
| 345 | 
            +
                    file = urlparse(file).path
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                file = os.path.basename(file)
         | 
| 348 | 
            +
                model_name, extension = os.path.splitext(file)
         | 
| 349 | 
            +
                return model_name, extension
         | 
| 350 | 
            +
             | 
| 351 | 
            +
             | 
| 352 | 
            +
            def download_manager(
         | 
| 353 | 
            +
                url: str,
         | 
| 354 | 
            +
                path: str,
         | 
| 355 | 
            +
                extension: str = "",
         | 
| 356 | 
            +
                overwrite: bool = False,
         | 
| 357 | 
            +
                progress: bool = True,
         | 
| 358 | 
            +
            ):
         | 
| 359 | 
            +
                url = url.strip()
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                name, ext = friendly_name(url)
         | 
| 362 | 
            +
                name += ext if not extension else f".{extension}"
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                if url.startswith("http"):
         | 
| 365 | 
            +
                    filename = load_file_from_url(
         | 
| 366 | 
            +
                        url=url,
         | 
| 367 | 
            +
                        model_dir=path,
         | 
| 368 | 
            +
                        file_name=name,
         | 
| 369 | 
            +
                        overwrite=overwrite,
         | 
| 370 | 
            +
                        progress=progress,
         | 
| 371 | 
            +
                    )
         | 
| 372 | 
            +
                else:
         | 
| 373 | 
            +
                    filename = path
         | 
| 374 | 
            +
             | 
| 375 | 
            +
                return filename
         | 
| 376 | 
            +
             | 
| 377 | 
            +
             | 
| 378 | 
            +
            # =====================================
         | 
| 379 | 
            +
            # File management
         | 
| 380 | 
            +
            # =====================================
         | 
| 381 | 
            +
             | 
| 382 | 
            +
             | 
| 383 | 
            +
            # only remove files
         | 
| 384 | 
            +
            def remove_files(file_list):
         | 
| 385 | 
            +
                if isinstance(file_list, str):
         | 
| 386 | 
            +
                    file_list = [file_list]
         | 
| 387 | 
            +
             | 
| 388 | 
            +
                for file in file_list:
         | 
| 389 | 
            +
                    if os.path.exists(file):
         | 
| 390 | 
            +
                        os.remove(file)
         | 
| 391 | 
            +
             | 
| 392 | 
            +
             | 
| 393 | 
            +
            def remove_directory_contents(directory_path):
         | 
| 394 | 
            +
                """
         | 
| 395 | 
            +
                Removes all files and subdirectories within a directory.
         | 
| 396 | 
            +
             | 
| 397 | 
            +
                Parameters:
         | 
| 398 | 
            +
                directory_path (str): Path to the directory whose
         | 
| 399 | 
            +
                contents need to be removed.
         | 
| 400 | 
            +
                """
         | 
| 401 | 
            +
                if os.path.exists(directory_path):
         | 
| 402 | 
            +
                    for filename in os.listdir(directory_path):
         | 
| 403 | 
            +
                        file_path = os.path.join(directory_path, filename)
         | 
| 404 | 
            +
                        try:
         | 
| 405 | 
            +
                            if os.path.isfile(file_path):
         | 
| 406 | 
            +
                                os.remove(file_path)
         | 
| 407 | 
            +
                            elif os.path.isdir(file_path):
         | 
| 408 | 
            +
                                shutil.rmtree(file_path)
         | 
| 409 | 
            +
                        except Exception as e:
         | 
| 410 | 
            +
                            logger.error(f"Failed to delete {file_path}. Reason: {e}")
         | 
| 411 | 
            +
                    logger.info(f"Content in '{directory_path}' removed.")
         | 
| 412 | 
            +
                else:
         | 
| 413 | 
            +
                    logger.error(f"Directory '{directory_path}' does not exist.")
         | 
| 414 | 
            +
             | 
| 415 | 
            +
             | 
| 416 | 
            +
            # Create directory if not exists
         | 
| 417 | 
            +
            def create_directories(directory_path):
         | 
| 418 | 
            +
                if isinstance(directory_path, str):
         | 
| 419 | 
            +
                    directory_path = [directory_path]
         | 
| 420 | 
            +
                for one_dir_path in directory_path:
         | 
| 421 | 
            +
                    if not os.path.exists(one_dir_path):
         | 
| 422 | 
            +
                        os.makedirs(one_dir_path)
         | 
| 423 | 
            +
                        logger.debug(f"Directory '{one_dir_path}' created.")
         | 
| 424 | 
            +
             | 
| 425 | 
            +
             | 
| 426 | 
            +
            def move_files(source_dir, destination_dir, extension=""):
         | 
| 427 | 
            +
                """
         | 
| 428 | 
            +
                Moves file(s) from the source path to the destination path.
         | 
| 429 | 
            +
             | 
| 430 | 
            +
                Parameters:
         | 
| 431 | 
            +
                source_dir (str): Path to the source directory.
         | 
| 432 | 
            +
                destination_dir (str): Path to the destination directory.
         | 
| 433 | 
            +
                extension (str): Only move files with this extension.
         | 
| 434 | 
            +
                """
         | 
| 435 | 
            +
                create_directories(destination_dir)
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                for filename in os.listdir(source_dir):
         | 
| 438 | 
            +
                    source_path = os.path.join(source_dir, filename)
         | 
| 439 | 
            +
                    destination_path = os.path.join(destination_dir, filename)
         | 
| 440 | 
            +
                    if extension and not filename.endswith(extension):
         | 
| 441 | 
            +
                        continue
         | 
| 442 | 
            +
                    os.replace(source_path, destination_path)
         | 
| 443 | 
            +
             | 
| 444 | 
            +
             | 
| 445 | 
            +
            def copy_files(source_path, destination_path):
         | 
| 446 | 
            +
                """
         | 
| 447 | 
            +
                Copies a file or multiple files from a source path to a destination path.
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                Parameters:
         | 
| 450 | 
            +
                source_path (str or list): Path or list of paths to the source
         | 
| 451 | 
            +
                file(s) or directory.
         | 
| 452 | 
            +
                destination_path (str): Path to the destination directory.
         | 
| 453 | 
            +
                """
         | 
| 454 | 
            +
                create_directories(destination_path)
         | 
| 455 | 
            +
             | 
| 456 | 
            +
                if isinstance(source_path, str):
         | 
| 457 | 
            +
                    source_path = [source_path]
         | 
| 458 | 
            +
             | 
| 459 | 
            +
                if os.path.isdir(source_path[0]):
         | 
| 460 | 
            +
                    # Copy all files from the source directory to the destination directory
         | 
| 461 | 
            +
                    base_path = source_path[0]
         | 
| 462 | 
            +
                    source_path = os.listdir(source_path[0])
         | 
| 463 | 
            +
                    source_path = [
         | 
| 464 | 
            +
                        os.path.join(base_path, file_name) for file_name in source_path
         | 
| 465 | 
            +
                    ]
         | 
| 466 | 
            +
             | 
| 467 | 
            +
                for one_source_path in source_path:
         | 
| 468 | 
            +
                    if os.path.exists(one_source_path):
         | 
| 469 | 
            +
                        shutil.copy2(one_source_path, destination_path)
         | 
| 470 | 
            +
                        logger.debug(
         | 
| 471 | 
            +
                            f"File '{one_source_path}' copied to '{destination_path}'."
         | 
| 472 | 
            +
                        )
         | 
| 473 | 
            +
                    else:
         | 
| 474 | 
            +
                        logger.error(f"File '{one_source_path}' does not exist.")
         | 
| 475 | 
            +
             | 
| 476 | 
            +
             | 
| 477 | 
            +
            def rename_file(current_name, new_name):
         | 
| 478 | 
            +
                file_directory = os.path.dirname(current_name)
         | 
| 479 | 
            +
             | 
| 480 | 
            +
                if os.path.exists(current_name):
         | 
| 481 | 
            +
                    dir_new_name_file = os.path.join(file_directory, new_name)
         | 
| 482 | 
            +
                    os.rename(current_name, dir_new_name_file)
         | 
| 483 | 
            +
                    logger.debug(f"File '{current_name}' renamed to '{new_name}'.")
         | 
| 484 | 
            +
                    return dir_new_name_file
         | 
| 485 | 
            +
                else:
         | 
| 486 | 
            +
                    logger.error(f"File '{current_name}' does not exist.")
         | 
| 487 | 
            +
                    return None
         | 
    	
        vci_pipeline.py
    ADDED
    
    | @@ -0,0 +1,454 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np, parselmouth, torch, pdb, sys
         | 
| 2 | 
            +
            from time import time as ttime
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
            import scipy.signal as signal
         | 
| 5 | 
            +
            import pyworld, os, traceback, faiss, librosa, torchcrepe
         | 
| 6 | 
            +
            from scipy import signal
         | 
| 7 | 
            +
            from functools import lru_cache
         | 
| 8 | 
            +
            from soni_translate.logging_setup import logger
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            now_dir = os.getcwd()
         | 
| 11 | 
            +
            sys.path.append(now_dir)
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            bh, ah = signal.butter(N=5, Wn=48, btype="high", fs=16000)
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            input_audio_path2wav = {}
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            @lru_cache
         | 
| 19 | 
            +
            def cache_harvest_f0(input_audio_path, fs, f0max, f0min, frame_period):
         | 
| 20 | 
            +
                audio = input_audio_path2wav[input_audio_path]
         | 
| 21 | 
            +
                f0, t = pyworld.harvest(
         | 
| 22 | 
            +
                    audio,
         | 
| 23 | 
            +
                    fs=fs,
         | 
| 24 | 
            +
                    f0_ceil=f0max,
         | 
| 25 | 
            +
                    f0_floor=f0min,
         | 
| 26 | 
            +
                    frame_period=frame_period,
         | 
| 27 | 
            +
                )
         | 
| 28 | 
            +
                f0 = pyworld.stonemask(audio, f0, t, fs)
         | 
| 29 | 
            +
                return f0
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            def change_rms(data1, sr1, data2, sr2, rate):  # 1 is the input audio, 2 is the output audio, rate is the proportion of 2
         | 
| 33 | 
            +
                # print(data1.max(),data2.max())
         | 
| 34 | 
            +
                rms1 = librosa.feature.rms(
         | 
| 35 | 
            +
                    y=data1, frame_length=sr1 // 2 * 2, hop_length=sr1 // 2
         | 
| 36 | 
            +
                )  # one dot every half second
         | 
| 37 | 
            +
                rms2 = librosa.feature.rms(y=data2, frame_length=sr2 // 2 * 2, hop_length=sr2 // 2)
         | 
| 38 | 
            +
                rms1 = torch.from_numpy(rms1)
         | 
| 39 | 
            +
                rms1 = F.interpolate(
         | 
| 40 | 
            +
                    rms1.unsqueeze(0), size=data2.shape[0], mode="linear"
         | 
| 41 | 
            +
                ).squeeze()
         | 
| 42 | 
            +
                rms2 = torch.from_numpy(rms2)
         | 
| 43 | 
            +
                rms2 = F.interpolate(
         | 
| 44 | 
            +
                    rms2.unsqueeze(0), size=data2.shape[0], mode="linear"
         | 
| 45 | 
            +
                ).squeeze()
         | 
| 46 | 
            +
                rms2 = torch.max(rms2, torch.zeros_like(rms2) + 1e-6)
         | 
| 47 | 
            +
                data2 *= (
         | 
| 48 | 
            +
                    torch.pow(rms1, torch.tensor(1 - rate))
         | 
| 49 | 
            +
                    * torch.pow(rms2, torch.tensor(rate - 1))
         | 
| 50 | 
            +
                ).numpy()
         | 
| 51 | 
            +
                return data2
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            class VC(object):
         | 
| 55 | 
            +
                def __init__(self, tgt_sr, config):
         | 
| 56 | 
            +
                    self.x_pad, self.x_query, self.x_center, self.x_max, self.is_half = (
         | 
| 57 | 
            +
                        config.x_pad,
         | 
| 58 | 
            +
                        config.x_query,
         | 
| 59 | 
            +
                        config.x_center,
         | 
| 60 | 
            +
                        config.x_max,
         | 
| 61 | 
            +
                        config.is_half,
         | 
| 62 | 
            +
                    )
         | 
| 63 | 
            +
                    self.sr = 16000  # hubert input sampling rate
         | 
| 64 | 
            +
                    self.window = 160  # points per frame
         | 
| 65 | 
            +
                    self.t_pad = self.sr * self.x_pad  # Pad time before and after each bar
         | 
| 66 | 
            +
                    self.t_pad_tgt = tgt_sr * self.x_pad
         | 
| 67 | 
            +
                    self.t_pad2 = self.t_pad * 2
         | 
| 68 | 
            +
                    self.t_query = self.sr * self.x_query  # Query time before and after the cut point
         | 
| 69 | 
            +
                    self.t_center = self.sr * self.x_center  # Query point cut position
         | 
| 70 | 
            +
                    self.t_max = self.sr * self.x_max  # Query-free duration threshold
         | 
| 71 | 
            +
                    self.device = config.device
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                def get_f0(
         | 
| 74 | 
            +
                    self,
         | 
| 75 | 
            +
                    input_audio_path,
         | 
| 76 | 
            +
                    x,
         | 
| 77 | 
            +
                    p_len,
         | 
| 78 | 
            +
                    f0_up_key,
         | 
| 79 | 
            +
                    f0_method,
         | 
| 80 | 
            +
                    filter_radius,
         | 
| 81 | 
            +
                    inp_f0=None,
         | 
| 82 | 
            +
                ):
         | 
| 83 | 
            +
                    global input_audio_path2wav
         | 
| 84 | 
            +
                    time_step = self.window / self.sr * 1000
         | 
| 85 | 
            +
                    f0_min = 50
         | 
| 86 | 
            +
                    f0_max = 1100
         | 
| 87 | 
            +
                    f0_mel_min = 1127 * np.log(1 + f0_min / 700)
         | 
| 88 | 
            +
                    f0_mel_max = 1127 * np.log(1 + f0_max / 700)
         | 
| 89 | 
            +
                    if f0_method == "pm":
         | 
| 90 | 
            +
                        f0 = (
         | 
| 91 | 
            +
                            parselmouth.Sound(x, self.sr)
         | 
| 92 | 
            +
                            .to_pitch_ac(
         | 
| 93 | 
            +
                                time_step=time_step / 1000,
         | 
| 94 | 
            +
                                voicing_threshold=0.6,
         | 
| 95 | 
            +
                                pitch_floor=f0_min,
         | 
| 96 | 
            +
                                pitch_ceiling=f0_max,
         | 
| 97 | 
            +
                            )
         | 
| 98 | 
            +
                            .selected_array["frequency"]
         | 
| 99 | 
            +
                        )
         | 
| 100 | 
            +
                        pad_size = (p_len - len(f0) + 1) // 2
         | 
| 101 | 
            +
                        if pad_size > 0 or p_len - len(f0) - pad_size > 0:
         | 
| 102 | 
            +
                            f0 = np.pad(
         | 
| 103 | 
            +
                                f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant"
         | 
| 104 | 
            +
                            )
         | 
| 105 | 
            +
                    elif f0_method == "harvest":
         | 
| 106 | 
            +
                        input_audio_path2wav[input_audio_path] = x.astype(np.double)
         | 
| 107 | 
            +
                        f0 = cache_harvest_f0(input_audio_path, self.sr, f0_max, f0_min, 10)
         | 
| 108 | 
            +
                        if filter_radius > 2:
         | 
| 109 | 
            +
                            f0 = signal.medfilt(f0, 3)
         | 
| 110 | 
            +
                    elif f0_method == "crepe":
         | 
| 111 | 
            +
                        model = "full"
         | 
| 112 | 
            +
                        # Pick a batch size that doesn't cause memory errors on your gpu
         | 
| 113 | 
            +
                        batch_size = 512
         | 
| 114 | 
            +
                        # Compute pitch using first gpu
         | 
| 115 | 
            +
                        audio = torch.tensor(np.copy(x))[None].float()
         | 
| 116 | 
            +
                        f0, pd = torchcrepe.predict(
         | 
| 117 | 
            +
                            audio,
         | 
| 118 | 
            +
                            self.sr,
         | 
| 119 | 
            +
                            self.window,
         | 
| 120 | 
            +
                            f0_min,
         | 
| 121 | 
            +
                            f0_max,
         | 
| 122 | 
            +
                            model,
         | 
| 123 | 
            +
                            batch_size=batch_size,
         | 
| 124 | 
            +
                            device=self.device,
         | 
| 125 | 
            +
                            return_periodicity=True,
         | 
| 126 | 
            +
                        )
         | 
| 127 | 
            +
                        pd = torchcrepe.filter.median(pd, 3)
         | 
| 128 | 
            +
                        f0 = torchcrepe.filter.mean(f0, 3)
         | 
| 129 | 
            +
                        f0[pd < 0.1] = 0
         | 
| 130 | 
            +
                        f0 = f0[0].cpu().numpy()
         | 
| 131 | 
            +
                    elif "rmvpe" in f0_method:
         | 
| 132 | 
            +
                        if hasattr(self, "model_rmvpe") == False:
         | 
| 133 | 
            +
                            from lib.rmvpe import RMVPE
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                            logger.info("Loading vocal pitch estimator model")
         | 
| 136 | 
            +
                            self.model_rmvpe = RMVPE(
         | 
| 137 | 
            +
                                "rmvpe.pt", is_half=self.is_half, device=self.device
         | 
| 138 | 
            +
                            )
         | 
| 139 | 
            +
                        thred = 0.03
         | 
| 140 | 
            +
                        if "+" in f0_method:
         | 
| 141 | 
            +
                            f0 = self.model_rmvpe.pitch_based_audio_inference(x, thred, f0_min, f0_max)
         | 
| 142 | 
            +
                        else:
         | 
| 143 | 
            +
                            f0 = self.model_rmvpe.infer_from_audio(x, thred)
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    f0 *= pow(2, f0_up_key / 12)
         | 
| 146 | 
            +
                    # with open("test.txt","w")as f:f.write("\n".join([str(i)for i in f0.tolist()]))
         | 
| 147 | 
            +
                    tf0 = self.sr // self.window  # f0 points per second
         | 
| 148 | 
            +
                    if inp_f0 is not None:
         | 
| 149 | 
            +
                        delta_t = np.round(
         | 
| 150 | 
            +
                            (inp_f0[:, 0].max() - inp_f0[:, 0].min()) * tf0 + 1
         | 
| 151 | 
            +
                        ).astype("int16")
         | 
| 152 | 
            +
                        replace_f0 = np.interp(
         | 
| 153 | 
            +
                            list(range(delta_t)), inp_f0[:, 0] * 100, inp_f0[:, 1]
         | 
| 154 | 
            +
                        )
         | 
| 155 | 
            +
                        shape = f0[self.x_pad * tf0 : self.x_pad * tf0 + len(replace_f0)].shape[0]
         | 
| 156 | 
            +
                        f0[self.x_pad * tf0 : self.x_pad * tf0 + len(replace_f0)] = replace_f0[
         | 
| 157 | 
            +
                            :shape
         | 
| 158 | 
            +
                        ]
         | 
| 159 | 
            +
                    # with open("test_opt.txt","w")as f:f.write("\n".join([str(i)for i in f0.tolist()]))
         | 
| 160 | 
            +
                    f0bak = f0.copy()
         | 
| 161 | 
            +
                    f0_mel = 1127 * np.log(1 + f0 / 700)
         | 
| 162 | 
            +
                    f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * 254 / (
         | 
| 163 | 
            +
                        f0_mel_max - f0_mel_min
         | 
| 164 | 
            +
                    ) + 1
         | 
| 165 | 
            +
                    f0_mel[f0_mel <= 1] = 1
         | 
| 166 | 
            +
                    f0_mel[f0_mel > 255] = 255
         | 
| 167 | 
            +
                    try:
         | 
| 168 | 
            +
                        f0_coarse = np.rint(f0_mel).astype(np.int)
         | 
| 169 | 
            +
                    except: # noqa
         | 
| 170 | 
            +
                        f0_coarse = np.rint(f0_mel).astype(int)
         | 
| 171 | 
            +
                    return f0_coarse, f0bak  # 1-0
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                def vc(
         | 
| 174 | 
            +
                    self,
         | 
| 175 | 
            +
                    model,
         | 
| 176 | 
            +
                    net_g,
         | 
| 177 | 
            +
                    sid,
         | 
| 178 | 
            +
                    audio0,
         | 
| 179 | 
            +
                    pitch,
         | 
| 180 | 
            +
                    pitchf,
         | 
| 181 | 
            +
                    times,
         | 
| 182 | 
            +
                    index,
         | 
| 183 | 
            +
                    big_npy,
         | 
| 184 | 
            +
                    index_rate,
         | 
| 185 | 
            +
                    version,
         | 
| 186 | 
            +
                    protect,
         | 
| 187 | 
            +
                ):  # ,file_index,file_big_npy
         | 
| 188 | 
            +
                    feats = torch.from_numpy(audio0)
         | 
| 189 | 
            +
                    if self.is_half:
         | 
| 190 | 
            +
                        feats = feats.half()
         | 
| 191 | 
            +
                    else:
         | 
| 192 | 
            +
                        feats = feats.float()
         | 
| 193 | 
            +
                    if feats.dim() == 2:  # double channels
         | 
| 194 | 
            +
                        feats = feats.mean(-1)
         | 
| 195 | 
            +
                    assert feats.dim() == 1, feats.dim()
         | 
| 196 | 
            +
                    feats = feats.view(1, -1)
         | 
| 197 | 
            +
                    padding_mask = torch.BoolTensor(feats.shape).to(self.device).fill_(False)
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    inputs = {
         | 
| 200 | 
            +
                        "source": feats.to(self.device),
         | 
| 201 | 
            +
                        "padding_mask": padding_mask,
         | 
| 202 | 
            +
                        "output_layer": 9 if version == "v1" else 12,
         | 
| 203 | 
            +
                    }
         | 
| 204 | 
            +
                    t0 = ttime()
         | 
| 205 | 
            +
                    with torch.no_grad():
         | 
| 206 | 
            +
                        logits = model.extract_features(**inputs)
         | 
| 207 | 
            +
                        feats = model.final_proj(logits[0]) if version == "v1" else logits[0]
         | 
| 208 | 
            +
                    if protect < 0.5 and pitch != None and pitchf != None:
         | 
| 209 | 
            +
                        feats0 = feats.clone()
         | 
| 210 | 
            +
                    if (
         | 
| 211 | 
            +
                        isinstance(index, type(None)) == False
         | 
| 212 | 
            +
                        and isinstance(big_npy, type(None)) == False
         | 
| 213 | 
            +
                        and index_rate != 0
         | 
| 214 | 
            +
                    ):
         | 
| 215 | 
            +
                        npy = feats[0].cpu().numpy()
         | 
| 216 | 
            +
                        if self.is_half:
         | 
| 217 | 
            +
                            npy = npy.astype("float32")
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                        # _, I = index.search(npy, 1)
         | 
| 220 | 
            +
                        # npy = big_npy[I.squeeze()]
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                        score, ix = index.search(npy, k=8)
         | 
| 223 | 
            +
                        weight = np.square(1 / score)
         | 
| 224 | 
            +
                        weight /= weight.sum(axis=1, keepdims=True)
         | 
| 225 | 
            +
                        npy = np.sum(big_npy[ix] * np.expand_dims(weight, axis=2), axis=1)
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                        if self.is_half:
         | 
| 228 | 
            +
                            npy = npy.astype("float16")
         | 
| 229 | 
            +
                        feats = (
         | 
| 230 | 
            +
                            torch.from_numpy(npy).unsqueeze(0).to(self.device) * index_rate
         | 
| 231 | 
            +
                            + (1 - index_rate) * feats
         | 
| 232 | 
            +
                        )
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                    feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
         | 
| 235 | 
            +
                    if protect < 0.5 and pitch != None and pitchf != None:
         | 
| 236 | 
            +
                        feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(
         | 
| 237 | 
            +
                            0, 2, 1
         | 
| 238 | 
            +
                        )
         | 
| 239 | 
            +
                    t1 = ttime()
         | 
| 240 | 
            +
                    p_len = audio0.shape[0] // self.window
         | 
| 241 | 
            +
                    if feats.shape[1] < p_len:
         | 
| 242 | 
            +
                        p_len = feats.shape[1]
         | 
| 243 | 
            +
                        if pitch != None and pitchf != None:
         | 
| 244 | 
            +
                            pitch = pitch[:, :p_len]
         | 
| 245 | 
            +
                            pitchf = pitchf[:, :p_len]
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                    if protect < 0.5 and pitch != None and pitchf != None:
         | 
| 248 | 
            +
                        pitchff = pitchf.clone()
         | 
| 249 | 
            +
                        pitchff[pitchf > 0] = 1
         | 
| 250 | 
            +
                        pitchff[pitchf < 1] = protect
         | 
| 251 | 
            +
                        pitchff = pitchff.unsqueeze(-1)
         | 
| 252 | 
            +
                        feats = feats * pitchff + feats0 * (1 - pitchff)
         | 
| 253 | 
            +
                        feats = feats.to(feats0.dtype)
         | 
| 254 | 
            +
                    p_len = torch.tensor([p_len], device=self.device).long()
         | 
| 255 | 
            +
                    with torch.no_grad():
         | 
| 256 | 
            +
                        if pitch != None and pitchf != None:
         | 
| 257 | 
            +
                            audio1 = (
         | 
| 258 | 
            +
                                (net_g.infer(feats, p_len, pitch, pitchf, sid)[0][0, 0])
         | 
| 259 | 
            +
                                .data.cpu()
         | 
| 260 | 
            +
                                .float()
         | 
| 261 | 
            +
                                .numpy()
         | 
| 262 | 
            +
                            )
         | 
| 263 | 
            +
                        else:
         | 
| 264 | 
            +
                            audio1 = (
         | 
| 265 | 
            +
                                (net_g.infer(feats, p_len, sid)[0][0, 0]).data.cpu().float().numpy()
         | 
| 266 | 
            +
                            )
         | 
| 267 | 
            +
                    del feats, p_len, padding_mask
         | 
| 268 | 
            +
                    if torch.cuda.is_available():
         | 
| 269 | 
            +
                        torch.cuda.empty_cache()
         | 
| 270 | 
            +
                    t2 = ttime()
         | 
| 271 | 
            +
                    times[0] += t1 - t0
         | 
| 272 | 
            +
                    times[2] += t2 - t1
         | 
| 273 | 
            +
                    return audio1
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                def pipeline(
         | 
| 276 | 
            +
                    self,
         | 
| 277 | 
            +
                    model,
         | 
| 278 | 
            +
                    net_g,
         | 
| 279 | 
            +
                    sid,
         | 
| 280 | 
            +
                    audio,
         | 
| 281 | 
            +
                    input_audio_path,
         | 
| 282 | 
            +
                    times,
         | 
| 283 | 
            +
                    f0_up_key,
         | 
| 284 | 
            +
                    f0_method,
         | 
| 285 | 
            +
                    file_index,
         | 
| 286 | 
            +
                    # file_big_npy,
         | 
| 287 | 
            +
                    index_rate,
         | 
| 288 | 
            +
                    if_f0,
         | 
| 289 | 
            +
                    filter_radius,
         | 
| 290 | 
            +
                    tgt_sr,
         | 
| 291 | 
            +
                    resample_sr,
         | 
| 292 | 
            +
                    rms_mix_rate,
         | 
| 293 | 
            +
                    version,
         | 
| 294 | 
            +
                    protect,
         | 
| 295 | 
            +
                    f0_file=None,
         | 
| 296 | 
            +
                ):
         | 
| 297 | 
            +
                    if (
         | 
| 298 | 
            +
                        file_index != ""
         | 
| 299 | 
            +
                        # and file_big_npy != ""
         | 
| 300 | 
            +
                        # and os.path.exists(file_big_npy) == True
         | 
| 301 | 
            +
                        and os.path.exists(file_index) == True
         | 
| 302 | 
            +
                        and index_rate != 0
         | 
| 303 | 
            +
                    ):
         | 
| 304 | 
            +
                        try:
         | 
| 305 | 
            +
                            index = faiss.read_index(file_index)
         | 
| 306 | 
            +
                            # big_npy = np.load(file_big_npy)
         | 
| 307 | 
            +
                            big_npy = index.reconstruct_n(0, index.ntotal)
         | 
| 308 | 
            +
                        except:
         | 
| 309 | 
            +
                            traceback.print_exc()
         | 
| 310 | 
            +
                            index = big_npy = None
         | 
| 311 | 
            +
                    else:
         | 
| 312 | 
            +
                        index = big_npy = None
         | 
| 313 | 
            +
                        logger.warning("File index Not found, set None")
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                    audio = signal.filtfilt(bh, ah, audio)
         | 
| 316 | 
            +
                    audio_pad = np.pad(audio, (self.window // 2, self.window // 2), mode="reflect")
         | 
| 317 | 
            +
                    opt_ts = []
         | 
| 318 | 
            +
                    if audio_pad.shape[0] > self.t_max:
         | 
| 319 | 
            +
                        audio_sum = np.zeros_like(audio)
         | 
| 320 | 
            +
                        for i in range(self.window):
         | 
| 321 | 
            +
                            audio_sum += audio_pad[i : i - self.window]
         | 
| 322 | 
            +
                        for t in range(self.t_center, audio.shape[0], self.t_center):
         | 
| 323 | 
            +
                            opt_ts.append(
         | 
| 324 | 
            +
                                t
         | 
| 325 | 
            +
                                - self.t_query
         | 
| 326 | 
            +
                                + np.where(
         | 
| 327 | 
            +
                                    np.abs(audio_sum[t - self.t_query : t + self.t_query])
         | 
| 328 | 
            +
                                    == np.abs(audio_sum[t - self.t_query : t + self.t_query]).min()
         | 
| 329 | 
            +
                                )[0][0]
         | 
| 330 | 
            +
                            )
         | 
| 331 | 
            +
                    s = 0
         | 
| 332 | 
            +
                    audio_opt = []
         | 
| 333 | 
            +
                    t = None
         | 
| 334 | 
            +
                    t1 = ttime()
         | 
| 335 | 
            +
                    audio_pad = np.pad(audio, (self.t_pad, self.t_pad), mode="reflect")
         | 
| 336 | 
            +
                    p_len = audio_pad.shape[0] // self.window
         | 
| 337 | 
            +
                    inp_f0 = None
         | 
| 338 | 
            +
                    if hasattr(f0_file, "name") == True:
         | 
| 339 | 
            +
                        try:
         | 
| 340 | 
            +
                            with open(f0_file.name, "r") as f:
         | 
| 341 | 
            +
                                lines = f.read().strip("\n").split("\n")
         | 
| 342 | 
            +
                            inp_f0 = []
         | 
| 343 | 
            +
                            for line in lines:
         | 
| 344 | 
            +
                                inp_f0.append([float(i) for i in line.split(",")])
         | 
| 345 | 
            +
                            inp_f0 = np.array(inp_f0, dtype="float32")
         | 
| 346 | 
            +
                        except:
         | 
| 347 | 
            +
                            traceback.print_exc()
         | 
| 348 | 
            +
                    sid = torch.tensor(sid, device=self.device).unsqueeze(0).long()
         | 
| 349 | 
            +
                    pitch, pitchf = None, None
         | 
| 350 | 
            +
                    if if_f0 == 1:
         | 
| 351 | 
            +
                        pitch, pitchf = self.get_f0(
         | 
| 352 | 
            +
                            input_audio_path,
         | 
| 353 | 
            +
                            audio_pad,
         | 
| 354 | 
            +
                            p_len,
         | 
| 355 | 
            +
                            f0_up_key,
         | 
| 356 | 
            +
                            f0_method,
         | 
| 357 | 
            +
                            filter_radius,
         | 
| 358 | 
            +
                            inp_f0,
         | 
| 359 | 
            +
                        )
         | 
| 360 | 
            +
                        pitch = pitch[:p_len]
         | 
| 361 | 
            +
                        pitchf = pitchf[:p_len]
         | 
| 362 | 
            +
                        if self.device == "mps":
         | 
| 363 | 
            +
                            pitchf = pitchf.astype(np.float32)
         | 
| 364 | 
            +
                        pitch = torch.tensor(pitch, device=self.device).unsqueeze(0).long()
         | 
| 365 | 
            +
                        pitchf = torch.tensor(pitchf, device=self.device).unsqueeze(0).float()
         | 
| 366 | 
            +
                    t2 = ttime()
         | 
| 367 | 
            +
                    times[1] += t2 - t1
         | 
| 368 | 
            +
                    for t in opt_ts:
         | 
| 369 | 
            +
                        t = t // self.window * self.window
         | 
| 370 | 
            +
                        if if_f0 == 1:
         | 
| 371 | 
            +
                            audio_opt.append(
         | 
| 372 | 
            +
                                self.vc(
         | 
| 373 | 
            +
                                    model,
         | 
| 374 | 
            +
                                    net_g,
         | 
| 375 | 
            +
                                    sid,
         | 
| 376 | 
            +
                                    audio_pad[s : t + self.t_pad2 + self.window],
         | 
| 377 | 
            +
                                    pitch[:, s // self.window : (t + self.t_pad2) // self.window],
         | 
| 378 | 
            +
                                    pitchf[:, s // self.window : (t + self.t_pad2) // self.window],
         | 
| 379 | 
            +
                                    times,
         | 
| 380 | 
            +
                                    index,
         | 
| 381 | 
            +
                                    big_npy,
         | 
| 382 | 
            +
                                    index_rate,
         | 
| 383 | 
            +
                                    version,
         | 
| 384 | 
            +
                                    protect,
         | 
| 385 | 
            +
                                )[self.t_pad_tgt : -self.t_pad_tgt]
         | 
| 386 | 
            +
                            )
         | 
| 387 | 
            +
                        else:
         | 
| 388 | 
            +
                            audio_opt.append(
         | 
| 389 | 
            +
                                self.vc(
         | 
| 390 | 
            +
                                    model,
         | 
| 391 | 
            +
                                    net_g,
         | 
| 392 | 
            +
                                    sid,
         | 
| 393 | 
            +
                                    audio_pad[s : t + self.t_pad2 + self.window],
         | 
| 394 | 
            +
                                    None,
         | 
| 395 | 
            +
                                    None,
         | 
| 396 | 
            +
                                    times,
         | 
| 397 | 
            +
                                    index,
         | 
| 398 | 
            +
                                    big_npy,
         | 
| 399 | 
            +
                                    index_rate,
         | 
| 400 | 
            +
                                    version,
         | 
| 401 | 
            +
                                    protect,
         | 
| 402 | 
            +
                                )[self.t_pad_tgt : -self.t_pad_tgt]
         | 
| 403 | 
            +
                            )
         | 
| 404 | 
            +
                        s = t
         | 
| 405 | 
            +
                    if if_f0 == 1:
         | 
| 406 | 
            +
                        audio_opt.append(
         | 
| 407 | 
            +
                            self.vc(
         | 
| 408 | 
            +
                                model,
         | 
| 409 | 
            +
                                net_g,
         | 
| 410 | 
            +
                                sid,
         | 
| 411 | 
            +
                                audio_pad[t:],
         | 
| 412 | 
            +
                                pitch[:, t // self.window :] if t is not None else pitch,
         | 
| 413 | 
            +
                                pitchf[:, t // self.window :] if t is not None else pitchf,
         | 
| 414 | 
            +
                                times,
         | 
| 415 | 
            +
                                index,
         | 
| 416 | 
            +
                                big_npy,
         | 
| 417 | 
            +
                                index_rate,
         | 
| 418 | 
            +
                                version,
         | 
| 419 | 
            +
                                protect,
         | 
| 420 | 
            +
                            )[self.t_pad_tgt : -self.t_pad_tgt]
         | 
| 421 | 
            +
                        )
         | 
| 422 | 
            +
                    else:
         | 
| 423 | 
            +
                        audio_opt.append(
         | 
| 424 | 
            +
                            self.vc(
         | 
| 425 | 
            +
                                model,
         | 
| 426 | 
            +
                                net_g,
         | 
| 427 | 
            +
                                sid,
         | 
| 428 | 
            +
                                audio_pad[t:],
         | 
| 429 | 
            +
                                None,
         | 
| 430 | 
            +
                                None,
         | 
| 431 | 
            +
                                times,
         | 
| 432 | 
            +
                                index,
         | 
| 433 | 
            +
                                big_npy,
         | 
| 434 | 
            +
                                index_rate,
         | 
| 435 | 
            +
                                version,
         | 
| 436 | 
            +
                                protect,
         | 
| 437 | 
            +
                            )[self.t_pad_tgt : -self.t_pad_tgt]
         | 
| 438 | 
            +
                        )
         | 
| 439 | 
            +
                    audio_opt = np.concatenate(audio_opt)
         | 
| 440 | 
            +
                    if rms_mix_rate != 1:
         | 
| 441 | 
            +
                        audio_opt = change_rms(audio, 16000, audio_opt, tgt_sr, rms_mix_rate)
         | 
| 442 | 
            +
                    if resample_sr >= 16000 and tgt_sr != resample_sr:
         | 
| 443 | 
            +
                        audio_opt = librosa.resample(
         | 
| 444 | 
            +
                            audio_opt, orig_sr=tgt_sr, target_sr=resample_sr
         | 
| 445 | 
            +
                        )
         | 
| 446 | 
            +
                    audio_max = np.abs(audio_opt).max() / 0.99
         | 
| 447 | 
            +
                    max_int16 = 32768
         | 
| 448 | 
            +
                    if audio_max > 1:
         | 
| 449 | 
            +
                        max_int16 /= audio_max
         | 
| 450 | 
            +
                    audio_opt = (audio_opt * max_int16).astype(np.int16)
         | 
| 451 | 
            +
                    del pitch, pitchf, sid
         | 
| 452 | 
            +
                    if torch.cuda.is_available():
         | 
| 453 | 
            +
                        torch.cuda.empty_cache()
         | 
| 454 | 
            +
                    return audio_opt
         | 
    	
        voice_main.py
    ADDED
    
    | @@ -0,0 +1,732 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from soni_translate.logging_setup import logger
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import gc
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            import os
         | 
| 6 | 
            +
            import shutil
         | 
| 7 | 
            +
            import warnings
         | 
| 8 | 
            +
            import threading
         | 
| 9 | 
            +
            from tqdm import tqdm
         | 
| 10 | 
            +
            from lib.infer_pack.models import (
         | 
| 11 | 
            +
                SynthesizerTrnMs256NSFsid,
         | 
| 12 | 
            +
                SynthesizerTrnMs256NSFsid_nono,
         | 
| 13 | 
            +
                SynthesizerTrnMs768NSFsid,
         | 
| 14 | 
            +
                SynthesizerTrnMs768NSFsid_nono,
         | 
| 15 | 
            +
            )
         | 
| 16 | 
            +
            from lib.audio import load_audio
         | 
| 17 | 
            +
            import soundfile as sf
         | 
| 18 | 
            +
            import edge_tts
         | 
| 19 | 
            +
            import asyncio
         | 
| 20 | 
            +
            from soni_translate.utils import remove_directory_contents, create_directories
         | 
| 21 | 
            +
            from scipy import signal
         | 
| 22 | 
            +
            from time import time as ttime
         | 
| 23 | 
            +
            import faiss
         | 
| 24 | 
            +
            from vci_pipeline import VC, change_rms, bh, ah
         | 
| 25 | 
            +
            import librosa
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            warnings.filterwarnings("ignore")
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            class Config:
         | 
| 31 | 
            +
                def __init__(self, only_cpu=False):
         | 
| 32 | 
            +
                    self.device = "cuda:0"
         | 
| 33 | 
            +
                    self.is_half = True
         | 
| 34 | 
            +
                    self.n_cpu = 0
         | 
| 35 | 
            +
                    self.gpu_name = None
         | 
| 36 | 
            +
                    self.gpu_mem = None
         | 
| 37 | 
            +
                    (
         | 
| 38 | 
            +
                        self.x_pad,
         | 
| 39 | 
            +
                        self.x_query,
         | 
| 40 | 
            +
                        self.x_center,
         | 
| 41 | 
            +
                        self.x_max
         | 
| 42 | 
            +
                    ) = self.device_config(only_cpu)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                def device_config(self, only_cpu) -> tuple:
         | 
| 45 | 
            +
                    if torch.cuda.is_available() and not only_cpu:
         | 
| 46 | 
            +
                        i_device = int(self.device.split(":")[-1])
         | 
| 47 | 
            +
                        self.gpu_name = torch.cuda.get_device_name(i_device)
         | 
| 48 | 
            +
                        if (
         | 
| 49 | 
            +
                            ("16" in self.gpu_name and "V100" not in self.gpu_name.upper())
         | 
| 50 | 
            +
                            or "P40" in self.gpu_name.upper()
         | 
| 51 | 
            +
                            or "1060" in self.gpu_name
         | 
| 52 | 
            +
                            or "1070" in self.gpu_name
         | 
| 53 | 
            +
                            or "1080" in self.gpu_name
         | 
| 54 | 
            +
                        ):
         | 
| 55 | 
            +
                            logger.info(
         | 
| 56 | 
            +
                                "16/10 Series GPUs and P40 excel "
         | 
| 57 | 
            +
                                "in single-precision tasks."
         | 
| 58 | 
            +
                            )
         | 
| 59 | 
            +
                            self.is_half = False
         | 
| 60 | 
            +
                        else:
         | 
| 61 | 
            +
                            self.gpu_name = None
         | 
| 62 | 
            +
                        self.gpu_mem = int(
         | 
| 63 | 
            +
                            torch.cuda.get_device_properties(i_device).total_memory
         | 
| 64 | 
            +
                            / 1024
         | 
| 65 | 
            +
                            / 1024
         | 
| 66 | 
            +
                            / 1024
         | 
| 67 | 
            +
                            + 0.4
         | 
| 68 | 
            +
                        )
         | 
| 69 | 
            +
                    elif torch.backends.mps.is_available() and not only_cpu:
         | 
| 70 | 
            +
                        logger.info("Supported N-card not found, using MPS for inference")
         | 
| 71 | 
            +
                        self.device = "mps"
         | 
| 72 | 
            +
                    else:
         | 
| 73 | 
            +
                        logger.info("No supported N-card found, using CPU for inference")
         | 
| 74 | 
            +
                        self.device = "cpu"
         | 
| 75 | 
            +
                        self.is_half = False
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    if self.n_cpu == 0:
         | 
| 78 | 
            +
                        self.n_cpu = os.cpu_count()
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    if self.is_half:
         | 
| 81 | 
            +
                        # 6GB VRAM configuration
         | 
| 82 | 
            +
                        x_pad = 3
         | 
| 83 | 
            +
                        x_query = 10
         | 
| 84 | 
            +
                        x_center = 60
         | 
| 85 | 
            +
                        x_max = 65
         | 
| 86 | 
            +
                    else:
         | 
| 87 | 
            +
                        # 5GB VRAM configuration
         | 
| 88 | 
            +
                        x_pad = 1
         | 
| 89 | 
            +
                        x_query = 6
         | 
| 90 | 
            +
                        x_center = 38
         | 
| 91 | 
            +
                        x_max = 41
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                    if self.gpu_mem is not None and self.gpu_mem <= 4:
         | 
| 94 | 
            +
                        x_pad = 1
         | 
| 95 | 
            +
                        x_query = 5
         | 
| 96 | 
            +
                        x_center = 30
         | 
| 97 | 
            +
                        x_max = 32
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    logger.info(
         | 
| 100 | 
            +
                        f"Config: Device is {self.device}, "
         | 
| 101 | 
            +
                        f"half precision is {self.is_half}"
         | 
| 102 | 
            +
                    )
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    return x_pad, x_query, x_center, x_max
         | 
| 105 | 
            +
             | 
| 106 | 
            +
             | 
| 107 | 
            +
            BASE_DOWNLOAD_LINK = "https://huggingface.co/r3gm/sonitranslate_voice_models/resolve/main/"
         | 
| 108 | 
            +
            BASE_MODELS = [
         | 
| 109 | 
            +
                "hubert_base.pt",
         | 
| 110 | 
            +
                "rmvpe.pt"
         | 
| 111 | 
            +
            ]
         | 
| 112 | 
            +
            BASE_DIR = "."
         | 
| 113 | 
            +
             | 
| 114 | 
            +
             | 
| 115 | 
            +
            def load_hu_bert(config):
         | 
| 116 | 
            +
                from fairseq import checkpoint_utils
         | 
| 117 | 
            +
                from soni_translate.utils import download_manager
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                for id_model in BASE_MODELS:
         | 
| 120 | 
            +
                    download_manager(
         | 
| 121 | 
            +
                        os.path.join(BASE_DOWNLOAD_LINK, id_model), BASE_DIR
         | 
| 122 | 
            +
                    )
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
         | 
| 125 | 
            +
                    ["hubert_base.pt"],
         | 
| 126 | 
            +
                    suffix="",
         | 
| 127 | 
            +
                )
         | 
| 128 | 
            +
                hubert_model = models[0]
         | 
| 129 | 
            +
                hubert_model = hubert_model.to(config.device)
         | 
| 130 | 
            +
                if config.is_half:
         | 
| 131 | 
            +
                    hubert_model = hubert_model.half()
         | 
| 132 | 
            +
                else:
         | 
| 133 | 
            +
                    hubert_model = hubert_model.float()
         | 
| 134 | 
            +
                hubert_model.eval()
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                return hubert_model
         | 
| 137 | 
            +
             | 
| 138 | 
            +
             | 
| 139 | 
            +
            def load_trained_model(model_path, config):
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                if not model_path:
         | 
| 142 | 
            +
                    raise ValueError("No model found")
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                logger.info("Loading %s" % model_path)
         | 
| 145 | 
            +
                cpt = torch.load(model_path, map_location="cpu")
         | 
| 146 | 
            +
                tgt_sr = cpt["config"][-1]
         | 
| 147 | 
            +
                cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]  # n_spk
         | 
| 148 | 
            +
                if_f0 = cpt.get("f0", 1)
         | 
| 149 | 
            +
                if if_f0 == 0:
         | 
| 150 | 
            +
                    # protect to 0.5 need?
         | 
| 151 | 
            +
                    pass
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                version = cpt.get("version", "v1")
         | 
| 154 | 
            +
                if version == "v1":
         | 
| 155 | 
            +
                    if if_f0 == 1:
         | 
| 156 | 
            +
                        net_g = SynthesizerTrnMs256NSFsid(
         | 
| 157 | 
            +
                            *cpt["config"], is_half=config.is_half
         | 
| 158 | 
            +
                        )
         | 
| 159 | 
            +
                    else:
         | 
| 160 | 
            +
                        net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
         | 
| 161 | 
            +
                elif version == "v2":
         | 
| 162 | 
            +
                    if if_f0 == 1:
         | 
| 163 | 
            +
                        net_g = SynthesizerTrnMs768NSFsid(
         | 
| 164 | 
            +
                            *cpt["config"], is_half=config.is_half
         | 
| 165 | 
            +
                        )
         | 
| 166 | 
            +
                    else:
         | 
| 167 | 
            +
                        net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
         | 
| 168 | 
            +
                del net_g.enc_q
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                net_g.load_state_dict(cpt["weight"], strict=False)
         | 
| 171 | 
            +
                net_g.eval().to(config.device)
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                if config.is_half:
         | 
| 174 | 
            +
                    net_g = net_g.half()
         | 
| 175 | 
            +
                else:
         | 
| 176 | 
            +
                    net_g = net_g.float()
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                vc = VC(tgt_sr, config)
         | 
| 179 | 
            +
                n_spk = cpt["config"][-3]
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                return n_spk, tgt_sr, net_g, vc, cpt, version
         | 
| 182 | 
            +
             | 
| 183 | 
            +
             | 
| 184 | 
            +
            class ClassVoices:
         | 
| 185 | 
            +
                def __init__(self, only_cpu=False):
         | 
| 186 | 
            +
                    self.model_config = {}
         | 
| 187 | 
            +
                    self.config = None
         | 
| 188 | 
            +
                    self.only_cpu = only_cpu
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                def apply_conf(
         | 
| 191 | 
            +
                    self,
         | 
| 192 | 
            +
                    tag="base_model",
         | 
| 193 | 
            +
                    file_model="",
         | 
| 194 | 
            +
                    pitch_algo="pm",
         | 
| 195 | 
            +
                    pitch_lvl=0,
         | 
| 196 | 
            +
                    file_index="",
         | 
| 197 | 
            +
                    index_influence=0.66,
         | 
| 198 | 
            +
                    respiration_median_filtering=3,
         | 
| 199 | 
            +
                    envelope_ratio=0.25,
         | 
| 200 | 
            +
                    consonant_breath_protection=0.33,
         | 
| 201 | 
            +
                    resample_sr=0,
         | 
| 202 | 
            +
                    file_pitch_algo="",
         | 
| 203 | 
            +
                ):
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                    if not file_model:
         | 
| 206 | 
            +
                        raise ValueError("Model not found")
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                    if file_index is None:
         | 
| 209 | 
            +
                        file_index = ""
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                    if file_pitch_algo is None:
         | 
| 212 | 
            +
                        file_pitch_algo = ""
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                    if not self.config:
         | 
| 215 | 
            +
                        self.config = Config(self.only_cpu)
         | 
| 216 | 
            +
                        self.hu_bert_model = None
         | 
| 217 | 
            +
                        self.model_pitch_estimator = None
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                    self.model_config[tag] = {
         | 
| 220 | 
            +
                        "file_model": file_model,
         | 
| 221 | 
            +
                        "pitch_algo": pitch_algo,
         | 
| 222 | 
            +
                        "pitch_lvl": pitch_lvl,  # no decimal
         | 
| 223 | 
            +
                        "file_index": file_index,
         | 
| 224 | 
            +
                        "index_influence": index_influence,
         | 
| 225 | 
            +
                        "respiration_median_filtering": respiration_median_filtering,
         | 
| 226 | 
            +
                        "envelope_ratio": envelope_ratio,
         | 
| 227 | 
            +
                        "consonant_breath_protection": consonant_breath_protection,
         | 
| 228 | 
            +
                        "resample_sr": resample_sr,
         | 
| 229 | 
            +
                        "file_pitch_algo": file_pitch_algo,
         | 
| 230 | 
            +
                    }
         | 
| 231 | 
            +
                    return f"CONFIGURATION APPLIED FOR {tag}: {file_model}"
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                def infer(
         | 
| 234 | 
            +
                    self,
         | 
| 235 | 
            +
                    task_id,
         | 
| 236 | 
            +
                    params,
         | 
| 237 | 
            +
                    # load model
         | 
| 238 | 
            +
                    n_spk,
         | 
| 239 | 
            +
                    tgt_sr,
         | 
| 240 | 
            +
                    net_g,
         | 
| 241 | 
            +
                    pipe,
         | 
| 242 | 
            +
                    cpt,
         | 
| 243 | 
            +
                    version,
         | 
| 244 | 
            +
                    if_f0,
         | 
| 245 | 
            +
                    # load index
         | 
| 246 | 
            +
                    index_rate,
         | 
| 247 | 
            +
                    index,
         | 
| 248 | 
            +
                    big_npy,
         | 
| 249 | 
            +
                    # load f0 file
         | 
| 250 | 
            +
                    inp_f0,
         | 
| 251 | 
            +
                    # audio file
         | 
| 252 | 
            +
                    input_audio_path,
         | 
| 253 | 
            +
                    overwrite,
         | 
| 254 | 
            +
                ):
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                    f0_method = params["pitch_algo"]
         | 
| 257 | 
            +
                    f0_up_key = params["pitch_lvl"]
         | 
| 258 | 
            +
                    filter_radius = params["respiration_median_filtering"]
         | 
| 259 | 
            +
                    resample_sr = params["resample_sr"]
         | 
| 260 | 
            +
                    rms_mix_rate = params["envelope_ratio"]
         | 
| 261 | 
            +
                    protect = params["consonant_breath_protection"]
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                    if not os.path.exists(input_audio_path):
         | 
| 264 | 
            +
                        raise ValueError(
         | 
| 265 | 
            +
                            "The audio file was not found or is not "
         | 
| 266 | 
            +
                            f"a valid file: {input_audio_path}"
         | 
| 267 | 
            +
                        )
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                    f0_up_key = int(f0_up_key)
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                    audio = load_audio(input_audio_path, 16000)
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                    # Normalize audio
         | 
| 274 | 
            +
                    audio_max = np.abs(audio).max() / 0.95
         | 
| 275 | 
            +
                    if audio_max > 1:
         | 
| 276 | 
            +
                        audio /= audio_max
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                    times = [0, 0, 0]
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                    # filters audio signal, pads it, computes sliding window sums,
         | 
| 281 | 
            +
                    # and extracts optimized time indices
         | 
| 282 | 
            +
                    audio = signal.filtfilt(bh, ah, audio)
         | 
| 283 | 
            +
                    audio_pad = np.pad(
         | 
| 284 | 
            +
                        audio, (pipe.window // 2, pipe.window // 2), mode="reflect"
         | 
| 285 | 
            +
                    )
         | 
| 286 | 
            +
                    opt_ts = []
         | 
| 287 | 
            +
                    if audio_pad.shape[0] > pipe.t_max:
         | 
| 288 | 
            +
                        audio_sum = np.zeros_like(audio)
         | 
| 289 | 
            +
                        for i in range(pipe.window):
         | 
| 290 | 
            +
                            audio_sum += audio_pad[i:i - pipe.window]
         | 
| 291 | 
            +
                        for t in range(pipe.t_center, audio.shape[0], pipe.t_center):
         | 
| 292 | 
            +
                            opt_ts.append(
         | 
| 293 | 
            +
                                t
         | 
| 294 | 
            +
                                - pipe.t_query
         | 
| 295 | 
            +
                                + np.where(
         | 
| 296 | 
            +
                                    np.abs(audio_sum[t - pipe.t_query: t + pipe.t_query])
         | 
| 297 | 
            +
                                    == np.abs(audio_sum[t - pipe.t_query: t + pipe.t_query]).min()
         | 
| 298 | 
            +
                                )[0][0]
         | 
| 299 | 
            +
                            )
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                    s = 0
         | 
| 302 | 
            +
                    audio_opt = []
         | 
| 303 | 
            +
                    t = None
         | 
| 304 | 
            +
                    t1 = ttime()
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                    sid_value = 0
         | 
| 307 | 
            +
                    sid = torch.tensor(sid_value, device=pipe.device).unsqueeze(0).long()
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                    # Pads audio symmetrically, calculates length divided by window size.
         | 
| 310 | 
            +
                    audio_pad = np.pad(audio, (pipe.t_pad, pipe.t_pad), mode="reflect")
         | 
| 311 | 
            +
                    p_len = audio_pad.shape[0] // pipe.window
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                    # Estimates pitch from audio signal
         | 
| 314 | 
            +
                    pitch, pitchf = None, None
         | 
| 315 | 
            +
                    if if_f0 == 1:
         | 
| 316 | 
            +
                        pitch, pitchf = pipe.get_f0(
         | 
| 317 | 
            +
                            input_audio_path,
         | 
| 318 | 
            +
                            audio_pad,
         | 
| 319 | 
            +
                            p_len,
         | 
| 320 | 
            +
                            f0_up_key,
         | 
| 321 | 
            +
                            f0_method,
         | 
| 322 | 
            +
                            filter_radius,
         | 
| 323 | 
            +
                            inp_f0,
         | 
| 324 | 
            +
                        )
         | 
| 325 | 
            +
                        pitch = pitch[:p_len]
         | 
| 326 | 
            +
                        pitchf = pitchf[:p_len]
         | 
| 327 | 
            +
                        if pipe.device == "mps":
         | 
| 328 | 
            +
                            pitchf = pitchf.astype(np.float32)
         | 
| 329 | 
            +
                        pitch = torch.tensor(
         | 
| 330 | 
            +
                            pitch, device=pipe.device
         | 
| 331 | 
            +
                        ).unsqueeze(0).long()
         | 
| 332 | 
            +
                        pitchf = torch.tensor(
         | 
| 333 | 
            +
                            pitchf, device=pipe.device
         | 
| 334 | 
            +
                        ).unsqueeze(0).float()
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                    t2 = ttime()
         | 
| 337 | 
            +
                    times[1] += t2 - t1
         | 
| 338 | 
            +
                    for t in opt_ts:
         | 
| 339 | 
            +
                        t = t // pipe.window * pipe.window
         | 
| 340 | 
            +
                        if if_f0 == 1:
         | 
| 341 | 
            +
                            pitch_slice = pitch[
         | 
| 342 | 
            +
                                :, s // pipe.window: (t + pipe.t_pad2) // pipe.window
         | 
| 343 | 
            +
                            ]
         | 
| 344 | 
            +
                            pitchf_slice = pitchf[
         | 
| 345 | 
            +
                                :, s // pipe.window: (t + pipe.t_pad2) // pipe.window
         | 
| 346 | 
            +
                            ]
         | 
| 347 | 
            +
                        else:
         | 
| 348 | 
            +
                            pitch_slice = None
         | 
| 349 | 
            +
                            pitchf_slice = None
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                        audio_slice = audio_pad[s:t + pipe.t_pad2 + pipe.window]
         | 
| 352 | 
            +
                        audio_opt.append(
         | 
| 353 | 
            +
                            pipe.vc(
         | 
| 354 | 
            +
                                self.hu_bert_model,
         | 
| 355 | 
            +
                                net_g,
         | 
| 356 | 
            +
                                sid,
         | 
| 357 | 
            +
                                audio_slice,
         | 
| 358 | 
            +
                                pitch_slice,
         | 
| 359 | 
            +
                                pitchf_slice,
         | 
| 360 | 
            +
                                times,
         | 
| 361 | 
            +
                                index,
         | 
| 362 | 
            +
                                big_npy,
         | 
| 363 | 
            +
                                index_rate,
         | 
| 364 | 
            +
                                version,
         | 
| 365 | 
            +
                                protect,
         | 
| 366 | 
            +
                            )[pipe.t_pad_tgt:-pipe.t_pad_tgt]
         | 
| 367 | 
            +
                        )
         | 
| 368 | 
            +
                        s = t
         | 
| 369 | 
            +
             | 
| 370 | 
            +
                    pitch_end_slice = pitch[
         | 
| 371 | 
            +
                        :, t // pipe.window:
         | 
| 372 | 
            +
                    ] if t is not None else pitch
         | 
| 373 | 
            +
                    pitchf_end_slice = pitchf[
         | 
| 374 | 
            +
                        :, t // pipe.window:
         | 
| 375 | 
            +
                    ] if t is not None else pitchf
         | 
| 376 | 
            +
             | 
| 377 | 
            +
                    audio_opt.append(
         | 
| 378 | 
            +
                        pipe.vc(
         | 
| 379 | 
            +
                            self.hu_bert_model,
         | 
| 380 | 
            +
                            net_g,
         | 
| 381 | 
            +
                            sid,
         | 
| 382 | 
            +
                            audio_pad[t:],
         | 
| 383 | 
            +
                            pitch_end_slice,
         | 
| 384 | 
            +
                            pitchf_end_slice,
         | 
| 385 | 
            +
                            times,
         | 
| 386 | 
            +
                            index,
         | 
| 387 | 
            +
                            big_npy,
         | 
| 388 | 
            +
                            index_rate,
         | 
| 389 | 
            +
                            version,
         | 
| 390 | 
            +
                            protect,
         | 
| 391 | 
            +
                        )[pipe.t_pad_tgt:-pipe.t_pad_tgt]
         | 
| 392 | 
            +
                    )
         | 
| 393 | 
            +
             | 
| 394 | 
            +
                    audio_opt = np.concatenate(audio_opt)
         | 
| 395 | 
            +
                    if rms_mix_rate != 1:
         | 
| 396 | 
            +
                        audio_opt = change_rms(
         | 
| 397 | 
            +
                            audio, 16000, audio_opt, tgt_sr, rms_mix_rate
         | 
| 398 | 
            +
                        )
         | 
| 399 | 
            +
                    if resample_sr >= 16000 and tgt_sr != resample_sr:
         | 
| 400 | 
            +
                        audio_opt = librosa.resample(
         | 
| 401 | 
            +
                            audio_opt, orig_sr=tgt_sr, target_sr=resample_sr
         | 
| 402 | 
            +
                        )
         | 
| 403 | 
            +
                    audio_max = np.abs(audio_opt).max() / 0.99
         | 
| 404 | 
            +
                    max_int16 = 32768
         | 
| 405 | 
            +
                    if audio_max > 1:
         | 
| 406 | 
            +
                        max_int16 /= audio_max
         | 
| 407 | 
            +
                    audio_opt = (audio_opt * max_int16).astype(np.int16)
         | 
| 408 | 
            +
                    del pitch, pitchf, sid
         | 
| 409 | 
            +
                    if torch.cuda.is_available():
         | 
| 410 | 
            +
                        torch.cuda.empty_cache()
         | 
| 411 | 
            +
             | 
| 412 | 
            +
                    if tgt_sr != resample_sr >= 16000:
         | 
| 413 | 
            +
                        final_sr = resample_sr
         | 
| 414 | 
            +
                    else:
         | 
| 415 | 
            +
                        final_sr = tgt_sr
         | 
| 416 | 
            +
             | 
| 417 | 
            +
                    """
         | 
| 418 | 
            +
                    "Success.\n %s\nTime:\n npy:%ss, f0:%ss, infer:%ss" % (
         | 
| 419 | 
            +
                        times[0],
         | 
| 420 | 
            +
                        times[1],
         | 
| 421 | 
            +
                        times[2],
         | 
| 422 | 
            +
                    ), (final_sr, audio_opt)
         | 
| 423 | 
            +
             | 
| 424 | 
            +
                    """
         | 
| 425 | 
            +
             | 
| 426 | 
            +
                    if overwrite:
         | 
| 427 | 
            +
                        output_audio_path = input_audio_path  # Overwrite
         | 
| 428 | 
            +
                    else:
         | 
| 429 | 
            +
                        basename = os.path.basename(input_audio_path)
         | 
| 430 | 
            +
                        dirname = os.path.dirname(input_audio_path)
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                        new_basename = basename.split(
         | 
| 433 | 
            +
                            '.')[0] + "_edited." + basename.split('.')[-1]
         | 
| 434 | 
            +
                        new_path = os.path.join(dirname, new_basename)
         | 
| 435 | 
            +
                        logger.info(str(new_path))
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                        output_audio_path = new_path
         | 
| 438 | 
            +
             | 
| 439 | 
            +
                    # Save file
         | 
| 440 | 
            +
                    sf.write(
         | 
| 441 | 
            +
                        file=output_audio_path,
         | 
| 442 | 
            +
                        samplerate=final_sr,
         | 
| 443 | 
            +
                        data=audio_opt
         | 
| 444 | 
            +
                    )
         | 
| 445 | 
            +
             | 
| 446 | 
            +
                    self.model_config[task_id]["result"].append(output_audio_path)
         | 
| 447 | 
            +
                    self.output_list.append(output_audio_path)
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                def make_test(
         | 
| 450 | 
            +
                    self,
         | 
| 451 | 
            +
                    tts_text,
         | 
| 452 | 
            +
                    tts_voice,
         | 
| 453 | 
            +
                    model_path,
         | 
| 454 | 
            +
                    index_path,
         | 
| 455 | 
            +
                    transpose,
         | 
| 456 | 
            +
                    f0_method,
         | 
| 457 | 
            +
                ):
         | 
| 458 | 
            +
             | 
| 459 | 
            +
                    folder_test = "test"
         | 
| 460 | 
            +
                    tag = "test_edge"
         | 
| 461 | 
            +
                    tts_file = "test/test.wav"
         | 
| 462 | 
            +
                    tts_edited = "test/test_edited.wav"
         | 
| 463 | 
            +
             | 
| 464 | 
            +
                    create_directories(folder_test)
         | 
| 465 | 
            +
                    remove_directory_contents(folder_test)
         | 
| 466 | 
            +
             | 
| 467 | 
            +
                    if "SET_LIMIT" == os.getenv("DEMO"):
         | 
| 468 | 
            +
                        if len(tts_text) > 60:
         | 
| 469 | 
            +
                            tts_text = tts_text[:60]
         | 
| 470 | 
            +
                            logger.warning("DEMO; limit to 60 characters")
         | 
| 471 | 
            +
             | 
| 472 | 
            +
                    try:
         | 
| 473 | 
            +
                        asyncio.run(edge_tts.Communicate(
         | 
| 474 | 
            +
                            tts_text, "-".join(tts_voice.split('-')[:-1])
         | 
| 475 | 
            +
                        ).save(tts_file))
         | 
| 476 | 
            +
                    except Exception as e:
         | 
| 477 | 
            +
                        raise ValueError(
         | 
| 478 | 
            +
                            "No audio was received. Please change the "
         | 
| 479 | 
            +
                            f"tts voice for {tts_voice}. Error: {str(e)}"
         | 
| 480 | 
            +
                        )
         | 
| 481 | 
            +
             | 
| 482 | 
            +
                    shutil.copy(tts_file, tts_edited)
         | 
| 483 | 
            +
             | 
| 484 | 
            +
                    self.apply_conf(
         | 
| 485 | 
            +
                        tag=tag,
         | 
| 486 | 
            +
                        file_model=model_path,
         | 
| 487 | 
            +
                        pitch_algo=f0_method,
         | 
| 488 | 
            +
                        pitch_lvl=transpose,
         | 
| 489 | 
            +
                        file_index=index_path,
         | 
| 490 | 
            +
                        index_influence=0.66,
         | 
| 491 | 
            +
                        respiration_median_filtering=3,
         | 
| 492 | 
            +
                        envelope_ratio=0.25,
         | 
| 493 | 
            +
                        consonant_breath_protection=0.33,
         | 
| 494 | 
            +
                    )
         | 
| 495 | 
            +
             | 
| 496 | 
            +
                    self(
         | 
| 497 | 
            +
                        audio_files=tts_edited,
         | 
| 498 | 
            +
                        tag_list=tag,
         | 
| 499 | 
            +
                        overwrite=True
         | 
| 500 | 
            +
                    )
         | 
| 501 | 
            +
             | 
| 502 | 
            +
                    return tts_edited, tts_file
         | 
| 503 | 
            +
             | 
| 504 | 
            +
                def run_threads(self, threads):
         | 
| 505 | 
            +
                    # Start threads
         | 
| 506 | 
            +
                    for thread in threads:
         | 
| 507 | 
            +
                        thread.start()
         | 
| 508 | 
            +
             | 
| 509 | 
            +
                    # Wait for all threads to finish
         | 
| 510 | 
            +
                    for thread in threads:
         | 
| 511 | 
            +
                        thread.join()
         | 
| 512 | 
            +
             | 
| 513 | 
            +
                    gc.collect()
         | 
| 514 | 
            +
                    torch.cuda.empty_cache()
         | 
| 515 | 
            +
             | 
| 516 | 
            +
                def unload_models(self):
         | 
| 517 | 
            +
                    self.hu_bert_model = None
         | 
| 518 | 
            +
                    self.model_pitch_estimator = None
         | 
| 519 | 
            +
                    gc.collect()
         | 
| 520 | 
            +
                    torch.cuda.empty_cache()
         | 
| 521 | 
            +
             | 
| 522 | 
            +
                def __call__(
         | 
| 523 | 
            +
                    self,
         | 
| 524 | 
            +
                    audio_files=[],
         | 
| 525 | 
            +
                    tag_list=[],
         | 
| 526 | 
            +
                    overwrite=False,
         | 
| 527 | 
            +
                    parallel_workers=1,
         | 
| 528 | 
            +
                ):
         | 
| 529 | 
            +
                    logger.info(f"Parallel workers: {str(parallel_workers)}")
         | 
| 530 | 
            +
             | 
| 531 | 
            +
                    self.output_list = []
         | 
| 532 | 
            +
             | 
| 533 | 
            +
                    if not self.model_config:
         | 
| 534 | 
            +
                        raise ValueError("No model has been configured for inference")
         | 
| 535 | 
            +
             | 
| 536 | 
            +
                    if isinstance(audio_files, str):
         | 
| 537 | 
            +
                        audio_files = [audio_files]
         | 
| 538 | 
            +
                    if isinstance(tag_list, str):
         | 
| 539 | 
            +
                        tag_list = [tag_list]
         | 
| 540 | 
            +
             | 
| 541 | 
            +
                    if not audio_files:
         | 
| 542 | 
            +
                        raise ValueError("No audio found to convert")
         | 
| 543 | 
            +
                    if not tag_list:
         | 
| 544 | 
            +
                        tag_list = [list(self.model_config.keys())[-1]] * len(audio_files)
         | 
| 545 | 
            +
             | 
| 546 | 
            +
                    if len(audio_files) > len(tag_list):
         | 
| 547 | 
            +
                        logger.info("Extend tag list to match audio files")
         | 
| 548 | 
            +
                        extend_number = len(audio_files) - len(tag_list)
         | 
| 549 | 
            +
                        tag_list.extend([tag_list[0]] * extend_number)
         | 
| 550 | 
            +
             | 
| 551 | 
            +
                    if len(audio_files) < len(tag_list):
         | 
| 552 | 
            +
                        logger.info("Cut list tags")
         | 
| 553 | 
            +
                        tag_list = tag_list[:len(audio_files)]
         | 
| 554 | 
            +
             | 
| 555 | 
            +
                    tag_file_pairs = list(zip(tag_list, audio_files))
         | 
| 556 | 
            +
                    sorted_tag_file = sorted(tag_file_pairs, key=lambda x: x[0])
         | 
| 557 | 
            +
             | 
| 558 | 
            +
                    # Base params
         | 
| 559 | 
            +
                    if not self.hu_bert_model:
         | 
| 560 | 
            +
                        self.hu_bert_model = load_hu_bert(self.config)
         | 
| 561 | 
            +
             | 
| 562 | 
            +
                    cache_params = None
         | 
| 563 | 
            +
                    threads = []
         | 
| 564 | 
            +
                    progress_bar = tqdm(total=len(tag_list), desc="Progress")
         | 
| 565 | 
            +
                    for i, (id_tag, input_audio_path) in enumerate(sorted_tag_file):
         | 
| 566 | 
            +
             | 
| 567 | 
            +
                        if id_tag not in self.model_config.keys():
         | 
| 568 | 
            +
                            logger.info(
         | 
| 569 | 
            +
                                f"No configured model for {id_tag} with {input_audio_path}"
         | 
| 570 | 
            +
                            )
         | 
| 571 | 
            +
                            continue
         | 
| 572 | 
            +
             | 
| 573 | 
            +
                        if (
         | 
| 574 | 
            +
                            len(threads) >= parallel_workers
         | 
| 575 | 
            +
                            or cache_params != id_tag
         | 
| 576 | 
            +
                            and cache_params is not None
         | 
| 577 | 
            +
                        ):
         | 
| 578 | 
            +
             | 
| 579 | 
            +
                            self.run_threads(threads)
         | 
| 580 | 
            +
                            progress_bar.update(len(threads))
         | 
| 581 | 
            +
             | 
| 582 | 
            +
                            threads = []
         | 
| 583 | 
            +
             | 
| 584 | 
            +
                        if cache_params != id_tag:
         | 
| 585 | 
            +
             | 
| 586 | 
            +
                            self.model_config[id_tag]["result"] = []
         | 
| 587 | 
            +
             | 
| 588 | 
            +
                            # Unload previous
         | 
| 589 | 
            +
                            (
         | 
| 590 | 
            +
                                n_spk,
         | 
| 591 | 
            +
                                tgt_sr,
         | 
| 592 | 
            +
                                net_g,
         | 
| 593 | 
            +
                                pipe,
         | 
| 594 | 
            +
                                cpt,
         | 
| 595 | 
            +
                                version,
         | 
| 596 | 
            +
                                if_f0,
         | 
| 597 | 
            +
                                index_rate,
         | 
| 598 | 
            +
                                index,
         | 
| 599 | 
            +
                                big_npy,
         | 
| 600 | 
            +
                                inp_f0,
         | 
| 601 | 
            +
                            ) = [None] * 11
         | 
| 602 | 
            +
                            gc.collect()
         | 
| 603 | 
            +
                            torch.cuda.empty_cache()
         | 
| 604 | 
            +
             | 
| 605 | 
            +
                            # Model params
         | 
| 606 | 
            +
                            params = self.model_config[id_tag]
         | 
| 607 | 
            +
             | 
| 608 | 
            +
                            model_path = params["file_model"]
         | 
| 609 | 
            +
                            f0_method = params["pitch_algo"]
         | 
| 610 | 
            +
                            file_index = params["file_index"]
         | 
| 611 | 
            +
                            index_rate = params["index_influence"]
         | 
| 612 | 
            +
                            f0_file = params["file_pitch_algo"]
         | 
| 613 | 
            +
             | 
| 614 | 
            +
                            # Load model
         | 
| 615 | 
            +
                            (
         | 
| 616 | 
            +
                                n_spk,
         | 
| 617 | 
            +
                                tgt_sr,
         | 
| 618 | 
            +
                                net_g,
         | 
| 619 | 
            +
                                pipe,
         | 
| 620 | 
            +
                                cpt,
         | 
| 621 | 
            +
                                version
         | 
| 622 | 
            +
                            ) = load_trained_model(model_path, self.config)
         | 
| 623 | 
            +
                            if_f0 = cpt.get("f0", 1)  # pitch data
         | 
| 624 | 
            +
             | 
| 625 | 
            +
                            # Load index
         | 
| 626 | 
            +
                            if os.path.exists(file_index) and index_rate != 0:
         | 
| 627 | 
            +
                                try:
         | 
| 628 | 
            +
                                    index = faiss.read_index(file_index)
         | 
| 629 | 
            +
                                    big_npy = index.reconstruct_n(0, index.ntotal)
         | 
| 630 | 
            +
                                except Exception as error:
         | 
| 631 | 
            +
                                    logger.error(f"Index: {str(error)}")
         | 
| 632 | 
            +
                                    index_rate = 0
         | 
| 633 | 
            +
                                    index = big_npy = None
         | 
| 634 | 
            +
                            else:
         | 
| 635 | 
            +
                                logger.warning("File index not found")
         | 
| 636 | 
            +
                                index_rate = 0
         | 
| 637 | 
            +
                                index = big_npy = None
         | 
| 638 | 
            +
             | 
| 639 | 
            +
                            # Load f0 file
         | 
| 640 | 
            +
                            inp_f0 = None
         | 
| 641 | 
            +
                            if os.path.exists(f0_file):
         | 
| 642 | 
            +
                                try:
         | 
| 643 | 
            +
                                    with open(f0_file, "r") as f:
         | 
| 644 | 
            +
                                        lines = f.read().strip("\n").split("\n")
         | 
| 645 | 
            +
                                    inp_f0 = []
         | 
| 646 | 
            +
                                    for line in lines:
         | 
| 647 | 
            +
                                        inp_f0.append([float(i) for i in line.split(",")])
         | 
| 648 | 
            +
                                    inp_f0 = np.array(inp_f0, dtype="float32")
         | 
| 649 | 
            +
                                except Exception as error:
         | 
| 650 | 
            +
                                    logger.error(f"f0 file: {str(error)}")
         | 
| 651 | 
            +
             | 
| 652 | 
            +
                            if "rmvpe" in f0_method:
         | 
| 653 | 
            +
                                if not self.model_pitch_estimator:
         | 
| 654 | 
            +
                                    from lib.rmvpe import RMVPE
         | 
| 655 | 
            +
             | 
| 656 | 
            +
                                    logger.info("Loading vocal pitch estimator model")
         | 
| 657 | 
            +
                                    self.model_pitch_estimator = RMVPE(
         | 
| 658 | 
            +
                                        "rmvpe.pt",
         | 
| 659 | 
            +
                                        is_half=self.config.is_half,
         | 
| 660 | 
            +
                                        device=self.config.device
         | 
| 661 | 
            +
                                    )
         | 
| 662 | 
            +
             | 
| 663 | 
            +
                                pipe.model_rmvpe = self.model_pitch_estimator
         | 
| 664 | 
            +
             | 
| 665 | 
            +
                            cache_params = id_tag
         | 
| 666 | 
            +
             | 
| 667 | 
            +
                        # self.infer(
         | 
| 668 | 
            +
                        #     id_tag,
         | 
| 669 | 
            +
                        #     params,
         | 
| 670 | 
            +
                        #     # load model
         | 
| 671 | 
            +
                        #     n_spk,
         | 
| 672 | 
            +
                        #     tgt_sr,
         | 
| 673 | 
            +
                        #     net_g,
         | 
| 674 | 
            +
                        #     pipe,
         | 
| 675 | 
            +
                        #     cpt,
         | 
| 676 | 
            +
                        #     version,
         | 
| 677 | 
            +
                        #     if_f0,
         | 
| 678 | 
            +
                        #     # load index
         | 
| 679 | 
            +
                        #     index_rate,
         | 
| 680 | 
            +
                        #     index,
         | 
| 681 | 
            +
                        #     big_npy,
         | 
| 682 | 
            +
                        #     # load f0 file
         | 
| 683 | 
            +
                        #     inp_f0,
         | 
| 684 | 
            +
                        #     # output file
         | 
| 685 | 
            +
                        #     input_audio_path,
         | 
| 686 | 
            +
                        #     overwrite,
         | 
| 687 | 
            +
                        # )
         | 
| 688 | 
            +
             | 
| 689 | 
            +
                        thread = threading.Thread(
         | 
| 690 | 
            +
                            target=self.infer,
         | 
| 691 | 
            +
                            args=(
         | 
| 692 | 
            +
                                id_tag,
         | 
| 693 | 
            +
                                params,
         | 
| 694 | 
            +
                                # loaded model
         | 
| 695 | 
            +
                                n_spk,
         | 
| 696 | 
            +
                                tgt_sr,
         | 
| 697 | 
            +
                                net_g,
         | 
| 698 | 
            +
                                pipe,
         | 
| 699 | 
            +
                                cpt,
         | 
| 700 | 
            +
                                version,
         | 
| 701 | 
            +
                                if_f0,
         | 
| 702 | 
            +
                                # loaded index
         | 
| 703 | 
            +
                                index_rate,
         | 
| 704 | 
            +
                                index,
         | 
| 705 | 
            +
                                big_npy,
         | 
| 706 | 
            +
                                # loaded f0 file
         | 
| 707 | 
            +
                                inp_f0,
         | 
| 708 | 
            +
                                # audio file
         | 
| 709 | 
            +
                                input_audio_path,
         | 
| 710 | 
            +
                                overwrite,
         | 
| 711 | 
            +
                            )
         | 
| 712 | 
            +
                        )
         | 
| 713 | 
            +
             | 
| 714 | 
            +
                        threads.append(thread)
         | 
| 715 | 
            +
             | 
| 716 | 
            +
                    # Run last
         | 
| 717 | 
            +
                    if threads:
         | 
| 718 | 
            +
                        self.run_threads(threads)
         | 
| 719 | 
            +
             | 
| 720 | 
            +
                    progress_bar.update(len(threads))
         | 
| 721 | 
            +
                    progress_bar.close()
         | 
| 722 | 
            +
             | 
| 723 | 
            +
                    final_result = []
         | 
| 724 | 
            +
                    valid_tags = set(tag_list)
         | 
| 725 | 
            +
                    for tag in valid_tags:
         | 
| 726 | 
            +
                        if (
         | 
| 727 | 
            +
                            tag in self.model_config.keys()
         | 
| 728 | 
            +
                            and "result" in self.model_config[tag].keys()
         | 
| 729 | 
            +
                        ):
         | 
| 730 | 
            +
                            final_result.extend(self.model_config[tag]["result"])
         | 
| 731 | 
            +
             | 
| 732 | 
            +
                    return final_result
         | 
