File size: 18,113 Bytes
862e5be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88d50ab
862e5be
dc56bd0
862e5be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88d50ab
862e5be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
# Copyright (c) 2025 SparkAudio & DragonLineageAI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import torch
import soundfile as sf
import logging
import gradio as gr
import platform
import numpy as np
from pathlib import Path
from datetime import datetime
import tempfile # To handle temporary audio files for Gradio

# --- Import Transformers ---
from transformers import AutoProcessor, AutoModel

# --- Configuration ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

model_id = "DragonLineageAI/Vi-Spark-TTS-0.5B-v2"
cache_dir = "model_cache" # Define a cache directory within the Space

# Mapping from Gradio Slider (1-5) to model's expected string values
# Adjust these strings if the model expects different ones (e.g., "slow", "fast")
LEVELS_MAP_UI = {
    1: "very_low",  # Or "slowest" / "lowest"
    2: "low",       # Or "slow" / "low"
    3: "moderate",  # Or "normal" / "medium"
    4: "high",      # Or "fast" / "high"
    5: "very_high"  # Or "fastest" / "highest"
}

# --- Model Loading ---
def load_model_and_processor(model_id, cache_dir):
    """Loads the Processor and Model using Transformers."""
    logging.info(f"Loading processor from: {model_id}")
    try:
        processor = AutoProcessor.from_pretrained(
            model_id,
            trust_remote_code=True,
            # token=api_key, # Use token only if necessary and ideally from secrets
            cache_dir=cache_dir
        )
        logging.info("Processor loaded successfully.")
    except Exception as e:
        logging.error(f"Error loading processor: {e}")
        raise

    logging.info(f"Loading model from: {model_id}")
    try:
        model = AutoModel.from_pretrained(
            model_id,
            trust_remote_code=True,
            cache_dir=cache_dir,
            # torch_dtype=torch.float16 # Optional: uncomment for potential speedup/memory saving if supported
        )
        model.eval() # Set model to evaluation mode
        logging.info("Model loaded successfully.")
    except Exception as e:
        logging.error(f"Error loading model: {e}")
        raise

    # --- Link Model to Processor ---
    # THIS STEP IS CRUCIAL
    processor.model = model
    logging.info("Model reference set in processor.")

    # Sync sampling rate if necessary
    if hasattr(model.config, 'sample_rate') and processor.sampling_rate != model.config.sample_rate:
        logging.warning(f"Processor SR ({processor.sampling_rate}) != Model Config SR ({model.config.sample_rate}). Updating processor.")
        processor.sampling_rate = model.config.sample_rate

    # --- Device Selection ---
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif platform.system() == "Darwin" and torch.backends.mps.is_available():
         # Check for MPS availability specifically
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    logging.info(f"Selected device: {device}")
    model.to(device)
    logging.info(f"Model moved to device: {device}")

    return processor, model, device

# --- Load Model Globally (once per Space instance) ---
try:
    processor, model, device = load_model_and_processor(model_id, cache_dir)
    MODEL_LOADED = True
except Exception as e:
    MODEL_LOADED = False
    logging.error(f"Failed to load model/processor: {e}")
    # You might want to display an error in the Gradio UI if loading fails

# --- Core TTS Functions ---

def run_voice_clone_tts(
    text,
    prompt_speech_path,
    prompt_text,
    processor,
    model,
    device,
):
    """Performs voice cloning TTS using Transformers."""
    if not MODEL_LOADED:
         return None, "Error: Model not loaded."
    if not text:
        return None, "Error: Please provide text to synthesize."
    if not prompt_speech_path:
        return None, "Error: Please provide a prompt audio file (upload or record)."

    logging.info("Starting voice cloning inference...")
    logging.info(f"Inputs - Text: '{text}', Prompt Audio: {prompt_speech_path}, Prompt Text: '{prompt_text}'")

    try:
        # Ensure prompt_text is None if empty/short, otherwise use it
        prompt_text_clean = None if not prompt_text or len(prompt_text.strip()) < 2 else prompt_text.strip()

        # 1. Preprocess using Processor
        inputs = processor(
            text=text.lower(),
            prompt_speech_path=prompt_speech_path,
            prompt_text=prompt_text_clean.lower() if prompt_text_clean else prompt_text_clean,
            return_tensors="pt"
        ).to(device) # Move processor output to model device

        # Store prompt global tokens if present (important for decoding)
        global_tokens_prompt = inputs.pop("global_token_ids_prompt", None)
        if global_tokens_prompt is None:
             logging.warning("global_token_ids_prompt not found in processor output. Decoding might be affected.")

        # 2. Generate using Model
        with torch.no_grad():
            # Use generate parameters consistent with the original pipeline/model card
            # Adjust max_new_tokens based on expected output length vs input length
            # A fixed large value might be okay, or calculate dynamically if needed.
            output_ids = model.generate(
                **inputs,
                max_new_tokens=3000, # Safeguard, might need adjustment
                do_sample=True,
                temperature=0.8,
                top_k=50,
                top_p=0.95,
                eos_token_id=processor.tokenizer.eos_token_id,
                pad_token_id=processor.tokenizer.pad_token_id # Use EOS if PAD is None
            )

        # 3. Decode using Processor
        output_clone = processor.decode(
            generated_ids=output_ids,
            global_token_ids_prompt=global_tokens_prompt,
            input_ids_len=inputs["input_ids"].shape[-1] # Pass prompt length
        )

        # Save audio to a temporary file for Gradio
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
            sf.write(tmpfile.name, output_clone["audio"], output_clone["sampling_rate"])
            output_path = tmpfile.name

        logging.info(f"Voice cloning successful. Audio saved temporarily at: {output_path}")
        return output_path, None # Return path and no error message

    except Exception as e:
        logging.error(f"Error during voice cloning inference: {e}", exc_info=True)
        return None, f"Error during generation: {e}"


