ancv commited on
Commit
862e5be
·
verified ·
1 Parent(s): 9f5a37d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +431 -0
app.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio & DragonLineageAI
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import torch
17
+ import soundfile as sf
18
+ import logging
19
+ import gradio as gr
20
+ import platform
21
+ import numpy as np
22
+ from pathlib import Path
23
+ from datetime import datetime
24
+ import tempfile # To handle temporary audio files for Gradio
25
+
26
+ # --- Import Transformers ---
27
+ from transformers import AutoProcessor, AutoModel
28
+
29
+ # --- Configuration ---
30
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
31
+
32
+ model_id = "DragonLineageAI/Vi-Spark-TTS-0.5B-v2"
33
+ cache_dir = "model_cache" # Define a cache directory within the Space
34
+
35
+ # Mapping from Gradio Slider (1-5) to model's expected string values
36
+ # Adjust these strings if the model expects different ones (e.g., "slow", "fast")
37
+ LEVELS_MAP_UI = {
38
+ 1: "very_low", # Or "slowest" / "lowest"
39
+ 2: "low", # Or "slow" / "low"
40
+ 3: "moderate", # Or "normal" / "medium"
41
+ 4: "high", # Or "fast" / "high"
42
+ 5: "very_high" # Or "fastest" / "highest"
43
+ }
44
+
45
+ # --- Model Loading ---
46
+ def load_model_and_processor(model_id, cache_dir):
47
+ """Loads the Processor and Model using Transformers."""
48
+ logging.info(f"Loading processor from: {model_id}")
49
+ try:
50
+ processor = AutoProcessor.from_pretrained(
51
+ model_id,
52
+ trust_remote_code=True,
53
+ # token=api_key, # Use token only if necessary and ideally from secrets
54
+ cache_dir=cache_dir
55
+ )
56
+ logging.info("Processor loaded successfully.")
57
+ except Exception as e:
58
+ logging.error(f"Error loading processor: {e}")
59
+ raise
60
+
61
+ logging.info(f"Loading model from: {model_id}")
62
+ try:
63
+ model = AutoModel.from_pretrained(
64
+ model_id,
65
+ trust_remote_code=True,
66
+ cache_dir=cache_dir,
67
+ # torch_dtype=torch.float16 # Optional: uncomment for potential speedup/memory saving if supported
68
+ )
69
+ model.eval() # Set model to evaluation mode
70
+ logging.info("Model loaded successfully.")
71
+ except Exception as e:
72
+ logging.error(f"Error loading model: {e}")
73
+ raise
74
+
75
+ # --- Link Model to Processor ---
76
+ # THIS STEP IS CRUCIAL
77
+ processor.model = model
78
+ logging.info("Model reference set in processor.")
79
+
80
+ # Sync sampling rate if necessary
81
+ if hasattr(model.config, 'sample_rate') and processor.sampling_rate != model.config.sample_rate:
82
+ logging.warning(f"Processor SR ({processor.sampling_rate}) != Model Config SR ({model.config.sample_rate}). Updating processor.")
83
+ processor.sampling_rate = model.config.sample_rate
84
+
85
+ # --- Device Selection ---
86
+ if torch.cuda.is_available():
87
+ device = torch.device("cuda")
88
+ elif platform.system() == "Darwin" and torch.backends.mps.is_available():
89
+ # Check for MPS availability specifically
90
+ device = torch.device("mps")
91
+ else:
92
+ device = torch.device("cpu")
93
+
94
+ logging.info(f"Selected device: {device}")
95
+ model.to(device)
96
+ logging.info(f"Model moved to device: {device}")
97
+
98
+ return processor, model, device
99
+
100
+ # --- Load Model Globally (once per Space instance) ---
101
+ try:
102
+ processor, model, device = load_model_and_processor(model_id, cache_dir)
103
+ MODEL_LOADED = True
104
+ except Exception as e:
105
+ MODEL_LOADED = False
106
+ logging.error(f"Failed to load model/processor: {e}")
107
+ # You might want to display an error in the Gradio UI if loading fails
108
+
109
+ # --- Core TTS Functions ---
110
+
111
+ def run_voice_clone_tts(
112
+ text,
113
+ prompt_speech_path,
114
+ prompt_text,
115
+ processor,
116
+ model,
117
+ device,
118
+ ):
119
+ """Performs voice cloning TTS using Transformers."""
120
+ if not MODEL_LOADED:
121
+ return None, "Error: Model not loaded."
122
+ if not text:
123
+ return None, "Error: Please provide text to synthesize."
124
+ if not prompt_speech_path:
125
+ return None, "Error: Please provide a prompt audio file (upload or record)."
126
+
127
+ logging.info("Starting voice cloning inference...")
128
+ logging.info(f"Inputs - Text: '{text}', Prompt Audio: {prompt_speech_path}, Prompt Text: '{prompt_text}'")
129
+
130
+ try:
131
+ # Ensure prompt_text is None if empty/short, otherwise use it
132
+ prompt_text_clean = None if not prompt_text or len(prompt_text.strip()) < 2 else prompt_text.strip()
133
+
134
+ # 1. Preprocess using Processor
135
+ inputs = processor(
136
+ text=text,
137
+ prompt_speech_path=prompt_speech_path,
138
+ prompt_text=prompt_text_clean,
139
+ return_tensors="pt"
140
+ ).to(device) # Move processor output to model device
141
+
142
+ # Store prompt global tokens if present (important for decoding)
143
+ global_tokens_prompt = inputs.pop("global_token_ids_prompt", None)
144
+ if global_tokens_prompt is None:
145
+ logging.warning("global_token_ids_prompt not found in processor output. Decoding might be affected.")
146
+
147
+ # 2. Generate using Model
148
+ with torch.no_grad():
149
+ # Use generate parameters consistent with the original pipeline/model card
150
+ # Adjust max_new_tokens based on expected output length vs input length
151
+ # A fixed large value might be okay, or calculate dynamically if needed.
152
+ output_ids = model.generate(
153
+ **inputs,
154
+ max_new_tokens=3000, # Safeguard, might need adjustment
155
+ do_sample=True,
156
+ temperature=0.8,
157
+ top_k=50,
158
+ top_p=0.95,
159
+ eos_token_id=processor.tokenizer.eos_token_id,
160
+ pad_token_id=processor.tokenizer.pad_token_id # Use EOS if PAD is None
161
+ )
162
+
163
+ # 3. Decode using Processor
164
+ output_clone = processor.decode(
165
+ generated_ids=output_ids,
166
+ global_token_ids_prompt=global_tokens_prompt,
167
+ input_ids_len=inputs["input_ids"].shape[-1] # Pass prompt length
168
+ )
169
+
170
+ # Save audio to a temporary file for Gradio
171
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
172
+ sf.write(tmpfile.name, output_clone["audio"], output_clone["sampling_rate"])
173
+ output_path = tmpfile.name
174
+
175
+ logging.info(f"Voice cloning successful. Audio saved temporarily at: {output_path}")
176
+ return output_path, None # Return path and no error message
177
+
178
+ except Exception as e:
179
+ logging.error(f"Error during voice cloning inference: {e}", exc_info=True)
180
+ return None, f"Error during generation: {e}"
181
+
182
+
183
+ def run_voice_creation_tts(
184
+ text,
185
+ gender,
186
+ pitch_level, # Expecting 1-5
187
+ speed_level, # Expecting 1-5
188
+ processor,
189
+ model,
190
+ device,
191
+ ):
192
+ """Performs voice creation TTS using Transformers."""
193
+ if not MODEL_LOADED:
194
+ return None, "Error: Model not loaded."
195
+ if not text:
196
+ return None, "Error: Please provide text to synthesize."
197
+
198
+ # Map numeric levels to string representations
199
+ pitch_str = LEVELS_MAP_UI.get(pitch_level, "moderate") # Default to moderate if invalid
200
+ speed_str = LEVELS_MAP_UI.get(speed_level, "moderate") # Default to moderate if invalid
201
+
202
+ logging.info("Starting voice creation inference...")
203
+ logging.info(f"Inputs - Text: '{text}', Gender: {gender}, Pitch: {pitch_str} (Level {pitch_level}), Speed: {speed_str} (Level {speed_level})")
204
+
205
+ try:
206
+ # 1. Preprocess
207
+ inputs = processor(
208
+ text=text,
209
+ # prompt_speech_path=None, # No audio prompt for creation
210
+ # prompt_text=None, # No text prompt for creation
211
+ gender=gender,
212
+ pitch=pitch_str,
213
+ speed=speed_str,
214
+ return_tensors="pt"
215
+ ).to(device)
216
+
217
+ # 2. Generate
218
+ with torch.no_grad():
219
+ output_ids = model.generate(
220
+ **inputs,
221
+ max_new_tokens=3000, # Safeguard
222
+ do_sample=True,
223
+ temperature=0.8,
224
+ top_k=50,
225
+ top_p=0.95,
226
+ eos_token_id=processor.tokenizer.eos_token_id,
227
+ pad_token_id=processor.tokenizer.pad_token_id
228
+ )
229
+
230
+ # 3. Decode (no prompt global tokens needed here)
231
+ output_create = processor.decode(
232
+ generated_ids=output_ids,
233
+ input_ids_len=inputs["input_ids"].shape[-1] # Pass prompt length
234
+ )
235
+
236
+ # Save audio to a temporary file for Gradio
237
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
238
+ sf.write(tmpfile.name, output_create["audio"], output_create["sampling_rate"])
239
+ output_path = tmpfile.name
240
+
241
+ logging.info(f"Voice creation successful. Audio saved temporarily at: {output_path}")
242
+ return output_path, None # Return path and no error message
243
+
244
+ except Exception as e:
245
+ logging.error(f"Error during voice creation inference: {e}", exc_info=True)
246
+ return None, f"Error during generation: {e}"
247
+
248
+
249
+ # --- Gradio UI ---
250
+ def build_ui():
251
+ with gr.Blocks() as demo:
252
+ gr.HTML('<h1 style="text-align: center;">Spark-TTS Demo (Transformers)</h1>') # Changed title slightly
253
+ gr.Markdown(
254
+ "Powered by [DragonLineageAI/Vi-Spark-TTS-0.5B-v2](https://huggingface.co/DragonLineageAI/Vi-Spark-TTS-0.5B-v2). "
255
+ "Choose a tab for Voice Cloning or Voice Creation."
256
+ )
257
+
258
+ if not MODEL_LOADED:
259
+ gr.Markdown("## ⚠️ Error: Model failed to load. Please check the Space logs.")
260
+
261
+ with gr.Tabs():
262
+ # --- Voice Clone Tab ---
263
+ with gr.TabItem("Voice Clone"):
264
+ gr.Markdown(
265
+ "### Upload Reference Audio or Record"
266
+ )
267
+ gr.Markdown(
268
+ "Provide a short audio clip (5-20 seconds) of the voice you want to clone. "
269
+ "Optionally, provide the transcript of that audio for better results, especially if the language is the same as the text you want to synthesize."
270
+ )
271
+
272
+ with gr.Row():
273
+ prompt_wav_upload = gr.Audio(
274
+ sources=["upload"],
275
+ type="filepath",
276
+ label="Upload Prompt Audio File (WAV/MP3)",
277
+ )
278
+ prompt_wav_record = gr.Audio(
279
+ sources=["microphone"],
280
+ type="filepath",
281
+ label="Or Record Prompt Audio",
282
+ )
283
+
284
+ with gr.Row():
285
+ text_input_clone = gr.Textbox(
286
+ label="Text to Synthesize",
287
+ lines=4,
288
+ placeholder="Enter text here..."
289
+ )
290
+ prompt_text_input = gr.Textbox(
291
+ label="Text of Prompt Speech (Optional)",
292
+ lines=2,
293
+ placeholder="Enter the transcript of the prompt audio (if available).",
294
+ info="Recommended for cloning in the same language." # Added info here
295
+ )
296
+
297
+ audio_output_clone = gr.Audio(
298
+ label="Generated Audio",
299
+ autoplay=False,
300
+ )
301
+ status_clone = gr.Textbox(label="Status", interactive=False) # For status/error messages
302
+
303
+ generate_button_clone = gr.Button("Generate Cloned Voice", variant="primary", interactive=MODEL_LOADED)
304
+
305
+ def voice_clone_callback(text, prompt_text, audio_upload, audio_record):
306
+ # Prioritize uploaded file, fallback to recorded file
307
+ prompt_speech = audio_upload if audio_upload else audio_record
308
+ if not prompt_speech:
309
+ # Return None for the audio component and the error message for the status component
310
+ return None, "Error: Please upload or record a reference audio."
311
+
312
+ # Call the core TTS function
313
+ output_path, error_msg = run_voice_clone_tts(
314
+ text,
315
+ prompt_speech,
316
+ prompt_text,
317
+ processor,
318
+ model,
319
+ device
320
+ )
321
+ if error_msg:
322
+ return None, error_msg # Return error message to status_clone
323
+ else:
324
+ # Return the audio file path and a success message (or empty)
325
+ return output_path, "Audio generated successfully!"
326
+
327
+
328
+ generate_button_clone.click(
329
+ voice_clone_callback,
330
+ inputs=[
331
+ text_input_clone,
332
+ prompt_text_input,
333
+ prompt_wav_upload,
334
+ prompt_wav_record,
335
+ ],
336
+ outputs=[audio_output_clone, status_clone], # Update both audio and status
337
+ )
338
+
339
+ # Examples need actual audio files in an 'examples' directory in your Space repo
340
+ # Make sure 'examples/sample_prompt.wav' exists or change the path
341
+ gr.Examples(
342
+ examples=[
343
+ ["Hello, this is a test of voice cloning.", "I am a sample reference voice.", "examples/sample_prompt.wav", None],
344
+ ["You can experiment with different voices and texts.", None, None, "examples/sample_record.wav"], # Assuming a recorded sample exists
345
+ ["The quality of the clone depends on the reference audio.", "This is the reference text.", "examples/another_prompt.wav", None]
346
+ ],
347
+ inputs=[text_input_clone, prompt_text_input, prompt_wav_upload, prompt_wav_record],
348
+ outputs=[audio_output_clone, status_clone],
349
+ fn=voice_clone_callback,
350
+ cache_examples=False, # Disable caching if examples might change or for demos
351
+ label="Clone Examples"
352
+ )
353
+
354
+
355
+ # --- Voice Creation Tab ---
356
+ with gr.TabItem("Voice Creation"):
357
+ gr.Markdown(
358
+ "### Create Your Own Voice Based on the Following Parameters"
359
+ )
360
+ gr.Markdown(
361
+ "Select gender, adjust pitch and speed to generate a new synthetic voice."
362
+ )
363
+
364
+ with gr.Row():
365
+ with gr.Column(scale=1):
366
+ gender = gr.Radio(
367
+ choices=["male", "female"], value="female", label="Gender"
368
+ )
369
+ pitch = gr.Slider(
370
+ minimum=1, maximum=5, step=1, value=3, label="Pitch (1=Lowest, 5=Highest)"
371
+ )
372
+ speed = gr.Slider(
373
+ minimum=1, maximum=5, step=1, value=3, label="Speed (1=Slowest, 5=Fastest)"
374
+ )
375
+ with gr.Column(scale=2):
376
+ text_input_creation = gr.Textbox(
377
+ label="Text to Synthesize",
378
+ lines=5,
379
+ placeholder="Enter text here...",
380
+ value="You can generate a customized voice by adjusting parameters such as pitch and speed.",
381
+ )
382
+
383
+ audio_output_creation = gr.Audio(
384
+ label="Generated Audio",
385
+ autoplay=False,
386
+ )
387
+ status_create = gr.Textbox(label="Status", interactive=False) # For status/error messages
388
+
389
+ create_button = gr.Button("Create New Voice", variant="primary", interactive=MODEL_LOADED)
390
+
391
+ def voice_creation_callback(text, gender, pitch_val, speed_val):
392
+ # Call the core TTS function
393
+ output_path, error_msg = run_voice_creation_tts(
394
+ text,
395
+ gender,
396
+ int(pitch_val), # Convert slider value to int
397
+ int(speed_val), # Convert slider value to int
398
+ processor,
399
+ model,
400
+ device
401
+ )
402
+ if error_msg:
403
+ return None, error_msg
404
+ else:
405
+ return output_path, "Audio generated successfully!"
406
+
407
+ create_button.click(
408
+ voice_creation_callback,
409
+ inputs=[text_input_creation, gender, pitch, speed],
410
+ outputs=[audio_output_creation, status_create],
411
+ )
412
+
413
+ gr.Examples(
414
+ examples=[
415
+ ["This is a female voice with average pitch and speed.", "female", 3, 3],
416
+ ["This is a male voice, speaking quickly with a slightly higher pitch.", "male", 4, 4],
417
+ ["A deep and slow female voice.", "female", 1, 2],
418
+ ["A very high-pitched and fast male voice.", "male", 5, 5]
419
+ ],
420
+ inputs=[text_input_creation, gender, pitch, speed],
421
+ outputs=[audio_output_creation, status_create],
422
+ fn=voice_creation_callback,
423
+ cache_examples=False,
424
+ label="Creation Examples"
425
+ )
426
+ return demo
427
+
428
+ # --- Launch the Gradio App ---
429
+ if __name__ == "__main__":
430
+ demo = build_ui()
431
+ demo.launch()