mehdi999 commited on
Commit
9f2e2fc
·
1 Parent(s): 2dc4aff

added few things

Browse files
Files changed (1) hide show
  1. app.py +127 -168
app.py CHANGED
@@ -1,122 +1,33 @@
1
  import os
2
- import time
3
- import traceback
4
- import threading
5
- from concurrent.futures import ThreadPoolExecutor, TimeoutError as FTimeout
6
-
7
  import gradio as gr
8
  import numpy as np
9
- import soundfile as sf
10
  import torch
 
11
  import spaces
 
 
 
12
 
13
- # ---------- Force safe runtime BEFORE any project imports ----------
14
  os.environ.setdefault("FLA_CONV_BACKEND", "torch")
15
  os.environ.setdefault("FLA_USE_FAST_OPS", "0")
16
- os.environ.setdefault("FLA_DISABLE_TRITON", "1")
17
- os.environ.setdefault("TORCH_COMPILE_DISABLE", "1")
18
  os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
19
- os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "1")
20
- os.environ.setdefault("PYTORCH_NO_CUDA_MEMORY_CACHING", "1")
21
- os.environ.setdefault("PYTORCH_JIT_DISABLE", "1")
22
- os.environ.setdefault("TORCHINDUCTOR_DISABLE", "1")
23
- os.environ.setdefault("NVTX_PROFILE", "0")
24
 
 
25
  torch.backends.cuda.matmul.allow_tf32 = True
26
  try:
27
  torch.set_float32_matmul_precision("high")
28
  except Exception:
29
  pass
30
 
31
- from huggingface_hub import login
32
-
33
- # Delay project imports until after we install stubs/patches
34
- def _install_fla_stub_and_instrumentation(LOG):
35
- """
36
- - Replace SimpleGatedLinearAttention by a safe PyTorch stub
37
- - Instrument key constructors to log begin/end
38
- """
39
- try:
40
- import importlib
41
-
42
- # --- FLA stub on SimpleGatedLinearAttention
43
- sgm = importlib.import_module("tts.model.simple_gla")
44
- import torch.nn as nn
45
-
46
- class SafeSimpleGatedLinearAttention(nn.Module):
47
- def __init__(self, *args, **kwargs):
48
- super().__init__()
49
- self.kwargs = dict(kwargs)
50
 
51
- def forward(self, x, past_key_values=None, use_cache: bool = False, **kwargs):
52
- conv_state = None
53
- if use_cache and isinstance(past_key_values, dict):
54
- conv_state = past_key_values.get("conv_state")
55
- return x, conv_state
56
 
57
- sgm.SimpleGatedLinearAttention = SafeSimpleGatedLinearAttention
58
- LOG("[patch] SimpleGatedLinearAttention -> Safe stub")
59
- except Exception as e:
60
- LOG(f"[patch] FLA stub failed: {e}")
61
-
62
- # --- Instrument deeper pieces
63
- try:
64
- tts_mod = importlib.import_module("tts.tts")
65
- _orig_ifc = tts_mod.ARTTSModel.instantiate_from_config
66
- def _ifc_verbose(cfg):
67
- LOG("[inst] ARTTSModel.instantiate_from_config: begin")
68
- o = _orig_ifc(cfg)
69
- LOG("[inst] ARTTSModel.instantiate_from_config: end")
70
- return o
71
- tts_mod.ARTTSModel.instantiate_from_config = staticmethod(_ifc_verbose) # type: ignore
72
- LOG("[patch] ARTTSModel.instantiate_from_config instrumented")
73
- except Exception as e:
74
- LOG(f"[patch] ARTTSModel patch failed: {e}")
75
-
76
- # Patch constructors that previously appeared in traces
77
- try:
78
- from codec.models.patchvae.model import PatchVAE
79
- _orig_p_init = PatchVAE.__init__
80
- def _p_init_verbose(self, *a, **kw):
81
- LOG("[inst] PatchVAE.__init__: begin")
82
- r = _orig_p_init(self, *a, **kw)
83
- LOG("[inst] PatchVAE.__init__: end")
84
- return r
85
- PatchVAE.__init__ = _p_init_verbose # type: ignore
86
- LOG("[patch] PatchVAE.__init__ instrumented")
87
- except Exception as e:
88
- LOG(f"[patch] PatchVAE patch failed: {e}")
89
-
90
- try:
91
- from codec.models.wavvae.model import WavVAE
92
- _orig_w_init = WavVAE.__init__
93
- def _w_init_verbose(self, *a, **kw):
94
- LOG("[inst] WavVAE.__init__: begin")
95
- r = _orig_w_init(self, *a, **kw)
96
- LOG("[inst] WavVAE.__init__: end")
97
- return r
98
- WavVAE.__init__ = _w_init_verbose # type: ignore
99
- LOG("[patch] WavVAE.__init__ instrumented")
100
- except Exception as e:
101
- LOG(f"[patch] WavVAE patch failed: {e}")
102
-
103
-
104
- def _env_diag() -> str:
105
- parts = [f"torch={torch.__version__}"]
106
- try:
107
- import triton # type: ignore
108
- parts.append(f"triton={getattr(triton, '__version__', 'unknown')}")
109
- except Exception:
110
- parts.append("triton=not_importable")
111
- parts.append(f"cuda.is_available={torch.cuda.is_available()}")
112
- if torch.cuda.is_available():
113
- parts.append(f"cuda.version={torch.version.cuda}")
114
- try:
115
- free, total = torch.cuda.mem_get_info()
116
- parts.append(f"mem_free={free/1e9:.2f}GB/{total/1e9:.2f}GB")
117
- except Exception:
118
- pass
119
- return " | ".join(parts)
120
 
