import argparse from pathlib import Path import librosa import soundfile as sf import torch from safetensors import safe_open from safetensors.torch import save_file from torch.utils.data import DataLoader, Dataset from tqdm import tqdm from zcodec.models import ZFlowAutoEncoder class AudioDataset(Dataset): def __init__(self, file_list, target_sr): self.paths = file_list self.target_sr = target_sr def __len__(self): return len(self.paths) def __getitem__(self, idx): path = self.paths[idx] wav, sr = sf.read(str(path)) if sr != self.target_sr: wav = librosa.resample(wav, orig_sr=sr, target_sr=self.target_sr) wav = torch.tensor(wav).unsqueeze(0).float() # shape: [1, T] return wav, path class SafeTensorDataset(Dataset): """ On __getitem__, opens the safetensor, uses get_slice() to inspect shape, then either drops too-short files (return None) or returns a random subsequence slice. """ def __init__( self, file_paths: list[str], key: str = "audio_z", ): self.file_paths = file_paths self.key = key def __len__(self): return len(self.file_paths) def __getitem__(self, idx: int) -> torch.Tensor | None: path = self.file_paths[idx] # open file, get a slice wrapper for full tensor with safe_open(path, framework="pt") as f: tensor = f.get_tensor(self.key) return tensor, path @torch.no_grad() def encode_batch(model, batch, device, out_dir, save_latent=False): wavs, paths = batch for wav, path in zip(wavs, paths): wav = wav.to(device) latent, indices = model.encode(wav) if save_latent: to_save = latent.cpu() else: to_save = indices.cpu() out_path = out_dir / (path.stem + ".st") save_file({"audio_z": to_save}, str(out_path)) def main(): parser = argparse.ArgumentParser( description="Batch encode audio files with ZFlowAutoEncoder." ) parser.add_argument( "file_list", type=Path, help="Text file listing paths to audio files (one per line)", ) parser.add_argument( "checkpoint", type=Path, help="Path to ZFlowAutoEncoder checkpoint directory" ) parser.add_argument( "output_dir", type=Path, help="Directory to save Safetensors latents" ) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" ) parser.add_argument("--num_workers", type=int, default=4, help="Num. workers") args = parser.parse_args() device = torch.device(args.device) # Load model zflowae = ZFlowAutoEncoder.from_pretrained_local(args.checkpoint) zflowae = zflowae.to(device).eval() # Prepare dataset and dataloader with open(args.file_list, "r") as f: file_paths = [Path(line.strip()) for line in f if line.strip()] dataset = SafeTensorDataset(file_paths) dataloader = DataLoader( dataset, batch_size=1, num_workers=args.num_workers, collate_fn=lambda x: list(zip(*x)), ) args.output_dir.mkdir(parents=True, exist_ok=True) # Inference loop for batch in tqdm(dataloader): try: encode_batch(zflowae, batch, device, args.output_dir) except Exception as e: print(f"❌ Batch failed: {e}") if __name__ == "__main__": main()