Spaces:
Running
on
Zero
Running
on
Zero
| import argparse | |
| import hashlib | |
| from pathlib import Path | |
| from zcodec.trainers import TrainZFlowAutoEncoder | |
| def hash_checkpoint_file(path, method="sha256", length=8): | |
| h = hashlib.new(method) | |
| with open(path, "rb") as f: | |
| while chunk := f.read(8192): | |
| h.update(chunk) | |
| return h.hexdigest()[:length] | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Export ZFlowAutoEncoder pretrained checkpoint." | |
| ) | |
| parser.add_argument( | |
| "checkpoint", type=Path, help="Path to the Lightning checkpoint (.ckpt)" | |
| ) | |
| parser.add_argument( | |
| "--output", | |
| type=Path, | |
| default=None, | |
| help="Optional output directory (default is based on config)", | |
| ) | |
| args = parser.parse_args() | |
| zflowae = TrainZFlowAutoEncoder.load_from_checkpoint(args.checkpoint) | |
| checkpoint_hash = hash_checkpoint_file(args.checkpoint) | |
| config = zflowae.hparams.config | |
| repa = False | |
| if hasattr(zflowae.hparams, "ssl_repa_head"): | |
| if zflowae.hparams.ssl_repa_head: | |
| repa = True | |
| stride = config.latent_stride | |
| zdim = config.latent_dim | |
| ae_factory = config.autoencoder_factory | |
| if config.fsq_levels is not None: | |
| codebook_size = 1 | |
| for c in config.fsq_levels: | |
| codebook_size *= c | |
| type = f"fsq{codebook_size}" | |
| elif config.vae: | |
| type = f"vae{config.bottleneck_size}" | |
| if args.output is None: | |
| out_dir = Path( | |
| f"checkpoints/zflowae/pretrained/ZFlowAutoEncoder_stride{stride}_zdim{zdim}_{type}_{ae_factory}_repa{repa}_{checkpoint_hash}" | |
| ) | |
| else: | |
| out_dir = args.output | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| # Save weights and config | |
| zflowae.save_model_weights_and_config(str(out_dir)) | |
| print(f"β Exported model to {out_dir}") | |
| if __name__ == "__main__": | |
| main() | |