# Copyright (c) 2025 SparkAudio & DragonLineageAI # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import torch import soundfile as sf import logging import gradio as gr import platform import numpy as np from pathlib import Path from datetime import datetime import tempfile # To handle temporary audio files for Gradio # --- Import Transformers --- from transformers import AutoProcessor, AutoModel # --- Configuration --- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') model_id = "DragonLineageAI/Vi-Spark-TTS-0.5B-v2" cache_dir = "model_cache" # Define a cache directory within the Space # Mapping from Gradio Slider (1-5) to model's expected string values # Adjust these strings if the model expects different ones (e.g., "slow", "fast") LEVELS_MAP_UI = { 1: "very_low", # Or "slowest" / "lowest" 2: "low", # Or "slow" / "low" 3: "moderate", # Or "normal" / "medium" 4: "high", # Or "fast" / "high" 5: "very_high" # Or "fastest" / "highest" } # --- Model Loading --- def load_model_and_processor(model_id, cache_dir): """Loads the Processor and Model using Transformers.""" logging.info(f"Loading processor from: {model_id}") try: processor = AutoProcessor.from_pretrained( model_id, trust_remote_code=True, # token=api_key, # Use token only if necessary and ideally from secrets cache_dir=cache_dir ) logging.info("Processor loaded successfully.") except Exception as e: logging.error(f"Error loading processor: {e}") raise logging.info(f"Loading model from: {model_id}") try: model = AutoModel.from_pretrained( model_id, trust_remote_code=True, cache_dir=cache_dir, # torch_dtype=torch.float16 # Optional: uncomment for potential speedup/memory saving if supported ) model.eval() # Set model to evaluation mode logging.info("Model loaded successfully.") except Exception as e: logging.error(f"Error loading model: {e}") raise # --- Link Model to Processor --- # THIS STEP IS CRUCIAL processor.model = model logging.info("Model reference set in processor.") # Sync sampling rate if necessary if hasattr(model.config, 'sample_rate') and processor.sampling_rate != model.config.sample_rate: logging.warning(f"Processor SR ({processor.sampling_rate}) != Model Config SR ({model.config.sample_rate}). Updating processor.") processor.sampling_rate = model.config.sample_rate # --- Device Selection --- if torch.cuda.is_available(): device = torch.device("cuda") elif platform.system() == "Darwin" and torch.backends.mps.is_available(): # Check for MPS availability specifically device = torch.device("mps") else: device = torch.device("cpu") logging.info(f"Selected device: {device}") model.to(device) logging.info(f"Model moved to device: {device}") return processor, model, device # --- Load Model Globally (once per Space instance) --- try: processor, model, device = load_model_and_processor(model_id, cache_dir) MODEL_LOADED = True except Exception as e: MODEL_LOADED = False logging.error(f"Failed to load model/processor: {e}") # You might want to display an error in the Gradio UI if loading fails # --- Core TTS Functions --- def run_voice_clone_tts( text, prompt_speech_path, prompt_text, processor, model, device, ): """Performs voice cloning TTS using Transformers.""" if not MODEL_LOADED: return None, "Error: Model not loaded." if not text: return None, "Error: Please provide text to synthesize." if not prompt_speech_path: return None, "Error: Please provide a prompt audio file (upload or record)." logging.info("Starting voice cloning inference...") logging.info(f"Inputs - Text: '{text}', Prompt Audio: {prompt_speech_path}, Prompt Text: '{prompt_text}'") try: # Ensure prompt_text is None if empty/short, otherwise use it prompt_text_clean = None if not prompt_text or len(prompt_text.strip()) < 2 else prompt_text.strip() # 1. Preprocess using Processor inputs = processor( text=text.lower(), prompt_speech_path=prompt_speech_path, prompt_text=prompt_text_clean.lower() if prompt_text_clean else prompt_text_clean, return_tensors="pt" ).to(device) # Move processor output to model device # Store prompt global tokens if present (important for decoding) global_tokens_prompt = inputs.pop("global_token_ids_prompt", None) if global_tokens_prompt is None: logging.warning("global_token_ids_prompt not found in processor output. Decoding might be affected.") # 2. Generate using Model with torch.no_grad(): # Use generate parameters consistent with the original pipeline/model card # Adjust max_new_tokens based on expected output length vs input length # A fixed large value might be okay, or calculate dynamically if needed. output_ids = model.generate( **inputs, max_new_tokens=3000, # Safeguard, might need adjustment do_sample=True, temperature=0.8, top_k=50, top_p=0.95, eos_token_id=processor.tokenizer.eos_token_id, pad_token_id=processor.tokenizer.pad_token_id # Use EOS if PAD is None ) # 3. Decode using Processor output_clone = processor.decode( generated_ids=output_ids, global_token_ids_prompt=global_tokens_prompt, input_ids_len=inputs["input_ids"].shape[-1] # Pass prompt length ) # Save audio to a temporary file for Gradio with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile: sf.write(tmpfile.name, output_clone["audio"], output_clone["sampling_rate"]) output_path = tmpfile.name logging.info(f"Voice cloning successful. Audio saved temporarily at: {output_path}") return output_path, None # Return path and no error message except Exception as e: logging.error(f"Error during voice cloning inference: {e}", exc_info=True) return None, f"Error during generation: {e}" def run_voice_creation_tts( text, gender, pitch_level, # Expecting 1-5 speed_level, # Expecting 1-5 processor, model, device, ): """Performs voice creation TTS using Transformers.""" if not MODEL_LOADED: return None, "Error: Model not loaded." if not text: return None, "Error: Please provide text to synthesize." # Map numeric levels to string representations pitch_str = LEVELS_MAP_UI.get(pitch_level, "moderate") # Default to moderate if invalid speed_str = LEVELS_MAP_UI.get(speed_level, "moderate") # Default to moderate if invalid logging.info("Starting voice creation inference...") logging.info(f"Inputs - Text: '{text}', Gender: {gender}, Pitch: {pitch_str} (Level {pitch_level}), Speed: {speed_str} (Level {speed_level})") try: # 1. Preprocess inputs = processor( text=text.lower(), # prompt_speech_path=None, # No audio prompt for creation # prompt_text=None, # No text prompt for creation gender=gender, pitch=pitch_str, speed=speed_str, return_tensors="pt" ).to(device) # 2. Generate with torch.no_grad(): output_ids = model.generate( **inputs, max_new_tokens=3000, # Safeguard do_sample=True, temperature=0.8, top_k=50, top_p=0.95, eos_token_id=processor.tokenizer.eos_token_id, pad_token_id=processor.tokenizer.pad_token_id ) # 3. Decode (no prompt global tokens needed here) output_create = processor.decode( generated_ids=output_ids, input_ids_len=inputs["input_ids"].shape[-1] # Pass prompt length ) # Save audio to a temporary file for Gradio with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile: sf.write(tmpfile.name, output_create["audio"], output_create["sampling_rate"]) output_path = tmpfile.name logging.info(f"Voice creation successful. Audio saved temporarily at: {output_path}") return output_path, None # Return path and no error message except Exception as e: logging.error(f"Error during voice creation inference: {e}", exc_info=True) return None, f"Error during generation: {e}" # --- Gradio UI --- def build_ui(): with gr.Blocks() as demo: gr.HTML('