ancv's picture
Update app.py
dc56bd0 verified
# Copyright (c) 2025 SparkAudio & DragonLineageAI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
import soundfile as sf
import logging
import gradio as gr
import platform
import numpy as np
from pathlib import Path
from datetime import datetime
import tempfile # To handle temporary audio files for Gradio
# --- Import Transformers ---
from transformers import AutoProcessor, AutoModel
# --- Configuration ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
model_id = "DragonLineageAI/Vi-Spark-TTS-0.5B-v2"
cache_dir = "model_cache" # Define a cache directory within the Space
# Mapping from Gradio Slider (1-5) to model's expected string values
# Adjust these strings if the model expects different ones (e.g., "slow", "fast")
LEVELS_MAP_UI = {
1: "very_low", # Or "slowest" / "lowest"
2: "low", # Or "slow" / "low"
3: "moderate", # Or "normal" / "medium"
4: "high", # Or "fast" / "high"
5: "very_high" # Or "fastest" / "highest"
}
# --- Model Loading ---
def load_model_and_processor(model_id, cache_dir):
"""Loads the Processor and Model using Transformers."""
logging.info(f"Loading processor from: {model_id}")
try:
processor = AutoProcessor.from_pretrained(
model_id,
trust_remote_code=True,
# token=api_key, # Use token only if necessary and ideally from secrets
cache_dir=cache_dir
)
logging.info("Processor loaded successfully.")
except Exception as e:
logging.error(f"Error loading processor: {e}")
raise
logging.info(f"Loading model from: {model_id}")
try:
model = AutoModel.from_pretrained(
model_id,
trust_remote_code=True,
cache_dir=cache_dir,
# torch_dtype=torch.float16 # Optional: uncomment for potential speedup/memory saving if supported
)
model.eval() # Set model to evaluation mode
logging.info("Model loaded successfully.")
except Exception as e:
logging.error(f"Error loading model: {e}")
raise
# --- Link Model to Processor ---
# THIS STEP IS CRUCIAL
processor.model = model
logging.info("Model reference set in processor.")
# Sync sampling rate if necessary
if hasattr(model.config, 'sample_rate') and processor.sampling_rate != model.config.sample_rate:
logging.warning(f"Processor SR ({processor.sampling_rate}) != Model Config SR ({model.config.sample_rate}). Updating processor.")
processor.sampling_rate = model.config.sample_rate
# --- Device Selection ---
if torch.cuda.is_available():
device = torch.device("cuda")
elif platform.system() == "Darwin" and torch.backends.mps.is_available():
# Check for MPS availability specifically
device = torch.device("mps")
else:
device = torch.device("cpu")
logging.info(f"Selected device: {device}")
model.to(device)
logging.info(f"Model moved to device: {device}")
return processor, model, device
# --- Load Model Globally (once per Space instance) ---
try:
processor, model, device = load_model_and_processor(model_id, cache_dir)
MODEL_LOADED = True
except Exception as e:
MODEL_LOADED = False
logging.error(f"Failed to load model/processor: {e}")
# You might want to display an error in the Gradio UI if loading fails
# --- Core TTS Functions ---
def run_voice_clone_tts(
text,
prompt_speech_path,
prompt_text,
processor,
model,
device,
):
"""Performs voice cloning TTS using Transformers."""
if not MODEL_LOADED:
return None, "Error: Model not loaded."
if not text:
return None, "Error: Please provide text to synthesize."
if not prompt_speech_path:
return None, "Error: Please provide a prompt audio file (upload or record)."
logging.info("Starting voice cloning inference...")
logging.info(f"Inputs - Text: '{text}', Prompt Audio: {prompt_speech_path}, Prompt Text: '{prompt_text}'")
try:
# Ensure prompt_text is None if empty/short, otherwise use it
prompt_text_clean = None if not prompt_text or len(prompt_text.strip()) < 2 else prompt_text.strip()
# 1. Preprocess using Processor
inputs = processor(
text=text.lower(),
prompt_speech_path=prompt_speech_path,
prompt_text=prompt_text_clean.lower() if prompt_text_clean else prompt_text_clean,
return_tensors="pt"
).to(device) # Move processor output to model device
# Store prompt global tokens if present (important for decoding)
global_tokens_prompt = inputs.pop("global_token_ids_prompt", None)
if global_tokens_prompt is None:
logging.warning("global_token_ids_prompt not found in processor output. Decoding might be affected.")
# 2. Generate using Model
with torch.no_grad():
# Use generate parameters consistent with the original pipeline/model card
# Adjust max_new_tokens based on expected output length vs input length
# A fixed large value might be okay, or calculate dynamically if needed.
output_ids = model.generate(
**inputs,
max_new_tokens=3000, # Safeguard, might need adjustment
do_sample=True,
temperature=0.8,
top_k=50,
top_p=0.95,
eos_token_id=processor.tokenizer.eos_token_id,
pad_token_id=processor.tokenizer.pad_token_id # Use EOS if PAD is None
)
# 3. Decode using Processor
output_clone = processor.decode(
generated_ids=output_ids,
global_token_ids_prompt=global_tokens_prompt,
input_ids_len=inputs["input_ids"].shape[-1] # Pass prompt length
)
# Save audio to a temporary file for Gradio
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
sf.write(tmpfile.name, output_clone["audio"], output_clone["sampling_rate"])
output_path = tmpfile.name
logging.info(f"Voice cloning successful. Audio saved temporarily at: {output_path}")
return output_path, None # Return path and no error message
except Exception as e:
logging.error(f"Error during voice cloning inference: {e}", exc_info=True)
return None, f"Error during generation: {e}"
def run_voice_creation_tts(
text,
gender,
pitch_level, # Expecting 1-5
speed_level, # Expecting 1-5
processor,
model,
device,
):
"""Performs voice creation TTS using Transformers."""
if not MODEL_LOADED:
return None, "Error: Model not loaded."
if not text:
return None, "Error: Please provide text to synthesize."
# Map numeric levels to string representations
pitch_str = LEVELS_MAP_UI.get(pitch_level, "moderate") # Default to moderate if invalid
speed_str = LEVELS_MAP_UI.get(speed_level, "moderate") # Default to moderate if invalid
logging.info("Starting voice creation inference...")
logging.info(f"Inputs - Text: '{text}', Gender: {gender}, Pitch: {pitch_str} (Level {pitch_level}), Speed: {speed_str} (Level {speed_level})")
try:
# 1. Preprocess
inputs = processor(
text=text.lower(),
# prompt_speech_path=None, # No audio prompt for creation
# prompt_text=None, # No text prompt for creation
gender=gender,
pitch=pitch_str,
speed=speed_str,
return_tensors="pt"
).to(device)
# 2. Generate
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=3000, # Safeguard
do_sample=True,
temperature=0.8,
top_k=50,
top_p=0.95,
eos_token_id=processor.tokenizer.eos_token_id,
pad_token_id=processor.tokenizer.pad_token_id
)
# 3. Decode (no prompt global tokens needed here)
output_create = processor.decode(
generated_ids=output_ids,
input_ids_len=inputs["input_ids"].shape[-1] # Pass prompt length
)
# Save audio to a temporary file for Gradio
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
sf.write(tmpfile.name, output_create["audio"], output_create["sampling_rate"])
output_path = tmpfile.name
logging.info(f"Voice creation successful. Audio saved temporarily at: {output_path}")
return output_path, None # Return path and no error message
except Exception as e:
logging.error(f"Error during voice creation inference: {e}", exc_info=True)
return None, f"Error during generation: {e}"
# --- Gradio UI ---
def build_ui():
with gr.Blocks() as demo:
gr.HTML('<h1 style="text-align: center;">Spark-TTS Demo (Transformers)</h1>') # Changed title slightly
gr.Markdown(
"Powered by [DragonLineageAI/Vi-Spark-TTS-0.5B-v2](https://huggingface.co/DragonLineageAI/Vi-Spark-TTS-0.5B-v2). "
"Choose a tab for Voice Cloning or Voice Creation."
)
if not MODEL_LOADED:
gr.Markdown("## ⚠️ Error: Model failed to load. Please check the Space logs.")
with gr.Tabs():
# --- Voice Clone Tab ---
with gr.TabItem("Voice Clone"):
gr.Markdown(
"### Upload Reference Audio or Record"
)
gr.Markdown(
"Provide a short audio clip (5-20 seconds) of the voice you want to clone. "
"Optionally, provide the transcript of that audio for better results, especially if the language is the same as the text you want to synthesize."
)
with gr.Row():
prompt_wav_upload = gr.Audio(
sources=["upload"],
type="filepath",
label="Upload Prompt Audio File (WAV/MP3)",
)
prompt_wav_record = gr.Audio(
sources=["microphone"],
type="filepath",
label="Or Record Prompt Audio",
)
with gr.Row():
text_input_clone = gr.Textbox(
label="Text to Synthesize",
lines=4,
placeholder="Enter text here..."
)
prompt_text_input = gr.Textbox(
label="Text of Prompt Speech (Optional)",
lines=2,
placeholder="Enter the transcript of the prompt audio (if available).",
info="Recommended for cloning in the same language." # Added info here
)
audio_output_clone = gr.Audio(
label="Generated Audio",
autoplay=False,
)
status_clone = gr.Textbox(label="Status", interactive=False) # For status/error messages
generate_button_clone = gr.Button("Generate Cloned Voice", variant="primary", interactive=MODEL_LOADED)
def voice_clone_callback(text, prompt_text, audio_upload, audio_record):
# Prioritize uploaded file, fallback to recorded file
prompt_speech = audio_upload if audio_upload else audio_record
if not prompt_speech:
# Return None for the audio component and the error message for the status component
return None, "Error: Please upload or record a reference audio."
# Call the core TTS function
output_path, error_msg = run_voice_clone_tts(
text,
prompt_speech,
prompt_text,
processor,
model,
device
)
if error_msg:
return None, error_msg # Return error message to status_clone
else:
# Return the audio file path and a success message (or empty)
return output_path, "Audio generated successfully!"
generate_button_clone.click(
voice_clone_callback,
inputs=[
text_input_clone,
prompt_text_input,
prompt_wav_upload,
prompt_wav_record,
],
outputs=[audio_output_clone, status_clone], # Update both audio and status
)
# Examples need actual audio files in an 'examples' directory in your Space repo
# Make sure 'examples/sample_prompt.wav' exists or change the path
gr.Examples(
examples=[
["Hello, this is a test of voice cloning.", "I am a sample reference voice.", "examples/sample_prompt.wav", None],
["You can experiment with different voices and texts.", None, None, "examples/sample_record.wav"], # Assuming a recorded sample exists
["The quality of the clone depends on the reference audio.", "This is the reference text.", "examples/another_prompt.wav", None]
],
inputs=[text_input_clone, prompt_text_input, prompt_wav_upload, prompt_wav_record],
outputs=[audio_output_clone, status_clone],
fn=voice_clone_callback,
cache_examples=False, # Disable caching if examples might change or for demos
label="Clone Examples"
)
# --- Voice Creation Tab ---
with gr.TabItem("Voice Creation"):
gr.Markdown(
"### Create Your Own Voice Based on the Following Parameters"
)
gr.Markdown(
"Select gender, adjust pitch and speed to generate a new synthetic voice."
)
with gr.Row():
with gr.Column(scale=1):
gender = gr.Radio(
choices=["male", "female"], value="female", label="Gender"
)
pitch = gr.Slider(
minimum=1, maximum=5, step=1, value=3, label="Pitch (1=Lowest, 5=Highest)"
)
speed = gr.Slider(
minimum=1, maximum=5, step=1, value=3, label="Speed (1=Slowest, 5=Fastest)"
)
with gr.Column(scale=2):
text_input_creation = gr.Textbox(
label="Text to Synthesize",
lines=5,
placeholder="Enter text here...",
value="You can generate a customized voice by adjusting parameters such as pitch and speed.",
)
audio_output_creation = gr.Audio(
label="Generated Audio",
autoplay=False,
)
status_create = gr.Textbox(label="Status", interactive=False) # For status/error messages
create_button = gr.Button("Create New Voice", variant="primary", interactive=MODEL_LOADED)
def voice_creation_callback(text, gender, pitch_val, speed_val):
# Call the core TTS function
output_path, error_msg = run_voice_creation_tts(
text,
gender,
int(pitch_val), # Convert slider value to int
int(speed_val), # Convert slider value to int
processor,
model,
device
)
if error_msg:
return None, error_msg
else:
return output_path, "Audio generated successfully!"
create_button.click(
voice_creation_callback,
inputs=[text_input_creation, gender, pitch, speed],
outputs=[audio_output_creation, status_create],
)
gr.Examples(
examples=[
["This is a female voice with average pitch and speed.", "female", 3, 3],
["This is a male voice, speaking quickly with a slightly higher pitch.", "male", 4, 4],
["A deep and slow female voice.", "female", 1, 2],
["A very high-pitched and fast male voice.", "male", 5, 5]
],
inputs=[text_input_creation, gender, pitch, speed],
outputs=[audio_output_creation, status_create],
fn=voice_creation_callback,
cache_examples=False,
label="Creation Examples"
)
return demo
# --- Launch the Gradio App ---
if __name__ == "__main__":
demo = build_ui()
demo.launch()