Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import re | |
| import json | |
| import sys | |
| import time | |
| import threading | |
| import traceback | |
| import gradio as gr | |
| import numpy as np | |
| import soundfile as sf | |
| import torch | |
| import spaces | |
| from huggingface_hub import login, snapshot_download | |
| # --------- Environnement / stabilité ---------- | |
| os.environ.setdefault("FLA_CONV_BACKEND", "torch") # éviter les kernels Triton | |
| os.environ.setdefault("FLA_USE_FAST_OPS", "0") | |
| os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1") | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| try: | |
| torch.set_float32_matmul_precision("high") | |
| except Exception: | |
| pass | |
| from pardi_speech import PardiSpeech, VelocityHeadSamplingParams # présent dans ce repo | |
| MODEL_REPO_ID = os.environ.get("MODEL_REPO_ID", "theodorr/pardi-speech-enfr-forbidden") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| # --------- Cache global (préchargement au démarrage) ---------- | |
| _MODEL = {"pardi": None, "sr": 24000, "err": None, "logs": [], "thread": None} | |
| def _log(msg: str): | |
| _MODEL["logs"].append(str(msg)) | |
| # borne la taille | |
| if len(_MODEL["logs"]) > 2000: | |
| _MODEL["logs"] = _MODEL["logs"][-2000:] | |
| def _env_diag() -> str: | |
| parts = [] | |
| try: | |
| parts.append(f"torch={torch.__version__}") | |
| try: | |
| import triton # type: ignore | |
| parts.append(f"triton={getattr(triton, '__version__', 'unknown')}") | |
| except Exception: | |
| parts.append("triton=not_importable") | |
| parts.append(f"cuda.is_available={torch.cuda.is_available()}") | |
| if torch.cuda.is_available(): | |
| parts.append(f"cuda.version={torch.version.cuda}") | |
| try: | |
| free, total = torch.cuda.mem_get_info() | |
| parts.append(f"mem_free={free/1e9:.2f}GB/{total/1e9:.2f}GB") | |
| except Exception: | |
| pass | |
| except Exception as e: | |
| parts.append(f"env_diag_error={e}") | |
| return " | ".join(parts) | |
| def _normalize_text(s: str, lang_hint: str = "fr") -> str: | |
| s = (s or "").strip() | |
| try: | |
| import re as _re | |
| from num2words import num2words | |
| def repl(m): | |
| try: | |
| return num2words(int(m.group()), lang=lang_hint) | |
| except Exception: | |
| return m.group() | |
| s = _re.sub(r"\d+", repl, s) | |
| except Exception: | |
| pass | |
| return s | |
| def _to_mono_float32(arr: np.ndarray) -> np.ndarray: | |
| arr = np.asarray(arr) | |
| if arr.ndim == 2: | |
| arr = arr.mean(axis=1) | |
| return arr.astype(np.float32) | |
| def _extract_repo_ids_from_config(config_path: str): | |
| repo_ids = set() | |
| preview = None | |
| try: | |
| with open(config_path, "r", encoding="utf-8") as f: | |
| cfg = json.load(f) | |
| pattern = re.compile(r"^[\w\-]+\/[\w\.\-]+$") # org/name | |
| def rec(obj): | |
| if isinstance(obj, dict): | |
| for v in obj.values(): rec(v) | |
| elif isinstance(obj, list): | |
| for v in obj: rec(v) | |
| elif isinstance(obj, str): | |
| if pattern.match(obj): repo_ids.add(obj) | |
| rec(cfg) | |
| try: | |
| subset_keys = list(cfg)[:5] if isinstance(cfg, dict) else [] | |
| preview = json.dumps({k: cfg[k] for k in subset_keys}, ensure_ascii=False)[:600] | |
| except Exception: | |
| pass | |
| except Exception: | |
| pass | |
| return sorted(repo_ids), preview | |
| def _prefetch_and_load_cpu(): | |
| """Exécuté dans un thread au démarrage du Space (hors worker GPU).""" | |
| try: | |
| _log("[prefetch] snapshot_download (main)...") | |
| local_dir = snapshot_download( | |
| repo_id=MODEL_REPO_ID, | |
| token=HF_TOKEN, | |
| local_dir=None, | |
| local_files_only=False, | |
| ) | |
| _log(f"[prefetch] main done -> {local_dir}") | |
| cfg_path = os.path.join(local_dir, "config.json") | |
| nested, cfg_preview = _extract_repo_ids_from_config(cfg_path) | |
| if cfg_preview: | |
| _log(f"[config] preview: {cfg_preview}") | |
| for rid in nested: | |
| if rid == MODEL_REPO_ID: | |
| continue | |
| _log(f"[prefetch] nested repo: {rid} ...") | |
| snapshot_download(repo_id=rid, token=HF_TOKEN, local_dir=None, local_files_only=False) | |
| _log(f"[prefetch] nested repo: {rid} done") | |
| # Forcer offline pendant le vrai chargement | |
| old_off = os.environ.get("HF_HUB_OFFLINE") | |
| os.environ["HF_HUB_OFFLINE"] = "1" | |
| try: | |
| _log("[load] from_pretrained(map_location='cpu')...") | |
| m = PardiSpeech.from_pretrained(local_dir, map_location="cpu") | |
| m.eval() | |
| _MODEL["pardi"] = m | |
| _MODEL["sr"] = getattr(m, "sampling_rate", 24000) | |
| _log(f"[load] cpu OK (sr={_MODEL['sr']})") | |
| finally: | |
| if old_off is None: | |
| os.environ.pop("HF_HUB_OFFLINE", None) | |
| else: | |
| os.environ["HF_HUB_OFFLINE"] = old_off | |
| except BaseException as e: | |
| _MODEL["err"] = e | |
| _log(f"[EXC@preload] {type(e).__name__}: {e}") | |
| _log(traceback.format_exc()) | |
| # Lance le préchargement (hors GPU) dès l’import | |
| if _MODEL["thread"] is None: | |
| _MODEL["thread"] = threading.Thread(target=_prefetch_and_load_cpu, daemon=True) | |
| _MODEL["thread"].start() | |
| def _move_to_cuda_if_available(m, logs_acc): | |
| def L(msg): logs_acc.append(str(msg)) | |
| if torch.cuda.is_available(): | |
| L("[move] moving model to cuda...") | |
| try: | |
| m = m.to("cuda") # type: ignore[attr-defined] | |
| L("[move] cuda OK") | |
| except Exception as e: | |
| L(f"[move] .to('cuda') failed: {e}. Keeping on CPU.") | |
| else: | |
| L("[move] cuda not available, keep CPU") | |
| return m | |
| # --------- UI callback (GPU) ---------- | |
| def synthesize( | |
| text: str, | |
| debug: bool, | |
| adv_sampling: bool, # Velocity Head sampling | |
| ref_audio, | |
| ref_text: str, | |
| steps: int, | |
| cfg: float, | |
| cfg_ref: float, | |
| temperature: float, | |
| max_seq_len: int, | |
| seed: int, | |
| lang_hint: str, | |
| ): | |
| logs = [] | |
| def LOG(msg: str): | |
| logs.append(str(msg)) | |
| joined = "\n".join(logs + _MODEL["logs"][-50:]) # mêle quelques logs de préchargement | |
| if len(joined) > 12000: | |
| joined = joined[-12000:] | |
| return joined | |
| try: | |
| if HF_TOKEN: | |
| try: | |
| login(token=HF_TOKEN) | |
| yield None, LOG("✅ HF login ok") | |
| except Exception as e: | |
| yield None, LOG(f"⚠️ HF login failed: {e}") | |
| yield None, LOG("[env] " + _env_diag()) | |
| torch.manual_seed(int(seed)) | |
| os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "1") | |
| # Si le modèle n’est pas encore prêt, on attend jusqu’à 180s max ici | |
| t0 = time.perf_counter() | |
| while _MODEL["pardi"] is None and _MODEL["err"] is None: | |
| elapsed = time.perf_counter() - t0 | |
| yield None, LOG(f"[init] still loading on CPU… {elapsed:.1f}s") | |
| if elapsed > 180: | |
| # dump de la stack du thread de préchargement pour debug | |
| tid = _MODEL["thread"].ident if _MODEL["thread"] else None | |
| if tid is not None: | |
| frame = sys._current_frames().get(tid) | |
| if frame is not None: | |
| stack_txt = "".join(traceback.format_stack(frame)) | |
| yield None, LOG("[stack-final]\n" + stack_txt) | |
| raise TimeoutError("Preload timeout (>180s)") | |
| time.sleep(2.0) | |
| if _MODEL["err"]: | |
| raise _MODEL["err"] | |
| pardi = _MODEL["pardi"] | |
| sr_out = _MODEL["sr"] | |
| # Déplacement vers CUDA si possible | |
| pardi = _move_to_cuda_if_available(pardi, logs) | |
| yield None, LOG(f"[init] model ready on {'cuda' if torch.cuda.is_available() else 'cpu'}, sr={sr_out}") | |
| # ---- Texte + prefix optionnel ---- | |
| txt = _normalize_text(text or "", lang_hint=lang_hint) | |
| yield None, LOG(f"[text] {txt[:120]}{'...' if len(txt) > 120 else ''}") | |
| steps = int(min(max(1, int(steps)), 16)) | |
| max_seq_len = int(min(max(50, int(max_seq_len)), 600)) | |
| prefix = None | |
| if ref_audio is not None: | |
| yield None, LOG("[prefix] encoding reference audio...") | |
| if isinstance(ref_audio, str): | |
| wav, sr = sf.read(ref_audio) | |
| else: | |
| sr, wav = ref_audio | |
| wav = _to_mono_float32(wav) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| wav_t = torch.from_numpy(wav).to(device).unsqueeze(0) | |
| with torch.inference_mode(): | |
| prefix_tokens = pardi.patchvae.encode(wav_t) # type: ignore[attr-defined] | |
| prefix = (ref_text or "", prefix_tokens[0]) | |
| yield None, LOG("[prefix] done.") | |
| yield None, LOG(f"[run] has_prefix={prefix is not None}, steps={steps}, cfg={cfg}, cfg_ref={cfg_ref}, " | |
| f"T={temperature}, max_seq_len={max_seq_len}, seed={seed}, adv_sampling={adv_sampling}") | |
| # ---- Chemin rapide (comme le notebook) ---- | |
| with torch.inference_mode(): | |
| if adv_sampling: | |
| try: | |
| vparams = VelocityHeadSamplingParams(cfg_ref=float(cfg_ref), cfg=float(cfg), num_steps=int(steps)) | |
| except TypeError: | |
| vparams = VelocityHeadSamplingParams(cfg_ref=float(cfg_ref), cfg=float(cfg), | |
| num_steps=int(steps), temperature=float(temperature)) | |
| wavs, _ = pardi.text_to_speech([txt], prefix, max_seq_len=int(max_seq_len), | |
| velocity_head_sampling_params=vparams) | |
| else: | |
| wavs, _ = pardi.text_to_speech([txt], prefix, max_seq_len=int(max_seq_len)) | |
| wav = wavs[0].detach().cpu().numpy().astype(np.float32) | |
| yield (sr_out, wav), LOG("[ok] done.") | |
| except Exception as e: | |
| tb = traceback.format_exc() | |
| yield None, LOG(f"[EXC] {type(e).__name__}: {e}\n{tb}") | |
| # --------- UI ---------- | |
| def build_demo(): | |
| with gr.Blocks(title="Lina-speech / pardi-speech Demo") as demo: | |
| gr.Markdown( | |
| "### Lina-speech (pardi-speech) – Démo TTS\n" | |
| "Génère de l'audio à partir de texte, avec ou sans prefix (audio de référence).\n" | |
| "Chemin rapide par défaut (comme le notebook)." | |
| ) | |
| with gr.Row(): | |
| text = gr.Textbox(label="Texte à synthétiser", lines=4, placeholder="Tape ton texte ici…") | |
| with gr.Accordion("Prefix (optionnel)", open=False): | |
| ref_audio = gr.Audio(sources=["upload", "microphone"], type="numpy", label="Audio de référence") | |
| ref_text = gr.Textbox(label="Texte du prefix (si connu)", placeholder="Transcription du prefix (optionnel)") | |
| with gr.Accordion("Options avancées", open=False): | |
| with gr.Row(): | |
| steps = gr.Slider(1, 50, value=10, step=1, label="num_steps") | |
| cfg = gr.Slider(0.5, 3.0, value=1.4, step=0.05, label="CFG (guidance)") | |
| cfg_ref = gr.Slider(0.5, 3.0, value=1.0, step=0.05, label="CFG (réf.)") | |
| with gr.Row(): | |
| temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.05, label="Température") | |
| max_seq_len = gr.Slider(50, 1200, value=300, step=10, label="max_seq_len (tokens audio)") | |
| seed = gr.Number(value=0, precision=0, label="Seed") | |
| lang_hint = gr.Dropdown(choices=["fr", "en"], value="fr", label="Langue (normalisation)") | |
| with gr.Row(): | |
| debug = gr.Checkbox(value=False, label="Mode debug") | |
| adv_sampling = gr.Checkbox(value=False, label="Sampling avancé (Velocity Head)") | |
| btn = gr.Button("Synthétiser") | |
| out_audio = gr.Audio(label="Sortie audio", type="numpy") | |
| logs_box = gr.Textbox(label="Logs (live)", lines=28) | |
| demo.queue(default_concurrency_limit=1, max_size=32) | |
| btn.click( | |
| fn=synthesize, | |
| inputs=[text, debug, adv_sampling, ref_audio, ref_text, steps, cfg, cfg_ref, temperature, max_seq_len, seed, lang_hint], | |
| outputs=[out_audio, logs_box], | |
| api_name="synthesize", | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| build_demo().launch(ssr_mode=False) | |