#!/usr/bin/env python3 import argparse import json import os import sys from dataclasses import dataclass from pathlib import Path from typing import Dict, List, Optional, Any, Tuple import torch from torchaudio import load as ta_load from torchaudio.functional import resample as ta_resample import torchaudio # Your libs from zcodec.models import WavVAE, ZFlowAutoEncoder # ------------------------- # Data structures # ------------------------- @dataclass class DecodeParams: num_steps: int = 10 cfg: float = 2.0 @dataclass class ModelPairSpec: name: str wavvae_dir: str zflowae_dir: str decode: DecodeParams # ------------------------- # Utilities # ------------------------- def load_json_if_exists(path: Path) -> Optional[Dict[str, Any]]: if path.is_file(): try: with path.open("r", encoding="utf-8") as f: return json.load(f) except Exception: return None return None def read_config_any(checkpoint_dir: str) -> Dict[str, Any]: """ Try to read config.json (or a few common fallbacks) from a checkpoint dir. Returns {} if nothing could be parsed. """ cand = [ Path(checkpoint_dir) / "config.json", Path(checkpoint_dir) / "config.yaml", # won't parse yaml here, we only display path Path(checkpoint_dir) / "model_config.json", ] for p in cand: if p.exists(): if p.suffix == ".json": j = load_json_if_exists(p) if j is not None: return j else: # For YAML or unknown, just show filename rather than failing return {"_config_file": str(p)} return {} def sanitize_name(s: str) -> str: return "".join(c if c.isalnum() or c in "-_." else "_" for c in s) def ensure_mono_and_resample( wav: torch.Tensor, sr: int, target_sr: int ) -> Tuple[torch.Tensor, int]: """ wav: (channels, samples) returns mono float32 in [-1,1], resampled to target_sr """ if wav.ndim != 2: raise ValueError(f"Expected 2D waveform (C, T), got shape {tuple(wav.shape)}") # to mono if wav.size(0) > 1: wav = wav.mean(dim=0, keepdim=True) # resample if needed if sr != target_sr: wav = ta_resample(wav, sr, target_sr) sr = target_sr return wav.to(torch.float32), sr def save_wav(path: Path, wav: torch.Tensor, sr: int): path.parent.mkdir(parents=True, exist_ok=True) # (C, T) if wav.ndim == 1: wav = wav.unsqueeze(0) # Clamp to [-1,1] wav = wav.clamp(-1, 1).contiguous().cpu() torchaudio.save( str(path), wav, sample_rate=sr, encoding="PCM_S", bits_per_sample=16 ) # ------------------------- # Core inference # ------------------------- @torch.inference_mode() def reconstruct_full_pipeline( wav_mono: torch.Tensor, sr: int, wavvae: WavVAE, zflowae: ZFlowAutoEncoder, decode_params: DecodeParams, device: str, ) -> torch.Tensor: """ Full path: audio -> WavVAE.encode -> ZFlowAE.encode -> ZFlowAE.decode -> WavVAE.decode -> audio_hat """ wav_mono = wav_mono.to(device) # WavVAE expects (B, C, T); assume C=1 x = wav_mono.unsqueeze(0) # (1, 1, T) # Encode to high-framerate latents z = wavvae.encode(x) # Compress latents y = zflowae.encode(z) # Decompress z_hat = zflowae.decode(y, num_steps=decode_params.num_steps, cfg=decode_params.cfg) # Decode to waveform wav_hat = wavvae.decode(z_hat) # (1, 1, T) # Return mono 1D return wav_hat.squeeze(0).squeeze(0).detach() def load_model_pair(spec: ModelPairSpec, device: str): wavvae = WavVAE.from_pretrained_local(spec.wavvae_dir).to(device) zflowae = ZFlowAutoEncoder.from_pretrained_local(spec.zflowae_dir).to(device) # try to get sampling rate from WavVAE target_sr = getattr(wavvae, "sampling_rate", None) if target_sr is None: # reasonable fallback target_sr = 24000 return wavvae, zflowae, int(target_sr) def parse_manifest(path: str) -> List[ModelPairSpec]: """ Manifest format (JSON list): [ { "name": "zdim32x8", "wavvae": "/path/to/WavVAE_framerate100_zdim32/", "zflowae": "/path/to/ZFlowAutoEncoder_stride4_zdim32_vae8_.../", "decode": {"num_steps": 10, "cfg": 2.0} } ] """ with open(path, "r", encoding="utf-8") as f: raw = json.load(f) out: List[ModelPairSpec] = [] for item in raw: name = item["name"] wavvae_dir = item["wavvae"] zflowae_dir = item["zflowae"] d = item.get("decode", {}) out.append( ModelPairSpec( name=name, wavvae_dir=wavvae_dir, zflowae_dir=zflowae_dir, decode=DecodeParams( num_steps=int(d.get("num_steps", 10)), cfg=float(d.get("cfg", 2.0)), ), ) ) return out # ------------------------- # HTML generation # ------------------------- def html_escape(s: str) -> str: return ( s.replace("&", "&") .replace("<", "<") .replace(">", ">") .replace('"', """) .replace("'", "'") ) def make_html( output_dir: Path, audio_files: List[Path], models: List[ModelPairSpec], sr_by_model: Dict[str, int], wavvae_cfg: Dict[str, Dict[str, Any]], zflow_cfg: Dict[str, Dict[str, Any]], ) -> str: """ Build a static HTML page with a table: Row = input audio file Col 1 = Original Col 2..N = each model reconstruction Also shows minimal model config info above the table. """ def player(src_rel: str, controls: bool = True) -> str: return f'' # Model cards model_cards = [] for spec in models: wcfg = wavvae_cfg.get(spec.name, {}) zcfg = zflow_cfg.get(spec.name, {}) w_short = json.dumps(wcfg if wcfg else {"_": "no JSON config found"}, indent=2)[ :1200 ] z_short = json.dumps(zcfg if zcfg else {"_": "no JSON config found"}, indent=2)[ :1200 ] card = f"""

