Spaces:
Running
on
Zero
Running
on
Zero
| import argparse | |
| from pathlib import Path | |
| import librosa | |
| import soundfile as sf | |
| import torch | |
| from safetensors.torch import save_file | |
| from torch.utils.data import DataLoader, Dataset | |
| from tqdm import tqdm | |
| from zcodec.models import WavVAE | |
| class AudioDataset(Dataset): | |
| def __init__(self, file_list, target_sr): | |
| self.paths = file_list | |
| self.target_sr = target_sr | |
| def __len__(self): | |
| return len(self.paths) | |
| def __getitem__(self, idx): | |
| path = self.paths[idx] | |
| wav, sr = sf.read(str(path)) | |
| if sr != self.target_sr: | |
| wav = librosa.resample(wav, orig_sr=sr, target_sr=self.target_sr) | |
| wav = torch.tensor(wav).unsqueeze(0).float() # shape: [1, T] | |
| return wav, path | |
| def encode_batch(model, batch, device, out_dir): | |
| wavs, paths = batch | |
| for wav, path in zip(wavs, paths): | |
| wav = wav.to(device) | |
| latent = model.encode(wav).cpu() | |
| out_path = out_dir / (path.stem + ".st") | |
| save_file({"audio_z": latent}, str(out_path)) | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Batch encode audio files with WavVAE." | |
| ) | |
| parser.add_argument( | |
| "file_list", | |
| type=Path, | |
| help="Text file listing paths to audio files (one per line)", | |
| ) | |
| parser.add_argument( | |
| "checkpoint", type=Path, help="Path to WavVAE checkpoint directory" | |
| ) | |
| parser.add_argument( | |
| "output_dir", type=Path, help="Directory to save Safetensors latents" | |
| ) | |
| parser.add_argument( | |
| "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| parser.add_argument( | |
| "--num_workers", type=int, default=3, help="Number of DataLoader workers" | |
| ) | |
| args = parser.parse_args() | |
| device = torch.device(args.device) | |
| # Load model | |
| wavvae = WavVAE.from_pretrained_local(args.checkpoint) | |
| wavvae = wavvae.to(device).eval() | |
| target_sr = wavvae.sampling_rate | |
| # Prepare dataset and dataloader | |
| with open(args.file_list, "r") as f: | |
| file_paths = [Path(line.strip()) for line in f if line.strip()] | |
| dataset = AudioDataset(file_paths, target_sr) | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=1, | |
| num_workers=args.num_workers, | |
| collate_fn=lambda x: list(zip(*x)), | |
| ) | |
| args.output_dir.mkdir(parents=True, exist_ok=True) | |
| # Inference loop | |
| for batch in tqdm(dataloader): | |
| try: | |
| encode_batch(wavvae, batch, device, args.output_dir) | |
| except Exception as e: | |
| print(f"β Batch failed: {e}") | |
| if __name__ == "__main__": | |
| main() | |