import argparse import hashlib from pathlib import Path from zcodec.trainers import TrainZFlowAutoEncoder def hash_checkpoint_file(path, method="sha256", length=8): h = hashlib.new(method) with open(path, "rb") as f: while chunk := f.read(8192): h.update(chunk) return h.hexdigest()[:length] def main(): parser = argparse.ArgumentParser( description="Export ZFlowAutoEncoder pretrained checkpoint." ) parser.add_argument( "checkpoint", type=Path, help="Path to the Lightning checkpoint (.ckpt)" ) parser.add_argument( "--output", type=Path, default=None, help="Optional output directory (default is based on config)", ) args = parser.parse_args() zflowae = TrainZFlowAutoEncoder.load_from_checkpoint(args.checkpoint) checkpoint_hash = hash_checkpoint_file(args.checkpoint) config = zflowae.hparams.config repa = False if hasattr(zflowae.hparams, "ssl_repa_head"): if zflowae.hparams.ssl_repa_head: repa = True stride = config.latent_stride zdim = config.latent_dim ae_factory = config.autoencoder_factory if config.fsq_levels is not None: codebook_size = 1 for c in config.fsq_levels: codebook_size *= c type = f"fsq{codebook_size}" elif config.vae: type = f"vae{config.bottleneck_size}" if args.output is None: out_dir = Path( f"checkpoints/zflowae/pretrained/ZFlowAutoEncoder_stride{stride}_zdim{zdim}_{type}_{ae_factory}_repa{repa}_{checkpoint_hash}" ) else: out_dir = args.output out_dir.mkdir(parents=True, exist_ok=True) # Save weights and config zflowae.save_model_weights_and_config(str(out_dir)) print(f"✅ Exported model to {out_dir}") if __name__ == "__main__": main()