import itertools import random import time from dataclasses import dataclass from functools import partial from pathlib import Path import numpy as np import pytorch_lightning as ptl import torch import torchaudio from safetensors.torch import safe_open from sklearn.model_selection import train_test_split from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader, Dataset from datasets import load_dataset, load_from_disk @dataclass class WavVAEDataConfig: filelist_path: str sampling_rate: int num_samples: int batch_size: int num_workers: int class WavVAEDataModule(ptl.LightningDataModule): def __init__(self, train_params: WavVAEDataConfig, val_params: WavVAEDataConfig): super().__init__() self.train_config = train_params self.val_config = val_params def _get_dataloder(self, cfg: WavVAEDataConfig, train: bool): dataset = WavVAEDataset(cfg, train=train) dataloader = DataLoader( dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers, shuffle=train, pin_memory=True, ) return dataloader def train_dataloader(self) -> DataLoader: return self._get_dataloder(self.train_config, train=True) def val_dataloader(self) -> DataLoader: return self._get_dataloder(self.val_config, train=False) class WavVAEDataset(Dataset): def __init__(self, cfg: WavVAEDataConfig, train: bool): with open(cfg.filelist_path) as f: self.filelist = f.read().splitlines() self.sampling_rate = cfg.sampling_rate self.num_samples = cfg.num_samples self.train = train def __len__(self) -> int: return len(self.filelist) def __getitem__(self, index: int) -> torch.Tensor: audio_path = self.filelist[index] y, sr = torchaudio.load(audio_path) if y.size(0) > 1: # mix to mono y = y.mean(dim=0, keepdim=True) gain = np.random.uniform(-1, -6) if self.train else -3 y, _ = torchaudio.sox_effects.apply_effects_tensor( y, sr, [["norm", f"{gain:.2f}"]] ) if sr != self.sampling_rate: y = torchaudio.functional.resample( y, orig_freq=sr, new_freq=self.sampling_rate ) if y.size(-1) < self.num_samples: pad_length = self.num_samples - y.size(-1) padding_tensor = y.repeat(1, 1 + pad_length // y.size(-1)) y = torch.cat((y, padding_tensor[:, :pad_length]), dim=1) elif self.train: start = np.random.randint(low=0, high=y.size(-1) - self.num_samples + 1) y = y[:, start : start + self.num_samples] else: # During validation, take always the first segment for determinism y = y[:, : self.num_samples] return y[0] def pad_tensor_list_raw( tensor_list: list[tuple[torch.Tensor, torch.Tensor]], pad_idx: int = 0 ) -> dict[str, torch.Tensor | None]: audio, hubert_maybe = zip(*tensor_list) audio = torch.cat(audio, dim=0) if hubert_maybe[0] is not None: hubert_maybe = torch.stack(hubert_maybe, dim=0) else: hubert_maybe = None return {"audio_z": audio, "hubert": hubert_maybe} class SafeTensorDataset(Dataset): """ On __getitem__, opens the safetensor, uses get_slice() to inspect shape, then either drops too-short files (return None) or returns a random subsequence slice. """ def __init__( self, file_paths: list[str], key: str, hubert_path: str | None = None, hubert_key: str = "layer_9", min_length: int = 1, subseq_length: int | None = None, ): self.file_paths = file_paths self.key = key self.min_length = min_length self.subseq_length = subseq_length self.hubert_path = hubert_path self.hubert_key = hubert_key def __len__(self): return len(self.file_paths) def __getitem__(self, idx: int) -> torch.Tensor | None: path = self.file_paths[idx] # open file, get a slice wrapper for full tensor with safe_open(path, framework="pt") as f: tensor_slice = f.get_slice(self.key) Q, N, D = tensor_slice.get_shape() # full shape [K, N] # drop too-short if N < self.min_length: return None L = self.subseq_length or N if L < N: # sample random start start = torch.randint(0, max(1, N - L - 1), ()).item() start -= start % 2 # this yields a torch.Tensor of shape [K, L] seq = tensor_slice[:, start : start + L] else: # full length start = 0 seq = tensor_slice[:, :] if self.hubert_path is not None: path = Path(self.hubert_path) / Path(path).name with safe_open(path, framework="pt") as f: tensor_slice = f.get_slice(self.hubert_key) hubert_N, hubert_D = tensor_slice.get_shape() # full shape [K, N] seq_hubert = tensor_slice[start // 2 : start // 2 + L // 2] return (seq, seq_hubert) return (seq, None) class SafeTensorDataModule(ptl.LightningDataModule): """ LightningDataModule using raw .safetensors file list + get_slice inside Dataset. """ def __init__( self, train_file_list: str, val_file_list: str | None = None, hubert_path: str | None = None, key: str = "audio_z", hubert_key: str = "layer_9", val_split: float = 0.1, batch_size: int = 32, num_workers: int = 4, shuffle: bool = True, seed: int = 1234, min_length: int = 1, subseq_length: int | None = None, ): super().__init__() self.train_file_list = train_file_list self.val_file_list = val_file_list self.hubert_path = hubert_path self.key = key self.val_split = val_split self.batch_size = batch_size self.num_workers = num_workers self.shuffle = shuffle self.seed = seed self.min_length = min_length self.subseq_length = subseq_length def setup(self, stage=None): with open(self.train_file_list, "r") as f: train_paths = [line.strip() for line in f if line.strip()] val_paths = None if self.val_file_list is not None: with open(self.train_file_list, "r") as f: val_paths = [line.strip() for line in f if line.strip()] # Split into train/val if ( isinstance(self.val_split, float) and 0 < self.val_split < 1 and val_paths is None ): train_paths, val_paths = train_test_split( train_paths, test_size=self.val_split, random_state=self.seed ) self.train_ds = SafeTensorDataset( train_paths, key=self.key, min_length=self.min_length, subseq_length=self.subseq_length, hubert_path=self.hubert_path, ) self.val_ds = SafeTensorDataset( val_paths, key=self.key, min_length=self.min_length, subseq_length=self.subseq_length, ) def _collate_fn( self, batch: list[torch.Tensor | None] ) -> tuple[torch.Tensor, torch.BoolTensor]: seqs = [s for s in batch if s is not None] return pad_tensor_list_raw(seqs, pad_idx=0) def train_dataloader(self): return DataLoader( self.train_ds, batch_size=self.batch_size, shuffle=self.shuffle, num_workers=self.num_workers, collate_fn=self._collate_fn, ) def val_dataloader(self): return DataLoader( self.val_ds, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, collate_fn=self._collate_fn, )