{html_escape(spec.name)}

Sample rate: {sr_by_model.get(spec.name, "N/A")} Hz

WavVAE config
{html_escape(w_short)}
ZFlowAE config
{html_escape(z_short)}

Decode: num_steps={spec.decode.num_steps}, cfg={spec.decode.cfg}

""" model_cards.append(card) # Table header th = "InputOriginal" + "".join( f"{html_escape(m.name)}" for m in models ) # Rows rows = [] for af in audio_files: base = af.stem orig_rel = f"original/{html_escape(af.name)}" tds = [f"{html_escape(base)}", f"{player(orig_rel)}"] for m in models: rec_rel = f"recon/{html_escape(m.name)}/{html_escape(base)}.wav" tds.append(f"{player(rec_rel)}") rows.append("" + "".join(tds) + "") # Simple CSS to keep it clean css = """ body { font-family: system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif; padding: 20px; } h1 { margin-bottom: 0.2rem; } .cards { display: grid; grid-template-columns: repeat(auto-fill, minmax(320px, 1fr)); gap: 12px; margin-bottom: 18px; } .model-card { border: 1px solid #ddd; border-radius: 12px; padding: 12px; } table { border-collapse: collapse; width: 100%; } th, td { border: 1px solid #eee; padding: 8px; vertical-align: top; } th { background: #fafafa; position: sticky; top: 0; } audio { width: 260px; } """ html = f""" Codec Comparison

Codec Comparison

This page compares reconstructions across model checkpoints. Click play in each cell.

Models

{"".join(model_cards)}

Audio

{th} {"".join(rows)}
""" out = output_dir / "index.html" out.write_text(html, encoding="utf-8") return str(out) # ------------------------- # Main # ------------------------- def main(): p = argparse.ArgumentParser( description="Compare Z-Codec configurations and generate a static HTML page." ) p.add_argument( "--manifest", type=str, required=True, help="JSON file listing model pairs. See docstring in parse_manifest().", ) p.add_argument( "--audio", type=str, nargs="+", required=True, help="List of input audio files." ) p.add_argument( "--out", type=str, default="codec_compare_out", help="Output directory for reconstructions and HTML.", ) p.add_argument( "--device", type=str, default="cuda", help="Device to run inference on (cuda or cpu).", ) p.add_argument( "--force", action="store_true", help="Recompute even if target wav already exists.", ) args = p.parse_args() device = "cuda" if args.device == "cuda" and torch.cuda.is_available() else "cpu" out_dir = Path(args.out) orig_dir = out_dir / "original" recon_dir = out_dir / "recon" orig_dir.mkdir(parents=True, exist_ok=True) recon_dir.mkdir(parents=True, exist_ok=True) # Parse models specs = parse_manifest(args.manifest) if not specs: print("No models in manifest.", file=sys.stderr) sys.exit(1) # Load models loaded: Dict[str, Dict[str, Any]] = {} sr_by_model: Dict[str, int] = {} wavvae_cfg: Dict[str, Dict[str, Any]] = {} zflow_cfg: Dict[str, Dict[str, Any]] = {} for spec in specs: print(f"[Load] {spec.name}") wavvae, zflowae, target_sr = load_model_pair(spec, device) loaded[spec.name] = {"wavvae": wavvae, "zflowae": zflowae, "sr": target_sr} sr_by_model[spec.name] = target_sr wavvae_cfg[spec.name] = read_config_any(spec.wavvae_dir) zflow_cfg[spec.name] = read_config_any(spec.zflowae_dir) # Process audio files audio_files = [Path(a) for a in args.audio] for af in audio_files: if not af.exists(): print(f"[Skip] Missing: {af}", file=sys.stderr) continue # copy original (resampled per model? We'll store original as-is) # Just place the original file for direct playback # If it's not wav, we still copy a WAV version for compatibility. # But simplest: if not wav, we re-save as wav 16-bit for the page. out_orig = orig_dir / af.name if args.force or not out_orig.exists(): # Load and resave as wav to ensure browser-compat wav, sr = ta_load(str(af)) # make it mono for fair listening wav_mono, sr = ensure_mono_and_resample(wav, sr, sr) save_wav(out_orig.with_suffix(".wav"), wav_mono, sr) # keep the name consistent in the HTML (use .wav) af = af.with_suffix(".wav") # rename saved file to matched name if out_orig.suffix != ".wav": # Clean: ensure HTML references the .wav filename out_orig = out_orig.with_suffix(".wav") # For each model, run full pipeline and save base = af.stem # Re-load from disk to ensure consistent start-point (original .wav in out folder) wav0, sr0 = ta_load(str(out_orig if out_orig.exists() else orig_dir / af.name)) # Make mono only once; resample per-model to each target SR if wav0.size(0) > 1: wav0 = wav0.mean(dim=0, keepdim=True) for spec in specs: mname = spec.name target_sr = sr_by_model[mname] # resample to model's SR if sr0 != target_sr: wav_mono = ta_resample(wav0, sr0, target_sr) else: wav_mono = wav0 # reconstruct out_path = recon_dir / mname / f"{sanitize_name(base)}.wav" if args.force or not out_path.exists(): print(f"[Reconstruct] {mname} ← {base}") wavvae = loaded[mname]["wavvae"] zflowae = loaded[mname]["zflowae"] wav_hat = reconstruct_full_pipeline( wav_mono, target_sr, wavvae, zflowae, spec.decode, device ) save_wav(out_path, wav_hat.unsqueeze(0), target_sr) # Build HTML # Rebuild the list of files actually present in original/ (use .wav names) actual_audio = sorted([p for p in (orig_dir).glob("*.wav")]) html_path = make_html( out_dir, actual_audio, specs, sr_by_model, wavvae_cfg, zflow_cfg, ) print(f"\nDone. Open: {html_path}") if __name__ == "__main__": main()