# Copyright (c) 2025 SparkAudio & DragonLineageAI # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import torch import soundfile as sf import logging import gradio as gr import platform import numpy as np from pathlib import Path from datetime import datetime import tempfile # To handle temporary audio files for Gradio # --- Import Transformers --- from transformers import AutoProcessor, AutoModel # --- Configuration --- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') model_id = "DragonLineageAI/Vi-Spark-TTS-0.5B-v2" cache_dir = "model_cache" # Define a cache directory within the Space # Mapping from Gradio Slider (1-5) to model's expected string values # Adjust these strings if the model expects different ones (e.g., "slow", "fast") LEVELS_MAP_UI = { 1: "very_low", # Or "slowest" / "lowest" 2: "low", # Or "slow" / "low" 3: "moderate", # Or "normal" / "medium" 4: "high", # Or "fast" / "high" 5: "very_high" # Or "fastest" / "highest" } # --- Model Loading --- def load_model_and_processor(model_id, cache_dir): """Loads the Processor and Model using Transformers.""" logging.info(f"Loading processor from: {model_id}") try: processor = AutoProcessor.from_pretrained( model_id, trust_remote_code=True, # token=api_key, # Use token only if necessary and ideally from secrets cache_dir=cache_dir ) logging.info("Processor loaded successfully.") except Exception as e: logging.error(f"Error loading processor: {e}") raise logging.info(f"Loading model from: {model_id}") try: model = AutoModel.from_pretrained( model_id, trust_remote_code=True, cache_dir=cache_dir, # torch_dtype=torch.float16 # Optional: uncomment for potential speedup/memory saving if supported ) model.eval() # Set model to evaluation mode logging.info("Model loaded successfully.") except Exception as e: logging.error(f"Error loading model: {e}") raise # --- Link Model to Processor --- # THIS STEP IS CRUCIAL processor.model = model logging.info("Model reference set in processor.") # Sync sampling rate if necessary if hasattr(model.config, 'sample_rate') and processor.sampling_rate != model.config.sample_rate: logging.warning(f"Processor SR ({processor.sampling_rate}) != Model Config SR ({model.config.sample_rate}). Updating processor.") processor.sampling_rate = model.config.sample_rate # --- Device Selection --- if torch.cuda.is_available(): device = torch.device("cuda") elif platform.system() == "Darwin" and torch.backends.mps.is_available(): # Check for MPS availability specifically device = torch.device("mps") else: device = torch.device("cpu") logging.info(f"Selected device: {device}") model.to(device) logging.info(f"Model moved to device: {device}") return processor, model, device # --- Load Model Globally (once per Space instance) --- try: processor, model, device = load_model_and_processor(model_id, cache_dir) MODEL_LOADED = True except Exception as e: MODEL_LOADED = False logging.error(f"Failed to load model/processor: {e}") # You might want to display an error in the Gradio UI if loading fails # --- Core TTS Functions --- def run_voice_clone_tts( text, prompt_speech_path, prompt_text, processor, model, device, ): """Performs voice cloning TTS using Transformers.""" if not MODEL_LOADED: return None, "Error: Model not loaded." if not text: return None, "Error: Please provide text to synthesize." if not prompt_speech_path: return None, "Error: Please provide a prompt audio file (upload or record)." logging.info("Starting voice cloning inference...") logging.info(f"Inputs - Text: '{text}', Prompt Audio: {prompt_speech_path}, Prompt Text: '{prompt_text}'") try: # Ensure prompt_text is None if empty/short, otherwise use it prompt_text_clean = None if not prompt_text or len(prompt_text.strip()) < 2 else prompt_text.strip() # 1. Preprocess using Processor inputs = processor( text=text.lower(), prompt_speech_path=prompt_speech_path, prompt_text=prompt_text_clean.lower() if prompt_text_clean else prompt_text_clean, return_tensors="pt" ).to(device) # Move processor output to model device # Store prompt global tokens if present (important for decoding) global_tokens_prompt = inputs.pop("global_token_ids_prompt", None) if global_tokens_prompt is None: logging.warning("global_token_ids_prompt not found in processor output. Decoding might be affected.") # 2. Generate using Model with torch.no_grad(): # Use generate parameters consistent with the original pipeline/model card # Adjust max_new_tokens based on expected output length vs input length # A fixed large value might be okay, or calculate dynamically if needed. output_ids = model.generate( **inputs, max_new_tokens=3000, # Safeguard, might need adjustment do_sample=True, temperature=0.8, top_k=50, top_p=0.95, eos_token_id=processor.tokenizer.eos_token_id, pad_token_id=processor.tokenizer.pad_token_id # Use EOS if PAD is None ) # 3. Decode using Processor output_clone = processor.decode( generated_ids=output_ids, global_token_ids_prompt=global_tokens_prompt, input_ids_len=inputs["input_ids"].shape[-1] # Pass prompt length ) # Save audio to a temporary file for Gradio with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile: sf.write(tmpfile.name, output_clone["audio"], output_clone["sampling_rate"]) output_path = tmpfile.name logging.info(f"Voice cloning successful. Audio saved temporarily at: {output_path}") return output_path, None # Return path and no error message except Exception as e: logging.error(f"Error during voice cloning inference: {e}", exc_info=True) return None, f"Error during generation: {e}" def run_voice_creation_tts( text, gender, pitch_level, # Expecting 1-5 speed_level, # Expecting 1-5 processor, model, device, ): """Performs voice creation TTS using Transformers.""" if not MODEL_LOADED: return None, "Error: Model not loaded." if not text: return None, "Error: Please provide text to synthesize." # Map numeric levels to string representations pitch_str = LEVELS_MAP_UI.get(pitch_level, "moderate") # Default to moderate if invalid speed_str = LEVELS_MAP_UI.get(speed_level, "moderate") # Default to moderate if invalid logging.info("Starting voice creation inference...") logging.info(f"Inputs - Text: '{text}', Gender: {gender}, Pitch: {pitch_str} (Level {pitch_level}), Speed: {speed_str} (Level {speed_level})") try: # 1. Preprocess inputs = processor( text=text.lower(), # prompt_speech_path=None, # No audio prompt for creation # prompt_text=None, # No text prompt for creation gender=gender, pitch=pitch_str, speed=speed_str, return_tensors="pt" ).to(device) # 2. Generate with torch.no_grad(): output_ids = model.generate( **inputs, max_new_tokens=3000, # Safeguard do_sample=True, temperature=0.8, top_k=50, top_p=0.95, eos_token_id=processor.tokenizer.eos_token_id, pad_token_id=processor.tokenizer.pad_token_id ) # 3. Decode (no prompt global tokens needed here) output_create = processor.decode( generated_ids=output_ids, input_ids_len=inputs["input_ids"].shape[-1] # Pass prompt length ) # Save audio to a temporary file for Gradio with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile: sf.write(tmpfile.name, output_create["audio"], output_create["sampling_rate"]) output_path = tmpfile.name logging.info(f"Voice creation successful. Audio saved temporarily at: {output_path}") return output_path, None # Return path and no error message except Exception as e: logging.error(f"Error during voice creation inference: {e}", exc_info=True) return None, f"Error during generation: {e}" # --- Gradio UI --- def build_ui(): with gr.Blocks() as demo: gr.HTML('

