pardi-speech / codec /scripts /infer_wavvae_audiocite.py
Mehdi Lakbar
Initial demo of Lina-speech (pardi-speech)
56cfa73
import argparse
from pathlib import Path
import librosa
import soundfile as sf
import torch
from datasets import load_dataset, load_from_disk
from zcodec.models import WavVAE
def load_and_resample(path, target_sr):
wav, sr = sf.read(str(path))
if sr != target_sr:
wav = librosa.resample(wav, orig_sr=sr, target_sr=target_sr)
wav = torch.tensor(wav).unsqueeze(0).float() # shape: [1, T]
return wav
def main():
parser = argparse.ArgumentParser(
description="Encode HF dataset audio with WavVAE using map() (non-batched)."
)
parser.add_argument("dataset", type=str, help="Path or HF hub ID of dataset")
parser.add_argument("path_column", type=str, help="Column name with wav file paths")
parser.add_argument(
"checkpoint", type=Path, help="Path to WavVAE checkpoint directory"
)
parser.add_argument(
"--split", type=str, default=None, help="Dataset split (if loading from hub)"
)
parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
parser.add_argument(
"--num_proc", type=int, default=1, help="Number of processes for map()"
)
args = parser.parse_args()
device = torch.device(args.device)
# Load model
wavvae = WavVAE.from_pretrained_local(args.checkpoint).to(device).eval()
target_sr = wavvae.sampling_rate
# Load dataset
if Path(args.dataset).exists():
ds = load_from_disk(args.dataset)
else:
ds = load_dataset(args.dataset, split=args.split or "train")
ds = ds.filter(lambda x: x > 1.0, input_columns="duration")
# Mapping function (non-batched)
@torch.no_grad()
def encode_example(example):
wav = load_and_resample(example[args.path_column], target_sr).to(device)
latent = wavvae.encode(wav).cpu().numpy()
example["audio_z"] = latent
return example
# Apply map without batching
ds = ds.map(
encode_example,
num_proc=args.num_proc,
)
# Save dataset with new column
ds.save_to_disk(str(Path(args.dataset) + "_with_latents"))
if __name__ == "__main__":
main()