pardi-speech / codec /scripts /export_wavvae.py
Mehdi Lakbar
Initial demo of Lina-speech (pardi-speech)
56cfa73
import argparse
import hashlib
from pathlib import Path
from zcodec.trainers import TrainWavVAE
def hash_checkpoint_file(path, method="sha256", length=8):
h = hashlib.new(method)
with open(path, "rb") as f:
while chunk := f.read(8192):
h.update(chunk)
return h.hexdigest()[:length]
def main():
parser = argparse.ArgumentParser(description="Export WavVAE pretrained checkpoint.")
parser.add_argument(
"checkpoint", type=Path, help="Path to the Lightning checkpoint (.ckpt)"
)
parser.add_argument(
"--output",
type=Path,
default=None,
help="Optional output directory (default is based on config)",
)
args = parser.parse_args()
# Load Lightning module
wavvae = TrainWavVAE.load_from_checkpoint(args.checkpoint)
config = wavvae.hparams.config
# Compute framerate and z_dim
frame_rate = wavvae.wavvae.frame_rate
z_dim = config.latent_dim
checkpoint_hash = hash_checkpoint_file(args.checkpoint)
# Determine output directory
if args.output is None:
out_dir = Path(
f"checkpoints/wavvae/pretrained/WavVAE_framerate{frame_rate}_zdim{z_dim}_{checkpoint_hash}"
)
else:
out_dir = args.output
out_dir.mkdir(parents=True, exist_ok=True)
# Save weights and config
wavvae.save_model_weights_and_config(str(out_dir))
print(f"βœ… Exported model to {out_dir}")
if __name__ == "__main__":
main()