Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python3 | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Any, Tuple | |
| import torch | |
| from torchaudio import load as ta_load | |
| from torchaudio.functional import resample as ta_resample | |
| import torchaudio | |
| # Your libs | |
| from zcodec.models import WavVAE, ZFlowAutoEncoder | |
| # ------------------------- | |
| # Data structures | |
| # ------------------------- | |
| class DecodeParams: | |
| num_steps: int = 10 | |
| cfg: float = 2.0 | |
| class ModelPairSpec: | |
| name: str | |
| wavvae_dir: str | |
| zflowae_dir: str | |
| decode: DecodeParams | |
| # ------------------------- | |
| # Utilities | |
| # ------------------------- | |
| def load_json_if_exists(path: Path) -> Optional[Dict[str, Any]]: | |
| if path.is_file(): | |
| try: | |
| with path.open("r", encoding="utf-8") as f: | |
| return json.load(f) | |
| except Exception: | |
| return None | |
| return None | |
| def read_config_any(checkpoint_dir: str) -> Dict[str, Any]: | |
| """ | |
| Try to read config.json (or a few common fallbacks) from a checkpoint dir. | |
| Returns {} if nothing could be parsed. | |
| """ | |
| cand = [ | |
| Path(checkpoint_dir) / "config.json", | |
| Path(checkpoint_dir) | |
| / "config.yaml", # won't parse yaml here, we only display path | |
| Path(checkpoint_dir) / "model_config.json", | |
| ] | |
| for p in cand: | |
| if p.exists(): | |
| if p.suffix == ".json": | |
| j = load_json_if_exists(p) | |
| if j is not None: | |
| return j | |
| else: | |
| # For YAML or unknown, just show filename rather than failing | |
| return {"_config_file": str(p)} | |
| return {} | |
| def sanitize_name(s: str) -> str: | |
| return "".join(c if c.isalnum() or c in "-_." else "_" for c in s) | |
| def ensure_mono_and_resample( | |
| wav: torch.Tensor, sr: int, target_sr: int | |
| ) -> Tuple[torch.Tensor, int]: | |
| """ | |
| wav: (channels, samples) | |
| returns mono float32 in [-1,1], resampled to target_sr | |
| """ | |
| if wav.ndim != 2: | |
| raise ValueError(f"Expected 2D waveform (C, T), got shape {tuple(wav.shape)}") | |
| # to mono | |
| if wav.size(0) > 1: | |
| wav = wav.mean(dim=0, keepdim=True) | |
| # resample if needed | |
| if sr != target_sr: | |
| wav = ta_resample(wav, sr, target_sr) | |
| sr = target_sr | |
| return wav.to(torch.float32), sr | |
| def save_wav(path: Path, wav: torch.Tensor, sr: int): | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| # (C, T) | |
| if wav.ndim == 1: | |
| wav = wav.unsqueeze(0) | |
| # Clamp to [-1,1] | |
| wav = wav.clamp(-1, 1).contiguous().cpu() | |
| torchaudio.save( | |
| str(path), wav, sample_rate=sr, encoding="PCM_S", bits_per_sample=16 | |
| ) | |
| # ------------------------- | |
| # Core inference | |
| # ------------------------- | |
| def reconstruct_full_pipeline( | |
| wav_mono: torch.Tensor, | |
| sr: int, | |
| wavvae: WavVAE, | |
| zflowae: ZFlowAutoEncoder, | |
| decode_params: DecodeParams, | |
| device: str, | |
| ) -> torch.Tensor: | |
| """ | |
| Full path: audio -> WavVAE.encode -> ZFlowAE.encode -> ZFlowAE.decode -> WavVAE.decode -> audio_hat | |
| """ | |
| wav_mono = wav_mono.to(device) | |
| # WavVAE expects (B, C, T); assume C=1 | |
| x = wav_mono.unsqueeze(0) # (1, 1, T) | |
| # Encode to high-framerate latents | |
| z = wavvae.encode(x) | |
| # Compress latents | |
| y = zflowae.encode(z) | |
| # Decompress | |
| z_hat = zflowae.decode(y, num_steps=decode_params.num_steps, cfg=decode_params.cfg) | |
| # Decode to waveform | |
| wav_hat = wavvae.decode(z_hat) # (1, 1, T) | |
| # Return mono 1D | |
| return wav_hat.squeeze(0).squeeze(0).detach() | |
| def load_model_pair(spec: ModelPairSpec, device: str): | |
| wavvae = WavVAE.from_pretrained_local(spec.wavvae_dir).to(device) | |
| zflowae = ZFlowAutoEncoder.from_pretrained_local(spec.zflowae_dir).to(device) | |
| # try to get sampling rate from WavVAE | |
| target_sr = getattr(wavvae, "sampling_rate", None) | |
| if target_sr is None: | |
| # reasonable fallback | |
| target_sr = 24000 | |
| return wavvae, zflowae, int(target_sr) | |
| def parse_manifest(path: str) -> List[ModelPairSpec]: | |
| """ | |
| Manifest format (JSON list): | |
| [ | |
| { | |
| "name": "zdim32x8", | |
| "wavvae": "/path/to/WavVAE_framerate100_zdim32/", | |
| "zflowae": "/path/to/ZFlowAutoEncoder_stride4_zdim32_vae8_.../", | |
| "decode": {"num_steps": 10, "cfg": 2.0} | |
| } | |
| ] | |
| """ | |
| with open(path, "r", encoding="utf-8") as f: | |
| raw = json.load(f) | |
| out: List[ModelPairSpec] = [] | |
| for item in raw: | |
| name = item["name"] | |
| wavvae_dir = item["wavvae"] | |
| zflowae_dir = item["zflowae"] | |
| d = item.get("decode", {}) | |
| out.append( | |
| ModelPairSpec( | |
| name=name, | |
| wavvae_dir=wavvae_dir, | |
| zflowae_dir=zflowae_dir, | |
| decode=DecodeParams( | |
| num_steps=int(d.get("num_steps", 10)), | |
| cfg=float(d.get("cfg", 2.0)), | |
| ), | |
| ) | |
| ) | |
| return out | |
| # ------------------------- | |
| # HTML generation | |
| # ------------------------- | |
| def html_escape(s: str) -> str: | |
| return ( | |
| s.replace("&", "&") | |
| .replace("<", "<") | |
| .replace(">", ">") | |
| .replace('"', """) | |
| .replace("'", "'") | |
| ) | |
| def make_html( | |
| output_dir: Path, | |
| audio_files: List[Path], | |
| models: List[ModelPairSpec], | |
| sr_by_model: Dict[str, int], | |
| wavvae_cfg: Dict[str, Dict[str, Any]], | |
| zflow_cfg: Dict[str, Dict[str, Any]], | |
| ) -> str: | |
| """ | |
| Build a static HTML page with a table: | |
| Row = input audio file | |
| Col 1 = Original | |
| Col 2..N = each model reconstruction | |
| Also shows minimal model config info above the table. | |
| """ | |
| def player(src_rel: str, controls: bool = True) -> str: | |
| return f'<audio {"controls" if controls else ""} preload="none" src="{html_escape(src_rel)}"></audio>' | |
| # Model cards | |
| model_cards = [] | |
| for spec in models: | |
| wcfg = wavvae_cfg.get(spec.name, {}) | |
| zcfg = zflow_cfg.get(spec.name, {}) | |
| w_short = json.dumps(wcfg if wcfg else {"_": "no JSON config found"}, indent=2)[ | |
| :1200 | |
| ] | |
| z_short = json.dumps(zcfg if zcfg else {"_": "no JSON config found"}, indent=2)[ | |
| :1200 | |
| ] | |
| card = f""" | |
| <div class="model-card"> | |
| <h3>{html_escape(spec.name)}</h3> | |
| <p><b>Sample rate</b>: {sr_by_model.get(spec.name, "N/A")} Hz</p> | |
| <details> | |
| <summary>WavVAE config</summary> | |
| <pre>{html_escape(w_short)}</pre> | |
| </details> | |
| <details> | |
| <summary>ZFlowAE config</summary> | |
| <pre>{html_escape(z_short)}</pre> | |
| </details> | |
| <p><b>Decode</b>: num_steps={spec.decode.num_steps}, cfg={spec.decode.cfg}</p> | |
| </div> | |
| """ | |
| model_cards.append(card) | |
| # Table header | |
| th = "<th>Input</th><th>Original</th>" + "".join( | |
| f"<th>{html_escape(m.name)}</th>" for m in models | |
| ) | |
| # Rows | |
| rows = [] | |
| for af in audio_files: | |
| base = af.stem | |
| orig_rel = f"original/{html_escape(af.name)}" | |
| tds = [f"<td>{html_escape(base)}</td>", f"<td>{player(orig_rel)}</td>"] | |
| for m in models: | |
| rec_rel = f"recon/{html_escape(m.name)}/{html_escape(base)}.wav" | |
| tds.append(f"<td>{player(rec_rel)}</td>") | |
| rows.append("<tr>" + "".join(tds) + "</tr>") | |
| # Simple CSS to keep it clean | |
| css = """ | |
| body { font-family: system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif; padding: 20px; } | |
| h1 { margin-bottom: 0.2rem; } | |
| .cards { display: grid; grid-template-columns: repeat(auto-fill, minmax(320px, 1fr)); gap: 12px; margin-bottom: 18px; } | |
| .model-card { border: 1px solid #ddd; border-radius: 12px; padding: 12px; } | |
| table { border-collapse: collapse; width: 100%; } | |
| th, td { border: 1px solid #eee; padding: 8px; vertical-align: top; } | |
| th { background: #fafafa; position: sticky; top: 0; } | |
| audio { width: 260px; } | |
| """ | |
| html = f"""<!doctype html> | |
| <html> | |
| <head> | |
| <meta charset="utf-8"/> | |
| <title>Codec Comparison</title> | |
| <style>{css}</style> | |
| </head> | |
| <body> | |
| <h1>Codec Comparison</h1> | |
| <p>This page compares reconstructions across model checkpoints. Click play in each cell.</p> | |
| <h2>Models</h2> | |
| <div class="cards"> | |
| {"".join(model_cards)} | |
| </div> | |
| <h2>Audio</h2> | |
| <table> | |
| <thead><tr>{th}</tr></thead> | |
| <tbody> | |
| {"".join(rows)} | |
| </tbody> | |
| </table> | |
| </body> | |
| </html> | |
| """ | |
| out = output_dir / "index.html" | |
| out.write_text(html, encoding="utf-8") | |
| return str(out) | |
| # ------------------------- | |
| # Main | |
| # ------------------------- | |
| def main(): | |
| p = argparse.ArgumentParser( | |
| description="Compare Z-Codec configurations and generate a static HTML page." | |
| ) | |
| p.add_argument( | |
| "--manifest", | |
| type=str, | |
| required=True, | |
| help="JSON file listing model pairs. See docstring in parse_manifest().", | |
| ) | |
| p.add_argument( | |
| "--audio", type=str, nargs="+", required=True, help="List of input audio files." | |
| ) | |
| p.add_argument( | |
| "--out", | |
| type=str, | |
| default="codec_compare_out", | |
| help="Output directory for reconstructions and HTML.", | |
| ) | |
| p.add_argument( | |
| "--device", | |
| type=str, | |
| default="cuda", | |
| help="Device to run inference on (cuda or cpu).", | |
| ) | |
| p.add_argument( | |
| "--force", | |
| action="store_true", | |
| help="Recompute even if target wav already exists.", | |
| ) | |
| args = p.parse_args() | |
| device = "cuda" if args.device == "cuda" and torch.cuda.is_available() else "cpu" | |
| out_dir = Path(args.out) | |
| orig_dir = out_dir / "original" | |
| recon_dir = out_dir / "recon" | |
| orig_dir.mkdir(parents=True, exist_ok=True) | |
| recon_dir.mkdir(parents=True, exist_ok=True) | |
| # Parse models | |
| specs = parse_manifest(args.manifest) | |
| if not specs: | |
| print("No models in manifest.", file=sys.stderr) | |
| sys.exit(1) | |
| # Load models | |
| loaded: Dict[str, Dict[str, Any]] = {} | |
| sr_by_model: Dict[str, int] = {} | |
| wavvae_cfg: Dict[str, Dict[str, Any]] = {} | |
| zflow_cfg: Dict[str, Dict[str, Any]] = {} | |
| for spec in specs: | |
| print(f"[Load] {spec.name}") | |
| wavvae, zflowae, target_sr = load_model_pair(spec, device) | |
| loaded[spec.name] = {"wavvae": wavvae, "zflowae": zflowae, "sr": target_sr} | |
| sr_by_model[spec.name] = target_sr | |
| wavvae_cfg[spec.name] = read_config_any(spec.wavvae_dir) | |
| zflow_cfg[spec.name] = read_config_any(spec.zflowae_dir) | |
| # Process audio files | |
| audio_files = [Path(a) for a in args.audio] | |
| for af in audio_files: | |
| if not af.exists(): | |
| print(f"[Skip] Missing: {af}", file=sys.stderr) | |
| continue | |
| # copy original (resampled per model? We'll store original as-is) | |
| # Just place the original file for direct playback | |
| # If it's not wav, we still copy a WAV version for compatibility. | |
| # But simplest: if not wav, we re-save as wav 16-bit for the page. | |
| out_orig = orig_dir / af.name | |
| if args.force or not out_orig.exists(): | |
| # Load and resave as wav to ensure browser-compat | |
| wav, sr = ta_load(str(af)) | |
| # make it mono for fair listening | |
| wav_mono, sr = ensure_mono_and_resample(wav, sr, sr) | |
| save_wav(out_orig.with_suffix(".wav"), wav_mono, sr) | |
| # keep the name consistent in the HTML (use .wav) | |
| af = af.with_suffix(".wav") | |
| # rename saved file to matched name | |
| if out_orig.suffix != ".wav": | |
| # Clean: ensure HTML references the .wav filename | |
| out_orig = out_orig.with_suffix(".wav") | |
| # For each model, run full pipeline and save | |
| base = af.stem | |
| # Re-load from disk to ensure consistent start-point (original .wav in out folder) | |
| wav0, sr0 = ta_load(str(out_orig if out_orig.exists() else orig_dir / af.name)) | |
| # Make mono only once; resample per-model to each target SR | |
| if wav0.size(0) > 1: | |
| wav0 = wav0.mean(dim=0, keepdim=True) | |
| for spec in specs: | |
| mname = spec.name | |
| target_sr = sr_by_model[mname] | |
| # resample to model's SR | |
| if sr0 != target_sr: | |
| wav_mono = ta_resample(wav0, sr0, target_sr) | |
| else: | |
| wav_mono = wav0 | |
| # reconstruct | |
| out_path = recon_dir / mname / f"{sanitize_name(base)}.wav" | |
| if args.force or not out_path.exists(): | |
| print(f"[Reconstruct] {mname} ← {base}") | |
| wavvae = loaded[mname]["wavvae"] | |
| zflowae = loaded[mname]["zflowae"] | |
| wav_hat = reconstruct_full_pipeline( | |
| wav_mono, target_sr, wavvae, zflowae, spec.decode, device | |
| ) | |
| save_wav(out_path, wav_hat.unsqueeze(0), target_sr) | |
| # Build HTML | |
| # Rebuild the list of files actually present in original/ (use .wav names) | |
| actual_audio = sorted([p for p in (orig_dir).glob("*.wav")]) | |
| html_path = make_html( | |
| out_dir, | |
| actual_audio, | |
| specs, | |
| sr_by_model, | |
| wavvae_cfg, | |
| zflow_cfg, | |
| ) | |
| print(f"\nDone. Open: {html_path}") | |
| if __name__ == "__main__": | |
| main() | |