Spaces:
Running
on
Zero
Running
on
Zero
| from dataclasses import dataclass | |
| import numpy as np | |
| import torch | |
| import torchaudio | |
| from pytorch_lightning import LightningDataModule | |
| from torch.utils.data import DataLoader, Dataset | |
| torch.set_num_threads(1) | |
| class WavVAEDataConfig: | |
| filelist_path: str | |
| sampling_rate: int | |
| num_samples: int | |
| batch_size: int | |
| num_workers: int | |
| class WavVAEDataModule(LightningDataModule): | |
| def __init__(self, train_params: WavVAEDataConfig, val_params: WavVAEDataConfig): | |
| super().__init__() | |
| self.train_config = train_params | |
| self.val_config = val_params | |
| def _get_dataloder(self, cfg: DataConfig, train: bool): | |
| dataset = WavVAEDataset(cfg, train=train) | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=cfg.batch_size, | |
| num_workers=cfg.num_workers, | |
| shuffle=train, | |
| pin_memory=True, | |
| ) | |
| return dataloader | |
| def train_dataloader(self) -> DataLoader: | |
| return self._get_dataloder(self.train_config, train=True) | |
| def val_dataloader(self) -> DataLoader: | |
| return self._get_dataloder(self.val_config, train=False) | |
| class WavVAEDataset(Dataset): | |
| def __init__(self, cfg: DataConfig, train: bool): | |
| with open(cfg.filelist_path) as f: | |
| self.filelist = f.read().splitlines() | |
| self.sampling_rate = cfg.sampling_rate | |
| self.num_samples = cfg.num_samples | |
| self.train = train | |
| def __len__(self) -> int: | |
| return len(self.filelist) | |
| def __getitem__(self, index: int) -> torch.Tensor: | |
| audio_path = self.filelist[index] | |
| y, sr = torchaudio.load(audio_path) | |
| if y.size(0) > 1: | |
| # mix to mono | |
| y = y.mean(dim=0, keepdim=True) | |
| gain = np.random.uniform(-1, -6) if self.train else -3 | |
| y, _ = torchaudio.sox_effects.apply_effects_tensor( | |
| y, sr, [["norm", f"{gain:.2f}"]] | |
| ) | |
| try: | |
| if sr != self.sampling_rate: | |
| y = torchaudio.functional.resample( | |
| y, orig_freq=sr, new_freq=self.sampling_rate | |
| ) | |
| except: | |
| print(audio_path, y.shape) | |
| if y.size(-1) < self.num_samples: | |
| pad_length = self.num_samples - y.size(-1) | |
| padding_tensor = y.repeat(1, 1 + pad_length // y.size(-1)) | |
| y = torch.cat((y, padding_tensor[:, :pad_length]), dim=1) | |
| elif self.train: | |
| start = np.random.randint(low=0, high=y.size(-1) - self.num_samples + 1) | |
| y = y[:, start : start + self.num_samples] | |
| else: | |
| # During validation, take always the first segment for determinism | |
| y = y[:, : self.num_samples] | |
| return y[0] | |