Spaces:
Running
on
Zero
Running
on
Zero
| import argparse | |
| import hashlib | |
| from pathlib import Path | |
| from zcodec.trainers import TrainWavVAE | |
| def hash_checkpoint_file(path, method="sha256", length=8): | |
| h = hashlib.new(method) | |
| with open(path, "rb") as f: | |
| while chunk := f.read(8192): | |
| h.update(chunk) | |
| return h.hexdigest()[:length] | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Export WavVAE pretrained checkpoint.") | |
| parser.add_argument( | |
| "checkpoint", type=Path, help="Path to the Lightning checkpoint (.ckpt)" | |
| ) | |
| parser.add_argument( | |
| "--output", | |
| type=Path, | |
| default=None, | |
| help="Optional output directory (default is based on config)", | |
| ) | |
| args = parser.parse_args() | |
| # Load Lightning module | |
| wavvae = TrainWavVAE.load_from_checkpoint(args.checkpoint) | |
| config = wavvae.hparams.config | |
| # Compute framerate and z_dim | |
| frame_rate = wavvae.wavvae.frame_rate | |
| z_dim = config.latent_dim | |
| checkpoint_hash = hash_checkpoint_file(args.checkpoint) | |
| # Determine output directory | |
| if args.output is None: | |
| out_dir = Path( | |
| f"checkpoints/wavvae/pretrained/WavVAE_framerate{frame_rate}_zdim{z_dim}_{checkpoint_hash}" | |
| ) | |
| else: | |
| out_dir = args.output | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| # Save weights and config | |
| wavvae.save_model_weights_and_config(str(out_dir)) | |
| print(f"β Exported model to {out_dir}") | |
| if __name__ == "__main__": | |
| main() | |