Spaces:
Running
on
Zero
Running
on
Zero
| import argparse | |
| from pathlib import Path | |
| import librosa | |
| import soundfile as sf | |
| import torch | |
| from safetensors import safe_open | |
| from safetensors.torch import save_file | |
| from torch.utils.data import DataLoader, Dataset | |
| from tqdm import tqdm | |
| from zcodec.models import ZFlowAutoEncoder | |
| class AudioDataset(Dataset): | |
| def __init__(self, file_list, target_sr): | |
| self.paths = file_list | |
| self.target_sr = target_sr | |
| def __len__(self): | |
| return len(self.paths) | |
| def __getitem__(self, idx): | |
| path = self.paths[idx] | |
| wav, sr = sf.read(str(path)) | |
| if sr != self.target_sr: | |
| wav = librosa.resample(wav, orig_sr=sr, target_sr=self.target_sr) | |
| wav = torch.tensor(wav).unsqueeze(0).float() # shape: [1, T] | |
| return wav, path | |
| class SafeTensorDataset(Dataset): | |
| """ | |
| On __getitem__, opens the safetensor, uses get_slice() to inspect shape, | |
| then either drops too-short files (return None) or returns a random subsequence slice. | |
| """ | |
| def __init__( | |
| self, | |
| file_paths: list[str], | |
| key: str = "audio_z", | |
| ): | |
| self.file_paths = file_paths | |
| self.key = key | |
| def __len__(self): | |
| return len(self.file_paths) | |
| def __getitem__(self, idx: int) -> torch.Tensor | None: | |
| path = self.file_paths[idx] | |
| # open file, get a slice wrapper for full tensor | |
| with safe_open(path, framework="pt") as f: | |
| tensor = f.get_tensor(self.key) | |
| return tensor, path | |
| def encode_batch(model, batch, device, out_dir, save_latent=False): | |
| wavs, paths = batch | |
| for wav, path in zip(wavs, paths): | |
| wav = wav.to(device) | |
| latent, indices = model.encode(wav) | |
| if save_latent: | |
| to_save = latent.cpu() | |
| else: | |
| to_save = indices.cpu() | |
| out_path = out_dir / (path.stem + ".st") | |
| save_file({"audio_z": to_save}, str(out_path)) | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Batch encode audio files with ZFlowAutoEncoder." | |
| ) | |
| parser.add_argument( | |
| "file_list", | |
| type=Path, | |
| help="Text file listing paths to audio files (one per line)", | |
| ) | |
| parser.add_argument( | |
| "checkpoint", type=Path, help="Path to ZFlowAutoEncoder checkpoint directory" | |
| ) | |
| parser.add_argument( | |
| "output_dir", type=Path, help="Directory to save Safetensors latents" | |
| ) | |
| parser.add_argument( | |
| "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| parser.add_argument("--num_workers", type=int, default=4, help="Num. workers") | |
| args = parser.parse_args() | |
| device = torch.device(args.device) | |
| # Load model | |
| zflowae = ZFlowAutoEncoder.from_pretrained_local(args.checkpoint) | |
| zflowae = zflowae.to(device).eval() | |
| # Prepare dataset and dataloader | |
| with open(args.file_list, "r") as f: | |
| file_paths = [Path(line.strip()) for line in f if line.strip()] | |
| dataset = SafeTensorDataset(file_paths) | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=1, | |
| num_workers=args.num_workers, | |
| collate_fn=lambda x: list(zip(*x)), | |
| ) | |
| args.output_dir.mkdir(parents=True, exist_ok=True) | |
| # Inference loop | |
| for batch in tqdm(dataloader): | |
| try: | |
| encode_batch(zflowae, batch, device, args.output_dir) | |
| except Exception as e: | |
| print(f"❌ Batch failed: {e}") | |
| if __name__ == "__main__": | |
| main() | |