TTS_MMS_VITS / app-new.py
VIZINTZOR's picture
Rename app.py to app-new.py
466cec3 verified
import torch
from transformers import VitsModel, AutoTokenizer, VitsConfig
import soundfile as sf
import numpy as np
import gradio as gr
import os
from thaicleantext import clean_thai_text
def load_tts_model(pth_path, speed=1.0):
"""Load the TTS model from a .pth file"""
try:
loaded_dict = torch.load(pth_path, map_location=torch.device('cpu'))
config = VitsConfig(**loaded_dict['config'])
model = VitsModel(config)
model.load_state_dict(loaded_dict['model_state'])
model.eval()
model.speaking_rate = speed
tokenizer = AutoTokenizer.from_pretrained("VIZINTZOR/tts-tha-vits")
return model, tokenizer, None
except Exception as e:
return None, None, f"Error loading model: {str(e)}"
def generate_speech(model, tokenizer, text, speed, volume, output_file="output.wav"):
"""Generate speech from text and save to file"""
try:
model.speaking_rate = speed
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
waveform = model(**inputs).waveform
waveform = waveform.squeeze().cpu().numpy()
waveform = waveform / np.max(np.abs(waveform)) # Normalize to [-1, 1]
waveform = waveform * volume # Apply volume adjustment
sample_rate = model.config.sampling_rate
sf.write(output_file, waveform, sample_rate)
return output_file, None
except Exception as e:
return None, f"Error generating speech: {str(e)}"
def get_available_models(model_dir="./models"):
"""Get list of .pth files in the models directory"""
if not os.path.exists(model_dir):
return []
return [os.path.join(model_dir, f) for f in os.listdir(model_dir) if f.endswith('.pth')]
def tts_interface(text, model_path, speed, volume):
"""Gradio interface function"""
model, tokenizer, error = load_tts_model(model_path, speed)
if model is None or tokenizer is None:
return None, error
output_file = "output.wav"
text = clean_thai_text(text)
audio_file, error = generate_speech(model, tokenizer, text, speed, volume, output_file)
if audio_file:
return audio_file, "Audio generated successfully!"
return None, error
# Create Gradio interface
with gr.Blocks(title="Text-to-Speech Generator", theme=gr.themes.Soft()) as demo:
gr.Markdown("# Text-to-Speech Generator")
gr.Markdown("Enter text, select a model, adjust speed and volume, and generate audio!")
with gr.Row():
with gr.Column(scale=2):
text_input = gr.Textbox(
label="Input Text",
placeholder="Enter your text here...",
lines=5
)
model_dropdown = gr.Dropdown(
label="Select Model",
choices=get_available_models(),
value=get_available_models()[0] if get_available_models() else None
)
with gr.Column(scale=1):
speed_slider = gr.Slider(
minimum=0.5,
maximum=2.0,
value=1.0,
step=0.05,
label="Speaking Speed",
info="1.0 is normal speed"
)
volume_slider = gr.Slider(
minimum=0.1,
maximum=1.0,
value=1,
step=0.05,
label="Volume",
info="Adjust output volume"
)
generate_btn = gr.Button("Generate Audio", variant="primary")
with gr.Row():
audio_output = gr.Audio(label="Generated Audio")
status_output = gr.Textbox(label="Status", interactive=False)
# Connect the button to the function
generate_btn.click(
fn=tts_interface,
inputs=[text_input, model_dropdown, speed_slider, volume_slider],
outputs=[audio_output, status_output]
)
# Launch the interface
demo.launch()