File size: 18,113 Bytes
862e5be 88d50ab 862e5be dc56bd0 862e5be 88d50ab 862e5be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 |
# Copyright (c) 2025 SparkAudio & DragonLineageAI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
import soundfile as sf
import logging
import gradio as gr
import platform
import numpy as np
from pathlib import Path
from datetime import datetime
import tempfile # To handle temporary audio files for Gradio
# --- Import Transformers ---
from transformers import AutoProcessor, AutoModel
# --- Configuration ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
model_id = "DragonLineageAI/Vi-Spark-TTS-0.5B-v2"
cache_dir = "model_cache" # Define a cache directory within the Space
# Mapping from Gradio Slider (1-5) to model's expected string values
# Adjust these strings if the model expects different ones (e.g., "slow", "fast")
LEVELS_MAP_UI = {
1: "very_low", # Or "slowest" / "lowest"
2: "low", # Or "slow" / "low"
3: "moderate", # Or "normal" / "medium"
4: "high", # Or "fast" / "high"
5: "very_high" # Or "fastest" / "highest"
}
# --- Model Loading ---
def load_model_and_processor(model_id, cache_dir):
"""Loads the Processor and Model using Transformers."""
logging.info(f"Loading processor from: {model_id}")
try:
processor = AutoProcessor.from_pretrained(
model_id,
trust_remote_code=True,
# token=api_key, # Use token only if necessary and ideally from secrets
cache_dir=cache_dir
)
logging.info("Processor loaded successfully.")
except Exception as e:
logging.error(f"Error loading processor: {e}")
raise
logging.info(f"Loading model from: {model_id}")
try:
model = AutoModel.from_pretrained(
model_id,
trust_remote_code=True,
cache_dir=cache_dir,
# torch_dtype=torch.float16 # Optional: uncomment for potential speedup/memory saving if supported
)
model.eval() # Set model to evaluation mode
logging.info("Model loaded successfully.")
except Exception as e:
logging.error(f"Error loading model: {e}")
raise
# --- Link Model to Processor ---
# THIS STEP IS CRUCIAL
processor.model = model
logging.info("Model reference set in processor.")
# Sync sampling rate if necessary
if hasattr(model.config, 'sample_rate') and processor.sampling_rate != model.config.sample_rate:
logging.warning(f"Processor SR ({processor.sampling_rate}) != Model Config SR ({model.config.sample_rate}). Updating processor.")
processor.sampling_rate = model.config.sample_rate
# --- Device Selection ---
if torch.cuda.is_available():
device = torch.device("cuda")
elif platform.system() == "Darwin" and torch.backends.mps.is_available():
# Check for MPS availability specifically
device = torch.device("mps")
else:
device = torch.device("cpu")
logging.info(f"Selected device: {device}")
model.to(device)
logging.info(f"Model moved to device: {device}")
return processor, model, device
# --- Load Model Globally (once per Space instance) ---
try:
processor, model, device = load_model_and_processor(model_id, cache_dir)
MODEL_LOADED = True
except Exception as e:
MODEL_LOADED = False
logging.error(f"Failed to load model/processor: {e}")
# You might want to display an error in the Gradio UI if loading fails
# --- Core TTS Functions ---
def run_voice_clone_tts(
text,
prompt_speech_path,
prompt_text,
processor,
model,
device,
):
"""Performs voice cloning TTS using Transformers."""
if not MODEL_LOADED:
return None, "Error: Model not loaded."
if not text:
return None, "Error: Please provide text to synthesize."
if not prompt_speech_path:
return None, "Error: Please provide a prompt audio file (upload or record)."
logging.info("Starting voice cloning inference...")
logging.info(f"Inputs - Text: '{text}', Prompt Audio: {prompt_speech_path}, Prompt Text: '{prompt_text}'")
try:
# Ensure prompt_text is None if empty/short, otherwise use it
prompt_text_clean = None if not prompt_text or len(prompt_text.strip()) < 2 else prompt_text.strip()
# 1. Preprocess using Processor
inputs = processor(
text=text.lower(),
prompt_speech_path=prompt_speech_path,
prompt_text=prompt_text_clean.lower() if prompt_text_clean else prompt_text_clean,
return_tensors="pt"
).to(device) # Move processor output to model device
# Store prompt global tokens if present (important for decoding)
global_tokens_prompt = inputs.pop("global_token_ids_prompt", None)
if global_tokens_prompt is None:
logging.warning("global_token_ids_prompt not found in processor output. Decoding might be affected.")
# 2. Generate using Model
with torch.no_grad():
# Use generate parameters consistent with the original pipeline/model card
# Adjust max_new_tokens based on expected output length vs input length
# A fixed large value might be okay, or calculate dynamically if needed.
output_ids = model.generate(
**inputs,
max_new_tokens=3000, # Safeguard, might need adjustment
do_sample=True,
temperature=0.8,
top_k=50,
top_p=0.95,
eos_token_id=processor.tokenizer.eos_token_id,
pad_token_id=processor.tokenizer.pad_token_id # Use EOS if PAD is None
)
# 3. Decode using Processor
output_clone = processor.decode(
generated_ids=output_ids,
global_token_ids_prompt=global_tokens_prompt,
input_ids_len=inputs["input_ids"].shape[-1] # Pass prompt length
)
# Save audio to a temporary file for Gradio
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
sf.write(tmpfile.name, output_clone["audio"], output_clone["sampling_rate"])
output_path = tmpfile.name
logging.info(f"Voice cloning successful. Audio saved temporarily at: {output_path}")
return output_path, None # Return path and no error message
except Exception as e:
logging.error(f"Error during voice cloning inference: {e}", exc_info=True)
return None, f"Error during generation: {e}"
def run_voice_creation_tts(
text,
gender,
pitch_level, # Expecting 1-5
speed_level, # Expecting 1-5
processor,
model,
device,
):
"""Performs voice creation TTS using Transformers."""
if not MODEL_LOADED:
return None, "Error: Model not loaded."
if not text:
return None, "Error: Please provide text to synthesize."
# Map numeric levels to string representations
pitch_str = LEVELS_MAP_UI.get(pitch_level, "moderate") # Default to moderate if invalid
speed_str = LEVELS_MAP_UI.get(speed_level, "moderate") # Default to moderate if invalid
logging.info("Starting voice creation inference...")
logging.info(f"Inputs - Text: '{text}', Gender: {gender}, Pitch: {pitch_str} (Level {pitch_level}), Speed: {speed_str} (Level {speed_level})")
try:
# 1. Preprocess
inputs = processor(
text=text.lower(),
# prompt_speech_path=None, # No audio prompt for creation
# prompt_text=None, # No text prompt for creation
gender=gender,
pitch=pitch_str,
speed=speed_str,
return_tensors="pt"
).to(device)
# 2. Generate
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=3000, # Safeguard
do_sample=True,
temperature=0.8,
top_k=50,
top_p=0.95,
eos_token_id=processor.tokenizer.eos_token_id,
pad_token_id=processor.tokenizer.pad_token_id
)
# 3. Decode (no prompt global tokens needed here)
output_create = processor.decode(
generated_ids=output_ids,
input_ids_len=inputs["input_ids"].shape[-1] # Pass prompt length
)
# Save audio to a temporary file for Gradio
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
sf.write(tmpfile.name, output_create["audio"], output_create["sampling_rate"])
output_path = tmpfile.name
logging.info(f"Voice creation successful. Audio saved temporarily at: {output_path}")
return output_path, None # Return path and no error message
except Exception as e:
logging.error(f"Error during voice creation inference: {e}", exc_info=True)
return None, f"Error during generation: {e}"
# --- Gradio UI ---
def build_ui():
with gr.Blocks() as demo:
gr.HTML('<h1 style="text-align: center;">Spark-TTS Demo (Transformers)</h1>') # Changed title slightly
gr.Markdown(
"Powered by [DragonLineageAI/Vi-Spark-TTS-0.5B-v2](https://huggingface.co/DragonLineageAI/Vi-Spark-TTS-0.5B-v2). "
"Choose a tab for Voice Cloning or Voice Creation."
)
if not MODEL_LOADED:
gr.Markdown("## ⚠️ Error: Model failed to load. Please check the Space logs.")
with gr.Tabs():
# --- Voice Clone Tab ---
with gr.TabItem("Voice Clone"):
gr.Markdown(
"### Upload Reference Audio or Record"
)
gr.Markdown(
"Provide a short audio clip (5-20 seconds) of the voice you want to clone. "
"Optionally, provide the transcript of that audio for better results, especially if the language is the same as the text you want to synthesize."
)
with gr.Row():
prompt_wav_upload = gr.Audio(
sources=["upload"],
type="filepath",
label="Upload Prompt Audio File (WAV/MP3)",
)
prompt_wav_record = gr.Audio(
sources=["microphone"],
type="filepath",
label="Or Record Prompt Audio",
)
with gr.Row():
text_input_clone = gr.Textbox(
label="Text to Synthesize",
lines=4,
placeholder="Enter text here..."
)
prompt_text_input = gr.Textbox(
label="Text of Prompt Speech (Optional)",
lines=2,
placeholder="Enter the transcript of the prompt audio (if available).",
info="Recommended for cloning in the same language." # Added info here
)
audio_output_clone = gr.Audio(
label="Generated Audio",
autoplay=False,
)
status_clone = gr.Textbox(label="Status", interactive=False) # For status/error messages
generate_button_clone = gr.Button("Generate Cloned Voice", variant="primary", interactive=MODEL_LOADED)
def voice_clone_callback(text, prompt_text, audio_upload, audio_record):
# Prioritize uploaded file, fallback to recorded file
prompt_speech = audio_upload if audio_upload else audio_record
if not prompt_speech:
# Return None for the audio component and the error message for the status component
return None, "Error: Please upload or record a reference audio."
# Call the core TTS function
output_path, error_msg = run_voice_clone_tts(
text,
prompt_speech,
prompt_text,
processor,
model,
device
)
if error_msg:
return None, error_msg # Return error message to status_clone
else:
# Return the audio file path and a success message (or empty)
return output_path, "Audio generated successfully!"
generate_button_clone.click(
voice_clone_callback,
inputs=[
text_input_clone,
prompt_text_input,
prompt_wav_upload,
prompt_wav_record,
],
outputs=[audio_output_clone, status_clone], # Update both audio and status
)
# Examples need actual audio files in an 'examples' directory in your Space repo
# Make sure 'examples/sample_prompt.wav' exists or change the path
gr.Examples(
examples=[
["Hello, this is a test of voice cloning.", "I am a sample reference voice.", "examples/sample_prompt.wav", None],
["You can experiment with different voices and texts.", None, None, "examples/sample_record.wav"], # Assuming a recorded sample exists
["The quality of the clone depends on the reference audio.", "This is the reference text.", "examples/another_prompt.wav", None]
],
inputs=[text_input_clone, prompt_text_input, prompt_wav_upload, prompt_wav_record],
outputs=[audio_output_clone, status_clone],
fn=voice_clone_callback,
cache_examples=False, # Disable caching if examples might change or for demos
label="Clone Examples"
)
# --- Voice Creation Tab ---
with gr.TabItem("Voice Creation"):
gr.Markdown(
"### Create Your Own Voice Based on the Following Parameters"
)
gr.Markdown(
"Select gender, adjust pitch and speed to generate a new synthetic voice."
)
with gr.Row():
with gr.Column(scale=1):
gender = gr.Radio(
choices=["male", "female"], value="female", label="Gender"
)
pitch = gr.Slider(
minimum=1, maximum=5, step=1, value=3, label="Pitch (1=Lowest, 5=Highest)"
)
speed = gr.Slider(
minimum=1, maximum=5, step=1, value=3, label="Speed (1=Slowest, 5=Fastest)"
)
with gr.Column(scale=2):
text_input_creation = gr.Textbox(
label="Text to Synthesize",
lines=5,
placeholder="Enter text here...",
value="You can generate a customized voice by adjusting parameters such as pitch and speed.",
)
audio_output_creation = gr.Audio(
label="Generated Audio",
autoplay=False,
)
status_create = gr.Textbox(label="Status", interactive=False) # For status/error messages
create_button = gr.Button("Create New Voice", variant="primary", interactive=MODEL_LOADED)
def voice_creation_callback(text, gender, pitch_val, speed_val):
# Call the core TTS function
output_path, error_msg = run_voice_creation_tts(
text,
gender,
int(pitch_val), # Convert slider value to int
int(speed_val), # Convert slider value to int
processor,
model,
device
)
if error_msg:
return None, error_msg
else:
return output_path, "Audio generated successfully!"
create_button.click(
voice_creation_callback,
inputs=[text_input_creation, gender, pitch, speed],
outputs=[audio_output_creation, status_create],
)
gr.Examples(
examples=[
["This is a female voice with average pitch and speed.", "female", 3, 3],
["This is a male voice, speaking quickly with a slightly higher pitch.", "male", 4, 4],
["A deep and slow female voice.", "female", 1, 2],
["A very high-pitched and fast male voice.", "male", 5, 5]
],
inputs=[text_input_creation, gender, pitch, speed],
outputs=[audio_output_creation, status_create],
fn=voice_creation_callback,
cache_examples=False,
label="Creation Examples"
)
return demo
# --- Launch the Gradio App ---
if __name__ == "__main__":
demo = build_ui()
demo.launch() |