121
 
122
  def _normalize_text(s: str, lang_hint: str = "fr") -> str:
@@ -142,54 +53,71 @@ def _to_mono_float32(arr: np.ndarray) -> np.ndarray:
142
  return arr.astype(np.float32)
143
 
144
 
145
- def _full_thread_dump(LOG, label="stack"):
 
146
  try:
147
- import faulthandler, io, sys
148
- buf = io.StringIO()
149
- faulthandler.dump_traceback(file=buf, all_threads=True)
150
- LOG(f"[{label}] dump begin")
151
- LOG(buf.getvalue()[-2000:])
152
- LOG(f"[{label}] dump end")
 
 
 
 
 
 
 
 
153
  except Exception as e:
154
- LOG(f"[{label}] dump failed: {e}")
155
-
156
-
157
- def _load_model(LOG):
158
- # Apply stub & instrumentation BEFORE imports that build the graph
159
- _install_fla_stub_and_instrumentation(LOG)
160
-
161
- # Import model AFTER patches
162
- from pardi_speech import PardiSpeech, VelocityHeadSamplingParams as _VHSP # noqa
163
-
164
- dev = "cuda" if torch.cuda.is_available() else "cpu"
165
- LOG(f"[load] PardiSpeech.from_pretrained(repo_id=theodorr/pardi-speech-enfr-forbidden, map_location={dev})…")
166
 
167
- # Start a watchdog dumper thread for extra detail every 20s
168
- stop_evt = threading.Event()
169
- def dumper():
170
- k = 1
171
- while not stop_evt.wait(20.0):
172
- _full_thread_dump(LOG, label=f"stack@{20*k}s")
173
- k += 1
174
- th = threading.Thread(target=dumper, daemon=True)
175
- th.start()
176
 
177
- m = PardiSpeech.from_pretrained("theodorr/pardi-speech-enfr-forbidden", map_location=dev)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  m.eval()
179
-
180
- stop_evt.set()
181
- th.join(timeout=1.0)
182
-
183
  sr = getattr(m, "sampling_rate", 24000)
184
- LOG(f"[load] ready (sr={sr})")
185
  return m, sr
186
 
