Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torchaudio | |
| class PadCrop(nn.Module): | |
| def __init__(self, n_samples, randomize=True): | |
| super().__init__() | |
| self.n_samples = n_samples | |
| self.randomize = randomize | |
| def __call__(self, signal): | |
| n, s = signal.shape | |
| start = 0 if ( | |
| not self.randomize | |
| ) else torch.randint(0, | |
| max(0, s - self.n_samples) + 1, []).item() | |
| end = start + self.n_samples | |
| output = signal.new_zeros([n, self.n_samples]) | |
| output[:, :min(s, self.n_samples)] = signal[:, start:end] | |
| return output | |
| def set_audio_channels(audio, target_channels): | |
| if target_channels == 1: | |
| # Convert to mono | |
| audio = audio.mean(1, keepdim=True) | |
| elif target_channels == 2: | |
| # Convert to stereo | |
| if audio.shape[1] == 1: | |
| audio = audio.repeat(1, 2, 1) | |
| elif audio.shape[1] > 2: | |
| audio = audio[:, :2, :] | |
| return audio | |
| def prepare_audio( | |
| audio, in_sr, target_sr, target_length, target_channels, device | |
| ): | |
| audio = audio.to(device) | |
| if in_sr != target_sr: | |
| resample_tf = torchaudio.transforms.Resample(in_sr, | |
| target_sr).to(device) | |
| audio = resample_tf(audio) | |
| audio = PadCrop(target_length, randomize=False)(audio) | |
| # Add batch dimension | |
| if audio.dim() == 1: | |
| audio = audio.unsqueeze(0).unsqueeze(0) | |
| elif audio.dim() == 2: | |
| audio = audio.unsqueeze(0) | |
| audio = set_audio_channels(audio, target_channels) | |
| return audio | |