Wismut's picture
fixed title
1c73802
import gradio as gr
import os
import torch
# Import eSpeak TTS pipeline
from tts_cli import (
build_model as build_model_espeak,
generate_long_form_tts as generate_long_form_tts_espeak,
)
# Import OpenPhonemizer TTS pipeline
from tts_cli_op import (
build_model as build_model_open,
generate_long_form_tts as generate_long_form_tts_open,
)
from pretrained_models import Kokoro
# ---------------------------------------------------------------------
# Path to models and voicepacks
# ---------------------------------------------------------------------
MODELS_DIR = "pretrained_models/Kokoro"
VOICES_DIR = "pretrained_models/Kokoro/voices"
# ---------------------------------------------------------------------
# List the models (.pth) and voices (.pt)
# ---------------------------------------------------------------------
def get_models():
return sorted([f for f in os.listdir(MODELS_DIR) if f.endswith(".pth")])
def get_voices():
return sorted([f for f in os.listdir(VOICES_DIR) if f.endswith(".pt")])
# ---------------------------------------------------------------------
# We'll map engine selection -> (build_model_func, generate_func)
# ---------------------------------------------------------------------
ENGINES = {
"espeak": (build_model_espeak, generate_long_form_tts_espeak),
"openphonemizer": (build_model_open, generate_long_form_tts_open),
}
# ---------------------------------------------------------------------
# The main inference function called by Gradio
# ---------------------------------------------------------------------
def tts_inference(text, engine, model_file, voice_file, speed=1.0):
"""
text: Input string
engine: "espeak" or "openphonemizer"
model_file: Selected .pth from the models folder
voice_file: Selected .pt from the voices folder
speed: Speech speed
"""
# 1) Map engine to the correct build_model + generate_long_form_tts
build_fn, gen_fn = ENGINES[engine]
# 2) Prepare paths
model_path = os.path.join(MODELS_DIR, model_file)
voice_path = os.path.join(VOICES_DIR, voice_file)
# 3) Decide device
device = "cuda" if torch.cuda.is_available() else "cpu"
# 4) Load model
model = build_fn(model_path, device=device)
# Set submodules eval
for k, subm in model.items():
if hasattr(subm, "eval"):
subm.eval()
# 5) Load voicepack
voicepack = torch.load(voice_path, map_location=device)
if hasattr(voicepack, "eval"):
voicepack.eval()
# 6) Generate TTS
audio, phonemes = gen_fn(model, text, voicepack, speed=speed)
sr = 22050 # or your actual sample rate
return (sr, audio) # Gradio expects (sample_rate, np_array)
# ---------------------------------------------------------------------
# Build Gradio App
# ---------------------------------------------------------------------
def create_gradio_app():
model_list = get_models()
voice_list = get_voices()
css = """
h4 {
text-align: center;
display:block;
}
h2 {
text-align: center;
display:block;
}
"""
with gr.Blocks(theme=gr.themes.Ocean(), css=css) as demo:
gr.Markdown("## Kokoro TTS Demo: Choose phonemizer, model, and voice")
# Row 1: Text input
text_input = gr.Textbox(
label="Input Text",
value="Hello, world! Testing both eSpeak and OpenPhonemizer. Can you believe that we live in 2025 and have access to advanced AI?",
lines=3,
)
# Row 2: Engine selection
engine_dropdown = gr.Dropdown(
choices=["espeak", "openphonemizer"],
value="openphonemizer",
label="Phonemizer",
)
# Row 3: Model dropdown
model_dropdown = gr.Dropdown(
choices=model_list,
value=model_list[0] if model_list else None,
label="Model (.pth)",
)
# Row 4: Voice dropdown
voice_dropdown = gr.Dropdown(
choices=voice_list,
value=voice_list[0] if voice_list else None,
label="Voice (.pt)",
)
# Row 5: Speed slider
speed_slider = gr.Slider(
minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Speech Speed"
)
# Generate button + audio output
generate_btn = gr.Button("Generate")
tts_output = gr.Audio(label="TTS Output")
# Connect the button to our inference function
generate_btn.click(
fn=tts_inference,
inputs=[
text_input,
engine_dropdown,
model_dropdown,
voice_dropdown,
speed_slider,
],
outputs=tts_output,
)
gr.Markdown(
"#### Kokoro TTS Demo based on [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M)"
)
return demo
# ---------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------
if __name__ == "__main__":
app = create_gradio_app()
app.launch()