def run_voice_creation_tts(
    text,
    gender,
    pitch_level, # Expecting 1-5
    speed_level, # Expecting 1-5
    processor,
    model,
    device,
):
    """Performs voice creation TTS using Transformers."""
    if not MODEL_LOADED:
         return None, "Error: Model not loaded."
    if not text:
        return None, "Error: Please provide text to synthesize."

    # Map numeric levels to string representations
    pitch_str = LEVELS_MAP_UI.get(pitch_level, "moderate") # Default to moderate if invalid
    speed_str = LEVELS_MAP_UI.get(speed_level, "moderate") # Default to moderate if invalid

    logging.info("Starting voice creation inference...")
    logging.info(f"Inputs - Text: '{text}', Gender: {gender}, Pitch: {pitch_str} (Level {pitch_level}), Speed: {speed_str} (Level {speed_level})")

    try:
        # 1. Preprocess
        inputs = processor(
            text=text.lower(),
            # prompt_speech_path=None, # No audio prompt for creation
            # prompt_text=None,       # No text prompt for creation
            gender=gender,
            pitch=pitch_str,
            speed=speed_str,
            return_tensors="pt"
        ).to(device)

        # 2. Generate
        with torch.no_grad():
            output_ids = model.generate(
                **inputs,
                max_new_tokens=3000, # Safeguard
                do_sample=True,
                temperature=0.8,
                top_k=50,
                top_p=0.95,
                eos_token_id=processor.tokenizer.eos_token_id,
                pad_token_id=processor.tokenizer.pad_token_id
            )

        # 3. Decode (no prompt global tokens needed here)
        output_create = processor.decode(
            generated_ids=output_ids,
            input_ids_len=inputs["input_ids"].shape[-1] # Pass prompt length
        )

        # Save audio to a temporary file for Gradio
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
            sf.write(tmpfile.name, output_create["audio"], output_create["sampling_rate"])
            output_path = tmpfile.name

        logging.info(f"Voice creation successful. Audio saved temporarily at: {output_path}")
        return output_path, None # Return path and no error message

    except Exception as e:
        logging.error(f"Error during voice creation inference: {e}", exc_info=True)
        return None, f"Error during generation: {e}"


