pardi-speech / codec /scripts /infer_zflowae_hf.py
Mehdi Lakbar
Initial demo of Lina-speech (pardi-speech)
56cfa73
import argparse
from pathlib import Path
import librosa
import soundfile as sf
import torch
import torchaudio
from safetensors.torch import save_file
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from datasets import Audio, load_dataset, load_from_disk
from zcodec.models import WavVAE, ZFlowAutoEncoder
class AudioDataset(Dataset):
def __init__(self, file_list, target_sr):
self.paths = file_list
self.target_sr = target_sr
def __len__(self):
return len(self.paths)
def __getitem__(self, idx):
path = self.paths[idx]
wav, sr = sf.read(str(path))
if sr != self.target_sr:
wav = librosa.resample(wav, orig_sr=sr, target_sr=self.target_sr)
wav = torch.tensor(wav).unsqueeze(0).float() # shape: [1, T]
return wav, path
@torch.no_grad()
def encode_batch(model, batch, device, out_dir, save_latent=True):
wavs, paths = batch
for wav, path in zip(wavs, paths):
wav = wav.to(device)
latent, indices = model.encode(wav)
if save_latent:
to_save = latent.cpu()
else:
save_latent = indices.cpu()
out_path = out_dir / (path.stem + ".st")
save_file({"audio_z": to_save}, str(out_path))
def main():
parser = argparse.ArgumentParser(
description="Batch encode audio files with WavVAE."
)
parser.add_argument(
"input_dataset",
type=Path,
help="Text file listing paths to audio files (one per line)",
)
parser.add_argument(
"checkpoint", type=Path, help="Path to zflowae checkpoint directory"
)
parser.add_argument(
"output_dataset", type=Path, help="Directory to save Safetensors latents"
)
parser.add_argument("--in_column", type=str, default="audio_z")
parser.add_argument("--out_column", type=str, default="audio_latent")
parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
parser.add_argument("--split", type=str, default="all")
parser.add_argument(
"--num_workers", type=int, default=1, help="Number of DataLoader workers"
)
parser.add_argument("--from_file", action="store_true")
parser.add_argument("--from_hub", action="store_true")
parser.add_argument("--file_prefix", type=str, default=None)
args = parser.parse_args()
device = torch.device(args.device)
# Load model
zflowae = ZFlowAutoEncoder.from_pretrained_local(args.checkpoint)
zflowae = zflowae.to(device).eval()
# Prepare dataset and dataloader
if args.from_hub:
dataset = load_dataset(str(args.input_dataset), args.split)
else:
dataset = load_from_disk(str(args.input_dataset), args.split)
#
if args.from_file:
raise NotImplemented
def map_fn(audio_file_path):
if args.file_prefix is not None:
audio_file_path = args.file_prefix+"/"+audio_file_path
wav, sr = torchaudio.load(audio_file_path)
wav = wav.mean(dim=0, keepdim=True)
with torch.inference_mode():
latent = zflowae.encode(wav.to(device))
return {"audio_z": latent}
dataset = dataset.map(map_fn, input_columns=args.in_column)
else:
dataset = dataset.with_format(
"torch",
columns=args.in_column,
)
def map_fn(audio):
with torch.inference_mode():
audio_z = audio.to(device)
latent, _ = zflowae.encode(audio_z)
return {args.out_column: latent}
dataset = dataset.map(
map_fn, input_columns=args.in_column, remove_columns=args.in_column
)
dataset.save_to_disk(args.output_dataset)
if __name__ == "__main__":
main()