#!/usr/bin/env python3 import argparse import json import sys from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import torch import torchaudio from torchaudio import load as ta_load from torchaudio.functional import resample as ta_resample from zcodec.models import WavVAE, ZFlowAutoEncoder # ------------------------- # Data structures # ------------------------- @dataclass class DecodeParams: num_steps: int = 10 cfg: float = 2.0 @dataclass class StackSpec: name: str wavvae_dir: str zflowae_dir: str decode: DecodeParams # ------------------------- # Utilities (same helpers) # ------------------------- def load_json_if_exists(path: Path): if path.is_file(): try: return json.load(path.open("r", encoding="utf-8")) except Exception: return None return None def read_config_any(checkpoint_dir: str) -> Dict[str, Any]: cand = [ Path(checkpoint_dir) / "config.json", Path(checkpoint_dir) / "model_config.json", Path(checkpoint_dir) / "config.yaml", ] for p in cand: if p.exists(): if p.suffix == ".json": j = load_json_if_exists(p) if j is not None: return j else: return {"_config_file": str(p)} return {} def sanitize_name(s: str) -> str: return "".join(c if c.isalnum() or c in "-_." else "_" for c in s) def ensure_mono_and_resample( wav: torch.Tensor, sr: int, target_sr: int ) -> Tuple[torch.Tensor, int]: if wav.ndim != 2: raise ValueError(f"Expected (C,T), got {tuple(wav.shape)}") if wav.size(0) > 1: wav = wav.mean(dim=0, keepdim=True) if sr != target_sr: wav = ta_resample(wav, sr, target_sr) sr = target_sr return wav.to(torch.float32), sr def save_wav(path: Path, wav: torch.Tensor, sr: int): path.parent.mkdir(parents=True, exist_ok=True) if wav.ndim == 1: wav = wav.unsqueeze(0) wav = wav.clamp(-1, 1).contiguous().cpu() torchaudio.save( str(path), wav, sample_rate=sr, encoding="PCM_S", bits_per_sample=16 ) def read_audio_manifest(txt_path: str) -> List[Path]: lines = Path(txt_path).read_text(encoding="utf-8").splitlines() return [ Path(l.strip()) for l in lines if l.strip() and not l.strip().startswith("#") ] def html_escape(s: str) -> str: return ( s.replace("&", "&") .replace("<", "<") .replace(">", ">") .replace('"', """) .replace("'", "'") ) def make_html( output_dir: Path, audio_files: List[Path], specs: List[StackSpec], sr_by_model: Dict[str, int], wavvae_cfg: Dict[str, Dict[str, Any]], zflow_cfg: Dict[str, Dict[str, Any]], ) -> str: def player(src_rel: str) -> str: return f'' cards = [] for s in specs: wcfg = wavvae_cfg.get(s.name, {}) zcfg = zflow_cfg.get(s.name, {}) w_short = json.dumps(wcfg if wcfg else {"_": "no JSON config found"}, indent=2)[ :1200 ] z_short = json.dumps(zcfg if zcfg else {"_": "no JSON config found"}, indent=2)[ :1200 ] card = f"""

{html_escape(s.name)}

Sample rate: {sr_by_model.get(s.name, "N/A")} Hz

Decode: steps={s.decode.num_steps}, cfg={s.decode.cfg}

