pardi-speech / codec /scripts /infer_zflowae.py
Mehdi Lakbar
Initial demo of Lina-speech (pardi-speech)
56cfa73
import argparse
from pathlib import Path
import librosa
import soundfile as sf
import torch
from safetensors import safe_open
from safetensors.torch import save_file
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from zcodec.models import ZFlowAutoEncoder
class AudioDataset(Dataset):
def __init__(self, file_list, target_sr):
self.paths = file_list
self.target_sr = target_sr
def __len__(self):
return len(self.paths)
def __getitem__(self, idx):
path = self.paths[idx]
wav, sr = sf.read(str(path))
if sr != self.target_sr:
wav = librosa.resample(wav, orig_sr=sr, target_sr=self.target_sr)
wav = torch.tensor(wav).unsqueeze(0).float() # shape: [1, T]
return wav, path
class SafeTensorDataset(Dataset):
"""
On __getitem__, opens the safetensor, uses get_slice() to inspect shape,
then either drops too-short files (return None) or returns a random subsequence slice.
"""
def __init__(
self,
file_paths: list[str],
key: str = "audio_z",
):
self.file_paths = file_paths
self.key = key
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx: int) -> torch.Tensor | None:
path = self.file_paths[idx]
# open file, get a slice wrapper for full tensor
with safe_open(path, framework="pt") as f:
tensor = f.get_tensor(self.key)
return tensor, path
@torch.no_grad()
def encode_batch(model, batch, device, out_dir, save_latent=False):
wavs, paths = batch
for wav, path in zip(wavs, paths):
wav = wav.to(device)
latent, indices = model.encode(wav)
if save_latent:
to_save = latent.cpu()
else:
to_save = indices.cpu()
out_path = out_dir / (path.stem + ".st")
save_file({"audio_z": to_save}, str(out_path))
def main():
parser = argparse.ArgumentParser(
description="Batch encode audio files with ZFlowAutoEncoder."
)
parser.add_argument(
"file_list",
type=Path,
help="Text file listing paths to audio files (one per line)",
)
parser.add_argument(
"checkpoint", type=Path, help="Path to ZFlowAutoEncoder checkpoint directory"
)
parser.add_argument(
"output_dir", type=Path, help="Directory to save Safetensors latents"
)
parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
parser.add_argument("--num_workers", type=int, default=4, help="Num. workers")
args = parser.parse_args()
device = torch.device(args.device)
# Load model
zflowae = ZFlowAutoEncoder.from_pretrained_local(args.checkpoint)
zflowae = zflowae.to(device).eval()
# Prepare dataset and dataloader
with open(args.file_list, "r") as f:
file_paths = [Path(line.strip()) for line in f if line.strip()]
dataset = SafeTensorDataset(file_paths)
dataloader = DataLoader(
dataset,
batch_size=1,
num_workers=args.num_workers,
collate_fn=lambda x: list(zip(*x)),
)
args.output_dir.mkdir(parents=True, exist_ok=True)
# Inference loop
for batch in tqdm(dataloader):
try:
encode_batch(zflowae, batch, device, args.output_dir)
except Exception as e:
print(f"❌ Batch failed: {e}")
if __name__ == "__main__":
main()