pardi-speech / codec /scripts /export_zflowae.py
Mehdi Lakbar
Initial demo of Lina-speech (pardi-speech)
56cfa73
import argparse
import hashlib
from pathlib import Path
from zcodec.trainers import TrainZFlowAutoEncoder
def hash_checkpoint_file(path, method="sha256", length=8):
h = hashlib.new(method)
with open(path, "rb") as f:
while chunk := f.read(8192):
h.update(chunk)
return h.hexdigest()[:length]
def main():
parser = argparse.ArgumentParser(
description="Export ZFlowAutoEncoder pretrained checkpoint."
)
parser.add_argument(
"checkpoint", type=Path, help="Path to the Lightning checkpoint (.ckpt)"
)
parser.add_argument(
"--output",
type=Path,
default=None,
help="Optional output directory (default is based on config)",
)
args = parser.parse_args()
zflowae = TrainZFlowAutoEncoder.load_from_checkpoint(args.checkpoint)
checkpoint_hash = hash_checkpoint_file(args.checkpoint)
config = zflowae.hparams.config
repa = False
if hasattr(zflowae.hparams, "ssl_repa_head"):
if zflowae.hparams.ssl_repa_head:
repa = True
stride = config.latent_stride
zdim = config.latent_dim
ae_factory = config.autoencoder_factory
if config.fsq_levels is not None:
codebook_size = 1
for c in config.fsq_levels:
codebook_size *= c
type = f"fsq{codebook_size}"
elif config.vae:
type = f"vae{config.bottleneck_size}"
if args.output is None:
out_dir = Path(
f"checkpoints/zflowae/pretrained/ZFlowAutoEncoder_stride{stride}_zdim{zdim}_{type}_{ae_factory}_repa{repa}_{checkpoint_hash}"
)
else:
out_dir = args.output
out_dir.mkdir(parents=True, exist_ok=True)
# Save weights and config
zflowae.save_model_weights_and_config(str(out_dir))
print(f"βœ… Exported model to {out_dir}")
if __name__ == "__main__":
main()