import argparse from pathlib import Path import librosa import soundfile as sf import torch import torchaudio from torchaudio.functional import resample from safetensors.torch import save_file from torch.utils.data import DataLoader, Dataset from tqdm import tqdm from datasets import Audio, load_dataset, load_from_disk from zcodec.models import WavVAE class AudioDataset(Dataset): def __init__(self, file_list, target_sr): self.paths = file_list self.target_sr = target_sr def __len__(self): return len(self.paths) def __getitem__(self, idx): path = self.paths[idx] wav, sr = sf.read(str(path)) if sr != self.target_sr: wav = librosa.resample(wav, orig_sr=sr, target_sr=self.target_sr) wav = torch.tensor(wav).unsqueeze(0).float() # shape: [1, T] return wav, path @torch.no_grad() def encode_batch(model, batch, device, out_dir): wavs, paths = batch for wav, path in zip(wavs, paths): wav = wav.to(device) latent = model.encode(wav).cpu() out_path = out_dir / (path.stem + ".st") save_file({"audio_z": latent}, str(out_path)) def main(): parser = argparse.ArgumentParser( description="Batch encode audio files with WavVAE." ) parser.add_argument( "input_dataset", type=Path, help="Text file listing paths to audio files (one per line)", ) parser.add_argument( "checkpoint", type=Path, help="Path to WavVAE checkpoint directory" ) parser.add_argument( "output_dataset", type=Path, help="Directory to save Safetensors latents" ) parser.add_argument("--in_column", type=str, default="audio") parser.add_argument("--out_column", type=str, default="audio_z") parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" ) parser.add_argument("--split", type=str, default=None) parser.add_argument( "--num_workers", type=int, default=1, help="Number of DataLoader workers" ) parser.add_argument( "--num_shards", type=int, default=1, help="Number of DataLoader workers" ) parser.add_argument( "--shard_index", type=int, default=0, help="Number of DataLoader workers" ) parser.add_argument("--from_file", action="store_true") parser.add_argument("--from_files", action="store_true") parser.add_argument("--no_resample", action="store_false") parser.add_argument("--from_hub", action="store_true") parser.add_argument("--file_prefix", type=str, default=None) args = parser.parse_args() device = torch.device(args.device) # Load model wavvae = WavVAE.from_pretrained_local(args.checkpoint) wavvae = wavvae.to(device).eval() target_sr = wavvae.sampling_rate # Prepare dataset and dataloader if args.from_hub: dataset = load_dataset(str(args.input_dataset), args.split) else: dataset = load_from_disk(str(args.input_dataset), args.split) # if args.num_shards > 1: dataset = dataset.shard(num_shards=args.num_shards, index=args.shard_index) if args.from_file: def map_fn(audio_file_path): if args.file_prefix is not None: audio_file_path = args.file_prefix + "/" + audio_file_path wav, sr = torchaudio.load(audio_file_path) wav = resample(wav, sr, target_sr) wav = wav.mean(dim=0, keepdim=True) if not args.no_resample: wav = resample(wav, sr, target_sr) with torch.inference_mode(): latent = wavvae.encode(wav.to(device)) return {"audio_z": latent} dataset = dataset.map(map_fn, input_columns=args.in_column) elif args.from_files: def map_fn(audio_file_paths): if args.file_prefix is not None: audio_file_paths = [args.file_prefix + "/" + x for x in audio_file_paths] wav, sr = torchaudio.load(audio_file_paths[0]) wavs = [wav.mean(dim=0, keepdim=True)] wavs = wavs + [torchaudio.load(x)[0].mean(dim=0, keepdim=True) for x in audio_file_paths[1:]] wav = torch.cat(wavs, dim=1) if not args.no_resample: wav = resample(wav, sr, target_sr) with torch.inference_mode(): latent = wavvae.encode(wav.to(device)) return {"audio_z": latent} dataset = dataset.map(map_fn, input_columns=args.in_column) else: dataset = dataset.cast_column(args.in_column, Audio(sampling_rate=target_sr)) dataset = dataset.with_format( "torch", columns=args.in_column, ) def map_fn(audio): with torch.inference_mode(): wav = audio["array"].unsqueeze(0).to(device) latent = wavvae.encode(wav) return {"audio_z": latent} dataset = dataset.map( map_fn, input_columns=args.in_column, remove_columns=args.in_column ) dataset.save_to_disk(args.output_dataset) if __name__ == "__main__": main()