File size: 1,483 Bytes
56cfa73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import argparse
import hashlib
from pathlib import Path

from zcodec.trainers import TrainWavVAE


def hash_checkpoint_file(path, method="sha256", length=8):
    h = hashlib.new(method)
    with open(path, "rb") as f:
        while chunk := f.read(8192):
            h.update(chunk)
        return h.hexdigest()[:length]


def main():
    parser = argparse.ArgumentParser(description="Export WavVAE pretrained checkpoint.")
    parser.add_argument(
        "checkpoint", type=Path, help="Path to the Lightning checkpoint (.ckpt)"
    )
    parser.add_argument(
        "--output",
        type=Path,
        default=None,
        help="Optional output directory (default is based on config)",
    )
    args = parser.parse_args()

    # Load Lightning module
    wavvae = TrainWavVAE.load_from_checkpoint(args.checkpoint)
    config = wavvae.hparams.config

    # Compute framerate and z_dim
    frame_rate = wavvae.wavvae.frame_rate
    z_dim = config.latent_dim

    checkpoint_hash = hash_checkpoint_file(args.checkpoint)
    # Determine output directory
    if args.output is None:
        out_dir = Path(
            f"checkpoints/wavvae/pretrained/WavVAE_framerate{frame_rate}_zdim{z_dim}_{checkpoint_hash}"
        )
    else:
        out_dir = args.output
    out_dir.mkdir(parents=True, exist_ok=True)

    # Save weights and config
    wavvae.save_model_weights_and_config(str(out_dir))
    print(f"✅ Exported model to {out_dir}")


if __name__ == "__main__":
    main()