|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import torch |
|
import soundfile as sf |
|
import logging |
|
import gradio as gr |
|
import platform |
|
import numpy as np |
|
from pathlib import Path |
|
from datetime import datetime |
|
import tempfile |
|
|
|
|
|
from transformers import AutoProcessor, AutoModel |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
model_id = "DragonLineageAI/Vi-Spark-TTS-0.5B-v2" |
|
cache_dir = "model_cache" |
|
|
|
|
|
|
|
LEVELS_MAP_UI = { |
|
1: "very_low", |
|
2: "low", |
|
3: "moderate", |
|
4: "high", |
|
5: "very_high" |
|
} |
|
|
|
|
|
def load_model_and_processor(model_id, cache_dir): |
|
"""Loads the Processor and Model using Transformers.""" |
|
logging.info(f"Loading processor from: {model_id}") |
|
try: |
|
processor = AutoProcessor.from_pretrained( |
|
model_id, |
|
trust_remote_code=True, |
|
|
|
cache_dir=cache_dir |
|
) |
|
logging.info("Processor loaded successfully.") |
|
except Exception as e: |
|
logging.error(f"Error loading processor: {e}") |
|
raise |
|
|
|
logging.info(f"Loading model from: {model_id}") |
|
try: |
|
model = AutoModel.from_pretrained( |
|
model_id, |
|
trust_remote_code=True, |
|
cache_dir=cache_dir, |
|
|
|
) |
|
model.eval() |
|
logging.info("Model loaded successfully.") |
|
except Exception as e: |
|
logging.error(f"Error loading model: {e}") |
|
raise |
|
|
|
|
|
|
|
processor.model = model |
|
logging.info("Model reference set in processor.") |
|
|
|
|
|
if hasattr(model.config, 'sample_rate') and processor.sampling_rate != model.config.sample_rate: |
|
logging.warning(f"Processor SR ({processor.sampling_rate}) != Model Config SR ({model.config.sample_rate}). Updating processor.") |
|
processor.sampling_rate = model.config.sample_rate |
|
|
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
elif platform.system() == "Darwin" and torch.backends.mps.is_available(): |
|
|
|
device = torch.device("mps") |
|
else: |
|
device = torch.device("cpu") |
|
|
|
logging.info(f"Selected device: {device}") |
|
model.to(device) |
|
logging.info(f"Model moved to device: {device}") |
|
|
|
return processor, model, device |
|
|
|
|
|
try: |
|
processor, model, device = load_model_and_processor(model_id, cache_dir) |
|
MODEL_LOADED = True |
|
except Exception as e: |
|
MODEL_LOADED = False |
|
logging.error(f"Failed to load model/processor: {e}") |
|
|
|
|
|
|
|
|
|
def run_voice_clone_tts( |
|
text, |
|
prompt_speech_path, |
|
prompt_text, |
|
processor, |
|
model, |
|
device, |
|
): |
|
"""Performs voice cloning TTS using Transformers.""" |
|
if not MODEL_LOADED: |
|
return None, "Error: Model not loaded." |
|
if not text: |
|
return None, "Error: Please provide text to synthesize." |
|
if not prompt_speech_path: |
|
return None, "Error: Please provide a prompt audio file (upload or record)." |
|
|
|
logging.info("Starting voice cloning inference...") |
|
logging.info(f"Inputs - Text: '{text}', Prompt Audio: {prompt_speech_path}, Prompt Text: '{prompt_text}'") |
|
|
|
try: |
|
|
|
prompt_text_clean = None if not prompt_text or len(prompt_text.strip()) < 2 else prompt_text.strip() |
|
|
|
|
|
inputs = processor( |
|
text=text.lower(), |
|
prompt_speech_path=prompt_speech_path, |
|
prompt_text=prompt_text_clean.lower() if prompt_text_clean else prompt_text_clean, |
|
return_tensors="pt" |
|
).to(device) |
|
|
|
|
|
global_tokens_prompt = inputs.pop("global_token_ids_prompt", None) |
|
if global_tokens_prompt is None: |
|
logging.warning("global_token_ids_prompt not found in processor output. Decoding might be affected.") |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
output_ids = model.generate( |
|
**inputs, |
|
max_new_tokens=3000, |
|
do_sample=True, |
|
temperature=0.8, |
|
top_k=50, |
|
top_p=0.95, |
|
eos_token_id=processor.tokenizer.eos_token_id, |
|
pad_token_id=processor.tokenizer.pad_token_id |
|
) |
|
|
|
|
|
output_clone = processor.decode( |
|
generated_ids=output_ids, |
|
global_token_ids_prompt=global_tokens_prompt, |
|
input_ids_len=inputs["input_ids"].shape[-1] |
|
) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile: |
|
sf.write(tmpfile.name, output_clone["audio"], output_clone["sampling_rate"]) |
|
output_path = tmpfile.name |
|
|
|
logging.info(f"Voice cloning successful. Audio saved temporarily at: {output_path}") |
|
return output_path, None |
|
|
|
except Exception as e: |
|
logging.error(f"Error during voice cloning inference: {e}", exc_info=True) |
|
return None, f"Error during generation: {e}" |
|
|
|
|
|
def run_voice_creation_tts( |
|
text, |
|
gender, |
|
pitch_level, |
|
speed_level, |
|
processor, |
|
model, |
|
device, |
|
): |
|
"""Performs voice creation TTS using Transformers.""" |
|
if not MODEL_LOADED: |
|
return None, "Error: Model not loaded." |
|
if not text: |
|
return None, "Error: Please provide text to synthesize." |
|
|
|
|
|
pitch_str = LEVELS_MAP_UI.get(pitch_level, "moderate") |
|
speed_str = LEVELS_MAP_UI.get(speed_level, "moderate") |
|
|
|
logging.info("Starting voice creation inference...") |
|
logging.info(f"Inputs - Text: '{text}', Gender: {gender}, Pitch: {pitch_str} (Level {pitch_level}), Speed: {speed_str} (Level {speed_level})") |
|
|
|
try: |
|
|
|
inputs = processor( |
|
text=text.lower(), |
|
|
|
|
|
gender=gender, |
|
pitch=pitch_str, |
|
speed=speed_str, |
|
return_tensors="pt" |
|
).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
output_ids = model.generate( |
|
**inputs, |
|
max_new_tokens=3000, |
|
do_sample=True, |
|
temperature=0.8, |
|
top_k=50, |
|
top_p=0.95, |
|
eos_token_id=processor.tokenizer.eos_token_id, |
|
pad_token_id=processor.tokenizer.pad_token_id |
|
) |
|
|
|
|
|
output_create = processor.decode( |
|
generated_ids=output_ids, |
|
input_ids_len=inputs["input_ids"].shape[-1] |
|
) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile: |
|
sf.write(tmpfile.name, output_create["audio"], output_create["sampling_rate"]) |
|
output_path = tmpfile.name |
|
|
|
logging.info(f"Voice creation successful. Audio saved temporarily at: {output_path}") |
|
return output_path, None |
|
|
|
except Exception as e: |
|
logging.error(f"Error during voice creation inference: {e}", exc_info=True) |
|
return None, f"Error during generation: {e}" |
|
|
|
|
|
|
|
def build_ui(): |
|
with gr.Blocks() as demo: |
|
gr.HTML('<h1 style="text-align: center;">Spark-TTS Demo (Transformers)</h1>') |
|
gr.Markdown( |
|
"Powered by [DragonLineageAI/Vi-Spark-TTS-0.5B-v2](https://huggingface.co/DragonLineageAI/Vi-Spark-TTS-0.5B-v2). " |
|
"Choose a tab for Voice Cloning or Voice Creation." |
|
) |
|
|
|
if not MODEL_LOADED: |
|
gr.Markdown("## ⚠️ Error: Model failed to load. Please check the Space logs.") |
|
|
|
with gr.Tabs(): |
|
|
|
with gr.TabItem("Voice Clone"): |
|
gr.Markdown( |
|
"### Upload Reference Audio or Record" |
|
) |
|
gr.Markdown( |
|
"Provide a short audio clip (5-20 seconds) of the voice you want to clone. " |
|
"Optionally, provide the transcript of that audio for better results, especially if the language is the same as the text you want to synthesize." |
|
) |
|
|
|
with gr.Row(): |
|
prompt_wav_upload = gr.Audio( |
|
sources=["upload"], |
|
type="filepath", |
|
label="Upload Prompt Audio File (WAV/MP3)", |
|
) |
|
prompt_wav_record = gr.Audio( |
|
sources=["microphone"], |
|
type="filepath", |
|
label="Or Record Prompt Audio", |
|
) |
|
|
|
with gr.Row(): |
|
text_input_clone = gr.Textbox( |
|
label="Text to Synthesize", |
|
lines=4, |
|
placeholder="Enter text here..." |
|
) |
|
prompt_text_input = gr.Textbox( |
|
label="Text of Prompt Speech (Optional)", |
|
lines=2, |
|
placeholder="Enter the transcript of the prompt audio (if available).", |
|
info="Recommended for cloning in the same language." |
|
) |
|
|
|
audio_output_clone = gr.Audio( |
|
label="Generated Audio", |
|
autoplay=False, |
|
) |
|
status_clone = gr.Textbox(label="Status", interactive=False) |
|
|
|
generate_button_clone = gr.Button("Generate Cloned Voice", variant="primary", interactive=MODEL_LOADED) |
|
|
|
def voice_clone_callback(text, prompt_text, audio_upload, audio_record): |
|
|
|
prompt_speech = audio_upload if audio_upload else audio_record |
|
if not prompt_speech: |
|
|
|
return None, "Error: Please upload or record a reference audio." |
|
|
|
|
|
output_path, error_msg = run_voice_clone_tts( |
|
text, |
|
prompt_speech, |
|
prompt_text, |
|
processor, |
|
model, |
|
device |
|
) |
|
if error_msg: |
|
return None, error_msg |
|
else: |
|
|
|
return output_path, "Audio generated successfully!" |
|
|
|
|
|
generate_button_clone.click( |
|
voice_clone_callback, |
|
inputs=[ |
|
text_input_clone, |
|
prompt_text_input, |
|
prompt_wav_upload, |
|
prompt_wav_record, |
|
], |
|
outputs=[audio_output_clone, status_clone], |
|
) |
|
|
|
|
|
|
|
gr.Examples( |
|
examples=[ |
|
["Hello, this is a test of voice cloning.", "I am a sample reference voice.", "examples/sample_prompt.wav", None], |
|
["You can experiment with different voices and texts.", None, None, "examples/sample_record.wav"], |
|
["The quality of the clone depends on the reference audio.", "This is the reference text.", "examples/another_prompt.wav", None] |
|
], |
|
inputs=[text_input_clone, prompt_text_input, prompt_wav_upload, prompt_wav_record], |
|
outputs=[audio_output_clone, status_clone], |
|
fn=voice_clone_callback, |
|
cache_examples=False, |
|
label="Clone Examples" |
|
) |
|
|
|
|
|
|
|
with gr.TabItem("Voice Creation"): |
|
gr.Markdown( |
|
"### Create Your Own Voice Based on the Following Parameters" |
|
) |
|
gr.Markdown( |
|
"Select gender, adjust pitch and speed to generate a new synthetic voice." |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gender = gr.Radio( |
|
choices=["male", "female"], value="female", label="Gender" |
|
) |
|
pitch = gr.Slider( |
|
minimum=1, maximum=5, step=1, value=3, label="Pitch (1=Lowest, 5=Highest)" |
|
) |
|
speed = gr.Slider( |
|
minimum=1, maximum=5, step=1, value=3, label="Speed (1=Slowest, 5=Fastest)" |
|
) |
|
with gr.Column(scale=2): |
|
text_input_creation = gr.Textbox( |
|
label="Text to Synthesize", |
|
lines=5, |
|
placeholder="Enter text here...", |
|
value="You can generate a customized voice by adjusting parameters such as pitch and speed.", |
|
) |
|
|
|
audio_output_creation = gr.Audio( |
|
label="Generated Audio", |
|
autoplay=False, |
|
) |
|
status_create = gr.Textbox(label="Status", interactive=False) |
|
|
|
create_button = gr.Button("Create New Voice", variant="primary", interactive=MODEL_LOADED) |
|
|
|
def voice_creation_callback(text, gender, pitch_val, speed_val): |
|
|
|
output_path, error_msg = run_voice_creation_tts( |
|
text, |
|
gender, |
|
int(pitch_val), |
|
int(speed_val), |
|
processor, |
|
model, |
|
device |
|
) |
|
if error_msg: |
|
return None, error_msg |
|
else: |
|
return output_path, "Audio generated successfully!" |
|
|
|
create_button.click( |
|
voice_creation_callback, |
|
inputs=[text_input_creation, gender, pitch, speed], |
|
outputs=[audio_output_creation, status_create], |
|
) |
|
|
|
gr.Examples( |
|
examples=[ |
|
["This is a female voice with average pitch and speed.", "female", 3, 3], |
|
["This is a male voice, speaking quickly with a slightly higher pitch.", "male", 4, 4], |
|
["A deep and slow female voice.", "female", 1, 2], |
|
["A very high-pitched and fast male voice.", "male", 5, 5] |
|
], |
|
inputs=[text_input_creation, gender, pitch, speed], |
|
outputs=[audio_output_creation, status_create], |
|
fn=voice_creation_callback, |
|
cache_examples=False, |
|
label="Creation Examples" |
|
) |
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
demo = build_ui() |
|
demo.launch() |