Text-to-Speech
Vietnamese
vietnamese
female
male
voice-cloning
erax commited on
Commit
a2c3ffd
·
verified ·
1 Parent(s): 0f7c4ef

Upload model/f5tts_wrapper.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model/f5tts_wrapper.py +549 -0
model/f5tts_wrapper.py ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchaudio
4
+ import numpy as np
5
+ from pathlib import Path
6
+ from typing import Optional, Union, List, Tuple, Dict
7
+
8
+ from cached_path import cached_path
9
+ from hydra.utils import get_class
10
+ from omegaconf import OmegaConf
11
+ from importlib.resources import files
12
+ from pydub import AudioSegment, silence
13
+
14
+ from f5_tts.model import CFM
15
+ from f5_tts.model.utils import (
16
+ get_tokenizer,
17
+ convert_char_to_pinyin,
18
+ )
19
+ from f5_tts.infer.utils_infer import (
20
+ chunk_text,
21
+ load_vocoder,
22
+ transcribe,
23
+ initialize_asr_pipeline,
24
+ )
25
+
26
+
27
+ class F5TTSWrapper:
28
+ """
29
+ A wrapper class for F5-TTS that preprocesses reference audio once
30
+ and allows for repeated TTS generation.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ model_name: str = "F5TTS_v1_Base",
36
+ ckpt_path: Optional[str] = None,
37
+ vocab_file: Optional[str] = None,
38
+ vocoder_name: str = "vocos",
39
+ use_local_vocoder: bool = False,
40
+ vocoder_path: Optional[str] = None,
41
+ device: Optional[str] = None,
42
+ hf_cache_dir: Optional[str] = None,
43
+ target_sample_rate: int = 24000,
44
+ n_mel_channels: int = 100,
45
+ hop_length: int = 256,
46
+ win_length: int = 1024,
47
+ n_fft: int = 1024,
48
+ ode_method: str = "euler",
49
+ use_ema: bool = True,
50
+ ):
51
+ """
52
+ Initialize the F5-TTS wrapper with model configuration.
53
+
54
+ Args:
55
+ model_name: Name of the F5-TTS model variant (e.g., "F5TTS_v1_Base")
56
+ ckpt_path: Path to the model checkpoint file. If None, will use default path.
57
+ vocab_file: Path to the vocab file. If None, will use default.
58
+ vocoder_name: Name of the vocoder to use ("vocos" or "bigvgan")
59
+ use_local_vocoder: Whether to use a local vocoder or download from HF
60
+ vocoder_path: Path to the local vocoder. Only used if use_local_vocoder is True.
61
+ device: Device to run the model on. If None, will automatically determine.
62
+ hf_cache_dir: Directory to cache HuggingFace models
63
+ target_sample_rate: Target sample rate for audio
64
+ n_mel_channels: Number of mel channels
65
+ hop_length: Hop length for the mel spectrogram
66
+ win_length: Window length for the mel spectrogram
67
+ n_fft: FFT size for the mel spectrogram
68
+ ode_method: ODE method for sampling ("euler" or "midpoint")
69
+ use_ema: Whether to use EMA weights from the checkpoint
70
+ """
71
+ # Set device
72
+ if device is None:
73
+ self.device = (
74
+ "cuda" if torch.cuda.is_available()
75
+ else "xpu" if torch.xpu.is_available()
76
+ else "mps" if torch.backends.mps.is_available()
77
+ else "cpu"
78
+ )
79
+ else:
80
+ self.device = device
81
+
82
+ # Audio processing parameters
83
+ self.target_sample_rate = target_sample_rate
84
+ self.n_mel_channels = n_mel_channels
85
+ self.hop_length = hop_length
86
+ self.win_length = win_length
87
+ self.n_fft = n_fft
88
+ self.mel_spec_type = vocoder_name
89
+
90
+ # Sampling parameters
91
+ self.ode_method = ode_method
92
+
93
+ # Initialize ASR for transcription if needed
94
+ initialize_asr_pipeline(device=self.device)
95
+
96
+ # Load model configuration
97
+ if ckpt_path is None:
98
+ repo_name = "F5-TTS"
99
+ ckpt_step = 1250000
100
+ ckpt_type = "safetensors"
101
+
102
+ # Adjust for previous models
103
+ if model_name == "F5TTS_Base":
104
+ if vocoder_name == "vocos":
105
+ ckpt_step = 1200000
106
+ elif vocoder_name == "bigvgan":
107
+ model_name = "F5TTS_Base_bigvgan"
108
+ ckpt_type = "pt"
109
+ elif model_name == "E2TTS_Base":
110
+ repo_name = "E2-TTS"
111
+ ckpt_step = 1200000
112
+
113
+ ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{model_name}/model_{ckpt_step}.{ckpt_type}"))
114
+
115
+ # Load model configuration
116
+ config_path = str(files("f5_tts").joinpath(f"configs/{model_name}.yaml"))
117
+ model_cfg = OmegaConf.load(config_path)
118
+ model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
119
+ model_arc = model_cfg.model.arch
120
+
121
+ # Load tokenizer
122
+ if vocab_file is None:
123
+ vocab_file = str(files("f5_tts").joinpath("infer/examples/vocab.txt"))
124
+ tokenizer_type = "custom"
125
+ self.vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer_type)
126
+
127
+ # Create model
128
+ self.model = CFM(
129
+ transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
130
+ mel_spec_kwargs=dict(
131
+ n_fft=n_fft,
132
+ hop_length=hop_length,
133
+ win_length=win_length,
134
+ n_mel_channels=n_mel_channels,
135
+ target_sample_rate=target_sample_rate,
136
+ mel_spec_type=vocoder_name,
137
+ ),
138
+ odeint_kwargs=dict(
139
+ method=ode_method,
140
+ ),
141
+ vocab_char_map=self.vocab_char_map,
142
+ ).to(self.device)
143
+
144
+ # Load checkpoint
145
+ dtype = torch.float32 if vocoder_name == "bigvgan" else None
146
+ self._load_checkpoint(self.model, ckpt_path, dtype=dtype, use_ema=use_ema)
147
+
148
+ # Load vocoder
149
+ if vocoder_path is None:
150
+ if vocoder_name == "vocos":
151
+ vocoder_path = "../checkpoints/vocos-mel-24khz"
152
+ elif vocoder_name == "bigvgan":
153
+ vocoder_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
154
+
155
+ self.vocoder = load_vocoder(
156
+ vocoder_name=vocoder_name,
157
+ is_local=use_local_vocoder,
158
+ local_path=vocoder_path,
159
+ device=self.device,
160
+ hf_cache_dir=hf_cache_dir
161
+ )
162
+
163
+ # Storage for reference data
164
+ self.ref_audio_processed = None
165
+ self.ref_text = None
166
+ self.ref_audio_len = None
167
+
168
+ # Default inference parameters
169
+ self.target_rms = 0.1
170
+ self.cross_fade_duration = 0.15
171
+ self.nfe_step = 32
172
+ self.cfg_strength = 2.0
173
+ self.sway_sampling_coef = -1.0
174
+ self.speed = 1.0
175
+ self.fix_duration = None
176
+
177
+ def _load_checkpoint(self, model, ckpt_path, dtype=None, use_ema=True):
178
+ """
179
+ Load model checkpoint with proper handling of different checkpoint formats.
180
+
181
+ Args:
182
+ model: The model to load weights into
183
+ ckpt_path: Path to the checkpoint file
184
+ dtype: Data type for model weights
185
+ use_ema: Whether to use EMA weights from the checkpoint
186
+
187
+ Returns:
188
+ Loaded model
189
+ """
190
+ if dtype is None:
191
+ dtype = (
192
+ torch.float16
193
+ if "cuda" in self.device
194
+ and torch.cuda.get_device_properties(self.device).major >= 7
195
+ and not torch.cuda.get_device_name().endswith("[ZLUDA]")
196
+ else torch.float32
197
+ )
198
+ model = model.to(dtype)
199
+
200
+ ckpt_type = ckpt_path.split(".")[-1]
201
+ if ckpt_type == "safetensors":
202
+ from safetensors.torch import load_file
203
+ checkpoint = load_file(ckpt_path, device=self.device)
204
+ else:
205
+ checkpoint = torch.load(ckpt_path, map_location=self.device, weights_only=True)
206
+
207
+ if use_ema:
208
+ if ckpt_type == "safetensors":
209
+ checkpoint = {"ema_model_state_dict": checkpoint}
210
+ checkpoint["model_state_dict"] = {
211
+ k.replace("ema_model.", ""): v
212
+ for k, v in checkpoint["ema_model_state_dict"].items()
213
+ if k not in ["initted", "step"]
214
+ }
215
+
216
+ # patch for backward compatibility
217
+ for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
218
+ if key in checkpoint["model_state_dict"]:
219
+ del checkpoint["model_state_dict"][key]
220
+
221
+ model.load_state_dict(checkpoint["model_state_dict"])
222
+ else:
223
+ if ckpt_type == "safetensors":
224
+ checkpoint = {"model_state_dict": checkpoint}
225
+ model.load_state_dict(checkpoint["model_state_dict"])
226
+
227
+ del checkpoint
228
+ torch.cuda.empty_cache()
229
+
230
+ return model.to(self.device)
231
+
232
+ def preprocess_reference(self, ref_audio_path: str, ref_text: str = "", clip_short: bool = True):
233
+ """
234
+ Preprocess the reference audio and text, storing them for later use.
235
+
236
+ Args:
237
+ ref_audio_path: Path to the reference audio file
238
+ ref_text: Text transcript of reference audio. If empty, will auto-transcribe.
239
+ clip_short: Whether to clip long audio to shorter segments
240
+
241
+ Returns:
242
+ Tuple of processed audio and text
243
+ """
244
+ print("Converting audio...")
245
+ # Load audio file
246
+ aseg = AudioSegment.from_file(ref_audio_path)
247
+
248
+ if clip_short:
249
+ # 1. try to find long silence for clipping
250
+ non_silent_segs = silence.split_on_silence(
251
+ aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10
252
+ )
253
+ non_silent_wave = AudioSegment.silent(duration=0)
254
+ for non_silent_seg in non_silent_segs:
255
+ if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
256
+ print("Audio is over 12s, clipping short. (1)")
257
+ break
258
+ non_silent_wave += non_silent_seg
259
+
260
+ # 2. try to find short silence for clipping if 1. failed
261
+ if len(non_silent_wave) > 12000:
262
+ non_silent_segs = silence.split_on_silence(
263
+ aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
264
+ )
265
+ non_silent_wave = AudioSegment.silent(duration=0)
266
+ for non_silent_seg in non_silent_segs:
267
+ if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
268
+ print("Audio is over 12s, clipping short. (2)")
269
+ break
270
+ non_silent_wave += non_silent_seg
271
+
272
+ aseg = non_silent_wave
273
+
274
+ # 3. if no proper silence found for clipping
275
+ if len(aseg) > 12000:
276
+ aseg = aseg[:12000]
277
+ print("Audio is over 12s, clipping short. (3)")
278
+
279
+ # Remove silence edges
280
+ aseg = self._remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
281
+
282
+ # Export to temporary file and load as tensor
283
+ import tempfile
284
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
285
+ aseg.export(tmp_file.name, format="wav")
286
+ processed_audio_path = tmp_file.name
287
+
288
+ # Transcribe if needed
289
+ if not ref_text.strip():
290
+ print("No reference text provided, transcribing reference audio...")
291
+ ref_text = transcribe(processed_audio_path)
292
+ else:
293
+ print("Using custom reference text...")
294
+
295
+ # Ensure ref_text ends with proper punctuation
296
+ if not ref_text.endswith(". ") and not ref_text.endswith("。"):
297
+ if ref_text.endswith("."):
298
+ ref_text += " "
299
+ else:
300
+ ref_text += ". "
301
+
302
+ print("\nReference text:", ref_text)
303
+
304
+ # Load and process audio
305
+ audio, sr = torchaudio.load(processed_audio_path)
306
+ if audio.shape[0] > 1: # Convert stereo to mono
307
+ audio = torch.mean(audio, dim=0, keepdim=True)
308
+
309
+ # Normalize volume
310
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
311
+ if rms < self.target_rms:
312
+ audio = audio * self.target_rms / rms
313
+
314
+ # Resample if needed
315
+ if sr != self.target_sample_rate:
316
+ resampler = torchaudio.transforms.Resample(sr, self.target_sample_rate)
317
+ audio = resampler(audio)
318
+
319
+ # Move to device
320
+ audio = audio.to(self.device)
321
+
322
+ # Store reference data
323
+ self.ref_audio_processed = audio
324
+ self.ref_text = ref_text
325
+ self.ref_audio_len = audio.shape[-1] // self.hop_length
326
+
327
+ # Remove temporary file
328
+ os.unlink(processed_audio_path)
329
+
330
+ return audio, ref_text
331
+
332
+ def _remove_silence_edges(self, audio, silence_threshold=-42):
333
+ """
334
+ Remove silence from the start and end of audio.
335
+
336
+ Args:
337
+ audio: AudioSegment to process
338
+ silence_threshold: dB threshold to consider as silence
339
+
340
+ Returns:
341
+ Processed AudioSegment
342
+ """
343
+ # Remove silence from the start
344
+ non_silent_start_idx = silence.detect_leading_silence(audio, silence_threshold=silence_threshold)
345
+ audio = audio[non_silent_start_idx:]
346
+
347
+ # Remove silence from the end
348
+ non_silent_end_duration = audio.duration_seconds
349
+ for ms in reversed(audio):
350
+ if ms.dBFS > silence_threshold:
351
+ break
352
+ non_silent_end_duration -= 0.001
353
+ trimmed_audio = audio[: int(non_silent_end_duration * 1000)]
354
+
355
+ return trimmed_audio
356
+
357
+ def generate(
358
+ self,
359
+ text: str,
360
+ output_path: Optional[str] = None,
361
+ nfe_step: Optional[int] = None,
362
+ cfg_strength: Optional[float] = None,
363
+ sway_sampling_coef: Optional[float] = None,
364
+ speed: Optional[float] = None,
365
+ fix_duration: Optional[float] = None,
366
+ cross_fade_duration: Optional[float] = None,
367
+ return_numpy: bool = False,
368
+ return_spectrogram: bool = False,
369
+ ) -> Union[str, Tuple[np.ndarray, int], Tuple[np.ndarray, int, np.ndarray]]:
370
+ """
371
+ Generate speech for the given text using the stored reference audio.
372
+
373
+ Args:
374
+ text: Text to synthesize
375
+ output_path: Path to save the generated audio. If None, won't save.
376
+ nfe_step: Number of function evaluation steps
377
+ cfg_strength: Classifier-free guidance strength
378
+ sway_sampling_coef: Sway sampling coefficient
379
+ speed: Speed of generated audio
380
+ fix_duration: Fixed duration in seconds
381
+ cross_fade_duration: Duration of cross-fade between segments
382
+ return_numpy: If True, returns the audio as a numpy array
383
+ return_spectrogram: If True, also returns the spectrogram
384
+
385
+ Returns:
386
+ If output_path provided: path to output file
387
+ If return_numpy=True: tuple of (audio_array, sample_rate)
388
+ If return_spectrogram=True: tuple of (audio_array, sample_rate, spectrogram)
389
+ """
390
+ if self.ref_audio_processed is None or self.ref_text is None:
391
+ raise ValueError("Reference audio not preprocessed. Call preprocess_reference() first.")
392
+
393
+ # Use default values if not specified
394
+ nfe_step = nfe_step if nfe_step is not None else self.nfe_step
395
+ cfg_strength = cfg_strength if cfg_strength is not None else self.cfg_strength
396
+ sway_sampling_coef = sway_sampling_coef if sway_sampling_coef is not None else self.sway_sampling_coef
397
+ speed = speed if speed is not None else self.speed
398
+ fix_duration = fix_duration if fix_duration is not None else self.fix_duration
399
+ cross_fade_duration = cross_fade_duration if cross_fade_duration is not None else self.cross_fade_duration
400
+
401
+ # Split the input text into batches
402
+ audio_len = self.ref_audio_processed.shape[-1] / self.target_sample_rate
403
+ max_chars = int(len(self.ref_text.encode("utf-8")) / audio_len * (22 - audio_len))
404
+ text_batches = chunk_text(text, max_chars=max_chars)
405
+
406
+ for i, text_batch in enumerate(text_batches):
407
+ print(f"Text batch {i}: {text_batch}")
408
+ print("\n")
409
+
410
+ # Generate audio for each batch
411
+ generated_waves = []
412
+ spectrograms = []
413
+
414
+ for text_batch in text_batches:
415
+ # Adjust speed for very short texts
416
+ local_speed = speed
417
+ if len(text_batch.encode("utf-8")) < 10:
418
+ local_speed = 0.3
419
+
420
+ # Prepare the text
421
+ text_list = [self.ref_text + text_batch]
422
+ final_text_list = convert_char_to_pinyin(text_list)
423
+
424
+ # Calculate duration
425
+ if fix_duration is not None:
426
+ duration = int(fix_duration * self.target_sample_rate / self.hop_length)
427
+ else:
428
+ # Calculate duration based on text length
429
+ ref_text_len = len(self.ref_text.encode("utf-8"))
430
+ gen_text_len = len(text_batch.encode("utf-8"))
431
+ duration = self.ref_audio_len + int(self.ref_audio_len / ref_text_len * gen_text_len / local_speed)
432
+
433
+ # Generate audio
434
+ with torch.inference_mode():
435
+ generated, _ = self.model.sample(
436
+ cond=self.ref_audio_processed,
437
+ text=final_text_list,
438
+ duration=duration,
439
+ steps=nfe_step,
440
+ cfg_strength=cfg_strength,
441
+ sway_sampling_coef=sway_sampling_coef,
442
+ )
443
+
444
+ # Process the generated mel spectrogram
445
+ generated = generated.to(torch.float32)
446
+ generated = generated[:, self.ref_audio_len:, :]
447
+ generated = generated.permute(0, 2, 1)
448
+
449
+ # Convert to audio
450
+ if self.mel_spec_type == "vocos":
451
+ generated_wave = self.vocoder.decode(generated)
452
+ elif self.mel_spec_type == "bigvgan":
453
+ generated_wave = self.vocoder(generated)
454
+
455
+ # Normalize volume if needed
456
+ rms = torch.sqrt(torch.mean(torch.square(self.ref_audio_processed)))
457
+ if rms < self.target_rms:
458
+ generated_wave = generated_wave * rms / self.target_rms
459
+
460
+ # Convert to numpy and append to list
461
+ generated_wave = generated_wave.squeeze().cpu().numpy()
462
+ generated_waves.append(generated_wave)
463
+
464
+ # Store spectrogram if needed
465
+ if return_spectrogram or output_path is not None:
466
+ spectrograms.append(generated.squeeze().cpu().numpy())
467
+
468
+ # Combine all segments
469
+ if generated_waves:
470
+ if cross_fade_duration <= 0:
471
+ # Simply concatenate
472
+ final_wave = np.concatenate(generated_waves)
473
+ else:
474
+ # Cross-fade between segments
475
+ final_wave = generated_waves[0]
476
+ for i in range(1, len(generated_waves)):
477
+ prev_wave = final_wave
478
+ next_wave = generated_waves[i]
479
+
480
+ # Calculate cross-fade samples
481
+ cross_fade_samples = int(cross_fade_duration * self.target_sample_rate)
482
+ cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
483
+
484
+ if cross_fade_samples <= 0:
485
+ # No overlap possible, concatenate
486
+ final_wave = np.concatenate([prev_wave, next_wave])
487
+ continue
488
+
489
+ # Create cross-fade
490
+ prev_overlap = prev_wave[-cross_fade_samples:]
491
+ next_overlap = next_wave[:cross_fade_samples]
492
+
493
+ fade_out = np.linspace(1, 0, cross_fade_samples)
494
+ fade_in = np.linspace(0, 1, cross_fade_samples)
495
+
496
+ cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
497
+
498
+ final_wave = np.concatenate([
499
+ prev_wave[:-cross_fade_samples],
500
+ cross_faded_overlap,
501
+ next_wave[cross_fade_samples:]
502
+ ])
503
+
504
+ # Combine spectrograms if needed
505
+ if return_spectrogram or output_path is not None:
506
+ combined_spectrogram = np.concatenate(spectrograms, axis=1)
507
+
508
+ # Save to file if path provided
509
+ if output_path is not None:
510
+ output_dir = os.path.dirname(output_path)
511
+ if output_dir and not os.path.exists(output_dir):
512
+ os.makedirs(output_dir)
513
+
514
+ # Save audio
515
+ torchaudio.save(output_path,
516
+ torch.tensor(final_wave).unsqueeze(0),
517
+ self.target_sample_rate)
518
+
519
+ # Save spectrogram if needed
520
+ if return_spectrogram:
521
+ spectrogram_path = os.path.splitext(output_path)[0] + '_spec.png'
522
+ self._save_spectrogram(combined_spectrogram, spectrogram_path)
523
+
524
+ if not return_numpy:
525
+ return output_path
526
+
527
+ # Return as requested
528
+ if return_spectrogram:
529
+ return final_wave, self.target_sample_rate, combined_spectrogram
530
+ else:
531
+ return final_wave, self.target_sample_rate
532
+
533
+ else:
534
+ raise RuntimeError("No audio generated")
535
+
536
+ def _save_spectrogram(self, spectrogram, path):
537
+ """Save spectrogram as image"""
538
+ import matplotlib.pyplot as plt
539
+ plt.figure(figsize=(12, 4))
540
+ plt.imshow(spectrogram, origin="lower", aspect="auto")
541
+ plt.colorbar()
542
+ plt.savefig(path)
543
+ plt.close()
544
+
545
+ def get_current_audio_length(self):
546
+ """Get the length of the reference audio in seconds"""
547
+ if self.ref_audio_processed is None:
548
+ return 0
549
+ return self.ref_audio_processed.shape[-1] / self.target_sample_rate