Spaces:
Running
on
Zero
Running
on
Zero
| import itertools | |
| import random | |
| import time | |
| from dataclasses import dataclass | |
| from functools import partial | |
| from pathlib import Path | |
| import numpy as np | |
| import pytorch_lightning as ptl | |
| import torch | |
| import torchaudio | |
| from safetensors.torch import safe_open | |
| from sklearn.model_selection import train_test_split | |
| from torch.nn.utils.rnn import pad_sequence | |
| from torch.utils.data import DataLoader, Dataset | |
| from datasets import load_dataset, load_from_disk | |
| class WavVAEDataConfig: | |
| filelist_path: str | |
| sampling_rate: int | |
| num_samples: int | |
| batch_size: int | |
| num_workers: int | |
| class WavVAEDataModule(ptl.LightningDataModule): | |
| def __init__(self, train_params: WavVAEDataConfig, val_params: WavVAEDataConfig): | |
| super().__init__() | |
| self.train_config = train_params | |
| self.val_config = val_params | |
| def _get_dataloder(self, cfg: WavVAEDataConfig, train: bool): | |
| dataset = WavVAEDataset(cfg, train=train) | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=cfg.batch_size, | |
| num_workers=cfg.num_workers, | |
| shuffle=train, | |
| pin_memory=True, | |
| ) | |
| return dataloader | |
| def train_dataloader(self) -> DataLoader: | |
| return self._get_dataloder(self.train_config, train=True) | |
| def val_dataloader(self) -> DataLoader: | |
| return self._get_dataloder(self.val_config, train=False) | |
| class WavVAEDataset(Dataset): | |
| def __init__(self, cfg: WavVAEDataConfig, train: bool): | |
| with open(cfg.filelist_path) as f: | |
| self.filelist = f.read().splitlines() | |
| self.sampling_rate = cfg.sampling_rate | |
| self.num_samples = cfg.num_samples | |
| self.train = train | |
| def __len__(self) -> int: | |
| return len(self.filelist) | |
| def __getitem__(self, index: int) -> torch.Tensor: | |
| audio_path = self.filelist[index] | |
| y, sr = torchaudio.load(audio_path) | |
| if y.size(0) > 1: | |
| # mix to mono | |
| y = y.mean(dim=0, keepdim=True) | |
| gain = np.random.uniform(-1, -6) if self.train else -3 | |
| y, _ = torchaudio.sox_effects.apply_effects_tensor( | |
| y, sr, [["norm", f"{gain:.2f}"]] | |
| ) | |
| if sr != self.sampling_rate: | |
| y = torchaudio.functional.resample( | |
| y, orig_freq=sr, new_freq=self.sampling_rate | |
| ) | |
| if y.size(-1) < self.num_samples: | |
| pad_length = self.num_samples - y.size(-1) | |
| padding_tensor = y.repeat(1, 1 + pad_length // y.size(-1)) | |
| y = torch.cat((y, padding_tensor[:, :pad_length]), dim=1) | |
| elif self.train: | |
| start = np.random.randint(low=0, high=y.size(-1) - self.num_samples + 1) | |
| y = y[:, start : start + self.num_samples] | |
| else: | |
| # During validation, take always the first segment for determinism | |
| y = y[:, : self.num_samples] | |
| return y[0] | |
| def pad_tensor_list_raw( | |
| tensor_list: list[tuple[torch.Tensor, torch.Tensor]], pad_idx: int = 0 | |
| ) -> dict[str, torch.Tensor | None]: | |
| audio, hubert_maybe = zip(*tensor_list) | |
| audio = torch.cat(audio, dim=0) | |
| if hubert_maybe[0] is not None: | |
| hubert_maybe = torch.stack(hubert_maybe, dim=0) | |
| else: | |
| hubert_maybe = None | |
| return {"audio_z": audio, "hubert": hubert_maybe} | |
| class SafeTensorDataset(Dataset): | |
| """ | |
| On __getitem__, opens the safetensor, uses get_slice() to inspect shape, | |
| then either drops too-short files (return None) or returns a random subsequence slice. | |
| """ | |
| def __init__( | |
| self, | |
| file_paths: list[str], | |
| key: str, | |
| hubert_path: str | None = None, | |
| hubert_key: str = "layer_9", | |
| min_length: int = 1, | |
| subseq_length: int | None = None, | |
| ): | |
| self.file_paths = file_paths | |
| self.key = key | |
| self.min_length = min_length | |
| self.subseq_length = subseq_length | |
| self.hubert_path = hubert_path | |
| self.hubert_key = hubert_key | |
| def __len__(self): | |
| return len(self.file_paths) | |
| def __getitem__(self, idx: int) -> torch.Tensor | None: | |
| path = self.file_paths[idx] | |
| # open file, get a slice wrapper for full tensor | |
| with safe_open(path, framework="pt") as f: | |
| tensor_slice = f.get_slice(self.key) | |
| Q, N, D = tensor_slice.get_shape() # full shape [K, N] | |
| # drop too-short | |
| if N < self.min_length: | |
| return None | |
| L = self.subseq_length or N | |
| if L < N: | |
| # sample random start | |
| start = torch.randint(0, max(1, N - L - 1), ()).item() | |
| start -= start % 2 | |
| # this yields a torch.Tensor of shape [K, L] | |
| seq = tensor_slice[:, start : start + L] | |
| else: | |
| # full length | |
| start = 0 | |
| seq = tensor_slice[:, :] | |
| if self.hubert_path is not None: | |
| path = Path(self.hubert_path) / Path(path).name | |
| with safe_open(path, framework="pt") as f: | |
| tensor_slice = f.get_slice(self.hubert_key) | |
| hubert_N, hubert_D = tensor_slice.get_shape() # full shape [K, N] | |
| seq_hubert = tensor_slice[start // 2 : start // 2 + L // 2] | |
| return (seq, seq_hubert) | |
| return (seq, None) | |
| class SafeTensorDataModule(ptl.LightningDataModule): | |
| """ | |
| LightningDataModule using raw .safetensors file list + get_slice inside Dataset. | |
| """ | |
| def __init__( | |
| self, | |
| train_file_list: str, | |
| val_file_list: str | None = None, | |
| hubert_path: str | None = None, | |
| key: str = "audio_z", | |
| hubert_key: str = "layer_9", | |
| val_split: float = 0.1, | |
| batch_size: int = 32, | |
| num_workers: int = 4, | |
| shuffle: bool = True, | |
| seed: int = 1234, | |
| min_length: int = 1, | |
| subseq_length: int | None = None, | |
| ): | |
| super().__init__() | |
| self.train_file_list = train_file_list | |
| self.val_file_list = val_file_list | |
| self.hubert_path = hubert_path | |
| self.key = key | |
| self.val_split = val_split | |
| self.batch_size = batch_size | |
| self.num_workers = num_workers | |
| self.shuffle = shuffle | |
| self.seed = seed | |
| self.min_length = min_length | |
| self.subseq_length = subseq_length | |
| def setup(self, stage=None): | |
| with open(self.train_file_list, "r") as f: | |
| train_paths = [line.strip() for line in f if line.strip()] | |
| val_paths = None | |
| if self.val_file_list is not None: | |
| with open(self.train_file_list, "r") as f: | |
| val_paths = [line.strip() for line in f if line.strip()] | |
| # Split into train/val | |
| if ( | |
| isinstance(self.val_split, float) | |
| and 0 < self.val_split < 1 | |
| and val_paths is None | |
| ): | |
| train_paths, val_paths = train_test_split( | |
| train_paths, test_size=self.val_split, random_state=self.seed | |
| ) | |
| self.train_ds = SafeTensorDataset( | |
| train_paths, | |
| key=self.key, | |
| min_length=self.min_length, | |
| subseq_length=self.subseq_length, | |
| hubert_path=self.hubert_path, | |
| ) | |
| self.val_ds = SafeTensorDataset( | |
| val_paths, | |
| key=self.key, | |
| min_length=self.min_length, | |
| subseq_length=self.subseq_length, | |
| ) | |
| def _collate_fn( | |
| self, batch: list[torch.Tensor | None] | |
| ) -> tuple[torch.Tensor, torch.BoolTensor]: | |
| seqs = [s for s in batch if s is not None] | |
| return pad_tensor_list_raw(seqs, pad_idx=0) | |
| def train_dataloader(self): | |
| return DataLoader( | |
| self.train_ds, | |
| batch_size=self.batch_size, | |
| shuffle=self.shuffle, | |
| num_workers=self.num_workers, | |
| collate_fn=self._collate_fn, | |
| ) | |
| def val_dataloader(self): | |
| return DataLoader( | |
| self.val_ds, | |
| batch_size=self.batch_size, | |
| shuffle=False, | |
| num_workers=self.num_workers, | |
| collate_fn=self._collate_fn, | |
| ) | |