# --- Gradio UI ---
def build_ui():
    with gr.Blocks() as demo:
        gr.HTML('<h1 style="text-align: center;">Spark-TTS Demo (Transformers)</h1>') # Changed title slightly
        gr.Markdown(
            "Powered by [DragonLineageAI/Vi-Spark-TTS-0.5B-v2](https://huggingface.co/DragonLineageAI/Vi-Spark-TTS-0.5B-v2). "
            "Choose a tab for Voice Cloning or Voice Creation."
        )

        if not MODEL_LOADED:
             gr.Markdown("## ⚠️ Error: Model failed to load. Please check the Space logs.")

        with gr.Tabs():
            # --- Voice Clone Tab ---
            with gr.TabItem("Voice Clone"):
                gr.Markdown(
                    "### Upload Reference Audio or Record"
                )
                gr.Markdown(
                    "Provide a short audio clip (5-20 seconds) of the voice you want to clone. "
                    "Optionally, provide the transcript of that audio for better results, especially if the language is the same as the text you want to synthesize."
                )

                with gr.Row():
                    prompt_wav_upload = gr.Audio(
                        sources=["upload"],
                        type="filepath",
                        label="Upload Prompt Audio File (WAV/MP3)",
                    )
                    prompt_wav_record = gr.Audio(
                        sources=["microphone"],
                        type="filepath",
                        label="Or Record Prompt Audio",
                    )

                with gr.Row():
                    text_input_clone = gr.Textbox(
                        label="Text to Synthesize",
                        lines=4,
                        placeholder="Enter text here..."
                    )
                    prompt_text_input = gr.Textbox(
                        label="Text of Prompt Speech (Optional)",
                        lines=2,
                        placeholder="Enter the transcript of the prompt audio (if available).",
                        info="Recommended for cloning in the same language." # Added info here
                    )

                audio_output_clone = gr.Audio(
                    label="Generated Audio",
                    autoplay=False,
                )
                status_clone = gr.Textbox(label="Status", interactive=False) # For status/error messages

                generate_button_clone = gr.Button("Generate Cloned Voice", variant="primary", interactive=MODEL_LOADED)

                def voice_clone_callback(text, prompt_text, audio_upload, audio_record):
                    # Prioritize uploaded file, fallback to recorded file
                    prompt_speech = audio_upload if audio_upload else audio_record
                    if not prompt_speech:
                        # Return None for the audio component and the error message for the status component
                        return None, "Error: Please upload or record a reference audio."

                    # Call the core TTS function
                    output_path, error_msg = run_voice_clone_tts(
                        text,
                        prompt_speech,
                        prompt_text,
                        processor,
                        model,
                        device
                    )
                    if error_msg:
                        return None, error_msg # Return error message to status_clone
                    else:
                        # Return the audio file path and a success message (or empty)
                        return output_path, "Audio generated successfully!"


                generate_button_clone.click(
                    voice_clone_callback,
                    inputs=[
                        text_input_clone,
                        prompt_text_input,
                        prompt_wav_upload,
                        prompt_wav_record,
                    ],
                    outputs=[audio_output_clone, status_clone], # Update both audio and status
                )

                # Examples need actual audio files in an 'examples' directory in your Space repo
                # Make sure 'examples/sample_prompt.wav' exists or change the path
                gr.Examples(
                    examples=[
                        ["Hello, this is a test of voice cloning.", "I am a sample reference voice.", "examples/sample_prompt.wav", None],
                        ["You can experiment with different voices and texts.", None, None, "examples/sample_record.wav"], # Assuming a recorded sample exists
                        ["The quality of the clone depends on the reference audio.", "This is the reference text.", "examples/another_prompt.wav", None]
                    ],
                    inputs=[text_input_clone, prompt_text_input, prompt_wav_upload, prompt_wav_record],
                    outputs=[audio_output_clone, status_clone],
                    fn=voice_clone_callback,
                    cache_examples=False, # Disable caching if examples might change or for demos
                    label="Clone Examples"
                )


            # --- Voice Creation Tab ---
            with gr.TabItem("Voice Creation"):
                gr.Markdown(
                    "### Create Your Own Voice Based on the Following Parameters"
                )
                gr.Markdown(
                    "Select gender, adjust pitch and speed to generate a new synthetic voice."
                )

                with gr.Row():
                    with gr.Column(scale=1):
                        gender = gr.Radio(
                            choices=["male", "female"], value="female", label="Gender"
                        )
                        pitch = gr.Slider(
                            minimum=1, maximum=5, step=1, value=3, label="Pitch (1=Lowest, 5=Highest)"
                        )
                        speed = gr.Slider(
                            minimum=1, maximum=5, step=1, value=3, label="Speed (1=Slowest, 5=Fastest)"
                        )
                    with gr.Column(scale=2):
                        text_input_creation = gr.Textbox(
                            label="Text to Synthesize",
                            lines=5,
                            placeholder="Enter text here...",
                            value="You can generate a customized voice by adjusting parameters such as pitch and speed.",
                        )

                audio_output_creation = gr.Audio(
                    label="Generated Audio",
                    autoplay=False,
                )
                status_create = gr.Textbox(label="Status", interactive=False) # For status/error messages

                create_button = gr.Button("Create New Voice", variant="primary", interactive=MODEL_LOADED)

                def voice_creation_callback(text, gender, pitch_val, speed_val):
                     # Call the core TTS function
                    output_path, error_msg = run_voice_creation_tts(
                        text,
                        gender,
                        int(pitch_val), # Convert slider value to int
                        int(speed_val), # Convert slider value to int
                        processor,
                        model,
                        device
                    )
                    if error_msg:
                        return None, error_msg
                    else:
                        return output_path, "Audio generated successfully!"

                create_button.click(
                    voice_creation_callback,
                    inputs=[text_input_creation, gender, pitch, speed],
                    outputs=[audio_output_creation, status_create],
                )

                gr.Examples(
                    examples=[
                        ["This is a female voice with average pitch and speed.", "female", 3, 3],
                        ["This is a male voice, speaking quickly with a slightly higher pitch.", "male", 4, 4],
                        ["A deep and slow female voice.", "female", 1, 2],
                        ["A very high-pitched and fast male voice.", "male", 5, 5]
                    ],
                    inputs=[text_input_creation, gender, pitch, speed],
                    outputs=[audio_output_creation, status_create],
                    fn=voice_creation_callback,
                    cache_examples=False,
                    label="Creation Examples"
                )
    return demo

# --- Launch the Gradio App ---
if __name__ == "__main__":
    demo = build_ui()
    demo.launch()