Spaces:
Sleeping
Sleeping
| # Copyright (c) 2025 SparkAudio & DragonLineageAI | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import torch | |
| import soundfile as sf | |
| import logging | |
| import gradio as gr | |
| import platform | |
| import numpy as np | |
| from pathlib import Path | |
| from datetime import datetime | |
| import tempfile # To handle temporary audio files for Gradio | |
| # --- Import Transformers --- | |
| from transformers import AutoProcessor, AutoModel | |
| # --- Configuration --- | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| model_id = "DragonLineageAI/Vi-Spark-TTS-0.5B-v2" | |
| cache_dir = "model_cache" # Define a cache directory within the Space | |
| # Mapping from Gradio Slider (1-5) to model's expected string values | |
| # Adjust these strings if the model expects different ones (e.g., "slow", "fast") | |
| LEVELS_MAP_UI = { | |
| 1: "very_low", # Or "slowest" / "lowest" | |
| 2: "low", # Or "slow" / "low" | |
| 3: "moderate", # Or "normal" / "medium" | |
| 4: "high", # Or "fast" / "high" | |
| 5: "very_high" # Or "fastest" / "highest" | |
| } | |
| # --- Model Loading --- | |
| def load_model_and_processor(model_id, cache_dir): | |
| """Loads the Processor and Model using Transformers.""" | |
| logging.info(f"Loading processor from: {model_id}") | |
| try: | |
| processor = AutoProcessor.from_pretrained( | |
| model_id, | |
| trust_remote_code=True, | |
| # token=api_key, # Use token only if necessary and ideally from secrets | |
| cache_dir=cache_dir | |
| ) | |
| logging.info("Processor loaded successfully.") | |
| except Exception as e: | |
| logging.error(f"Error loading processor: {e}") | |
| raise | |
| logging.info(f"Loading model from: {model_id}") | |
| try: | |
| model = AutoModel.from_pretrained( | |
| model_id, | |
| trust_remote_code=True, | |
| cache_dir=cache_dir, | |
| # torch_dtype=torch.float16 # Optional: uncomment for potential speedup/memory saving if supported | |
| ) | |
| model.eval() # Set model to evaluation mode | |
| logging.info("Model loaded successfully.") | |
| except Exception as e: | |
| logging.error(f"Error loading model: {e}") | |
| raise | |
| # --- Link Model to Processor --- | |
| # THIS STEP IS CRUCIAL | |
| processor.model = model | |
| logging.info("Model reference set in processor.") | |
| # Sync sampling rate if necessary | |
| if hasattr(model.config, 'sample_rate') and processor.sampling_rate != model.config.sample_rate: | |
| logging.warning(f"Processor SR ({processor.sampling_rate}) != Model Config SR ({model.config.sample_rate}). Updating processor.") | |
| processor.sampling_rate = model.config.sample_rate | |
| # --- Device Selection --- | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| elif platform.system() == "Darwin" and torch.backends.mps.is_available(): | |
| # Check for MPS availability specifically | |
| device = torch.device("mps") | |
| else: | |
| device = torch.device("cpu") | |
| logging.info(f"Selected device: {device}") | |
| model.to(device) | |
| logging.info(f"Model moved to device: {device}") | |
| return processor, model, device | |
| # --- Load Model Globally (once per Space instance) --- | |
| try: | |
| processor, model, device = load_model_and_processor(model_id, cache_dir) | |
| MODEL_LOADED = True | |
| except Exception as e: | |
| MODEL_LOADED = False | |
| logging.error(f"Failed to load model/processor: {e}") | |
| # You might want to display an error in the Gradio UI if loading fails | |
| # --- Core TTS Functions --- | |
| def run_voice_clone_tts( | |
| text, | |
| prompt_speech_path, | |
| prompt_text, | |
| processor, | |
| model, | |
| device, | |
| ): | |
| """Performs voice cloning TTS using Transformers.""" | |
| if not MODEL_LOADED: | |
| return None, "Error: Model not loaded." | |
| if not text: | |
| return None, "Error: Please provide text to synthesize." | |
| if not prompt_speech_path: | |
| return None, "Error: Please provide a prompt audio file (upload or record)." | |
| logging.info("Starting voice cloning inference...") | |
| logging.info(f"Inputs - Text: '{text}', Prompt Audio: {prompt_speech_path}, Prompt Text: '{prompt_text}'") | |
| try: | |
| # Ensure prompt_text is None if empty/short, otherwise use it | |
| prompt_text_clean = None if not prompt_text or len(prompt_text.strip()) < 2 else prompt_text.strip() | |
| # 1. Preprocess using Processor | |
| inputs = processor( | |
| text=text.lower(), | |
| prompt_speech_path=prompt_speech_path, | |
| prompt_text=prompt_text_clean.lower() if prompt_text_clean else prompt_text_clean, | |
| return_tensors="pt" | |
| ).to(device) # Move processor output to model device | |
| # Store prompt global tokens if present (important for decoding) | |
| global_tokens_prompt = inputs.pop("global_token_ids_prompt", None) | |
| if global_tokens_prompt is None: | |
| logging.warning("global_token_ids_prompt not found in processor output. Decoding might be affected.") | |
| # 2. Generate using Model | |
| with torch.no_grad(): | |
| # Use generate parameters consistent with the original pipeline/model card | |
| # Adjust max_new_tokens based on expected output length vs input length | |
| # A fixed large value might be okay, or calculate dynamically if needed. | |
| output_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=3000, # Safeguard, might need adjustment | |
| do_sample=True, | |
| temperature=0.8, | |
| top_k=50, | |
| top_p=0.95, | |
| eos_token_id=processor.tokenizer.eos_token_id, | |
| pad_token_id=processor.tokenizer.pad_token_id # Use EOS if PAD is None | |
| ) | |
| # 3. Decode using Processor | |
| output_clone = processor.decode( | |
| generated_ids=output_ids, | |
| global_token_ids_prompt=global_tokens_prompt, | |
| input_ids_len=inputs["input_ids"].shape[-1] # Pass prompt length | |
| ) | |
| # Save audio to a temporary file for Gradio | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile: | |
| sf.write(tmpfile.name, output_clone["audio"], output_clone["sampling_rate"]) | |
| output_path = tmpfile.name | |
| logging.info(f"Voice cloning successful. Audio saved temporarily at: {output_path}") | |
| return output_path, None # Return path and no error message | |
| except Exception as e: | |
| logging.error(f"Error during voice cloning inference: {e}", exc_info=True) | |
| return None, f"Error during generation: {e}" | |
| def run_voice_creation_tts( | |
| text, | |
| gender, | |
| pitch_level, # Expecting 1-5 | |
| speed_level, # Expecting 1-5 | |
| processor, | |
| model, | |
| device, | |
| ): | |
| """Performs voice creation TTS using Transformers.""" | |
| if not MODEL_LOADED: | |
| return None, "Error: Model not loaded." | |
| if not text: | |
| return None, "Error: Please provide text to synthesize." | |
| # Map numeric levels to string representations | |
| pitch_str = LEVELS_MAP_UI.get(pitch_level, "moderate") # Default to moderate if invalid | |
| speed_str = LEVELS_MAP_UI.get(speed_level, "moderate") # Default to moderate if invalid | |
| logging.info("Starting voice creation inference...") | |
| logging.info(f"Inputs - Text: '{text}', Gender: {gender}, Pitch: {pitch_str} (Level {pitch_level}), Speed: {speed_str} (Level {speed_level})") | |
| try: | |
| # 1. Preprocess | |
| inputs = processor( | |
| text=text.lower(), | |
| # prompt_speech_path=None, # No audio prompt for creation | |
| # prompt_text=None, # No text prompt for creation | |
| gender=gender, | |
| pitch=pitch_str, | |
| speed=speed_str, | |
| return_tensors="pt" | |
| ).to(device) | |
| # 2. Generate | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=3000, # Safeguard | |
| do_sample=True, | |
| temperature=0.8, | |
| top_k=50, | |
| top_p=0.95, | |
| eos_token_id=processor.tokenizer.eos_token_id, | |
| pad_token_id=processor.tokenizer.pad_token_id | |
| ) | |
| # 3. Decode (no prompt global tokens needed here) | |
| output_create = processor.decode( | |
| generated_ids=output_ids, | |
| input_ids_len=inputs["input_ids"].shape[-1] # Pass prompt length | |
| ) | |
| # Save audio to a temporary file for Gradio | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile: | |
| sf.write(tmpfile.name, output_create["audio"], output_create["sampling_rate"]) | |
| output_path = tmpfile.name | |
| logging.info(f"Voice creation successful. Audio saved temporarily at: {output_path}") | |
| return output_path, None # Return path and no error message | |
| except Exception as e: | |
| logging.error(f"Error during voice creation inference: {e}", exc_info=True) | |
| return None, f"Error during generation: {e}" | |
| # --- Gradio UI --- | |
| def build_ui(): | |
| with gr.Blocks() as demo: | |
| gr.HTML('<h1 style="text-align: center;">Spark-TTS Demo (Transformers)</h1>') # Changed title slightly | |
| gr.Markdown( | |
| "Powered by [DragonLineageAI/Vi-Spark-TTS-0.5B-v2](https://huggingface.co/DragonLineageAI/Vi-Spark-TTS-0.5B-v2). " | |
| "Choose a tab for Voice Cloning or Voice Creation." | |
| ) | |
| if not MODEL_LOADED: | |
| gr.Markdown("## ⚠️ Error: Model failed to load. Please check the Space logs.") | |
| with gr.Tabs(): | |
| # --- Voice Clone Tab --- | |
| with gr.TabItem("Voice Clone"): | |
| gr.Markdown( | |
| "### Upload Reference Audio or Record" | |
| ) | |
| gr.Markdown( | |
| "Provide a short audio clip (5-20 seconds) of the voice you want to clone. " | |
| "Optionally, provide the transcript of that audio for better results, especially if the language is the same as the text you want to synthesize." | |
| ) | |
| with gr.Row(): | |
| prompt_wav_upload = gr.Audio( | |
| sources=["upload"], | |
| type="filepath", | |
| label="Upload Prompt Audio File (WAV/MP3)", | |
| ) | |
| prompt_wav_record = gr.Audio( | |
| sources=["microphone"], | |
| type="filepath", | |
| label="Or Record Prompt Audio", | |
| ) | |
| with gr.Row(): | |
| text_input_clone = gr.Textbox( | |
| label="Text to Synthesize", | |
| lines=4, | |
| placeholder="Enter text here..." | |
| ) | |
| prompt_text_input = gr.Textbox( | |
| label="Text of Prompt Speech (Optional)", | |
| lines=2, | |
| placeholder="Enter the transcript of the prompt audio (if available).", | |
| info="Recommended for cloning in the same language." # Added info here | |
| ) | |
| audio_output_clone = gr.Audio( | |
| label="Generated Audio", | |
| autoplay=False, | |
| ) | |
| status_clone = gr.Textbox(label="Status", interactive=False) # For status/error messages | |
| generate_button_clone = gr.Button("Generate Cloned Voice", variant="primary", interactive=MODEL_LOADED) | |
| def voice_clone_callback(text, prompt_text, audio_upload, audio_record): | |
| # Prioritize uploaded file, fallback to recorded file | |
| prompt_speech = audio_upload if audio_upload else audio_record | |
| if not prompt_speech: | |
| # Return None for the audio component and the error message for the status component | |
| return None, "Error: Please upload or record a reference audio." | |
| # Call the core TTS function | |
| output_path, error_msg = run_voice_clone_tts( | |
| text, | |
| prompt_speech, | |
| prompt_text, | |
| processor, | |
| model, | |
| device | |
| ) | |
| if error_msg: | |
| return None, error_msg # Return error message to status_clone | |
| else: | |
| # Return the audio file path and a success message (or empty) | |
| return output_path, "Audio generated successfully!" | |
| generate_button_clone.click( | |
| voice_clone_callback, | |
| inputs=[ | |
| text_input_clone, | |
| prompt_text_input, | |
| prompt_wav_upload, | |
| prompt_wav_record, | |
| ], | |
| outputs=[audio_output_clone, status_clone], # Update both audio and status | |
| ) | |
| # Examples need actual audio files in an 'examples' directory in your Space repo | |
| # Make sure 'examples/sample_prompt.wav' exists or change the path | |
| gr.Examples( | |
| examples=[ | |
| ["Hello, this is a test of voice cloning.", "I am a sample reference voice.", "examples/sample_prompt.wav", None], | |
| ["You can experiment with different voices and texts.", None, None, "examples/sample_record.wav"], # Assuming a recorded sample exists | |
| ["The quality of the clone depends on the reference audio.", "This is the reference text.", "examples/another_prompt.wav", None] | |
| ], | |
| inputs=[text_input_clone, prompt_text_input, prompt_wav_upload, prompt_wav_record], | |
| outputs=[audio_output_clone, status_clone], | |
| fn=voice_clone_callback, | |
| cache_examples=False, # Disable caching if examples might change or for demos | |
| label="Clone Examples" | |
| ) | |
| # --- Voice Creation Tab --- | |
| with gr.TabItem("Voice Creation"): | |
| gr.Markdown( | |
| "### Create Your Own Voice Based on the Following Parameters" | |
| ) | |
| gr.Markdown( | |
| "Select gender, adjust pitch and speed to generate a new synthetic voice." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gender = gr.Radio( | |
| choices=["male", "female"], value="female", label="Gender" | |
| ) | |
| pitch = gr.Slider( | |
| minimum=1, maximum=5, step=1, value=3, label="Pitch (1=Lowest, 5=Highest)" | |
| ) | |
| speed = gr.Slider( | |
| minimum=1, maximum=5, step=1, value=3, label="Speed (1=Slowest, 5=Fastest)" | |
| ) | |
| with gr.Column(scale=2): | |
| text_input_creation = gr.Textbox( | |
| label="Text to Synthesize", | |
| lines=5, | |
| placeholder="Enter text here...", | |
| value="You can generate a customized voice by adjusting parameters such as pitch and speed.", | |
| ) | |
| audio_output_creation = gr.Audio( | |
| label="Generated Audio", | |
| autoplay=False, | |
| ) | |
| status_create = gr.Textbox(label="Status", interactive=False) # For status/error messages | |
| create_button = gr.Button("Create New Voice", variant="primary", interactive=MODEL_LOADED) | |
| def voice_creation_callback(text, gender, pitch_val, speed_val): | |
| # Call the core TTS function | |
| output_path, error_msg = run_voice_creation_tts( | |
| text, | |
| gender, | |
| int(pitch_val), # Convert slider value to int | |
| int(speed_val), # Convert slider value to int | |
| processor, | |
| model, | |
| device | |
| ) | |
| if error_msg: | |
| return None, error_msg | |
| else: | |
| return output_path, "Audio generated successfully!" | |
| create_button.click( | |
| voice_creation_callback, | |
| inputs=[text_input_creation, gender, pitch, speed], | |
| outputs=[audio_output_creation, status_create], | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["This is a female voice with average pitch and speed.", "female", 3, 3], | |
| ["This is a male voice, speaking quickly with a slightly higher pitch.", "male", 4, 4], | |
| ["A deep and slow female voice.", "female", 1, 2], | |
| ["A very high-pitched and fast male voice.", "male", 5, 5] | |
| ], | |
| inputs=[text_input_creation, gender, pitch, speed], | |
| outputs=[audio_output_creation, status_create], | |
| fn=voice_creation_callback, | |
| cache_examples=False, | |
| label="Creation Examples" | |
| ) | |
| return demo | |
| # --- Launch the Gradio App --- | |
| if __name__ == "__main__": | |
| demo = build_ui() | |
| demo.launch() |