187
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  @spaces.GPU(duration=200)
189
  def synthesize(
190
  text: str,
191
  debug: bool,
192
- adv_sampling: bool,
193
  ref_audio,
194
  ref_text: str,
195
  steps: int,
@@ -200,18 +128,19 @@ def synthesize(
200
  seed: int,
201
  lang_hint: str,
202
  ):
 
203
  logs = []
204
  t0 = time.perf_counter()
205
 
206
  def LOG(msg):
207
  logs.append(str(msg))
 
208
  joined = "\n".join(logs)
209
- if len(joined) > 12000:
210
- joined = joined[-12000:]
211
  return joined
212
 
213
  try:
214
- HF_TOKEN = os.environ.get("HF_TOKEN")
215
  if HF_TOKEN:
216
  try:
217
  login(token=HF_TOKEN)
@@ -220,32 +149,40 @@ def synthesize(
220
  yield None, LOG(f"⚠️ HF login failed: {e}")
221
 
222
  yield None, LOG("[env] " + _env_diag())
 
 
223
  torch.manual_seed(int(seed))
 
 
 
 
 
224
 
225
- # Load model with watchdog + heartbeats
226
- yield None, LOG("[init] loading model…")
227
- MAX_WALLTIME_S = 110
228
  with ThreadPoolExecutor(max_workers=1) as ex:
229
- fut = ex.submit(_load_model, LOG)
230
- last = time.perf_counter()
231
  while True:
232
  try:
233
- pardi, _sr = fut.result(timeout=2.0)
 
 
234
  break
235
  except FTimeout:
236
  now = time.perf_counter()
237
  elapsed = now - t0
238
- if now - last >= 2.0:
239
- yield None, LOG(f"[init] still loading… {elapsed:.1f}s")
240
- last = now
 
241
  if elapsed > MAX_WALLTIME_S:
242
- _full_thread_dump(LOG, label="stack@timeout")
243
  ex.shutdown(cancel_futures=True)
244
- raise TimeoutError(f"Watchdog: dépassement {elapsed:.1f}s pendant from_pretrained")
245
 
 
 
246
  yield None, LOG(f"[init] model ready on {'cuda' if torch.cuda.is_available() else 'cpu'}, sr={_sr}")
247
 
248
- # ---- Prepare text / prefix ----
249
  txt = _normalize_text(text, lang_hint=lang_hint)
250
  yield None, LOG(f"[text] normalized: {txt[:120]}{'…' if len(txt)>120 else ''}")
251
 
@@ -265,29 +202,45 @@ def synthesize(
265
  import torchaudio
266
  if sr != getattr(pardi, "sampling_rate", 24000):
267
  wav_t = torchaudio.functional.resample(wav_t, sr, getattr(pardi, "sampling_rate", 24000))
268
- except Exception:
269
- LOG("⚠️ torchaudio resample not available")
270
  wav_t = wav_t.unsqueeze(0)
271
  with torch.inference_mode():
272
  prefix_tokens = pardi.patchvae.encode(wav_t)
273
  prefix = (ref_text or "", prefix_tokens[0])
274
  yield None, LOG("[prefix] done.")
275
 
276
- # ---- Generate ----
277
  yield None, LOG(f"[run] has_prefix={prefix is not None}, steps={steps}, cfg={cfg}, cfg_ref={cfg_ref}, T={temperature}, max_seq_len={max_seq_len}, seed={seed}, adv_sampling={adv_sampling}")
 
 
278
  with torch.inference_mode():
279
  if adv_sampling:
280
- from pardi_speech import VelocityHeadSamplingParams
281
  try:
282
- vel_params = VelocityHeadSamplingParams(cfg_ref=float(cfg_ref), cfg=float(cfg), num_steps=int(steps))
 
 
 
 
283
  except TypeError:
284
- vel_params = VelocityHeadSamplingParams(cfg_ref=float(cfg_ref), cfg=float(cfg), num_steps=int(steps), temperature=float(temperature))
285
- wavs, _ = pardi.text_to_speech([txt], prefix, max_seq_len=int(max_seq_len), velocity_head_sampling_params=vel_params)
 
 
 
 
 
 
 
 
286
  else:
287
- wavs, _ = pardi.text_to_speech([txt], prefix, max_seq_len=int(max_seq_len))
288
-
 
 
289
  wav = wavs[0].detach().cpu().numpy().astype(np.float32)
290
- yield (24000, wav), LOG(f"[ok] walltime={time.perf_counter()-t0:.2f}s")
 
291
 
292
  except Exception as e:
293
  tb = traceback.format_exc()
@@ -298,11 +251,13 @@ def build_demo():
298
  with gr.Blocks(title="Lina-speech / pardi-speech Demo") as demo:
299
  gr.Markdown(
300
  "## Lina-speech (pardi-speech) – Démo TTS\n"
301
- "Génère de l'audio à partir de texte, avec ou sans *prefix* (audio de référence)."
 
 
302
  )
303
 
304
  with gr.Row():
305
- text = gr.Textbox(label="Texte à synthétiser", lines=4, value="Bonjour ! Ceci est un test de la démo Lina-speech.", placeholder="Tape ton texte ici…")
306
  debug = gr.Checkbox(value=False, label="Mode debug (afficher la stacktrace)")
307
  adv_sampling = gr.Checkbox(value=False, label="Sampling avancé (Velocity Head)")
308
 
@@ -323,13 +278,17 @@ def build_demo():
323
 
324
  btn = gr.Button("Synthétiser")
325
  out_audio = gr.Audio(label="Sortie audio", type="numpy")
326
- logs_box = gr.Textbox(label="Logs (live)", lines=20)
327
 
328
  demo.queue(default_concurrency_limit=1, max_size=32)
329
- btn.click(fn=synthesize,
330
- inputs=[text, debug, adv_sampling, ref_audio, ref_text, steps, cfg, cfg_ref, temperature, max_seq_len, seed, lang_hint],
331
- outputs=[out_audio, logs_box],
332
- api_name="synthesize")
 
 
 
 
333
  return demo
334
 
335
 
 
1
  import os
 
 
 
 
 
2
  import gradio as gr
3
  import numpy as np
 
4
  import torch
5
+ import soundfile as sf
6
  import spaces
7
+ import traceback
8
+ import time
9
+ from concurrent.futures import ThreadPoolExecutor, TimeoutError as FTimeout
10
 
11
+ # FLA: forcer les convolutions en backend PyTorch (pas de Triton)
12
  os.environ.setdefault("FLA_CONV_BACKEND", "torch")
13
  os.environ.setdefault("FLA_USE_FAST_OPS", "0")
 
 
14
  os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
 
 
 
 
 
15
 
16
+ # Meilleure perf FP32 sur GPU compatibles
17
  torch.backends.cuda.matmul.allow_tf32 = True
18
  try:
19
  torch.set_float32_matmul_precision("high")
20
  except Exception:
21
  pass
22
 
23
+ from huggingface_hub import login, snapshot_download
24
+ from pardi_speech import PardiSpeech, VelocityHeadSamplingParams # présent dans ce repo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ MODEL_REPO_ID = os.environ.get("MODEL_REPO_ID", "theodorr/pardi-speech-enfr-forbidden")
27
+ HF_TOKEN = os.environ.get("HF_TOKEN")
 
 
 
28
 
29
+ _pardi = None
30
+ _sampling_rate = 24000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
 
33
  def _normalize_text(s: str, lang_hint: str = "fr") -> str:
 
53
  return arr.astype(np.float32)
54
 
55
 
56
+ def _env_diag() -> str:
57
+ parts = []
58
  try:
59
+ parts.append(f"torch={torch.__version__}")
60
+ try:
61
+ import triton # type: ignore
62
+ parts.append(f"triton={getattr(triton, '__version__', 'unknown')}")
63
+ except Exception:
64
+ parts.append("triton=not_importable")
65
+ parts.append(f"cuda.is_available={torch.cuda.is_available()}")
66
+ if torch.cuda.is_available():
67
+ parts.append(f"cuda.version={torch.version.cuda}")
68
+ try:
69
+ free, total = torch.cuda.mem_get_info()
70
+ parts.append(f"mem_free={free/1e9:.2f}GB/{total/1e9:.2f}GB")
71
+ except Exception:
72
+ pass
73
  except Exception as e:
74
+ parts.append(f"env_diag_error={e}")
75
+ return " | ".join(parts)
 
 
 
 
 
 
 
 
 
 
76
 
 
 
 
 
 
 
 
 
 
77
 
78
+ def _load_model_cpu_first(log):
79
+ """
80
+ Essaye de pré-télécharger puis de charger sur CPU en priorité.
81
+ Si ça échoue ou dépasse le timeout, on réessaie sur CUDA.
82
+ """
83
+ # 1) prefetch repo to local cache (évite les blocages de téléchargement cachés)
84
+ log("[prefetch] snapshot_download…")
85
+ local_dir = snapshot_download(
86
+ repo_id=MODEL_REPO_ID,
87
+ token=HF_TOKEN,
88
+ local_dir=None, # hub cache
89
+ local_files_only=False,
90
+ allow_patterns=None, # tout
91
+ ignore_patterns=None,
92
+ )
93
+ log(f"[prefetch] done -> {local_dir}")
94
+
95
+ # 2) CPU load
96
+ log("[load] from_pretrained(map_location='cpu')…")
97
+ m = PardiSpeech.from_pretrained(local_dir, map_location='cpu')
98
  m.eval()
 
 
 
 
99
  sr = getattr(m, "sampling_rate", 24000)
100
+ log(f"[load] cpu OK (sr={sr})")
101
  return m, sr
102
 
103
 
104
+ def _move_to_cuda_if_available(m, log):
105
+ if torch.cuda.is_available():
106
+ log("[move] moving model to cuda…")
107
+ # PardiSpeech expose généralement un .to(device) (via nn.Module)
108
+ try:
109
+ m = m.to('cuda') # type: ignore[attr-defined]
110
+ except Exception as e:
111
+ log(f"[move] .to('cuda') failed: {e}. Keeping on CPU.")
112
+ return m
113
+ return m
114
+
115
+
116
  @spaces.GPU(duration=200)
117
  def synthesize(
118
  text: str,
119
  debug: bool,
120
+ adv_sampling: bool, # toggle "Sampling avancé (Velocity Head)"
121
  ref_audio,
122
  ref_text: str,
123
  steps: int,
 
128
  seed: int,
129
  lang_hint: str,
130
  ):
131
+ # ---- Generator that streams logs to UI ----
132
  logs = []
133
  t0 = time.perf_counter()
134
 
135
  def LOG(msg):
136
  logs.append(str(msg))
137
+ # Keep last ~8000 chars
138
  joined = "\n".join(logs)
139
+ if len(joined) > 8000:
140
+ joined = joined[-8000:]
141
  return joined
142
 
143
  try:
 
144
  if HF_TOKEN:
145
  try:
146
  login(token=HF_TOKEN)
 
149
  yield None, LOG(f"⚠️ HF login failed: {e}")
150
 
151
  yield None, LOG("[env] " + _env_diag())
152
+
153
+ device = "cuda" if torch.cuda.is_available() else "cpu"
154
  torch.manual_seed(int(seed))
155
+ os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "1")
156
+
157
+ # --- CPU-first loader with heartbeats and timeout ---
158
+ yield None, LOG("[init] prefetch + CPU-first load…")
159
+ MAX_WALLTIME_S = 110 # UX watchdog
160
 
 
 
 
161
  with ThreadPoolExecutor(max_workers=1) as ex:
162
+ fut = ex.submit(_load_model_cpu_first, LOG)
163
+ last_hb = time.perf_counter()
164
  while True:
165
  try:
166
+ m, sr = fut.result(timeout=2.0)
167
+ pardi = m
168
+ _sr = sr
169
  break
170
  except FTimeout:
171
  now = time.perf_counter()
172
  elapsed = now - t0
173
+ # heartbeat
174
+ if now - last_hb >= 2.0:
175
+ yield None, LOG(f"[init] still loading on CPU… {elapsed:.1f}s")
176
+ last_hb = now
177
  if elapsed > MAX_WALLTIME_S:
 
178
  ex.shutdown(cancel_futures=True)
179
+ raise TimeoutError(f"Watchdog: dépassement {elapsed:.1f}s pendant le chargement (CPU)")
180
 
181
+ # Move to cuda if possible
182
+ pardi = _move_to_cuda_if_available(pardi, LOG)
183
  yield None, LOG(f"[init] model ready on {'cuda' if torch.cuda.is_available() else 'cpu'}, sr={_sr}")
184
 
185
+ # ---- Text & prefix ----
186
  txt = _normalize_text(text, lang_hint=lang_hint)
187
  yield None, LOG(f"[text] normalized: {txt[:120]}{'…' if len(txt)>120 else ''}")
188
 
 
202
  import torchaudio
203
  if sr != getattr(pardi, "sampling_rate", 24000):
204
  wav_t = torchaudio.functional.resample(wav_t, sr, getattr(pardi, "sampling_rate", 24000))
205
+ except Exception as _e:
206
+ LOG("⚠️ torchaudio not available for resample; using original SR")
207
  wav_t = wav_t.unsqueeze(0)
208
  with torch.inference_mode():
209
  prefix_tokens = pardi.patchvae.encode(wav_t)
210
  prefix = (ref_text or "", prefix_tokens[0])
211
  yield None, LOG("[prefix] done.")
212
 
 
213
  yield None, LOG(f"[run] has_prefix={prefix is not None}, steps={steps}, cfg={cfg}, cfg_ref={cfg_ref}, T={temperature}, max_seq_len={max_seq_len}, seed={seed}, adv_sampling={adv_sampling}")
214
+
215
+ # ---- FAST PATH by default ----
216
  with torch.inference_mode():
217
  if adv_sampling:
218
+ yield None, LOG("[run] VelocityHeadSamplingParams enabled…")
219
  try:
220
+ vel_params = VelocityHeadSamplingParams(
221
+ cfg_ref=float(cfg_ref),
222
+ cfg=float(cfg),
223
+ num_steps=int(steps)
224
+ )
225
  except TypeError:
226
+ vel_params = VelocityHeadSamplingParams(
227
+ cfg_ref=float(cfg_ref),
228
+ cfg=float(cfg),
229
+ num_steps=int(steps),
230
+ temperature=float(temperature)
231
+ )
232
+ wavs, _ = pardi.text_to_speech(
233
+ [txt], prefix, max_seq_len=int(max_seq_len),
234
+ velocity_head_sampling_params=vel_params
235
+ )
236
  else:
237
+ yield None, LOG("[run] fast path (notebook) without VelocityHead…")
238
+ wavs, _ = pardi.text_to_speech(
239
+ [txt], prefix, max_seq_len=int(max_seq_len)
240
+ )
241
  wav = wavs[0].detach().cpu().numpy().astype(np.float32)
242
+
243
+ yield (_sampling_rate, wav), LOG(f"[ok] walltime={time.perf_counter()-t0:.2f}s")
244
 
245
  except Exception as e:
246
  tb = traceback.format_exc()
 
251
  with gr.Blocks(title="Lina-speech / pardi-speech Demo") as demo:
252
  gr.Markdown(
253
  "## Lina-speech (pardi-speech) – Démo TTS\n"
254
+ "Génère de l'audio à partir de texte, avec ou sans *prefix* (audio de référence).\n"
255
+ "Par défaut, le chemin **rapide** (comme dans le notebook) est utilisé. "
256
+ "Active **Sampling avancé** pour passer par Velocity Head."
257
  )
258
 
259
  with gr.Row():
260
+ text = gr.Textbox(label="Texte à synthétiser", lines=4, placeholder="Tape ton texte ici…")
261
  debug = gr.Checkbox(value=False, label="Mode debug (afficher la stacktrace)")
262
  adv_sampling = gr.Checkbox(value=False, label="Sampling avancé (Velocity Head)")
263
 
 
278
 
279
  btn = gr.Button("Synthétiser")
280
  out_audio = gr.Audio(label="Sortie audio", type="numpy")
281
+ logs_box = gr.Textbox(label="Logs (live)", lines=18)
282
 
283
  demo.queue(default_concurrency_limit=1, max_size=32)
284
+
285
+ # Use generator function: stream logs to UI while running
286
+ btn.click(
287
+ fn=synthesize,
288
+ inputs=[text, debug, adv_sampling, ref_audio, ref_text, steps, cfg, cfg_ref, temperature, max_seq_len, seed, lang_hint],
289
+ outputs=[out_audio, logs_box],
290
+ api_name="synthesize"
291
+ )
292
  return demo
293
 
294