PicoAudio2 / utils /audio.py
rookie9's picture
Upload 77 files
f582ec6 verified
import torch
import torch.nn as nn
import torchaudio
class PadCrop(nn.Module):
def __init__(self, n_samples, randomize=True):
super().__init__()
self.n_samples = n_samples
self.randomize = randomize
def __call__(self, signal):
n, s = signal.shape
start = 0 if (
not self.randomize
) else torch.randint(0,
max(0, s - self.n_samples) + 1, []).item()
end = start + self.n_samples
output = signal.new_zeros([n, self.n_samples])
output[:, :min(s, self.n_samples)] = signal[:, start:end]
return output
def set_audio_channels(audio, target_channels):
if target_channels == 1:
# Convert to mono
audio = audio.mean(1, keepdim=True)
elif target_channels == 2:
# Convert to stereo
if audio.shape[1] == 1:
audio = audio.repeat(1, 2, 1)
elif audio.shape[1] > 2:
audio = audio[:, :2, :]
return audio
def prepare_audio(
audio, in_sr, target_sr, target_length, target_channels, device
):
audio = audio.to(device)
if in_sr != target_sr:
resample_tf = torchaudio.transforms.Resample(in_sr,
target_sr).to(device)
audio = resample_tf(audio)
audio = PadCrop(target_length, randomize=False)(audio)
# Add batch dimension
if audio.dim() == 1:
audio = audio.unsqueeze(0).unsqueeze(0)
elif audio.dim() == 2:
audio = audio.unsqueeze(0)
audio = set_audio_channels(audio, target_channels)
return audio