AronaTTS / app.py
kdrkdrkdr's picture
edit files
d9e50d0
raw
history blame contribute delete
No virus
5.3 kB
import json
import os
import re
import librosa
import numpy as np
import torch
from torch import no_grad, LongTensor
import commons
import utils
import gradio as gr
from models import SynthesizerTrn
from text import text_to_sequence
from text.symbols import symbols
limitation = os.getenv("SYSTEM") == "spaces" # limit text and audio length in huggingface spaces
def get_text(text, hps):
text_norm = text_to_sequence(text, hps.data.text_cleaners)
if hps.data.add_blank:
text_norm = commons.intersperse(text_norm, 0)
text_norm = torch.LongTensor(text_norm)
return text_norm
def create_tts_fn(net_g, hps, speaker_ids):
def tts_fn(text, speaker, speed):
if limitation:
text_len = len(text)
max_len = 5000
if text_len > max_len:
return "Error: Text is too long", None
speaker_id = speaker_ids[speaker]
stn_tst = get_text(text, hps)
with no_grad():
x_tst = stn_tst.unsqueeze(0)
x_tst_lengths = LongTensor([stn_tst.size(0)])
sid = LongTensor([speaker_id])
audio = net_g.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8,
length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy()
del stn_tst, x_tst, x_tst_lengths, sid
return "Success", (hps.data.sampling_rate, audio)
return tts_fn
css = """
#advanced-btn {
color: white;
border-color: black;
background: black;
font-size: .7rem !important;
line-height: 19px;
margin-top: 24px;
margin-bottom: 12px;
padding: 2px 8px;
border-radius: 14px !important;
}
#advanced-options {
display: none;
margin-bottom: 20px;
}
"""
if __name__ == '__main__':
models_tts = []
name = 'AronaTTS'
lang = '일본어 / ν•œκ΅­μ–΄ (Japanese / Korean)'
example = '[JA]ε…ˆη”Ÿγ€δ»Šζ—₯γ―ε€©ζ°—γŒζœ¬ε½“γ«γ„γ„γ§γ™γ­γ€‚[JA][KO]μ„ μƒλ‹˜, μ•ˆλ…•ν•˜μ„Έμš”. my name is arona[KO]'
config_path = f"pretrained_model/arona_ms_istft_vits.json"
model_path = f"pretrained_model/arona_ms_istft_vits.pth"
cover_path = f"pretrained_model/cover.gif"
hps = utils.get_hparams_from_file(config_path)
net_g = SynthesizerTrn(
len(symbols),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model)
_ = net_g.eval()
utils.load_checkpoint(model_path, net_g, None)
net_g.eval()
speaker_ids = [0]
speakers = [name]
t = 'vits'
models_tts.append((name, cover_path, speakers, lang, example,
hps.symbols, create_tts_fn(net_g, hps, speaker_ids)))
app = gr.Blocks(css=css)
with app:
gr.Markdown("# BlueArchive Arona TTS Using VITS Model\n"
"![visitor badge](https://visitor-badge.glitch.me/badge?page_id=openduckparty.AronaTTS)\n\n")
for i, (name, cover_path, speakers, lang, example, symbols, tts_fn
) in enumerate(models_tts):
with gr.Column():
gr.Markdown(f"## {name}\n\n"
f"![cover](file/{cover_path})\n\n"
f"lang: {lang}")
tts_input1 = gr.TextArea(label="Text (5000 words limitation)", value=example,
elem_id=f"tts-input{i}")
tts_input2 = gr.Dropdown(label="Speaker", choices=speakers,
type="index", value=speakers[0])
tts_input3 = gr.Slider(label="Speed", value=1, minimum=0.1, maximum=2, step=0.1)
tts_submit = gr.Button("Generate", variant="primary")
tts_output1 = gr.Textbox(label="Output Message")
tts_output2 = gr.Audio(label="Output Audio")
tts_submit.click(tts_fn, [tts_input1, tts_input2, tts_input3],
[tts_output1, tts_output2])
_js=f"""
(i,phonemes) => {{
let root = document.querySelector("body > gradio-app");
if (root.shadowRoot != null)
root = root.shadowRoot;
let text_input = root.querySelector("#tts-input{i}").querySelector("textarea");
let startPos = text_input.selectionStart;
let endPos = text_input.selectionEnd;
let oldTxt = text_input.value;
let result = oldTxt.substring(0, startPos) + phonemes[i] + oldTxt.substring(endPos);
text_input.value = result;
let x = window.scrollX, y = window.scrollY;
text_input.focus();
text_input.selectionStart = startPos + phonemes[i].length;
text_input.selectionEnd = startPos + phonemes[i].length;
text_input.blur();
window.scrollTo(x, y);
return [];
}}"""
app.queue(concurrency_count=3).launch(show_api=False)