pardi-speech / codec /datamodules.py
Mehdi Lakbar
Initial demo of Lina-speech (pardi-speech)
56cfa73
import itertools
import random
import time
from dataclasses import dataclass
from functools import partial
from pathlib import Path
import numpy as np
import pytorch_lightning as ptl
import torch
import torchaudio
from safetensors.torch import safe_open
from sklearn.model_selection import train_test_split
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset, load_from_disk
@dataclass
class WavVAEDataConfig:
filelist_path: str
sampling_rate: int
num_samples: int
batch_size: int
num_workers: int
class WavVAEDataModule(ptl.LightningDataModule):
def __init__(self, train_params: WavVAEDataConfig, val_params: WavVAEDataConfig):
super().__init__()
self.train_config = train_params
self.val_config = val_params
def _get_dataloder(self, cfg: WavVAEDataConfig, train: bool):
dataset = WavVAEDataset(cfg, train=train)
dataloader = DataLoader(
dataset,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
shuffle=train,
pin_memory=True,
)
return dataloader
def train_dataloader(self) -> DataLoader:
return self._get_dataloder(self.train_config, train=True)
def val_dataloader(self) -> DataLoader:
return self._get_dataloder(self.val_config, train=False)
class WavVAEDataset(Dataset):
def __init__(self, cfg: WavVAEDataConfig, train: bool):
with open(cfg.filelist_path) as f:
self.filelist = f.read().splitlines()
self.sampling_rate = cfg.sampling_rate
self.num_samples = cfg.num_samples
self.train = train
def __len__(self) -> int:
return len(self.filelist)
def __getitem__(self, index: int) -> torch.Tensor:
audio_path = self.filelist[index]
y, sr = torchaudio.load(audio_path)
if y.size(0) > 1:
# mix to mono
y = y.mean(dim=0, keepdim=True)
gain = np.random.uniform(-1, -6) if self.train else -3
y, _ = torchaudio.sox_effects.apply_effects_tensor(
y, sr, [["norm", f"{gain:.2f}"]]
)
if sr != self.sampling_rate:
y = torchaudio.functional.resample(
y, orig_freq=sr, new_freq=self.sampling_rate
)
if y.size(-1) < self.num_samples:
pad_length = self.num_samples - y.size(-1)
padding_tensor = y.repeat(1, 1 + pad_length // y.size(-1))
y = torch.cat((y, padding_tensor[:, :pad_length]), dim=1)
elif self.train:
start = np.random.randint(low=0, high=y.size(-1) - self.num_samples + 1)
y = y[:, start : start + self.num_samples]
else:
# During validation, take always the first segment for determinism
y = y[:, : self.num_samples]
return y[0]
def pad_tensor_list_raw(
tensor_list: list[tuple[torch.Tensor, torch.Tensor]], pad_idx: int = 0
) -> dict[str, torch.Tensor | None]:
audio, hubert_maybe = zip(*tensor_list)
audio = torch.cat(audio, dim=0)
if hubert_maybe[0] is not None:
hubert_maybe = torch.stack(hubert_maybe, dim=0)
else:
hubert_maybe = None
return {"audio_z": audio, "hubert": hubert_maybe}
class SafeTensorDataset(Dataset):
"""
On __getitem__, opens the safetensor, uses get_slice() to inspect shape,
then either drops too-short files (return None) or returns a random subsequence slice.
"""
def __init__(
self,
file_paths: list[str],
key: str,
hubert_path: str | None = None,
hubert_key: str = "layer_9",
min_length: int = 1,
subseq_length: int | None = None,
):
self.file_paths = file_paths
self.key = key
self.min_length = min_length
self.subseq_length = subseq_length
self.hubert_path = hubert_path
self.hubert_key = hubert_key
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx: int) -> torch.Tensor | None:
path = self.file_paths[idx]
# open file, get a slice wrapper for full tensor
with safe_open(path, framework="pt") as f:
tensor_slice = f.get_slice(self.key)
Q, N, D = tensor_slice.get_shape() # full shape [K, N]
# drop too-short
if N < self.min_length:
return None
L = self.subseq_length or N
if L < N:
# sample random start
start = torch.randint(0, max(1, N - L - 1), ()).item()
start -= start % 2
# this yields a torch.Tensor of shape [K, L]
seq = tensor_slice[:, start : start + L]
else:
# full length
start = 0
seq = tensor_slice[:, :]
if self.hubert_path is not None:
path = Path(self.hubert_path) / Path(path).name
with safe_open(path, framework="pt") as f:
tensor_slice = f.get_slice(self.hubert_key)
hubert_N, hubert_D = tensor_slice.get_shape() # full shape [K, N]
seq_hubert = tensor_slice[start // 2 : start // 2 + L // 2]
return (seq, seq_hubert)
return (seq, None)
class SafeTensorDataModule(ptl.LightningDataModule):
"""
LightningDataModule using raw .safetensors file list + get_slice inside Dataset.
"""
def __init__(
self,
train_file_list: str,
val_file_list: str | None = None,
hubert_path: str | None = None,
key: str = "audio_z",
hubert_key: str = "layer_9",
val_split: float = 0.1,
batch_size: int = 32,
num_workers: int = 4,
shuffle: bool = True,
seed: int = 1234,
min_length: int = 1,
subseq_length: int | None = None,
):
super().__init__()
self.train_file_list = train_file_list
self.val_file_list = val_file_list
self.hubert_path = hubert_path
self.key = key
self.val_split = val_split
self.batch_size = batch_size
self.num_workers = num_workers
self.shuffle = shuffle
self.seed = seed
self.min_length = min_length
self.subseq_length = subseq_length
def setup(self, stage=None):
with open(self.train_file_list, "r") as f:
train_paths = [line.strip() for line in f if line.strip()]
val_paths = None
if self.val_file_list is not None:
with open(self.train_file_list, "r") as f:
val_paths = [line.strip() for line in f if line.strip()]
# Split into train/val
if (
isinstance(self.val_split, float)
and 0 < self.val_split < 1
and val_paths is None
):
train_paths, val_paths = train_test_split(
train_paths, test_size=self.val_split, random_state=self.seed
)
self.train_ds = SafeTensorDataset(
train_paths,
key=self.key,
min_length=self.min_length,
subseq_length=self.subseq_length,
hubert_path=self.hubert_path,
)
self.val_ds = SafeTensorDataset(
val_paths,
key=self.key,
min_length=self.min_length,
subseq_length=self.subseq_length,
)
def _collate_fn(
self, batch: list[torch.Tensor | None]
) -> tuple[torch.Tensor, torch.BoolTensor]:
seqs = [s for s in batch if s is not None]
return pad_tensor_list_raw(seqs, pad_idx=0)
def train_dataloader(self):
return DataLoader(
self.train_ds,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
collate_fn=self._collate_fn,
)
def val_dataloader(self):
return DataLoader(
self.val_ds,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
collate_fn=self._collate_fn,
)