Spaces:
Runtime error
Runtime error
saya
commited on
Commit
·
3fd0c0d
1
Parent(s):
f4da48c
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,15 +1,19 @@
|
|
| 1 |
# coding=utf-8
|
| 2 |
import os
|
| 3 |
import re
|
|
|
|
| 4 |
import utils
|
| 5 |
import commons
|
| 6 |
import json
|
|
|
|
| 7 |
import gradio as gr
|
| 8 |
from models import SynthesizerTrn
|
| 9 |
from text import text_to_sequence
|
| 10 |
from torch import no_grad, LongTensor
|
| 11 |
import logging
|
| 12 |
logging.getLogger('numba').setLevel(logging.WARNING)
|
|
|
|
|
|
|
| 13 |
hps_ms = utils.get_hparams_from_file(r'config/config.json')
|
| 14 |
|
| 15 |
def get_text(text, hps):
|
|
@@ -22,10 +26,11 @@ def get_text(text, hps):
|
|
| 22 |
def create_tts_fn(net_g_ms, speaker_id):
|
| 23 |
def tts_fn(text, language, noise_scale, noise_scale_w, length_scale):
|
| 24 |
text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
| 29 |
if language == 0:
|
| 30 |
text = f"[ZH]{text}[ZH]"
|
| 31 |
elif language == 1:
|
|
@@ -34,11 +39,11 @@ def create_tts_fn(net_g_ms, speaker_id):
|
|
| 34 |
text = f"{text}"
|
| 35 |
stn_tst, clean_text = get_text(text, hps_ms)
|
| 36 |
with no_grad():
|
| 37 |
-
x_tst = stn_tst.unsqueeze(0)
|
| 38 |
-
x_tst_lengths = LongTensor([stn_tst.size(0)])
|
| 39 |
-
sid = LongTensor([speaker_id])
|
| 40 |
audio = net_g_ms.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=noise_scale, noise_scale_w=noise_scale_w,
|
| 41 |
-
length_scale=length_scale)[0][0, 0].data.float().numpy()
|
| 42 |
|
| 43 |
return "Success", (22050, audio)
|
| 44 |
return tts_fn
|
|
@@ -72,23 +77,29 @@ download_audio_js = """
|
|
| 72 |
"""
|
| 73 |
|
| 74 |
if __name__ == '__main__':
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
models = []
|
| 76 |
with open("pretrained_models/info.json", "r", encoding="utf-8") as f:
|
| 77 |
models_info = json.load(f)
|
| 78 |
for i, info in models_info.items():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
net_g_ms = SynthesizerTrn(
|
| 80 |
len(hps_ms.symbols),
|
| 81 |
hps_ms.data.filter_length // 2 + 1,
|
| 82 |
hps_ms.train.segment_size // hps_ms.data.hop_length,
|
| 83 |
n_speakers=hps_ms.data.n_speakers,
|
| 84 |
**hps_ms.model)
|
| 85 |
-
_ = net_g_ms.eval()
|
| 86 |
-
sid = info['sid']
|
| 87 |
-
name_en = info['name_en']
|
| 88 |
-
name_zh = info['name_zh']
|
| 89 |
-
title = info['title']
|
| 90 |
-
cover = f"pretrained_models/{i}/{info['cover']}"
|
| 91 |
utils.load_checkpoint(f'pretrained_models/{i}/{i}.pth', net_g_ms, None)
|
|
|
|
| 92 |
models.append((sid, name_en, name_zh, title, cover, net_g_ms, create_tts_fn(net_g_ms, sid)))
|
| 93 |
with gr.Blocks() as app:
|
| 94 |
gr.Markdown(
|
|
|
|
| 1 |
# coding=utf-8
|
| 2 |
import os
|
| 3 |
import re
|
| 4 |
+
import argparse
|
| 5 |
import utils
|
| 6 |
import commons
|
| 7 |
import json
|
| 8 |
+
import torch
|
| 9 |
import gradio as gr
|
| 10 |
from models import SynthesizerTrn
|
| 11 |
from text import text_to_sequence
|
| 12 |
from torch import no_grad, LongTensor
|
| 13 |
import logging
|
| 14 |
logging.getLogger('numba').setLevel(logging.WARNING)
|
| 15 |
+
limitation = os.getenv("SYSTEM") == "spaces" # limit text and audio length in huggingface spaces
|
| 16 |
+
|
| 17 |
hps_ms = utils.get_hparams_from_file(r'config/config.json')
|
| 18 |
|
| 19 |
def get_text(text, hps):
|
|
|
|
| 26 |
def create_tts_fn(net_g_ms, speaker_id):
|
| 27 |
def tts_fn(text, language, noise_scale, noise_scale_w, length_scale):
|
| 28 |
text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
|
| 29 |
+
if limitation:
|
| 30 |
+
text_len = len(re.sub("\[([A-Z]{2})\]", "", text))
|
| 31 |
+
max_len = 100
|
| 32 |
+
if text_len > max_len:
|
| 33 |
+
return "Error: Text is too long", None
|
| 34 |
if language == 0:
|
| 35 |
text = f"[ZH]{text}[ZH]"
|
| 36 |
elif language == 1:
|
|
|
|
| 39 |
text = f"{text}"
|
| 40 |
stn_tst, clean_text = get_text(text, hps_ms)
|
| 41 |
with no_grad():
|
| 42 |
+
x_tst = stn_tst.unsqueeze(0).to(device)
|
| 43 |
+
x_tst_lengths = LongTensor([stn_tst.size(0)]).to(device)
|
| 44 |
+
sid = LongTensor([speaker_id]).to(device)
|
| 45 |
audio = net_g_ms.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=noise_scale, noise_scale_w=noise_scale_w,
|
| 46 |
+
length_scale=length_scale)[0][0, 0].data.cpu().float().numpy()
|
| 47 |
|
| 48 |
return "Success", (22050, audio)
|
| 49 |
return tts_fn
|
|
|
|
| 77 |
"""
|
| 78 |
|
| 79 |
if __name__ == '__main__':
|
| 80 |
+
parser = argparse.ArgumentParser()
|
| 81 |
+
parser.add_argument('--device', type=str, default='cpu')
|
| 82 |
+
parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
|
| 83 |
+
args = parser.parse_args()
|
| 84 |
+
device = torch.device(args.device)
|
| 85 |
+
|
| 86 |
models = []
|
| 87 |
with open("pretrained_models/info.json", "r", encoding="utf-8") as f:
|
| 88 |
models_info = json.load(f)
|
| 89 |
for i, info in models_info.items():
|
| 90 |
+
sid = info['sid']
|
| 91 |
+
name_en = info['name_en']
|
| 92 |
+
name_zh = info['name_zh']
|
| 93 |
+
title = info['title']
|
| 94 |
+
cover = f"pretrained_models/{i}/{info['cover']}"
|
| 95 |
net_g_ms = SynthesizerTrn(
|
| 96 |
len(hps_ms.symbols),
|
| 97 |
hps_ms.data.filter_length // 2 + 1,
|
| 98 |
hps_ms.train.segment_size // hps_ms.data.hop_length,
|
| 99 |
n_speakers=hps_ms.data.n_speakers,
|
| 100 |
**hps_ms.model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
utils.load_checkpoint(f'pretrained_models/{i}/{i}.pth', net_g_ms, None)
|
| 102 |
+
_ = net_g_ms.eval().to(device)
|
| 103 |
models.append((sid, name_en, name_zh, title, cover, net_g_ms, create_tts_fn(net_g_ms, sid)))
|
| 104 |
with gr.Blocks() as app:
|
| 105 |
gr.Markdown(
|