Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python3 | |
| import argparse | |
| import json | |
| import sys | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import torch | |
| import torchaudio | |
| from torchaudio import load as ta_load | |
| from torchaudio.functional import resample as ta_resample | |
| from zcodec.models import WavVAE, ZFlowAutoEncoder | |
| # ------------------------- | |
| # Data structures | |
| # ------------------------- | |
| class DecodeParams: | |
| num_steps: int = 10 | |
| cfg: float = 2.0 | |
| class StackSpec: | |
| name: str | |
| wavvae_dir: str | |
| zflowae_dir: str | |
| decode: DecodeParams | |
| # ------------------------- | |
| # Utilities (same helpers) | |
| # ------------------------- | |
| def load_json_if_exists(path: Path): | |
| if path.is_file(): | |
| try: | |
| return json.load(path.open("r", encoding="utf-8")) | |
| except Exception: | |
| return None | |
| return None | |
| def read_config_any(checkpoint_dir: str) -> Dict[str, Any]: | |
| cand = [ | |
| Path(checkpoint_dir) / "config.json", | |
| Path(checkpoint_dir) / "model_config.json", | |
| Path(checkpoint_dir) / "config.yaml", | |
| ] | |
| for p in cand: | |
| if p.exists(): | |
| if p.suffix == ".json": | |
| j = load_json_if_exists(p) | |
| if j is not None: | |
| return j | |
| else: | |
| return {"_config_file": str(p)} | |
| return {} | |
| def sanitize_name(s: str) -> str: | |
| return "".join(c if c.isalnum() or c in "-_." else "_" for c in s) | |
| def ensure_mono_and_resample( | |
| wav: torch.Tensor, sr: int, target_sr: int | |
| ) -> Tuple[torch.Tensor, int]: | |
| if wav.ndim != 2: | |
| raise ValueError(f"Expected (C,T), got {tuple(wav.shape)}") | |
| if wav.size(0) > 1: | |
| wav = wav.mean(dim=0, keepdim=True) | |
| if sr != target_sr: | |
| wav = ta_resample(wav, sr, target_sr) | |
| sr = target_sr | |
| return wav.to(torch.float32), sr | |
| def save_wav(path: Path, wav: torch.Tensor, sr: int): | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| if wav.ndim == 1: | |
| wav = wav.unsqueeze(0) | |
| wav = wav.clamp(-1, 1).contiguous().cpu() | |
| torchaudio.save( | |
| str(path), wav, sample_rate=sr, encoding="PCM_S", bits_per_sample=16 | |
| ) | |
| def read_audio_manifest(txt_path: str) -> List[Path]: | |
| lines = Path(txt_path).read_text(encoding="utf-8").splitlines() | |
| return [ | |
| Path(l.strip()) for l in lines if l.strip() and not l.strip().startswith("#") | |
| ] | |
| def html_escape(s: str) -> str: | |
| return ( | |
| s.replace("&", "&") | |
| .replace("<", "<") | |
| .replace(">", ">") | |
| .replace('"', """) | |
| .replace("'", "'") | |
| ) | |
| def make_html( | |
| output_dir: Path, | |
| audio_files: List[Path], | |
| specs: List[StackSpec], | |
| sr_by_model: Dict[str, int], | |
| wavvae_cfg: Dict[str, Dict[str, Any]], | |
| zflow_cfg: Dict[str, Dict[str, Any]], | |
| ) -> str: | |
| def player(src_rel: str) -> str: | |
| return f'<audio controls preload="none" src="{html_escape(src_rel)}"></audio>' | |
| cards = [] | |
| for s in specs: | |
| wcfg = wavvae_cfg.get(s.name, {}) | |
| zcfg = zflow_cfg.get(s.name, {}) | |
| w_short = json.dumps(wcfg if wcfg else {"_": "no JSON config found"}, indent=2)[ | |
| :1200 | |
| ] | |
| z_short = json.dumps(zcfg if zcfg else {"_": "no JSON config found"}, indent=2)[ | |
| :1200 | |
| ] | |
| card = f""" | |
| <div class="model-card"> | |
| <h3>{html_escape(s.name)}</h3> | |
| <p><b>Sample rate</b>: {sr_by_model.get(s.name, "N/A")} Hz</p> | |
| <p><b>Decode</b>: steps={s.decode.num_steps}, cfg={s.decode.cfg}</p> | |
| <details><summary>WavVAE config</summary><pre>{html_escape(w_short)}</pre></details> | |
| <details><summary>ZFlowAE config</summary><pre>{html_escape(z_short)}</pre></details> | |
| </div> | |
| """ | |
| cards.append(card) | |
| css = """ | |
| body { font-family: system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif; padding: 20px; } | |
| .cards { display: grid; grid-template-columns: repeat(auto-fill, minmax(320px, 1fr)); gap: 12px; margin-bottom: 18px; } | |
| .model-card { border: 1px solid #ddd; border-radius: 12px; padding: 12px; } | |
| table { border-collapse: collapse; width: 100%; } | |
| th, td { border: 1px solid #eee; padding: 8px; vertical-align: top; } | |
| th { background: #fafafa; position: sticky; top: 0; } | |
| audio { width: 260px; } | |
| """ | |
| th = "<th>Input</th><th>Original</th>" + "".join( | |
| f"<th>{html_escape(s.name)}</th>" for s in specs | |
| ) | |
| rows = [] | |
| for af in audio_files: | |
| base = af.stem | |
| orig_rel = f"original/{html_escape(af.name)}" | |
| tds = [f"<td>{html_escape(base)}</td>", f"<td>{player(orig_rel)}</td>"] | |
| for s in specs: | |
| rec_rel = f"recon/{html_escape(s.name)}/{html_escape(base)}.wav" | |
| tds.append(f"<td>{player(rec_rel)}</td>") | |
| rows.append("<tr>" + "".join(tds) + "</tr>") | |
| html = f"""<!doctype html> | |
| <html> | |
| <head><meta charset="utf-8"/><title>Stacked Codec Comparison</title><style>{css}</style></head> | |
| <body> | |
| <h1>WavVAE + ZFlowAE Comparison</h1> | |
| <div class="cards">{"".join(cards)}</div> | |
| <table> | |
| <thead><tr>{th}</tr></thead> | |
| <tbody>{"".join(rows)}</tbody> | |
| </table> | |
| </body> | |
| </html> | |
| """ | |
| out = output_dir / "index.html" | |
| out.write_text(html, encoding="utf-8") | |
| return str(out) | |
| # ------------------------- | |
| # Core | |
| # ------------------------- | |
| def reconstruct_stack( | |
| wav_mono: torch.Tensor, | |
| wavvae: WavVAE, | |
| zflow: ZFlowAutoEncoder, | |
| steps: int, | |
| cfg: float, | |
| device: str, | |
| ) -> torch.Tensor: | |
| x = wav_mono.to(device) # (1,T) | |
| z = wavvae.encode(x) # high-framerate latents | |
| y, _ = zflow.encode(z) # compressed latents | |
| z_hat = zflow.decode(y, num_steps=steps, cfg=cfg) | |
| wav_hat = wavvae.decode(z_hat) # (1,1,T) | |
| return wav_hat.squeeze(0).squeeze(0).detach() | |
| def parse_models_manifest(path: str) -> List[StackSpec]: | |
| """ | |
| JSON list of: | |
| { | |
| "name": "...", | |
| "wavvae": "/path/to/WavVAE_dir", | |
| "zflowae": "/path/to/ZFlowAE_dir", | |
| "decode": {"num_steps": 10, "cfg": 2.0} | |
| } | |
| """ | |
| raw = json.loads(Path(path).read_text(encoding="utf-8")) | |
| specs = [] | |
| for it in raw: | |
| d = it.get("decode", {}) | |
| specs.append( | |
| StackSpec( | |
| name=it["name"], | |
| wavvae_dir=it["wavvae"], | |
| zflowae_dir=it["zflowae"], | |
| decode=DecodeParams( | |
| num_steps=int(d.get("num_steps", 10)), cfg=float(d.get("cfg", 2.0)) | |
| ), | |
| ) | |
| ) | |
| return specs | |
| def main(): | |
| ap = argparse.ArgumentParser( | |
| description="Compare WavVAE+ZFlowAE stacks and generate a static HTML page." | |
| ) | |
| ap.add_argument("--models", required=True, help="JSON manifest of stacks.") | |
| ap.add_argument( | |
| "--audio_manifest", required=True, help="TXT file: one audio path per line." | |
| ) | |
| ap.add_argument("--out", default="compare_stack_out") | |
| ap.add_argument("--device", default="cuda") | |
| ap.add_argument("--force", action="store_true") | |
| args = ap.parse_args() | |
| device = "cuda" if args.device == "cuda" and torch.cuda.is_available() else "cpu" | |
| out_dir = Path(args.out) | |
| (out_dir / "original").mkdir(parents=True, exist_ok=True) | |
| recon_dir = out_dir / "recon" | |
| recon_dir.mkdir(parents=True, exist_ok=True) | |
| specs = parse_models_manifest(args.models) | |
| if not specs: | |
| print("No models.", file=sys.stderr) | |
| sys.exit(1) | |
| # load models | |
| wavvae_by_name: Dict[str, WavVAE] = {} | |
| zflow_by_name: Dict[str, ZFlowAutoEncoder] = {} | |
| sr_by_model: Dict[str, int] = {} | |
| wavvae_cfg: Dict[str, Dict[str, Any]] = {} | |
| zflow_cfg: Dict[str, Dict[str, Any]] = {} | |
| for s in specs: | |
| print(f"[Load] {s.name}") | |
| w = WavVAE.from_pretrained_local(s.wavvae_dir).to(device) | |
| z = ZFlowAutoEncoder.from_pretrained_local(s.zflowae_dir).to(device) | |
| wavvae_by_name[s.name] = w | |
| zflow_by_name[s.name] = z | |
| sr_by_model[s.name] = int(getattr(w, "sampling_rate", 24000)) | |
| wavvae_cfg[s.name] = read_config_any(s.wavvae_dir) | |
| zflow_cfg[s.name] = read_config_any(s.zflowae_dir) | |
| audio_paths = read_audio_manifest(args.audio_manifest) | |
| actual_audio = [] | |
| for ap in audio_paths: | |
| if not ap.exists(): | |
| print(f"[Skip missing] {ap}", file=sys.stderr) | |
| continue | |
| wav, sr = ta_load(str(ap)) | |
| wav_mono, sr = ensure_mono_and_resample(wav, sr, sr) | |
| out_orig = out_dir / "original" / (ap.stem + ".wav") | |
| if args.force or not out_orig.exists(): | |
| save_wav(out_orig, wav_mono, sr) | |
| actual_audio.append(out_orig) | |
| for out_orig in actual_audio: | |
| wav0, sr0 = ta_load(str(out_orig)) | |
| if wav0.size(0) > 1: | |
| wav0 = wav0.mean(dim=0, keepdim=True) | |
| for s in specs: | |
| target_sr = sr_by_model[s.name] | |
| wav_in = ta_resample(wav0, sr0, target_sr) if sr0 != target_sr else wav0 | |
| out_path = recon_dir / s.name / f"{sanitize_name(out_orig.stem)}.wav" | |
| if args.force or not out_path.exists(): | |
| print(f"[Reconstruct] {s.name} ← {out_orig.name}") | |
| wav_hat = reconstruct_stack( | |
| wav_in, | |
| wavvae_by_name[s.name], | |
| zflow_by_name[s.name], | |
| s.decode.num_steps, | |
| s.decode.cfg, | |
| device, | |
| ) | |
| save_wav(out_path, wav_hat, target_sr) | |
| html_path = make_html( | |
| out_dir, actual_audio, specs, sr_by_model, wavvae_cfg, zflow_cfg | |
| ) | |
| print(f"Done. Open: {html_path}") | |
| if __name__ == "__main__": | |
| main() | |