pardi-speech / codec /scripts /infer_wavvae.py
Mehdi Lakbar
Initial demo of Lina-speech (pardi-speech)
56cfa73
import argparse
from pathlib import Path
import librosa
import soundfile as sf
import torch
from safetensors.torch import save_file
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from zcodec.models import WavVAE
class AudioDataset(Dataset):
def __init__(self, file_list, target_sr):
self.paths = file_list
self.target_sr = target_sr
def __len__(self):
return len(self.paths)
def __getitem__(self, idx):
path = self.paths[idx]
wav, sr = sf.read(str(path))
if sr != self.target_sr:
wav = librosa.resample(wav, orig_sr=sr, target_sr=self.target_sr)
wav = torch.tensor(wav).unsqueeze(0).float() # shape: [1, T]
return wav, path
@torch.no_grad()
def encode_batch(model, batch, device, out_dir):
wavs, paths = batch
for wav, path in zip(wavs, paths):
wav = wav.to(device)
latent = model.encode(wav).cpu()
out_path = out_dir / (path.stem + ".st")
save_file({"audio_z": latent}, str(out_path))
def main():
parser = argparse.ArgumentParser(
description="Batch encode audio files with WavVAE."
)
parser.add_argument(
"file_list",
type=Path,
help="Text file listing paths to audio files (one per line)",
)
parser.add_argument(
"checkpoint", type=Path, help="Path to WavVAE checkpoint directory"
)
parser.add_argument(
"output_dir", type=Path, help="Directory to save Safetensors latents"
)
parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
parser.add_argument(
"--num_workers", type=int, default=3, help="Number of DataLoader workers"
)
args = parser.parse_args()
device = torch.device(args.device)
# Load model
wavvae = WavVAE.from_pretrained_local(args.checkpoint)
wavvae = wavvae.to(device).eval()
target_sr = wavvae.sampling_rate
# Prepare dataset and dataloader
with open(args.file_list, "r") as f:
file_paths = [Path(line.strip()) for line in f if line.strip()]
dataset = AudioDataset(file_paths, target_sr)
dataloader = DataLoader(
dataset,
batch_size=1,
num_workers=args.num_workers,
collate_fn=lambda x: list(zip(*x)),
)
args.output_dir.mkdir(parents=True, exist_ok=True)
# Inference loop
for batch in tqdm(dataloader):
try:
encode_batch(wavvae, batch, device, args.output_dir)
except Exception as e:
print(f"❌ Batch failed: {e}")
if __name__ == "__main__":
main()