import argparse import hashlib from pathlib import Path from zcodec.trainers import TrainWavVAE def hash_checkpoint_file(path, method="sha256", length=8): h = hashlib.new(method) with open(path, "rb") as f: while chunk := f.read(8192): h.update(chunk) return h.hexdigest()[:length] def main(): parser = argparse.ArgumentParser(description="Export WavVAE pretrained checkpoint.") parser.add_argument( "checkpoint", type=Path, help="Path to the Lightning checkpoint (.ckpt)" ) parser.add_argument( "--output", type=Path, default=None, help="Optional output directory (default is based on config)", ) args = parser.parse_args() # Load Lightning module wavvae = TrainWavVAE.load_from_checkpoint(args.checkpoint) config = wavvae.hparams.config # Compute framerate and z_dim frame_rate = wavvae.wavvae.frame_rate z_dim = config.latent_dim checkpoint_hash = hash_checkpoint_file(args.checkpoint) # Determine output directory if args.output is None: out_dir = Path( f"checkpoints/wavvae/pretrained/WavVAE_framerate{frame_rate}_zdim{z_dim}_{checkpoint_hash}" ) else: out_dir = args.output out_dir.mkdir(parents=True, exist_ok=True) # Save weights and config wavvae.save_model_weights_and_config(str(out_dir)) print(f"✅ Exported model to {out_dir}") if __name__ == "__main__": main()