mehdi999 commited on
Commit
6d29905
·
1 Parent(s): 92ec5fe

Space: preload CPU thread + cache + logs

Browse files
Files changed (1) hide show
  1. app.py +94 -137
app.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import os
3
  import re
4
  import json
@@ -14,10 +13,8 @@ import torch
14
  import spaces
15
  from huggingface_hub import login, snapshot_download
16
 
17
- # -----------------------
18
- # Environment hardening
19
- # -----------------------
20
- os.environ.setdefault("FLA_CONV_BACKEND", "torch") # avoid Triton kernels
21
  os.environ.setdefault("FLA_USE_FAST_OPS", "0")
22
  os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
23
  torch.backends.cuda.matmul.allow_tf32 = True
@@ -26,14 +23,20 @@ try:
26
  except Exception:
27
  pass
28
 
29
- from pardi_speech import PardiSpeech, VelocityHeadSamplingParams # present in this repo
30
 
31
  MODEL_REPO_ID = os.environ.get("MODEL_REPO_ID", "theodorr/pardi-speech-enfr-forbidden")
32
  HF_TOKEN = os.environ.get("HF_TOKEN")
33
 
34
- # -----------------------
35
- # Helpers
36
- # -----------------------
 
 
 
 
 
 
37
  def _env_diag() -> str:
38
  parts = []
39
  try:
@@ -55,10 +58,8 @@ def _env_diag() -> str:
55
  parts.append(f"env_diag_error={e}")
56
  return " | ".join(parts)
57
 
58
-
59
  def _normalize_text(s: str, lang_hint: str = "fr") -> str:
60
  s = (s or "").strip()
61
- # optional: expand digits for FR/EN using num2words if available
62
  try:
63
  import re as _re
64
  from num2words import num2words
@@ -72,16 +73,13 @@ def _normalize_text(s: str, lang_hint: str = "fr") -> str:
72
  pass
73
  return s
74
 
75
-
76
  def _to_mono_float32(arr: np.ndarray) -> np.ndarray:
77
  arr = np.asarray(arr)
78
  if arr.ndim == 2:
79
  arr = arr.mean(axis=1)
80
  return arr.astype(np.float32)
81
 
82
-
83
  def _extract_repo_ids_from_config(config_path: str):
84
- """Return list of 'org/name' strings found in a JSON config (simple heuristic)."""
85
  repo_ids = set()
86
  preview = None
87
  try:
@@ -90,76 +88,72 @@ def _extract_repo_ids_from_config(config_path: str):
90
  pattern = re.compile(r"^[\w\-]+\/[\w\.\-]+$") # org/name
91
  def rec(obj):
92
  if isinstance(obj, dict):
93
- for v in obj.values():
94
- rec(v)
95
  elif isinstance(obj, list):
96
- for v in obj:
97
- rec(v)
98
  elif isinstance(obj, str):
99
- if pattern.match(obj):
100
- repo_ids.add(obj)
101
  rec(cfg)
102
- # create a small preview to help debugging
103
  try:
104
  subset_keys = list(cfg)[:5] if isinstance(cfg, dict) else []
105
- preview_obj = {k: cfg[k] for k in subset_keys}
106
- preview = json.dumps(preview_obj, ensure_ascii=False)[:600]
107
  except Exception:
108
- preview = None
109
  except Exception:
110
  pass
111
  return sorted(repo_ids), preview
112
 
