Upload model/f5tts_wrapper.py with huggingface_hub
Browse files- model/f5tts_wrapper.py +549 -0
model/f5tts_wrapper.py
ADDED
@@ -0,0 +1,549 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torchaudio
|
4 |
+
import numpy as np
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Optional, Union, List, Tuple, Dict
|
7 |
+
|
8 |
+
from cached_path import cached_path
|
9 |
+
from hydra.utils import get_class
|
10 |
+
from omegaconf import OmegaConf
|
11 |
+
from importlib.resources import files
|
12 |
+
from pydub import AudioSegment, silence
|
13 |
+
|
14 |
+
from f5_tts.model import CFM
|
15 |
+
from f5_tts.model.utils import (
|
16 |
+
get_tokenizer,
|
17 |
+
convert_char_to_pinyin,
|
18 |
+
)
|
19 |
+
from f5_tts.infer.utils_infer import (
|
20 |
+
chunk_text,
|
21 |
+
load_vocoder,
|
22 |
+
transcribe,
|
23 |
+
initialize_asr_pipeline,
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
class F5TTSWrapper:
|
28 |
+
"""
|
29 |
+
A wrapper class for F5-TTS that preprocesses reference audio once
|
30 |
+
and allows for repeated TTS generation.
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
model_name: str = "F5TTS_v1_Base",
|
36 |
+
ckpt_path: Optional[str] = None,
|
37 |
+
vocab_file: Optional[str] = None,
|
38 |
+
vocoder_name: str = "vocos",
|
39 |
+
use_local_vocoder: bool = False,
|
40 |
+
vocoder_path: Optional[str] = None,
|
41 |
+
device: Optional[str] = None,
|
42 |
+
hf_cache_dir: Optional[str] = None,
|
43 |
+
target_sample_rate: int = 24000,
|
44 |
+
n_mel_channels: int = 100,
|
45 |
+
hop_length: int = 256,
|
46 |
+
win_length: int = 1024,
|
47 |
+
n_fft: int = 1024,
|
48 |
+
ode_method: str = "euler",
|
49 |
+
use_ema: bool = True,
|
50 |
+
):
|
51 |
+
"""
|
52 |
+
Initialize the F5-TTS wrapper with model configuration.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
model_name: Name of the F5-TTS model variant (e.g., "F5TTS_v1_Base")
|
56 |
+
ckpt_path: Path to the model checkpoint file. If None, will use default path.
|
57 |
+
vocab_file: Path to the vocab file. If None, will use default.
|
58 |
+
vocoder_name: Name of the vocoder to use ("vocos" or "bigvgan")
|
59 |
+
use_local_vocoder: Whether to use a local vocoder or download from HF
|
60 |
+
vocoder_path: Path to the local vocoder. Only used if use_local_vocoder is True.
|
61 |
+
device: Device to run the model on. If None, will automatically determine.
|
62 |
+
hf_cache_dir: Directory to cache HuggingFace models
|
63 |
+
target_sample_rate: Target sample rate for audio
|
64 |
+
n_mel_channels: Number of mel channels
|
65 |
+
hop_length: Hop length for the mel spectrogram
|
66 |
+
win_length: Window length for the mel spectrogram
|
67 |
+
n_fft: FFT size for the mel spectrogram
|
68 |
+
ode_method: ODE method for sampling ("euler" or "midpoint")
|
69 |
+
use_ema: Whether to use EMA weights from the checkpoint
|
70 |
+
"""
|
71 |
+
# Set device
|
72 |
+
if device is None:
|
73 |
+
self.device = (
|
74 |
+
"cuda" if torch.cuda.is_available()
|
75 |
+
else "xpu" if torch.xpu.is_available()
|
76 |
+
else "mps" if torch.backends.mps.is_available()
|
77 |
+
else "cpu"
|
78 |
+
)
|
79 |
+
else:
|
80 |
+
self.device = device
|
81 |
+
|
82 |
+
# Audio processing parameters
|
83 |
+
self.target_sample_rate = target_sample_rate
|
84 |
+
self.n_mel_channels = n_mel_channels
|
85 |
+
self.hop_length = hop_length
|
86 |
+
self.win_length = win_length
|
87 |
+
self.n_fft = n_fft
|
88 |
+
self.mel_spec_type = vocoder_name
|
89 |
+
|
90 |
+
# Sampling parameters
|
91 |
+
self.ode_method = ode_method
|
92 |
+
|
93 |
+
# Initialize ASR for transcription if needed
|
94 |
+
initialize_asr_pipeline(device=self.device)
|
95 |
+
|
96 |
+
# Load model configuration
|
97 |
+
if ckpt_path is None:
|
98 |
+
repo_name = "F5-TTS"
|
99 |
+
ckpt_step = 1250000
|
100 |
+
ckpt_type = "safetensors"
|
101 |
+
|
102 |
+
# Adjust for previous models
|
103 |
+
if model_name == "F5TTS_Base":
|
104 |
+
if vocoder_name == "vocos":
|
105 |
+
ckpt_step = 1200000
|
106 |
+
elif vocoder_name == "bigvgan":
|
107 |
+
model_name = "F5TTS_Base_bigvgan"
|
108 |
+
ckpt_type = "pt"
|
109 |
+
elif model_name == "E2TTS_Base":
|
110 |
+
repo_name = "E2-TTS"
|
111 |
+
ckpt_step = 1200000
|
112 |
+
|
113 |
+
ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{model_name}/model_{ckpt_step}.{ckpt_type}"))
|
114 |
+
|
115 |
+
# Load model configuration
|
116 |
+
config_path = str(files("f5_tts").joinpath(f"configs/{model_name}.yaml"))
|
117 |
+
model_cfg = OmegaConf.load(config_path)
|
118 |
+
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
119 |
+
model_arc = model_cfg.model.arch
|
120 |
+
|
121 |
+
# Load tokenizer
|
122 |
+
if vocab_file is None:
|
123 |
+
vocab_file = str(files("f5_tts").joinpath("infer/examples/vocab.txt"))
|
124 |
+
tokenizer_type = "custom"
|
125 |
+
self.vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer_type)
|
126 |
+
|
127 |
+
# Create model
|
128 |
+
self.model = CFM(
|
129 |
+
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
|
130 |
+
mel_spec_kwargs=dict(
|
131 |
+
n_fft=n_fft,
|
132 |
+
hop_length=hop_length,
|
133 |
+
win_length=win_length,
|
134 |
+
n_mel_channels=n_mel_channels,
|
135 |
+
target_sample_rate=target_sample_rate,
|
136 |
+
mel_spec_type=vocoder_name,
|
137 |
+
),
|
138 |
+
odeint_kwargs=dict(
|
139 |
+
method=ode_method,
|
140 |
+
),
|
141 |
+
vocab_char_map=self.vocab_char_map,
|
142 |
+
).to(self.device)
|
143 |
+
|
144 |
+
# Load checkpoint
|
145 |
+
dtype = torch.float32 if vocoder_name == "bigvgan" else None
|
146 |
+
self._load_checkpoint(self.model, ckpt_path, dtype=dtype, use_ema=use_ema)
|
147 |
+
|
148 |
+
# Load vocoder
|
149 |
+
if vocoder_path is None:
|
150 |
+
if vocoder_name == "vocos":
|
151 |
+
vocoder_path = "../checkpoints/vocos-mel-24khz"
|
152 |
+
elif vocoder_name == "bigvgan":
|
153 |
+
vocoder_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
154 |
+
|
155 |
+
self.vocoder = load_vocoder(
|
156 |
+
vocoder_name=vocoder_name,
|
157 |
+
is_local=use_local_vocoder,
|
158 |
+
local_path=vocoder_path,
|
159 |
+
device=self.device,
|
160 |
+
hf_cache_dir=hf_cache_dir
|
161 |
+
)
|
162 |
+
|
163 |
+
# Storage for reference data
|
164 |
+
self.ref_audio_processed = None
|
165 |
+
self.ref_text = None
|
166 |
+
self.ref_audio_len = None
|
167 |
+
|
168 |
+
# Default inference parameters
|
169 |
+
self.target_rms = 0.1
|
170 |
+
self.cross_fade_duration = 0.15
|
171 |
+
self.nfe_step = 32
|
172 |
+
self.cfg_strength = 2.0
|
173 |
+
self.sway_sampling_coef = -1.0
|
174 |
+
self.speed = 1.0
|
175 |
+
self.fix_duration = None
|
176 |
+
|
177 |
+
def _load_checkpoint(self, model, ckpt_path, dtype=None, use_ema=True):
|
178 |
+
"""
|
179 |
+
Load model checkpoint with proper handling of different checkpoint formats.
|
180 |
+
|
181 |
+
Args:
|
182 |
+
model: The model to load weights into
|
183 |
+
ckpt_path: Path to the checkpoint file
|
184 |
+
dtype: Data type for model weights
|
185 |
+
use_ema: Whether to use EMA weights from the checkpoint
|
186 |
+
|
187 |
+
Returns:
|
188 |
+
Loaded model
|
189 |
+
"""
|
190 |
+
if dtype is None:
|
191 |
+
dtype = (
|
192 |
+
torch.float16
|
193 |
+
if "cuda" in self.device
|
194 |
+
and torch.cuda.get_device_properties(self.device).major >= 7
|
195 |
+
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
|
196 |
+
else torch.float32
|
197 |
+
)
|
198 |
+
model = model.to(dtype)
|
199 |
+
|
200 |
+
ckpt_type = ckpt_path.split(".")[-1]
|
201 |
+
if ckpt_type == "safetensors":
|
202 |
+
from safetensors.torch import load_file
|
203 |
+
checkpoint = load_file(ckpt_path, device=self.device)
|
204 |
+
else:
|
205 |
+
checkpoint = torch.load(ckpt_path, map_location=self.device, weights_only=True)
|
206 |
+
|
207 |
+
if use_ema:
|
208 |
+
if ckpt_type == "safetensors":
|
209 |
+
checkpoint = {"ema_model_state_dict": checkpoint}
|
210 |
+
checkpoint["model_state_dict"] = {
|
211 |
+
k.replace("ema_model.", ""): v
|
212 |
+
for k, v in checkpoint["ema_model_state_dict"].items()
|
213 |
+
if k not in ["initted", "step"]
|
214 |
+
}
|
215 |
+
|
216 |
+
# patch for backward compatibility
|
217 |
+
for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
|
218 |
+
if key in checkpoint["model_state_dict"]:
|
219 |
+
del checkpoint["model_state_dict"][key]
|
220 |
+
|
221 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
222 |
+
else:
|
223 |
+
if ckpt_type == "safetensors":
|
224 |
+
checkpoint = {"model_state_dict": checkpoint}
|
225 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
226 |
+
|
227 |
+
del checkpoint
|
228 |
+
torch.cuda.empty_cache()
|
229 |
+
|
230 |
+
return model.to(self.device)
|
231 |
+
|
232 |
+
def preprocess_reference(self, ref_audio_path: str, ref_text: str = "", clip_short: bool = True):
|
233 |
+
"""
|
234 |
+
Preprocess the reference audio and text, storing them for later use.
|
235 |
+
|
236 |
+
Args:
|
237 |
+
ref_audio_path: Path to the reference audio file
|
238 |
+
ref_text: Text transcript of reference audio. If empty, will auto-transcribe.
|
239 |
+
clip_short: Whether to clip long audio to shorter segments
|
240 |
+
|
241 |
+
Returns:
|
242 |
+
Tuple of processed audio and text
|
243 |
+
"""
|
244 |
+
print("Converting audio...")
|
245 |
+
# Load audio file
|
246 |
+
aseg = AudioSegment.from_file(ref_audio_path)
|
247 |
+
|
248 |
+
if clip_short:
|
249 |
+
# 1. try to find long silence for clipping
|
250 |
+
non_silent_segs = silence.split_on_silence(
|
251 |
+
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10
|
252 |
+
)
|
253 |
+
non_silent_wave = AudioSegment.silent(duration=0)
|
254 |
+
for non_silent_seg in non_silent_segs:
|
255 |
+
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
|
256 |
+
print("Audio is over 12s, clipping short. (1)")
|
257 |
+
break
|
258 |
+
non_silent_wave += non_silent_seg
|
259 |
+
|
260 |
+
# 2. try to find short silence for clipping if 1. failed
|
261 |
+
if len(non_silent_wave) > 12000:
|
262 |
+
non_silent_segs = silence.split_on_silence(
|
263 |
+
aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
|
264 |
+
)
|
265 |
+
non_silent_wave = AudioSegment.silent(duration=0)
|
266 |
+
for non_silent_seg in non_silent_segs:
|
267 |
+
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
|
268 |
+
print("Audio is over 12s, clipping short. (2)")
|
269 |
+
break
|
270 |
+
non_silent_wave += non_silent_seg
|
271 |
+
|
272 |
+
aseg = non_silent_wave
|
273 |
+
|
274 |
+
# 3. if no proper silence found for clipping
|
275 |
+
if len(aseg) > 12000:
|
276 |
+
aseg = aseg[:12000]
|
277 |
+
print("Audio is over 12s, clipping short. (3)")
|
278 |
+
|
279 |
+
# Remove silence edges
|
280 |
+
aseg = self._remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
|
281 |
+
|
282 |
+
# Export to temporary file and load as tensor
|
283 |
+
import tempfile
|
284 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
|
285 |
+
aseg.export(tmp_file.name, format="wav")
|
286 |
+
processed_audio_path = tmp_file.name
|
287 |
+
|
288 |
+
# Transcribe if needed
|
289 |
+
if not ref_text.strip():
|
290 |
+
print("No reference text provided, transcribing reference audio...")
|
291 |
+
ref_text = transcribe(processed_audio_path)
|
292 |
+
else:
|
293 |
+
print("Using custom reference text...")
|
294 |
+
|
295 |
+
# Ensure ref_text ends with proper punctuation
|
296 |
+
if not ref_text.endswith(". ") and not ref_text.endswith("。"):
|
297 |
+
if ref_text.endswith("."):
|
298 |
+
ref_text += " "
|
299 |
+
else:
|
300 |
+
ref_text += ". "
|
301 |
+
|
302 |
+
print("\nReference text:", ref_text)
|
303 |
+
|
304 |
+
# Load and process audio
|
305 |
+
audio, sr = torchaudio.load(processed_audio_path)
|
306 |
+
if audio.shape[0] > 1: # Convert stereo to mono
|
307 |
+
audio = torch.mean(audio, dim=0, keepdim=True)
|
308 |
+
|
309 |
+
# Normalize volume
|
310 |
+
rms = torch.sqrt(torch.mean(torch.square(audio)))
|
311 |
+
if rms < self.target_rms:
|
312 |
+
audio = audio * self.target_rms / rms
|
313 |
+
|
314 |
+
# Resample if needed
|
315 |
+
if sr != self.target_sample_rate:
|
316 |
+
resampler = torchaudio.transforms.Resample(sr, self.target_sample_rate)
|
317 |
+
audio = resampler(audio)
|
318 |
+
|
319 |
+
# Move to device
|
320 |
+
audio = audio.to(self.device)
|
321 |
+
|
322 |
+
# Store reference data
|
323 |
+
self.ref_audio_processed = audio
|
324 |
+
self.ref_text = ref_text
|
325 |
+
self.ref_audio_len = audio.shape[-1] // self.hop_length
|
326 |
+
|
327 |
+
# Remove temporary file
|
328 |
+
os.unlink(processed_audio_path)
|
329 |
+
|
330 |
+
return audio, ref_text
|
331 |
+
|
332 |
+
def _remove_silence_edges(self, audio, silence_threshold=-42):
|
333 |
+
"""
|
334 |
+
Remove silence from the start and end of audio.
|
335 |
+
|
336 |
+
Args:
|
337 |
+
audio: AudioSegment to process
|
338 |
+
silence_threshold: dB threshold to consider as silence
|
339 |
+
|
340 |
+
Returns:
|
341 |
+
Processed AudioSegment
|
342 |
+
"""
|
343 |
+
# Remove silence from the start
|
344 |
+
non_silent_start_idx = silence.detect_leading_silence(audio, silence_threshold=silence_threshold)
|
345 |
+
audio = audio[non_silent_start_idx:]
|
346 |
+
|
347 |
+
# Remove silence from the end
|
348 |
+
non_silent_end_duration = audio.duration_seconds
|
349 |
+
for ms in reversed(audio):
|
350 |
+
if ms.dBFS > silence_threshold:
|
351 |
+
break
|
352 |
+
non_silent_end_duration -= 0.001
|
353 |
+
trimmed_audio = audio[: int(non_silent_end_duration * 1000)]
|
354 |
+
|
355 |
+
return trimmed_audio
|
356 |
+
|
357 |
+
def generate(
|
358 |
+
self,
|
359 |
+
text: str,
|
360 |
+
output_path: Optional[str] = None,
|
361 |
+
nfe_step: Optional[int] = None,
|
362 |
+
cfg_strength: Optional[float] = None,
|
363 |
+
sway_sampling_coef: Optional[float] = None,
|
364 |
+
speed: Optional[float] = None,
|
365 |
+
fix_duration: Optional[float] = None,
|
366 |
+
cross_fade_duration: Optional[float] = None,
|
367 |
+
return_numpy: bool = False,
|
368 |
+
return_spectrogram: bool = False,
|
369 |
+
) -> Union[str, Tuple[np.ndarray, int], Tuple[np.ndarray, int, np.ndarray]]:
|
370 |
+
"""
|
371 |
+
Generate speech for the given text using the stored reference audio.
|
372 |
+
|
373 |
+
Args:
|
374 |
+
text: Text to synthesize
|
375 |
+
output_path: Path to save the generated audio. If None, won't save.
|
376 |
+
nfe_step: Number of function evaluation steps
|
377 |
+
cfg_strength: Classifier-free guidance strength
|
378 |
+
sway_sampling_coef: Sway sampling coefficient
|
379 |
+
speed: Speed of generated audio
|
380 |
+
fix_duration: Fixed duration in seconds
|
381 |
+
cross_fade_duration: Duration of cross-fade between segments
|
382 |
+
return_numpy: If True, returns the audio as a numpy array
|
383 |
+
return_spectrogram: If True, also returns the spectrogram
|
384 |
+
|
385 |
+
Returns:
|
386 |
+
If output_path provided: path to output file
|
387 |
+
If return_numpy=True: tuple of (audio_array, sample_rate)
|
388 |
+
If return_spectrogram=True: tuple of (audio_array, sample_rate, spectrogram)
|
389 |
+
"""
|
390 |
+
if self.ref_audio_processed is None or self.ref_text is None:
|
391 |
+
raise ValueError("Reference audio not preprocessed. Call preprocess_reference() first.")
|
392 |
+
|
393 |
+
# Use default values if not specified
|
394 |
+
nfe_step = nfe_step if nfe_step is not None else self.nfe_step
|
395 |
+
cfg_strength = cfg_strength if cfg_strength is not None else self.cfg_strength
|
396 |
+
sway_sampling_coef = sway_sampling_coef if sway_sampling_coef is not None else self.sway_sampling_coef
|
397 |
+
speed = speed if speed is not None else self.speed
|
398 |
+
fix_duration = fix_duration if fix_duration is not None else self.fix_duration
|
399 |
+
cross_fade_duration = cross_fade_duration if cross_fade_duration is not None else self.cross_fade_duration
|
400 |
+
|
401 |
+
# Split the input text into batches
|
402 |
+
audio_len = self.ref_audio_processed.shape[-1] / self.target_sample_rate
|
403 |
+
max_chars = int(len(self.ref_text.encode("utf-8")) / audio_len * (22 - audio_len))
|
404 |
+
text_batches = chunk_text(text, max_chars=max_chars)
|
405 |
+
|
406 |
+
for i, text_batch in enumerate(text_batches):
|
407 |
+
print(f"Text batch {i}: {text_batch}")
|
408 |
+
print("\n")
|
409 |
+
|
410 |
+
# Generate audio for each batch
|
411 |
+
generated_waves = []
|
412 |
+
spectrograms = []
|
413 |
+
|
414 |
+
for text_batch in text_batches:
|
415 |
+
# Adjust speed for very short texts
|
416 |
+
local_speed = speed
|
417 |
+
if len(text_batch.encode("utf-8")) < 10:
|
418 |
+
local_speed = 0.3
|
419 |
+
|
420 |
+
# Prepare the text
|
421 |
+
text_list = [self.ref_text + text_batch]
|
422 |
+
final_text_list = convert_char_to_pinyin(text_list)
|
423 |
+
|
424 |
+
# Calculate duration
|
425 |
+
if fix_duration is not None:
|
426 |
+
duration = int(fix_duration * self.target_sample_rate / self.hop_length)
|
427 |
+
else:
|
428 |
+
# Calculate duration based on text length
|
429 |
+
ref_text_len = len(self.ref_text.encode("utf-8"))
|
430 |
+
gen_text_len = len(text_batch.encode("utf-8"))
|
431 |
+
duration = self.ref_audio_len + int(self.ref_audio_len / ref_text_len * gen_text_len / local_speed)
|
432 |
+
|
433 |
+
# Generate audio
|
434 |
+
with torch.inference_mode():
|
435 |
+
generated, _ = self.model.sample(
|
436 |
+
cond=self.ref_audio_processed,
|
437 |
+
text=final_text_list,
|
438 |
+
duration=duration,
|
439 |
+
steps=nfe_step,
|
440 |
+
cfg_strength=cfg_strength,
|
441 |
+
sway_sampling_coef=sway_sampling_coef,
|
442 |
+
)
|
443 |
+
|
444 |
+
# Process the generated mel spectrogram
|
445 |
+
generated = generated.to(torch.float32)
|
446 |
+
generated = generated[:, self.ref_audio_len:, :]
|
447 |
+
generated = generated.permute(0, 2, 1)
|
448 |
+
|
449 |
+
# Convert to audio
|
450 |
+
if self.mel_spec_type == "vocos":
|
451 |
+
generated_wave = self.vocoder.decode(generated)
|
452 |
+
elif self.mel_spec_type == "bigvgan":
|
453 |
+
generated_wave = self.vocoder(generated)
|
454 |
+
|
455 |
+
# Normalize volume if needed
|
456 |
+
rms = torch.sqrt(torch.mean(torch.square(self.ref_audio_processed)))
|
457 |
+
if rms < self.target_rms:
|
458 |
+
generated_wave = generated_wave * rms / self.target_rms
|
459 |
+
|
460 |
+
# Convert to numpy and append to list
|
461 |
+
generated_wave = generated_wave.squeeze().cpu().numpy()
|
462 |
+
generated_waves.append(generated_wave)
|
463 |
+
|
464 |
+
# Store spectrogram if needed
|
465 |
+
if return_spectrogram or output_path is not None:
|
466 |
+
spectrograms.append(generated.squeeze().cpu().numpy())
|
467 |
+
|
468 |
+
# Combine all segments
|
469 |
+
if generated_waves:
|
470 |
+
if cross_fade_duration <= 0:
|
471 |
+
# Simply concatenate
|
472 |
+
final_wave = np.concatenate(generated_waves)
|
473 |
+
else:
|
474 |
+
# Cross-fade between segments
|
475 |
+
final_wave = generated_waves[0]
|
476 |
+
for i in range(1, len(generated_waves)):
|
477 |
+
prev_wave = final_wave
|
478 |
+
next_wave = generated_waves[i]
|
479 |
+
|
480 |
+
# Calculate cross-fade samples
|
481 |
+
cross_fade_samples = int(cross_fade_duration * self.target_sample_rate)
|
482 |
+
cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
|
483 |
+
|
484 |
+
if cross_fade_samples <= 0:
|
485 |
+
# No overlap possible, concatenate
|
486 |
+
final_wave = np.concatenate([prev_wave, next_wave])
|
487 |
+
continue
|
488 |
+
|
489 |
+
# Create cross-fade
|
490 |
+
prev_overlap = prev_wave[-cross_fade_samples:]
|
491 |
+
next_overlap = next_wave[:cross_fade_samples]
|
492 |
+
|
493 |
+
fade_out = np.linspace(1, 0, cross_fade_samples)
|
494 |
+
fade_in = np.linspace(0, 1, cross_fade_samples)
|
495 |
+
|
496 |
+
cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
|
497 |
+
|
498 |
+
final_wave = np.concatenate([
|
499 |
+
prev_wave[:-cross_fade_samples],
|
500 |
+
cross_faded_overlap,
|
501 |
+
next_wave[cross_fade_samples:]
|
502 |
+
])
|
503 |
+
|
504 |
+
# Combine spectrograms if needed
|
505 |
+
if return_spectrogram or output_path is not None:
|
506 |
+
combined_spectrogram = np.concatenate(spectrograms, axis=1)
|
507 |
+
|
508 |
+
# Save to file if path provided
|
509 |
+
if output_path is not None:
|
510 |
+
output_dir = os.path.dirname(output_path)
|
511 |
+
if output_dir and not os.path.exists(output_dir):
|
512 |
+
os.makedirs(output_dir)
|
513 |
+
|
514 |
+
# Save audio
|
515 |
+
torchaudio.save(output_path,
|
516 |
+
torch.tensor(final_wave).unsqueeze(0),
|
517 |
+
self.target_sample_rate)
|
518 |
+
|
519 |
+
# Save spectrogram if needed
|
520 |
+
if return_spectrogram:
|
521 |
+
spectrogram_path = os.path.splitext(output_path)[0] + '_spec.png'
|
522 |
+
self._save_spectrogram(combined_spectrogram, spectrogram_path)
|
523 |
+
|
524 |
+
if not return_numpy:
|
525 |
+
return output_path
|
526 |
+
|
527 |
+
# Return as requested
|
528 |
+
if return_spectrogram:
|
529 |
+
return final_wave, self.target_sample_rate, combined_spectrogram
|
530 |
+
else:
|
531 |
+
return final_wave, self.target_sample_rate
|
532 |
+
|
533 |
+
else:
|
534 |
+
raise RuntimeError("No audio generated")
|
535 |
+
|
536 |
+
def _save_spectrogram(self, spectrogram, path):
|
537 |
+
"""Save spectrogram as image"""
|
538 |
+
import matplotlib.pyplot as plt
|
539 |
+
plt.figure(figsize=(12, 4))
|
540 |
+
plt.imshow(spectrogram, origin="lower", aspect="auto")
|
541 |
+
plt.colorbar()
|
542 |
+
plt.savefig(path)
|
543 |
+
plt.close()
|
544 |
+
|
545 |
+
def get_current_audio_length(self):
|
546 |
+
"""Get the length of the reference audio in seconds"""
|
547 |
+
if self.ref_audio_processed is None:
|
548 |
+
return 0
|
549 |
+
return self.ref_audio_processed.shape[-1] / self.target_sample_rate
|