Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,483 Bytes
56cfa73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import argparse
import hashlib
from pathlib import Path
from zcodec.trainers import TrainWavVAE
def hash_checkpoint_file(path, method="sha256", length=8):
h = hashlib.new(method)
with open(path, "rb") as f:
while chunk := f.read(8192):
h.update(chunk)
return h.hexdigest()[:length]
def main():
parser = argparse.ArgumentParser(description="Export WavVAE pretrained checkpoint.")
parser.add_argument(
"checkpoint", type=Path, help="Path to the Lightning checkpoint (.ckpt)"
)
parser.add_argument(
"--output",
type=Path,
default=None,
help="Optional output directory (default is based on config)",
)
args = parser.parse_args()
# Load Lightning module
wavvae = TrainWavVAE.load_from_checkpoint(args.checkpoint)
config = wavvae.hparams.config
# Compute framerate and z_dim
frame_rate = wavvae.wavvae.frame_rate
z_dim = config.latent_dim
checkpoint_hash = hash_checkpoint_file(args.checkpoint)
# Determine output directory
if args.output is None:
out_dir = Path(
f"checkpoints/wavvae/pretrained/WavVAE_framerate{frame_rate}_zdim{z_dim}_{checkpoint_hash}"
)
else:
out_dir = args.output
out_dir.mkdir(parents=True, exist_ok=True)
# Save weights and config
wavvae.save_model_weights_and_config(str(out_dir))
print(f"✅ Exported model to {out_dir}")
if __name__ == "__main__":
main()
|