113
-
114
- def _cpu_first_loader(log_list):
115
- """Prefetch main and nested HF repos, then load on CPU in offline mode."""
116
- def L(msg):
117
- log_list.append(str(msg))
118
-
119
- # 1) Prefetch main repo to local cache
120
- L("[prefetch] snapshot_download (main)...")
121
- local_dir = snapshot_download(
122
- repo_id=MODEL_REPO_ID,
123
- token=HF_TOKEN,
124
- local_dir=None,
125
- local_files_only=False,
126
- )
127
- L(f"[prefetch] main done -> {local_dir}")
128
-
129
- # 2) Prefetch nested repos found in config.json
130
- cfg_path = os.path.join(local_dir, "config.json")
131
- nested, cfg_preview = _extract_repo_ids_from_config(cfg_path)
132
- if cfg_preview:
133
- L(f"[config] preview: {cfg_preview}")
134
- for rid in nested:
135
- if rid == MODEL_REPO_ID:
136
- continue
137
- L(f"[prefetch] nested repo: {rid} ...")
138
- snapshot_download(repo_id=rid, token=HF_TOKEN, local_dir=None, local_files_only=False)
139
- L(f"[prefetch] nested repo: {rid} done")
140
-
141
- # 3) Force offline for actual load to avoid hidden downloads
142
- old_off = os.environ.get("HF_HUB_OFFLINE")
143
- os.environ["HF_HUB_OFFLINE"] = "1"
144
  try:
145
- L("[load] from_pretrained(map_location='cpu')...")
146
- m = PardiSpeech.from_pretrained(local_dir, map_location="cpu")
147
- m.eval()
148
- sr = getattr(m, "sampling_rate", 24000)
149
- L(f"[load] cpu OK (sr={sr})")
150
- return m, sr, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  except BaseException as e:
152
- L(f"[EXC@load] {type(e).__name__}: {e}")
153
- return None, None, e
154
- finally:
155
- if old_off is None:
156
- os.environ.pop("HF_HUB_OFFLINE", None)
157
- else:
158
- os.environ["HF_HUB_OFFLINE"] = old_off
159
 
 
 
 
 
160
 
161
- def _move_to_cuda_if_available(m, log_list):
162
- def L(msg): log_list.append(str(msg))
163
  if torch.cuda.is_available():
164
  L("[move] moving model to cuda...")
165
  try:
@@ -171,15 +165,12 @@ def _move_to_cuda_if_available(m, log_list):
171
  L("[move] cuda not available, keep CPU")
172
  return m
173
 
174
-
175
- # -----------------------
176
- # Main synthesize (generator)
177
- # -----------------------
178
  @spaces.GPU(duration=200)
