Spaces:
Runtime error
Runtime error
import torch | |
from transformers import VitsModel, AutoTokenizer, VitsConfig | |
import soundfile as sf | |
import numpy as np | |
import gradio as gr | |
import os | |
from thaicleantext import clean_thai_text | |
def load_tts_model(pth_path, speed=1.0): | |
"""Load the TTS model from a .pth file""" | |
try: | |
loaded_dict = torch.load(pth_path, map_location=torch.device('cpu')) | |
config = VitsConfig(**loaded_dict['config']) | |
model = VitsModel(config) | |
model.load_state_dict(loaded_dict['model_state']) | |
model.eval() | |
model.speaking_rate = speed | |
tokenizer = AutoTokenizer.from_pretrained("VIZINTZOR/tts-tha-vits") | |
return model, tokenizer, None | |
except Exception as e: | |
return None, None, f"Error loading model: {str(e)}" | |
def generate_speech(model, tokenizer, text, speed, volume, output_file="output.wav"): | |
"""Generate speech from text and save to file""" | |
try: | |
model.speaking_rate = speed | |
inputs = tokenizer(text, return_tensors="pt") | |
with torch.no_grad(): | |
waveform = model(**inputs).waveform | |
waveform = waveform.squeeze().cpu().numpy() | |
waveform = waveform / np.max(np.abs(waveform)) # Normalize to [-1, 1] | |
waveform = waveform * volume # Apply volume adjustment | |
sample_rate = model.config.sampling_rate | |
sf.write(output_file, waveform, sample_rate) | |
return output_file, None | |
except Exception as e: | |
return None, f"Error generating speech: {str(e)}" | |
def get_available_models(model_dir="./models"): | |
"""Get list of .pth files in the models directory""" | |
if not os.path.exists(model_dir): | |
return [] | |
return [os.path.join(model_dir, f) for f in os.listdir(model_dir) if f.endswith('.pth')] | |
def tts_interface(text, model_path, speed, volume): | |
"""Gradio interface function""" | |
model, tokenizer, error = load_tts_model(model_path, speed) | |
if model is None or tokenizer is None: | |
return None, error | |
output_file = "output.wav" | |
text = clean_thai_text(text) | |
audio_file, error = generate_speech(model, tokenizer, text, speed, volume, output_file) | |
if audio_file: | |
return audio_file, "Audio generated successfully!" | |
return None, error | |
# Create Gradio interface | |
with gr.Blocks(title="Text-to-Speech Generator", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# Text-to-Speech Generator") | |
gr.Markdown("Enter text, select a model, adjust speed and volume, and generate audio!") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
text_input = gr.Textbox( | |
label="Input Text", | |
placeholder="Enter your text here...", | |
lines=5 | |
) | |
model_dropdown = gr.Dropdown( | |
label="Select Model", | |
choices=get_available_models(), | |
value=get_available_models()[0] if get_available_models() else None | |
) | |
with gr.Column(scale=1): | |
speed_slider = gr.Slider( | |
minimum=0.5, | |
maximum=2.0, | |
value=1.0, | |
step=0.05, | |
label="Speaking Speed", | |
info="1.0 is normal speed" | |
) | |
volume_slider = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=1, | |
step=0.05, | |
label="Volume", | |
info="Adjust output volume" | |
) | |
generate_btn = gr.Button("Generate Audio", variant="primary") | |
with gr.Row(): | |
audio_output = gr.Audio(label="Generated Audio") | |
status_output = gr.Textbox(label="Status", interactive=False) | |
# Connect the button to the function | |
generate_btn.click( | |
fn=tts_interface, | |
inputs=[text_input, model_dropdown, speed_slider, volume_slider], | |
outputs=[audio_output, status_output] | |
) | |
# Launch the interface | |
demo.launch() |