File size: 3,493 Bytes
56cfa73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import argparse
from pathlib import Path

import librosa
import soundfile as sf
import torch
from safetensors import safe_open
from safetensors.torch import save_file
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

from zcodec.models import ZFlowAutoEncoder


class AudioDataset(Dataset):
    def __init__(self, file_list, target_sr):
        self.paths = file_list
        self.target_sr = target_sr

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        wav, sr = sf.read(str(path))
        if sr != self.target_sr:
            wav = librosa.resample(wav, orig_sr=sr, target_sr=self.target_sr)
        wav = torch.tensor(wav).unsqueeze(0).float()  # shape: [1, T]
        return wav, path


class SafeTensorDataset(Dataset):
    """
    On __getitem__, opens the safetensor, uses get_slice() to inspect shape,
    then either drops too-short files (return None) or returns a random subsequence slice.
    """

    def __init__(
        self,
        file_paths: list[str],
        key: str = "audio_z",
    ):
        self.file_paths = file_paths
        self.key = key

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx: int) -> torch.Tensor | None:
        path = self.file_paths[idx]
        # open file, get a slice wrapper for full tensor
        with safe_open(path, framework="pt") as f:
            tensor = f.get_tensor(self.key)

        return tensor, path


@torch.no_grad()
def encode_batch(model, batch, device, out_dir, save_latent=False):
    wavs, paths = batch
    for wav, path in zip(wavs, paths):
        wav = wav.to(device)
        latent, indices = model.encode(wav)
        if save_latent:
            to_save = latent.cpu()
        else:
            to_save = indices.cpu()
        out_path = out_dir / (path.stem + ".st")
        save_file({"audio_z": to_save}, str(out_path))


def main():
    parser = argparse.ArgumentParser(
        description="Batch encode audio files with ZFlowAutoEncoder."
    )
    parser.add_argument(
        "file_list",
        type=Path,
        help="Text file listing paths to audio files (one per line)",
    )
    parser.add_argument(
        "checkpoint", type=Path, help="Path to ZFlowAutoEncoder checkpoint directory"
    )
    parser.add_argument(
        "output_dir", type=Path, help="Directory to save Safetensors latents"
    )
    parser.add_argument(
        "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
    )

    parser.add_argument("--num_workers", type=int, default=4, help="Num. workers")
    args = parser.parse_args()
    device = torch.device(args.device)

    # Load model
    zflowae = ZFlowAutoEncoder.from_pretrained_local(args.checkpoint)
    zflowae = zflowae.to(device).eval()

    # Prepare dataset and dataloader
    with open(args.file_list, "r") as f:
        file_paths = [Path(line.strip()) for line in f if line.strip()]
    dataset = SafeTensorDataset(file_paths)
    dataloader = DataLoader(
        dataset,
        batch_size=1,
        num_workers=args.num_workers,
        collate_fn=lambda x: list(zip(*x)),
    )

    args.output_dir.mkdir(parents=True, exist_ok=True)

    # Inference loop
    for batch in tqdm(dataloader):
        try:
            encode_batch(zflowae, batch, device, args.output_dir)
        except Exception as e:
            print(f"❌ Batch failed: {e}")


if __name__ == "__main__":
    main()