#!/usr/bin/env python3 import argparse import json import os import sys from dataclasses import dataclass from pathlib import Path from typing import Dict, List, Optional, Any, Tuple import torch from torchaudio import load as ta_load from torchaudio.functional import resample as ta_resample import torchaudio # Your libs from zcodec.models import WavVAE, ZFlowAutoEncoder # ------------------------- # Data structures # ------------------------- @dataclass class DecodeParams: num_steps: int = 10 cfg: float = 2.0 @dataclass class ModelPairSpec: name: str wavvae_dir: str zflowae_dir: str decode: DecodeParams # ------------------------- # Utilities # ------------------------- def load_json_if_exists(path: Path) -> Optional[Dict[str, Any]]: if path.is_file(): try: with path.open("r", encoding="utf-8") as f: return json.load(f) except Exception: return None return None def read_config_any(checkpoint_dir: str) -> Dict[str, Any]: """ Try to read config.json (or a few common fallbacks) from a checkpoint dir. Returns {} if nothing could be parsed. """ cand = [ Path(checkpoint_dir) / "config.json", Path(checkpoint_dir) / "config.yaml", # won't parse yaml here, we only display path Path(checkpoint_dir) / "model_config.json", ] for p in cand: if p.exists(): if p.suffix == ".json": j = load_json_if_exists(p) if j is not None: return j else: # For YAML or unknown, just show filename rather than failing return {"_config_file": str(p)} return {} def sanitize_name(s: str) -> str: return "".join(c if c.isalnum() or c in "-_." else "_" for c in s) def ensure_mono_and_resample( wav: torch.Tensor, sr: int, target_sr: int ) -> Tuple[torch.Tensor, int]: """ wav: (channels, samples) returns mono float32 in [-1,1], resampled to target_sr """ if wav.ndim != 2: raise ValueError(f"Expected 2D waveform (C, T), got shape {tuple(wav.shape)}") # to mono if wav.size(0) > 1: wav = wav.mean(dim=0, keepdim=True) # resample if needed if sr != target_sr: wav = ta_resample(wav, sr, target_sr) sr = target_sr return wav.to(torch.float32), sr def save_wav(path: Path, wav: torch.Tensor, sr: int): path.parent.mkdir(parents=True, exist_ok=True) # (C, T) if wav.ndim == 1: wav = wav.unsqueeze(0) # Clamp to [-1,1] wav = wav.clamp(-1, 1).contiguous().cpu() torchaudio.save( str(path), wav, sample_rate=sr, encoding="PCM_S", bits_per_sample=16 ) # ------------------------- # Core inference # ------------------------- @torch.inference_mode() def reconstruct_full_pipeline( wav_mono: torch.Tensor, sr: int, wavvae: WavVAE, zflowae: ZFlowAutoEncoder, decode_params: DecodeParams, device: str, ) -> torch.Tensor: """ Full path: audio -> WavVAE.encode -> ZFlowAE.encode -> ZFlowAE.decode -> WavVAE.decode -> audio_hat """ wav_mono = wav_mono.to(device) # WavVAE expects (B, C, T); assume C=1 x = wav_mono.unsqueeze(0) # (1, 1, T) # Encode to high-framerate latents z = wavvae.encode(x) # Compress latents y = zflowae.encode(z) # Decompress z_hat = zflowae.decode(y, num_steps=decode_params.num_steps, cfg=decode_params.cfg) # Decode to waveform wav_hat = wavvae.decode(z_hat) # (1, 1, T) # Return mono 1D return wav_hat.squeeze(0).squeeze(0).detach() def load_model_pair(spec: ModelPairSpec, device: str): wavvae = WavVAE.from_pretrained_local(spec.wavvae_dir).to(device) zflowae = ZFlowAutoEncoder.from_pretrained_local(spec.zflowae_dir).to(device) # try to get sampling rate from WavVAE target_sr = getattr(wavvae, "sampling_rate", None) if target_sr is None: # reasonable fallback target_sr = 24000 return wavvae, zflowae, int(target_sr) def parse_manifest(path: str) -> List[ModelPairSpec]: """ Manifest format (JSON list): [ { "name": "zdim32x8", "wavvae": "/path/to/WavVAE_framerate100_zdim32/", "zflowae": "/path/to/ZFlowAutoEncoder_stride4_zdim32_vae8_.../", "decode": {"num_steps": 10, "cfg": 2.0} } ] """ with open(path, "r", encoding="utf-8") as f: raw = json.load(f) out: List[ModelPairSpec] = [] for item in raw: name = item["name"] wavvae_dir = item["wavvae"] zflowae_dir = item["zflowae"] d = item.get("decode", {}) out.append( ModelPairSpec( name=name, wavvae_dir=wavvae_dir, zflowae_dir=zflowae_dir, decode=DecodeParams( num_steps=int(d.get("num_steps", 10)), cfg=float(d.get("cfg", 2.0)), ), ) ) return out # ------------------------- # HTML generation # ------------------------- def html_escape(s: str) -> str: return ( s.replace("&", "&") .replace("<", "<") .replace(">", ">") .replace('"', """) .replace("'", "'") ) def make_html( output_dir: Path, audio_files: List[Path], models: List[ModelPairSpec], sr_by_model: Dict[str, int], wavvae_cfg: Dict[str, Dict[str, Any]], zflow_cfg: Dict[str, Dict[str, Any]], ) -> str: """ Build a static HTML page with a table: Row = input audio file Col 1 = Original Col 2..N = each model reconstruction Also shows minimal model config info above the table. """ def player(src_rel: str, controls: bool = True) -> str: return f'' # Model cards model_cards = [] for spec in models: wcfg = wavvae_cfg.get(spec.name, {}) zcfg = zflow_cfg.get(spec.name, {}) w_short = json.dumps(wcfg if wcfg else {"_": "no JSON config found"}, indent=2)[ :1200 ] z_short = json.dumps(zcfg if zcfg else {"_": "no JSON config found"}, indent=2)[ :1200 ] card = f"""
Sample rate: {sr_by_model.get(spec.name, "N/A")} Hz
{html_escape(w_short)}
{html_escape(z_short)}
Decode: num_steps={spec.decode.num_steps}, cfg={spec.decode.cfg}
This page compares reconstructions across model checkpoints. Click play in each cell.