WavVAE config
{html_escape(w_short)}
ZFlowAE config
{html_escape(z_short)}
""" cards.append(card) css = """ body { font-family: system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif; padding: 20px; } .cards { display: grid; grid-template-columns: repeat(auto-fill, minmax(320px, 1fr)); gap: 12px; margin-bottom: 18px; } .model-card { border: 1px solid #ddd; border-radius: 12px; padding: 12px; } table { border-collapse: collapse; width: 100%; } th, td { border: 1px solid #eee; padding: 8px; vertical-align: top; } th { background: #fafafa; position: sticky; top: 0; } audio { width: 260px; } """ th = "InputOriginal" + "".join( f"{html_escape(s.name)}" for s in specs ) rows = [] for af in audio_files: base = af.stem orig_rel = f"original/{html_escape(af.name)}" tds = [f"{html_escape(base)}", f"{player(orig_rel)}"] for s in specs: rec_rel = f"recon/{html_escape(s.name)}/{html_escape(base)}.wav" tds.append(f"{player(rec_rel)}") rows.append("" + "".join(tds) + "") html = f""" Stacked Codec Comparison

WavVAE + ZFlowAE Comparison

{"".join(cards)}
{th}{"".join(rows)}
""" out = output_dir / "index.html" out.write_text(html, encoding="utf-8") return str(out) # ------------------------- # Core # ------------------------- @torch.inference_mode() def reconstruct_stack( wav_mono: torch.Tensor, wavvae: WavVAE, zflow: ZFlowAutoEncoder, steps: int, cfg: float, device: str, ) -> torch.Tensor: x = wav_mono.to(device) # (1,T) z = wavvae.encode(x) # high-framerate latents y, _ = zflow.encode(z) # compressed latents z_hat = zflow.decode(y, num_steps=steps, cfg=cfg) wav_hat = wavvae.decode(z_hat) # (1,1,T) return wav_hat.squeeze(0).squeeze(0).detach() def parse_models_manifest(path: str) -> List[StackSpec]: """ JSON list of: { "name": "...", "wavvae": "/path/to/WavVAE_dir", "zflowae": "/path/to/ZFlowAE_dir", "decode": {"num_steps": 10, "cfg": 2.0} } """ raw = json.loads(Path(path).read_text(encoding="utf-8")) specs = [] for it in raw: d = it.get("decode", {}) specs.append( StackSpec( name=it["name"], wavvae_dir=it["wavvae"], zflowae_dir=it["zflowae"], decode=DecodeParams( num_steps=int(d.get("num_steps", 10)), cfg=float(d.get("cfg", 2.0)) ), ) ) return specs def main(): ap = argparse.ArgumentParser( description="Compare WavVAE+ZFlowAE stacks and generate a static HTML page." ) ap.add_argument("--models", required=True, help="JSON manifest of stacks.") ap.add_argument( "--audio_manifest", required=True, help="TXT file: one audio path per line." ) ap.add_argument("--out", default="compare_stack_out") ap.add_argument("--device", default="cuda") ap.add_argument("--force", action="store_true") args = ap.parse_args() device = "cuda" if args.device == "cuda" and torch.cuda.is_available() else "cpu" out_dir = Path(args.out) (out_dir / "original").mkdir(parents=True, exist_ok=True) recon_dir = out_dir / "recon" recon_dir.mkdir(parents=True, exist_ok=True) specs = parse_models_manifest(args.models) if not specs: print("No models.", file=sys.stderr) sys.exit(1) # load models wavvae_by_name: Dict[str, WavVAE] = {} zflow_by_name: Dict[str, ZFlowAutoEncoder] = {} sr_by_model: Dict[str, int] = {} wavvae_cfg: Dict[str, Dict[str, Any]] = {} zflow_cfg: Dict[str, Dict[str, Any]] = {} for s in specs: print(f"[Load] {s.name}") w = WavVAE.from_pretrained_local(s.wavvae_dir).to(device) z = ZFlowAutoEncoder.from_pretrained_local(s.zflowae_dir).to(device) wavvae_by_name[s.name] = w zflow_by_name[s.name] = z sr_by_model[s.name] = int(getattr(w, "sampling_rate", 24000)) wavvae_cfg[s.name] = read_config_any(s.wavvae_dir) zflow_cfg[s.name] = read_config_any(s.zflowae_dir) audio_paths = read_audio_manifest(args.audio_manifest) actual_audio = [] for ap in audio_paths: if not ap.exists(): print(f"[Skip missing] {ap}", file=sys.stderr) continue wav, sr = ta_load(str(ap)) wav_mono, sr = ensure_mono_and_resample(wav, sr, sr) out_orig = out_dir / "original" / (ap.stem + ".wav") if args.force or not out_orig.exists(): save_wav(out_orig, wav_mono, sr) actual_audio.append(out_orig) for out_orig in actual_audio: wav0, sr0 = ta_load(str(out_orig)) if wav0.size(0) > 1: wav0 = wav0.mean(dim=0, keepdim=True) for s in specs: target_sr = sr_by_model[s.name] wav_in = ta_resample(wav0, sr0, target_sr) if sr0 != target_sr else wav0 out_path = recon_dir / s.name / f"{sanitize_name(out_orig.stem)}.wav" if args.force or not out_path.exists(): print(f"[Reconstruct] {s.name} ← {out_orig.name}") wav_hat = reconstruct_stack( wav_in, wavvae_by_name[s.name], zflow_by_name[s.name], s.decode.num_steps, s.decode.cfg, device, ) save_wav(out_path, wav_hat, target_sr) html_path = make_html( out_dir, actual_audio, specs, sr_by_model, wavvae_cfg, zflow_cfg ) print(f"Done. Open: {html_path}") if __name__ == "__main__": main()