Spark-TTS Demo (Transformers)

') # Changed title slightly gr.Markdown( "Powered by [DragonLineageAI/Vi-Spark-TTS-0.5B-v2](https://huggingface.co/DragonLineageAI/Vi-Spark-TTS-0.5B-v2). " "Choose a tab for Voice Cloning or Voice Creation." ) if not MODEL_LOADED: gr.Markdown("## ⚠️ Error: Model failed to load. Please check the Space logs.") with gr.Tabs(): # --- Voice Clone Tab --- with gr.TabItem("Voice Clone"): gr.Markdown( "### Upload Reference Audio or Record" ) gr.Markdown( "Provide a short audio clip (5-20 seconds) of the voice you want to clone. " "Optionally, provide the transcript of that audio for better results, especially if the language is the same as the text you want to synthesize." ) with gr.Row(): prompt_wav_upload = gr.Audio( sources=["upload"], type="filepath", label="Upload Prompt Audio File (WAV/MP3)", ) prompt_wav_record = gr.Audio( sources=["microphone"], type="filepath", label="Or Record Prompt Audio", ) with gr.Row(): text_input_clone = gr.Textbox( label="Text to Synthesize", lines=4, placeholder="Enter text here..." ) prompt_text_input = gr.Textbox( label="Text of Prompt Speech (Optional)", lines=2, placeholder="Enter the transcript of the prompt audio (if available).", info="Recommended for cloning in the same language." # Added info here ) audio_output_clone = gr.Audio( label="Generated Audio", autoplay=False, ) status_clone = gr.Textbox(label="Status", interactive=False) # For status/error messages generate_button_clone = gr.Button("Generate Cloned Voice", variant="primary", interactive=MODEL_LOADED) def voice_clone_callback(text, prompt_text, audio_upload, audio_record): # Prioritize uploaded file, fallback to recorded file prompt_speech = audio_upload if audio_upload else audio_record if not prompt_speech: # Return None for the audio component and the error message for the status component return None, "Error: Please upload or record a reference audio." # Call the core TTS function output_path, error_msg = run_voice_clone_tts( text, prompt_speech, prompt_text, processor, model, device ) if error_msg: return None, error_msg # Return error message to status_clone else: # Return the audio file path and a success message (or empty) return output_path, "Audio generated successfully!" generate_button_clone.click( voice_clone_callback, inputs=[ text_input_clone, prompt_text_input, prompt_wav_upload, prompt_wav_record, ], outputs=[audio_output_clone, status_clone], # Update both audio and status ) # Examples need actual audio files in an 'examples' directory in your Space repo # Make sure 'examples/sample_prompt.wav' exists or change the path gr.Examples( examples=[ ["Hello, this is a test of voice cloning.", "I am a sample reference voice.", "examples/sample_prompt.wav", None], ["You can experiment with different voices and texts.", None, None, "examples/sample_record.wav"], # Assuming a recorded sample exists ["The quality of the clone depends on the reference audio.", "This is the reference text.", "examples/another_prompt.wav", None] ], inputs=[text_input_clone, prompt_text_input, prompt_wav_upload, prompt_wav_record], outputs=[audio_output_clone, status_clone], fn=voice_clone_callback, cache_examples=False, # Disable caching if examples might change or for demos label="Clone Examples" ) # --- Voice Creation Tab --- with gr.TabItem("Voice Creation"): gr.Markdown( "### Create Your Own Voice Based on the Following Parameters" ) gr.Markdown( "Select gender, adjust pitch and speed to generate a new synthetic voice." ) with gr.Row(): with gr.Column(scale=1): gender = gr.Radio( choices=["male", "female"], value="female", label="Gender" ) pitch = gr.Slider( minimum=1, maximum=5, step=1, value=3, label="Pitch (1=Lowest, 5=Highest)" ) speed = gr.Slider( minimum=1, maximum=5, step=1, value=3, label="Speed (1=Slowest, 5=Fastest)" ) with gr.Column(scale=2): text_input_creation = gr.Textbox( label="Text to Synthesize", lines=5, placeholder="Enter text here...", value="You can generate a customized voice by adjusting parameters such as pitch and speed.", ) audio_output_creation = gr.Audio( label="Generated Audio", autoplay=False, ) status_create = gr.Textbox(label="Status", interactive=False) # For status/error messages create_button = gr.Button("Create New Voice", variant="primary", interactive=MODEL_LOADED) def voice_creation_callback(text, gender, pitch_val, speed_val): # Call the core TTS function output_path, error_msg = run_voice_creation_tts( text, gender, int(pitch_val), # Convert slider value to int int(speed_val), # Convert slider value to int processor, model, device ) if error_msg: return None, error_msg else: return output_path, "Audio generated successfully!" create_button.click( voice_creation_callback, inputs=[text_input_creation, gender, pitch, speed], outputs=[audio_output_creation, status_create], ) gr.Examples( examples=[ ["This is a female voice with average pitch and speed.", "female", 3, 3], ["This is a male voice, speaking quickly with a slightly higher pitch.", "male", 4, 4], ["A deep and slow female voice.", "female", 1, 2], ["A very high-pitched and fast male voice.", "male", 5, 5] ], inputs=[text_input_creation, gender, pitch, speed], outputs=[audio_output_creation, status_create], fn=voice_creation_callback, cache_examples=False, label="Creation Examples" ) return demo # --- Launch the Gradio App --- if __name__ == "__main__": demo = build_ui() demo.launch()