Spaces:
Running
on
Zero
Running
on
Zero
| import argparse | |
| from pathlib import Path | |
| import librosa | |
| import soundfile as sf | |
| import torch | |
| from datasets import load_dataset, load_from_disk | |
| from zcodec.models import WavVAE | |
| def load_and_resample(path, target_sr): | |
| wav, sr = sf.read(str(path)) | |
| if sr != target_sr: | |
| wav = librosa.resample(wav, orig_sr=sr, target_sr=target_sr) | |
| wav = torch.tensor(wav).unsqueeze(0).float() # shape: [1, T] | |
| return wav | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Encode HF dataset audio with WavVAE using map() (non-batched)." | |
| ) | |
| parser.add_argument("dataset", type=str, help="Path or HF hub ID of dataset") | |
| parser.add_argument("path_column", type=str, help="Column name with wav file paths") | |
| parser.add_argument( | |
| "checkpoint", type=Path, help="Path to WavVAE checkpoint directory" | |
| ) | |
| parser.add_argument( | |
| "--split", type=str, default=None, help="Dataset split (if loading from hub)" | |
| ) | |
| parser.add_argument( | |
| "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| parser.add_argument( | |
| "--num_proc", type=int, default=1, help="Number of processes for map()" | |
| ) | |
| args = parser.parse_args() | |
| device = torch.device(args.device) | |
| # Load model | |
| wavvae = WavVAE.from_pretrained_local(args.checkpoint).to(device).eval() | |
| target_sr = wavvae.sampling_rate | |
| # Load dataset | |
| if Path(args.dataset).exists(): | |
| ds = load_from_disk(args.dataset) | |
| else: | |
| ds = load_dataset(args.dataset, split=args.split or "train") | |
| ds = ds.filter(lambda x: x > 1.0, input_columns="duration") | |
| # Mapping function (non-batched) | |
| def encode_example(example): | |
| wav = load_and_resample(example[args.path_column], target_sr).to(device) | |
| latent = wavvae.encode(wav).cpu().numpy() | |
| example["audio_z"] = latent | |
| return example | |
| # Apply map without batching | |
| ds = ds.map( | |
| encode_example, | |
| num_proc=args.num_proc, | |
| ) | |
| # Save dataset with new column | |
| ds.save_to_disk(str(Path(args.dataset) + "_with_latents")) | |
| if __name__ == "__main__": | |
| main() | |