179
  def synthesize(
180
  text: str,
181
  debug: bool,
182
- adv_sampling: bool, # toggle Velocity Head sampling
183
  ref_audio,
184
  ref_text: str,
185
  steps: int,
@@ -193,7 +184,7 @@ def synthesize(
193
  logs = []
194
  def LOG(msg: str):
195
  logs.append(str(msg))
196
- joined = "\n".join(logs)
197
  if len(joined) > 12000:
198
  joined = joined[-12000:]
199
  return joined
@@ -210,63 +201,33 @@ def synthesize(
210
  torch.manual_seed(int(seed))
211
  os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "1")
212
 
213
- # --- Loader thread + heartbeat + stack dump ---
214
- yield None, LOG("[init] nested-prefetch + CPU-first load...")
215
-
216
- loader_logs = []
217
- holder = {"model": None, "sr": 24000, "err": None}
218
-
219
- def loader_run():
220
- m, sr, err = _cpu_first_loader(loader_logs)
221
- holder["model"] = m
222
- holder["sr"] = sr if sr is not None else 24000
223
- holder["err"] = err
224
-
225
- t = threading.Thread(target=loader_run, daemon=True)
226
- t.start()
227
- while t.ident is None:
228
- time.sleep(0.01)
229
- tid = t.ident
230
-
231
- start = time.perf_counter()
232
- last_stack = 0.0
233
- while t.is_alive():
234
- # stream recent loader logs
235
- if loader_logs:
236
- yield None, LOG("\n".join(loader_logs[-10:]))
237
- # dump the loader thread stack every ~6s
238
- now = time.perf_counter()
239
- if now - last_stack > 6.0 and tid is not None:
240
- frame = sys._current_frames().get(tid)
241
- if frame is not None:
242
- stack_txt = "".join(traceback.format_stack(frame)[-25:])
243
- yield None, LOG("[stack] loader thread:\n" + stack_txt)
244
- last_stack = now
245
- # timeout ~110s
246
- if now - start > 110:
247
  if tid is not None:
248
  frame = sys._current_frames().get(tid)
249
  if frame is not None:
250
  stack_txt = "".join(traceback.format_stack(frame))
251
  yield None, LOG("[stack-final]\n" + stack_txt)
252
- raise TimeoutError("Model load timeout (exceeded 110s)")
253
  time.sleep(2.0)
254
 
255
- # After join: flush final logs
256
- if loader_logs:
257
- yield None, LOG("\n".join(loader_logs[-20:]))
258
 
259
- if holder["err"]:
260
- raise holder["err"] # will print stack below
261
- pardi = holder["model"]
262
- if pardi is None:
263
- raise RuntimeError("Loader returned no model")
264
 
265
- # move to cuda if possible
266
  pardi = _move_to_cuda_if_available(pardi, logs)
267
- yield None, LOG(f"[init] model ready on {'cuda' if torch.cuda.is_available() else 'cpu'}, sr={holder['sr']}")
268
 
269
- # ---- Text & optional prefix ----
270
  txt = _normalize_text(text or "", lang_hint=lang_hint)
271
  yield None, LOG(f"[text] {txt[:120]}{'...' if len(txt) > 120 else ''}")
272
 
@@ -288,45 +249,42 @@ def synthesize(
288
  prefix = (ref_text or "", prefix_tokens[0])
289
  yield None, LOG("[prefix] done.")
290
 
291
- yield None, LOG(f"[run] has_prefix={prefix is not None}, steps={steps}, cfg={cfg}, cfg_ref={cfg_ref}, T={temperature}, max_seq_len={max_seq_len}, seed={seed}, adv_sampling={adv_sampling}")
 
292
 
293
- # ---- Fast path by default (as notebook) ----
294
  with torch.inference_mode():
295
  if adv_sampling:
296
  try:
297
  vparams = VelocityHeadSamplingParams(cfg_ref=float(cfg_ref), cfg=float(cfg), num_steps=int(steps))
298
  except TypeError:
299
- vparams = VelocityHeadSamplingParams(cfg_ref=float(cfg_ref), cfg=float(cfg), num_steps=int(steps), temperature=float(temperature))
300
- wavs, _ = pardi.text_to_speech([txt], prefix, max_seq_len=int(max_seq_len), velocity_head_sampling_params=vparams)
 
 
301
  else:
302
  wavs, _ = pardi.text_to_speech([txt], prefix, max_seq_len=int(max_seq_len))
303
 
304
  wav = wavs[0].detach().cpu().numpy().astype(np.float32)
305
- sr_out = getattr(pardi, "sampling_rate", 24000)
306
  yield (sr_out, wav), LOG("[ok] done.")
307
 
308
  except Exception as e:
309
  tb = traceback.format_exc()
310
- yield None, "\n".join(logs + [f"[EXC] {type(e).__name__}: {e}", tb])
311
 
312
-
313
- # -----------------------
314
- # UI
315
- # -----------------------
316
  def build_demo():
317
  with gr.Blocks(title="Lina-speech / pardi-speech Demo") as demo:
318
  gr.Markdown(
319
  "### Lina-speech (pardi-speech) – Démo TTS\n"
320
  "Génère de l'audio à partir de texte, avec ou sans prefix (audio de référence).\n"
321
- "Par défaut, la voie rapide (comme dans le notebook) est utilisée. Active 'Sampling avancé' pour Velocity Head."
322
  )
323
  with gr.Row():
324
  text = gr.Textbox(label="Texte à synthétiser", lines=4, placeholder="Tape ton texte ici…")
325
-
326
  with gr.Accordion("Prefix (optionnel)", open=False):
327
  ref_audio = gr.Audio(sources=["upload", "microphone"], type="numpy", label="Audio de référence")
328
  ref_text = gr.Textbox(label="Texte du prefix (si connu)", placeholder="Transcription du prefix (optionnel)")
329
-
330
  with gr.Accordion("Options avancées", open=False):
331
  with gr.Row():
332
  steps = gr.Slider(1, 50, value=10, step=1, label="num_steps")
@@ -338,7 +296,7 @@ def build_demo():
338
  seed = gr.Number(value=0, precision=0, label="Seed")
339
  lang_hint = gr.Dropdown(choices=["fr", "en"], value="fr", label="Langue (normalisation)")
340
  with gr.Row():
341
- debug = gr.Checkbox(value=False, label="Mode debug (affiche la stack du loader)")
342
  adv_sampling = gr.Checkbox(value=False, label="Sampling avancé (Velocity Head)")
343
 
344
  btn = gr.Button("Synthétiser")
@@ -354,6 +312,5 @@ def build_demo():
354
  )
355
  return demo
356
 
357
-
358
  if __name__ == "__main__":
359
  build_demo().launch(ssr_mode=False)
 
 
1
  import os
2
  import re
3
  import json
 
13
  import spaces
14
  from huggingface_hub import login, snapshot_download
15
 
16
+ # --------- Environnement / stabilité ----------
17
+ os.environ.setdefault("FLA_CONV_BACKEND", "torch") # éviter les kernels Triton
 
 
18
  os.environ.setdefault("FLA_USE_FAST_OPS", "0")
19
  os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
20
  torch.backends.cuda.matmul.allow_tf32 = True
 
23
  except Exception:
24
  pass
25
 
26
+ from pardi_speech import PardiSpeech, VelocityHeadSamplingParams # présent dans ce repo
27
 
28
  MODEL_REPO_ID = os.environ.get("MODEL_REPO_ID", "theodorr/pardi-speech-enfr-forbidden")
29
  HF_TOKEN = os.environ.get("HF_TOKEN")
30
 
31
+ # --------- Cache global (préchargement au démarrage) ----------
32
+ _MODEL = {"pardi": None, "sr": 24000, "err": None, "logs": [], "thread": None}
33
+
34
+ def _log(msg: str):
35
+ _MODEL["logs"].append(str(msg))
36
+ # borne la taille
37
+ if len(_MODEL["logs"]) > 2000:
38
+ _MODEL["logs"] = _MODEL["logs"][-2000:]
39
+
40
  def _env_diag() -> str:
41
  parts = []
42
  try:
 
58
  parts.append(f"env_diag_error={e}")
59
  return " | ".join(parts)
60
 
 
61
  def _normalize_text(s: str, lang_hint: str = "fr") -> str:
62
  s = (s or "").strip()
 
63
  try:
64
  import re as _re
65
  from num2words import num2words
 
73
  pass
74
  return s
75
 
 
76
  def _to_mono_float32(arr: np.ndarray) -> np.ndarray:
77
  arr = np.asarray(arr)
78
  if arr.ndim == 2:
79
  arr = arr.mean(axis=1)
80
  return arr.astype(np.float32)
81
 
 
82
  def _extract_repo_ids_from_config(config_path: str):
 
83
  repo_ids = set()
84
  preview = None
85
  try:
 
88
  pattern = re.compile(r"^[\w\-]+\/[\w\.\-]+$") # org/name
89
  def rec(obj):
90
  if isinstance(obj, dict):
91
+ for v in obj.values(): rec(v)
 
92
  elif isinstance(obj, list):
93
+ for v in obj: rec(v)
 
94
  elif isinstance(obj, str):
95
+ if pattern.match(obj): repo_ids.add(obj)
 
96
  rec(cfg)
 
97
  try:
98
  subset_keys = list(cfg)[:5] if isinstance(cfg, dict) else []
99
+ preview = json.dumps({k: cfg[k] for k in subset_keys}, ensure_ascii=False)[:600]
 
100
  except Exception:
101
+ pass
102
  except Exception:
103
  pass
104
  return sorted(repo_ids), preview
105
 
106
+ def _prefetch_and_load_cpu():
107
+ """Exécuté dans un thread au démarrage du Space (hors worker GPU)."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  try:
109
+ _log("[prefetch] snapshot_download (main)...")
110
+ local_dir = snapshot_download(
111
+ repo_id=MODEL_REPO_ID,
112
+ token=HF_TOKEN,
113
+ local_dir=None,
114
+ local_files_only=False,
115
+ )
116
+ _log(f"[prefetch] main done -> {local_dir}")
117
+
118
+ cfg_path = os.path.join(local_dir, "config.json")
119
+ nested, cfg_preview = _extract_repo_ids_from_config(cfg_path)
120
+ if cfg_preview:
121
+ _log(f"[config] preview: {cfg_preview}")
122
+ for rid in nested:
123
+ if rid == MODEL_REPO_ID:
124
+ continue
125
+ _log(f"[prefetch] nested repo: {rid} ...")
126
+ snapshot_download(repo_id=rid, token=HF_TOKEN, local_dir=None, local_files_only=False)
127
+ _log(f"[prefetch] nested repo: {rid} done")
128
+
129
+ # Forcer offline pendant le vrai chargement
130
+ old_off = os.environ.get("HF_HUB_OFFLINE")
131
+ os.environ["HF_HUB_OFFLINE"] = "1"
132
+ try:
133
+ _log("[load] from_pretrained(map_location='cpu')...")
134
+ m = PardiSpeech.from_pretrained(local_dir, map_location="cpu")
135
+ m.eval()
136
+ _MODEL["pardi"] = m
137
+ _MODEL["sr"] = getattr(m, "sampling_rate", 24000)
138
+ _log(f"[load] cpu OK (sr={_MODEL['sr']})")
139
+ finally:
140
+ if old_off is None:
141
+ os.environ.pop("HF_HUB_OFFLINE", None)
142
+ else:
143
+ os.environ["HF_HUB_OFFLINE"] = old_off
144
+
145
  except BaseException as e:
146
+ _MODEL["err"] = e
147
+ _log(f"[EXC@preload] {type(e).__name__}: {e}")
148
+ _log(traceback.format_exc())
 
 
 
 
149
 
150
+ # Lance le préchargement (hors GPU) dès l’import
151
+ if _MODEL["thread"] is None:
152
+ _MODEL["thread"] = threading.Thread(target=_prefetch_and_load_cpu, daemon=True)
153
+ _MODEL["thread"].start()
154
 
155
+ def _move_to_cuda_if_available(m, logs_acc):
156
+ def L(msg): logs_acc.append(str(msg))
157
  if torch.cuda.is_available():
158
  L("[move] moving model to cuda...")
159
  try:
 
165
  L("[move] cuda not available, keep CPU")
166
  return m
167
 
168
+ # --------- UI callback (GPU) ----------
 
 
 
169
  @spaces.GPU(duration=200)
170
  def synthesize(
171
  text: str,
172
  debug: bool,
173
+ adv_sampling: bool, # Velocity Head sampling
174
  ref_audio,
175
  ref_text: str,
176
  steps: int,
 
184
  logs = []
185
  def LOG(msg: str):
186
  logs.append(str(msg))
187
+ joined = "\n".join(logs + _MODEL["logs"][-50:]) # mêle quelques logs de préchargement
188
  if len(joined) > 12000:
189
  joined = joined[-12000:]
190
  return joined
 
201
  torch.manual_seed(int(seed))
202
  os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "1")
203
 
204
+ # Si le modèle n’est pas encore prêt, on attend jusqu’à 180s max ici
205
+ t0 = time.perf_counter()
206
+ while _MODEL["pardi"] is None and _MODEL["err"] is None:
207
+ elapsed = time.perf_counter() - t0
208
+ yield None, LOG(f"[init] still loading on CPU… {elapsed:.1f}s")
209
+ if elapsed > 180:
210
+ # dump de la stack du thread de préchargement pour debug
211
+ tid = _MODEL["thread"].ident if _MODEL["thread"] else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  if tid is not None:
213
  frame = sys._current_frames().get(tid)
214
  if frame is not None:
215
  stack_txt = "".join(traceback.format_stack(frame))
216
  yield None, LOG("[stack-final]\n" + stack_txt)
217
+ raise TimeoutError("Preload timeout (>180s)")
218
  time.sleep(2.0)
219
 
220
+ if _MODEL["err"]:
221
+ raise _MODEL["err"]
 
222
 
223
+ pardi = _MODEL["pardi"]
224
+ sr_out = _MODEL["sr"]
 
 
 
225
 
226
+ # Déplacement vers CUDA si possible
227
  pardi = _move_to_cuda_if_available(pardi, logs)
228
+ yield None, LOG(f"[init] model ready on {'cuda' if torch.cuda.is_available() else 'cpu'}, sr={sr_out}")
229
 
230
+ # ---- Texte + prefix optionnel ----
231
  txt = _normalize_text(text or "", lang_hint=lang_hint)
232
  yield None, LOG(f"[text] {txt[:120]}{'...' if len(txt) > 120 else ''}")
233
 
 
249
  prefix = (ref_text or "", prefix_tokens[0])
250
  yield None, LOG("[prefix] done.")
251
 
252
+ yield None, LOG(f"[run] has_prefix={prefix is not None}, steps={steps}, cfg={cfg}, cfg_ref={cfg_ref}, "
253
+ f"T={temperature}, max_seq_len={max_seq_len}, seed={seed}, adv_sampling={adv_sampling}")
254
 
255
+ # ---- Chemin rapide (comme le notebook) ----
256
  with torch.inference_mode():
257
  if adv_sampling:
258
  try:
259
  vparams = VelocityHeadSamplingParams(cfg_ref=float(cfg_ref), cfg=float(cfg), num_steps=int(steps))
260
  except TypeError:
261
+ vparams = VelocityHeadSamplingParams(cfg_ref=float(cfg_ref), cfg=float(cfg),
262
+ num_steps=int(steps), temperature=float(temperature))
263
+ wavs, _ = pardi.text_to_speech([txt], prefix, max_seq_len=int(max_seq_len),
264
+ velocity_head_sampling_params=vparams)
265
  else:
266
  wavs, _ = pardi.text_to_speech([txt], prefix, max_seq_len=int(max_seq_len))
267
 
268
  wav = wavs[0].detach().cpu().numpy().astype(np.float32)
 
269
  yield (sr_out, wav), LOG("[ok] done.")
270
 
271
  except Exception as e:
272
  tb = traceback.format_exc()
273
+ yield None, LOG(f"[EXC] {type(e).__name__}: {e}\n{tb}")
274
 
275
+ # --------- UI ----------
 
 
 
276
  def build_demo():
277
  with gr.Blocks(title="Lina-speech / pardi-speech Demo") as demo:
278
  gr.Markdown(
279
  "### Lina-speech (pardi-speech) – Démo TTS\n"
280
  "Génère de l'audio à partir de texte, avec ou sans prefix (audio de référence).\n"
281
+ "Chemin rapide par défaut (comme le notebook)."
282
  )
283
  with gr.Row():
284
  text = gr.Textbox(label="Texte à synthétiser", lines=4, placeholder="Tape ton texte ici…")
 
285
  with gr.Accordion("Prefix (optionnel)", open=False):
286
  ref_audio = gr.Audio(sources=["upload", "microphone"], type="numpy", label="Audio de référence")
287
  ref_text = gr.Textbox(label="Texte du prefix (si connu)", placeholder="Transcription du prefix (optionnel)")
 
288
  with gr.Accordion("Options avancées", open=False):
289
  with gr.Row():
290
  steps = gr.Slider(1, 50, value=10, step=1, label="num_steps")
 
296
  seed = gr.Number(value=0, precision=0, label="Seed")
297
  lang_hint = gr.Dropdown(choices=["fr", "en"], value="fr", label="Langue (normalisation)")
298
  with gr.Row():
299
+ debug = gr.Checkbox(value=False, label="Mode debug")
300
  adv_sampling = gr.Checkbox(value=False, label="Sampling avancé (Velocity Head)")
301
 
302
  btn = gr.Button("Synthétiser")
 
312
  )
313
  return demo
314
 
 
315
  if __name__ == "__main__":
316
  build_demo().launch(ssr_mode=False)