diff --git a/audiotools/__init__.py b/audiotools/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..573ffd06100ad72614df9363b12cda6672f1b70e
--- /dev/null
+++ b/audiotools/__init__.py
@@ -0,0 +1,10 @@
+__version__ = "0.7.3"
+from .core import AudioSignal
+from .core import STFTParams
+from .core import Meter
+from .core import util
+from . import metrics
+from . import data
+from . import ml
+from .data import datasets
+from .data import transforms
diff --git a/audiotools/core/__init__.py b/audiotools/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8660c4e67f43d0ded584a38939425e2c28d95cd3
--- /dev/null
+++ b/audiotools/core/__init__.py
@@ -0,0 +1,4 @@
+from . import util
+from .audio_signal import AudioSignal
+from .audio_signal import STFTParams
+from .loudness import Meter
diff --git a/audiotools/core/audio_signal.py b/audiotools/core/audio_signal.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb6d751cb968a003656e3e7874c487b83d94c82e
--- /dev/null
+++ b/audiotools/core/audio_signal.py
@@ -0,0 +1,1682 @@
+import copy
+import functools
+import hashlib
+import math
+import pathlib
+import tempfile
+import typing
+import warnings
+from collections import namedtuple
+from pathlib import Path
+
+import julius
+import numpy as np
+import soundfile
+import torch
+
+from . import util
+from .display import DisplayMixin
+from .dsp import DSPMixin
+from .effects import EffectMixin
+from .effects import ImpulseResponseMixin
+from .ffmpeg import FFMPEGMixin
+from .loudness import LoudnessMixin
+from .playback import PlayMixin
+from .whisper import WhisperMixin
+
+
+STFTParams = namedtuple(
+ "STFTParams",
+ ["window_length", "hop_length", "window_type", "match_stride", "padding_type"],
+)
+"""
+STFTParams object is a container that holds STFT parameters - window_length,
+hop_length, and window_type. Not all parameters need to be specified. Ones that
+are not specified will be inferred by the AudioSignal parameters.
+
+Parameters
+----------
+window_length : int, optional
+ Window length of STFT, by default ``0.032 * self.sample_rate``.
+hop_length : int, optional
+ Hop length of STFT, by default ``window_length // 4``.
+window_type : str, optional
+ Type of window to use, by default ``sqrt\_hann``.
+match_stride : bool, optional
+ Whether to match the stride of convolutional layers, by default False
+padding_type : str, optional
+ Type of padding to use, by default 'reflect'
+"""
+STFTParams.__new__.__defaults__ = (None, None, None, None, None)
+
+
+class AudioSignal(
+ EffectMixin,
+ LoudnessMixin,
+ PlayMixin,
+ ImpulseResponseMixin,
+ DSPMixin,
+ DisplayMixin,
+ FFMPEGMixin,
+ WhisperMixin,
+):
+ """This is the core object of this library. Audio is always
+ loaded into an AudioSignal, which then enables all the features
+ of this library, including audio augmentations, I/O, playback,
+ and more.
+
+ The structure of this object is that the base functionality
+ is defined in ``core/audio_signal.py``, while extensions to
+ that functionality are defined in the other ``core/*.py``
+ files. For example, all the display-based functionality
+ (e.g. plot spectrograms, waveforms, write to tensorboard)
+ are in ``core/display.py``.
+
+ Parameters
+ ----------
+ audio_path_or_array : typing.Union[torch.Tensor, str, Path, np.ndarray]
+ Object to create AudioSignal from. Can be a tensor, numpy array,
+ or a path to a file. The file is always reshaped to
+ sample_rate : int, optional
+ Sample rate of the audio. If different from underlying file, resampling is
+ performed. If passing in an array or tensor, this must be defined,
+ by default None
+ stft_params : STFTParams, optional
+ Parameters of STFT to use. , by default None
+ offset : float, optional
+ Offset in seconds to read from file, by default 0
+ duration : float, optional
+ Duration in seconds to read from file, by default None
+ device : str, optional
+ Device to load audio onto, by default None
+
+ Examples
+ --------
+ Loading an AudioSignal from an array, at a sample rate of
+ 44100.
+
+ >>> signal = AudioSignal(torch.randn(5*44100), 44100)
+
+ Note, the signal is reshaped to have a batch size, and one
+ audio channel:
+
+ >>> print(signal.shape)
+ (1, 1, 44100)
+
+ You can treat AudioSignals like tensors, and many of the same
+ functions you might use on tensors are defined for AudioSignals
+ as well:
+
+ >>> signal.to("cuda")
+ >>> signal.cuda()
+ >>> signal.clone()
+ >>> signal.detach()
+
+ Indexing AudioSignals returns an AudioSignal:
+
+ >>> signal[..., 3*44100:4*44100]
+
+ The above signal is 1 second long, and is also an AudioSignal.
+ """
+
+ def __init__(
+ self,
+ audio_path_or_array: typing.Union[torch.Tensor, str, Path, np.ndarray],
+ sample_rate: int = None,
+ stft_params: STFTParams = None,
+ offset: float = 0,
+ duration: float = None,
+ device: str = None,
+ ):
+ audio_path = None
+ audio_array = None
+
+ if isinstance(audio_path_or_array, str):
+ audio_path = audio_path_or_array
+ elif isinstance(audio_path_or_array, pathlib.Path):
+ audio_path = audio_path_or_array
+ elif isinstance(audio_path_or_array, np.ndarray):
+ audio_array = audio_path_or_array
+ elif torch.is_tensor(audio_path_or_array):
+ audio_array = audio_path_or_array
+ else:
+ raise ValueError(
+ "audio_path_or_array must be either a Path, "
+ "string, numpy array, or torch Tensor!"
+ )
+
+ self.path_to_file = None
+
+ self.audio_data = None
+ self.sources = None # List of AudioSignal objects.
+ self.stft_data = None
+ if audio_path is not None:
+ self.load_from_file(
+ audio_path, offset=offset, duration=duration, device=device
+ )
+ elif audio_array is not None:
+ assert sample_rate is not None, "Must set sample rate!"
+ self.load_from_array(audio_array, sample_rate, device=device)
+
+ self.window = None
+ self.stft_params = stft_params
+
+ self.metadata = {
+ "offset": offset,
+ "duration": duration,
+ }
+
+ @property
+ def path_to_input_file(
+ self,
+ ):
+ """
+ Path to input file, if it exists.
+ Alias to ``path_to_file`` for backwards compatibility
+ """
+ return self.path_to_file
+
+ @classmethod
+ def excerpt(
+ cls,
+ audio_path: typing.Union[str, Path],
+ offset: float = None,
+ duration: float = None,
+ state: typing.Union[np.random.RandomState, int] = None,
+ **kwargs,
+ ):
+ """Randomly draw an excerpt of ``duration`` seconds from an
+ audio file specified at ``audio_path``, between ``offset`` seconds
+ and end of file. ``state`` can be used to seed the random draw.
+
+ Parameters
+ ----------
+ audio_path : typing.Union[str, Path]
+ Path to audio file to grab excerpt from.
+ offset : float, optional
+ Lower bound for the start time, in seconds drawn from
+ the file, by default None.
+ duration : float, optional
+ Duration of excerpt, in seconds, by default None
+ state : typing.Union[np.random.RandomState, int], optional
+ RandomState or seed of random state, by default None
+
+ Returns
+ -------
+ AudioSignal
+ AudioSignal containing excerpt.
+
+ Examples
+ --------
+ >>> signal = AudioSignal.excerpt("path/to/audio", duration=5)
+ """
+ info = util.info(audio_path)
+ total_duration = info.duration
+
+ state = util.random_state(state)
+ lower_bound = 0 if offset is None else offset
+ upper_bound = max(total_duration - duration, 0)
+ offset = state.uniform(lower_bound, upper_bound)
+
+ signal = cls(audio_path, offset=offset, duration=duration, **kwargs)
+ signal.metadata["offset"] = offset
+ signal.metadata["duration"] = duration
+
+ return signal
+
+ @classmethod
+ def salient_excerpt(
+ cls,
+ audio_path: typing.Union[str, Path],
+ loudness_cutoff: float = None,
+ num_tries: int = 8,
+ state: typing.Union[np.random.RandomState, int] = None,
+ **kwargs,
+ ):
+ """Similar to AudioSignal.excerpt, except it extracts excerpts only
+ if they are above a specified loudness threshold, which is computed via
+ a fast LUFS routine.
+
+ Parameters
+ ----------
+ audio_path : typing.Union[str, Path]
+ Path to audio file to grab excerpt from.
+ loudness_cutoff : float, optional
+ Loudness threshold in dB. Typical values are ``-40, -60``,
+ etc, by default None
+ num_tries : int, optional
+ Number of tries to grab an excerpt above the threshold
+ before giving up, by default 8.
+ state : typing.Union[np.random.RandomState, int], optional
+ RandomState or seed of random state, by default None
+ kwargs : dict
+ Keyword arguments to AudioSignal.excerpt
+
+ Returns
+ -------
+ AudioSignal
+ AudioSignal containing excerpt.
+
+
+ .. warning::
+ if ``num_tries`` is set to None, ``salient_excerpt`` may try forever, which can
+ result in an infinite loop if ``audio_path`` does not have
+ any loud enough excerpts.
+
+ Examples
+ --------
+ >>> signal = AudioSignal.salient_excerpt(
+ "path/to/audio",
+ loudness_cutoff=-40,
+ duration=5
+ )
+ """
+ state = util.random_state(state)
+ if loudness_cutoff is None:
+ excerpt = cls.excerpt(audio_path, state=state, **kwargs)
+ else:
+ loudness = -np.inf
+ num_try = 0
+ while loudness <= loudness_cutoff:
+ excerpt = cls.excerpt(audio_path, state=state, **kwargs)
+ loudness = excerpt.loudness()
+ num_try += 1
+ if num_tries is not None and num_try >= num_tries:
+ break
+ return excerpt
+
+ @classmethod
+ def zeros(
+ cls,
+ duration: float,
+ sample_rate: int,
+ num_channels: int = 1,
+ batch_size: int = 1,
+ **kwargs,
+ ):
+ """Helper function create an AudioSignal of all zeros.
+
+ Parameters
+ ----------
+ duration : float
+ Duration of AudioSignal
+ sample_rate : int
+ Sample rate of AudioSignal
+ num_channels : int, optional
+ Number of channels, by default 1
+ batch_size : int, optional
+ Batch size, by default 1
+
+ Returns
+ -------
+ AudioSignal
+ AudioSignal containing all zeros.
+
+ Examples
+ --------
+ Generate 5 seconds of all zeros at a sample rate of 44100.
+
+ >>> signal = AudioSignal.zeros(5.0, 44100)
+ """
+ n_samples = int(duration * sample_rate)
+ return cls(
+ torch.zeros(batch_size, num_channels, n_samples), sample_rate, **kwargs
+ )
+
+ @classmethod
+ def wave(
+ cls,
+ frequency: float,
+ duration: float,
+ sample_rate: int,
+ num_channels: int = 1,
+ shape: str = "sine",
+ **kwargs,
+ ):
+ """
+ Generate a waveform of a given frequency and shape.
+
+ Parameters
+ ----------
+ frequency : float
+ Frequency of the waveform
+ duration : float
+ Duration of the waveform
+ sample_rate : int
+ Sample rate of the waveform
+ num_channels : int, optional
+ Number of channels, by default 1
+ shape : str, optional
+ Shape of the waveform, by default "saw"
+ One of "sawtooth", "square", "sine", "triangle"
+ kwargs : dict
+ Keyword arguments to AudioSignal
+ """
+ n_samples = int(duration * sample_rate)
+ t = torch.linspace(0, duration, n_samples)
+ if shape == "sawtooth":
+ from scipy.signal import sawtooth
+
+ wave_data = sawtooth(2 * np.pi * frequency * t, 0.5)
+ elif shape == "square":
+ from scipy.signal import square
+
+ wave_data = square(2 * np.pi * frequency * t)
+ elif shape == "sine":
+ wave_data = np.sin(2 * np.pi * frequency * t)
+ elif shape == "triangle":
+ from scipy.signal import sawtooth
+
+ # frequency is doubled by the abs call, so omit the 2 in 2pi
+ wave_data = sawtooth(np.pi * frequency * t, 0.5)
+ wave_data = -np.abs(wave_data) * 2 + 1
+ else:
+ raise ValueError(f"Invalid shape {shape}")
+
+ wave_data = torch.tensor(wave_data, dtype=torch.float32)
+ wave_data = wave_data.unsqueeze(0).unsqueeze(0).repeat(1, num_channels, 1)
+ return cls(wave_data, sample_rate, **kwargs)
+
+ @classmethod
+ def batch(
+ cls,
+ audio_signals: list,
+ pad_signals: bool = False,
+ truncate_signals: bool = False,
+ resample: bool = False,
+ dim: int = 0,
+ ):
+ """Creates a batched AudioSignal from a list of AudioSignals.
+
+ Parameters
+ ----------
+ audio_signals : list[AudioSignal]
+ List of AudioSignal objects
+ pad_signals : bool, optional
+ Whether to pad signals to length of the maximum length
+ AudioSignal in the list, by default False
+ truncate_signals : bool, optional
+ Whether to truncate signals to length of shortest length
+ AudioSignal in the list, by default False
+ resample : bool, optional
+ Whether to resample AudioSignal to the sample rate of
+ the first AudioSignal in the list, by default False
+ dim : int, optional
+ Dimension along which to batch the signals.
+
+ Returns
+ -------
+ AudioSignal
+ Batched AudioSignal.
+
+ Raises
+ ------
+ RuntimeError
+ If not all AudioSignals are the same sample rate, and
+ ``resample=False``, an error is raised.
+ RuntimeError
+ If not all AudioSignals are the same the length, and
+ both ``pad_signals=False`` and ``truncate_signals=False``,
+ an error is raised.
+
+ Examples
+ --------
+ Batching a bunch of random signals:
+
+ >>> signal_list = [AudioSignal(torch.randn(44100), 44100) for _ in range(10)]
+ >>> signal = AudioSignal.batch(signal_list)
+ >>> print(signal.shape)
+ (10, 1, 44100)
+
+ """
+ signal_lengths = [x.signal_length for x in audio_signals]
+ sample_rates = [x.sample_rate for x in audio_signals]
+
+ if len(set(sample_rates)) != 1:
+ if resample:
+ for x in audio_signals:
+ x.resample(sample_rates[0])
+ else:
+ raise RuntimeError(
+ f"Not all signals had the same sample rate! Got {sample_rates}. "
+ f"All signals must have the same sample rate, or resample must be True. "
+ )
+
+ if len(set(signal_lengths)) != 1:
+ if pad_signals:
+ max_length = max(signal_lengths)
+ for x in audio_signals:
+ pad_len = max_length - x.signal_length
+ x.zero_pad(0, pad_len)
+ elif truncate_signals:
+ min_length = min(signal_lengths)
+ for x in audio_signals:
+ x.truncate_samples(min_length)
+ else:
+ raise RuntimeError(
+ f"Not all signals had the same length! Got {signal_lengths}. "
+ f"All signals must be the same length, or pad_signals/truncate_signals "
+ f"must be True. "
+ )
+ # Concatenate along the specified dimension (default 0)
+ audio_data = torch.cat([x.audio_data for x in audio_signals], dim=dim)
+ audio_paths = [x.path_to_file for x in audio_signals]
+
+ batched_signal = cls(
+ audio_data,
+ sample_rate=audio_signals[0].sample_rate,
+ )
+ batched_signal.path_to_file = audio_paths
+ return batched_signal
+
+ # I/O
+ def load_from_file(
+ self,
+ audio_path: typing.Union[str, Path],
+ offset: float,
+ duration: float,
+ device: str = "cpu",
+ ):
+ """Loads data from file. Used internally when AudioSignal
+ is instantiated with a path to a file.
+
+ Parameters
+ ----------
+ audio_path : typing.Union[str, Path]
+ Path to file
+ offset : float
+ Offset in seconds
+ duration : float
+ Duration in seconds
+ device : str, optional
+ Device to put AudioSignal on, by default "cpu"
+
+ Returns
+ -------
+ AudioSignal
+ AudioSignal loaded from file
+ """
+ import librosa
+
+ data, sample_rate = librosa.load(
+ audio_path,
+ offset=offset,
+ duration=duration,
+ sr=None,
+ mono=False,
+ )
+ data = util.ensure_tensor(data)
+ if data.shape[-1] == 0:
+ raise RuntimeError(
+ f"Audio file {audio_path} with offset {offset} and duration {duration} is empty!"
+ )
+
+ if data.ndim < 2:
+ data = data.unsqueeze(0)
+ if data.ndim < 3:
+ data = data.unsqueeze(0)
+ self.audio_data = data
+
+ self.original_signal_length = self.signal_length
+
+ self.sample_rate = sample_rate
+ self.path_to_file = audio_path
+ return self.to(device)
+
+ def load_from_array(
+ self,
+ audio_array: typing.Union[torch.Tensor, np.ndarray],
+ sample_rate: int,
+ device: str = "cpu",
+ ):
+ """Loads data from array, reshaping it to be exactly 3
+ dimensions. Used internally when AudioSignal is called
+ with a tensor or an array.
+
+ Parameters
+ ----------
+ audio_array : typing.Union[torch.Tensor, np.ndarray]
+ Array/tensor of audio of samples.
+ sample_rate : int
+ Sample rate of audio
+ device : str, optional
+ Device to move audio onto, by default "cpu"
+
+ Returns
+ -------
+ AudioSignal
+ AudioSignal loaded from array
+ """
+ audio_data = util.ensure_tensor(audio_array)
+
+ if audio_data.dtype == torch.double:
+ audio_data = audio_data.float()
+
+ if audio_data.ndim < 2:
+ audio_data = audio_data.unsqueeze(0)
+ if audio_data.ndim < 3:
+ audio_data = audio_data.unsqueeze(0)
+ self.audio_data = audio_data
+
+ self.original_signal_length = self.signal_length
+
+ self.sample_rate = sample_rate
+ return self.to(device)
+
+ def write(self, audio_path: typing.Union[str, Path]):
+ """Writes audio to a file. Only writes the audio
+ that is in the very first item of the batch. To write other items
+ in the batch, index the signal along the batch dimension
+ before writing. After writing, the signal's ``path_to_file``
+ attribute is updated to the new path.
+
+ Parameters
+ ----------
+ audio_path : typing.Union[str, Path]
+ Path to write audio to.
+
+ Returns
+ -------
+ AudioSignal
+ Returns original AudioSignal, so you can use this in a fluent
+ interface.
+
+ Examples
+ --------
+ Creating and writing a signal to disk:
+
+ >>> signal = AudioSignal(torch.randn(10, 1, 44100), 44100)
+ >>> signal.write("/tmp/out.wav")
+
+ Writing a different element of the batch:
+
+ >>> signal[5].write("/tmp/out.wav")
+
+ Using this in a fluent interface:
+
+ >>> signal.write("/tmp/original.wav").low_pass(4000).write("/tmp/lowpass.wav")
+
+ """
+ if self.audio_data[0].abs().max() > 1:
+ warnings.warn("Audio amplitude > 1 clipped when saving")
+ soundfile.write(str(audio_path), self.audio_data[0].numpy().T, self.sample_rate)
+
+ self.path_to_file = audio_path
+ return self
+
+ def deepcopy(self):
+ """Copies the signal and all of its attributes.
+
+ Returns
+ -------
+ AudioSignal
+ Deep copy of the audio signal.
+ """
+ return copy.deepcopy(self)
+
+ def copy(self):
+ """Shallow copy of signal.
+
+ Returns
+ -------
+ AudioSignal
+ Shallow copy of the audio signal.
+ """
+ return copy.copy(self)
+
+ def clone(self):
+ """Clones all tensors contained in the AudioSignal,
+ and returns a copy of the signal with everything
+ cloned. Useful when using AudioSignal within autograd
+ computation graphs.
+
+ Relevant attributes are the stft data, the audio data,
+ and the loudness of the file.
+
+ Returns
+ -------
+ AudioSignal
+ Clone of AudioSignal.
+ """
+ clone = type(self)(
+ self.audio_data.clone(),
+ self.sample_rate,
+ stft_params=self.stft_params,
+ )
+ if self.stft_data is not None:
+ clone.stft_data = self.stft_data.clone()
+ if self._loudness is not None:
+ clone._loudness = self._loudness.clone()
+ clone.path_to_file = copy.deepcopy(self.path_to_file)
+ clone.metadata = copy.deepcopy(self.metadata)
+ return clone
+
+ def detach(self):
+ """Detaches tensors contained in AudioSignal.
+
+ Relevant attributes are the stft data, the audio data,
+ and the loudness of the file.
+
+ Returns
+ -------
+ AudioSignal
+ Same signal, but with all tensors detached.
+ """
+ if self._loudness is not None:
+ self._loudness = self._loudness.detach()
+ if self.stft_data is not None:
+ self.stft_data = self.stft_data.detach()
+
+ self.audio_data = self.audio_data.detach()
+ return self
+
+ def hash(self):
+ """Writes the audio data to a temporary file, and then
+ hashes it using hashlib. Useful for creating a file
+ name based on the audio content.
+
+ Returns
+ -------
+ str
+ Hash of audio data.
+
+ Examples
+ --------
+ Creating a signal, and writing it to a unique file name:
+
+ >>> signal = AudioSignal(torch.randn(44100), 44100)
+ >>> hash = signal.hash()
+ >>> signal.write(f"{hash}.wav")
+
+ """
+ with tempfile.NamedTemporaryFile(suffix=".wav") as f:
+ self.write(f.name)
+ h = hashlib.sha256()
+ b = bytearray(128 * 1024)
+ mv = memoryview(b)
+ with open(f.name, "rb", buffering=0) as f:
+ for n in iter(lambda: f.readinto(mv), 0):
+ h.update(mv[:n])
+ file_hash = h.hexdigest()
+ return file_hash
+
+ # Signal operations
+ def to_mono(self):
+ """Converts audio data to mono audio, by taking the mean
+ along the channels dimension.
+
+ Returns
+ -------
+ AudioSignal
+ AudioSignal with mean of channels.
+ """
+ self.audio_data = self.audio_data.mean(1, keepdim=True)
+ return self
+
+ def resample(self, sample_rate: int):
+ """Resamples the audio, using sinc interpolation. This works on both
+ cpu and gpu, and is much faster on gpu.
+
+ Parameters
+ ----------
+ sample_rate : int
+ Sample rate to resample to.
+
+ Returns
+ -------
+ AudioSignal
+ Resampled AudioSignal
+ """
+ if sample_rate == self.sample_rate:
+ return self
+ self.audio_data = julius.resample_frac(
+ self.audio_data, self.sample_rate, sample_rate
+ )
+ self.sample_rate = sample_rate
+ return self
+
+ # Tensor operations
+ def to(self, device: str):
+ """Moves all tensors contained in signal to the specified device.
+
+ Parameters
+ ----------
+ device : str
+ Device to move AudioSignal onto. Typical values are
+ "cuda", "cpu", or "cuda:n" to specify the nth gpu.
+
+ Returns
+ -------
+ AudioSignal
+ AudioSignal with all tensors moved to specified device.
+ """
+ if self._loudness is not None:
+ self._loudness = self._loudness.to(device)
+ if self.stft_data is not None:
+ self.stft_data = self.stft_data.to(device)
+ if self.audio_data is not None:
+ self.audio_data = self.audio_data.to(device)
+ return self
+
+ def float(self):
+ """Calls ``.float()`` on ``self.audio_data``.
+
+ Returns
+ -------
+ AudioSignal
+ """
+ self.audio_data = self.audio_data.float()
+ return self
+
+ def cpu(self):
+ """Moves AudioSignal to cpu.
+
+ Returns
+ -------
+ AudioSignal
+ """
+ return self.to("cpu")
+
+ def cuda(self): # pragma: no cover
+ """Moves AudioSignal to cuda.
+
+ Returns
+ -------
+ AudioSignal
+ """
+ return self.to("cuda")
+
+ def numpy(self):
+ """Detaches ``self.audio_data``, moves to cpu, and converts to numpy.
+
+ Returns
+ -------
+ np.ndarray
+ Audio data as a numpy array.
+ """
+ return self.audio_data.detach().cpu().numpy()
+
+ def zero_pad(self, before: int, after: int):
+ """Zero pads the audio_data tensor before and after.
+
+ Parameters
+ ----------
+ before : int
+ How many zeros to prepend to audio.
+ after : int
+ How many zeros to append to audio.
+
+ Returns
+ -------
+ AudioSignal
+ AudioSignal with padding applied.
+ """
+ self.audio_data = torch.nn.functional.pad(self.audio_data, (before, after))
+ return self
+
+ def zero_pad_to(self, length: int, mode: str = "after"):
+ """Pad with zeros to a specified length, either before or after
+ the audio data.
+
+ Parameters
+ ----------
+ length : int
+ Length to pad to
+ mode : str, optional
+ Whether to prepend or append zeros to signal, by default "after"
+
+ Returns
+ -------
+ AudioSignal
+ AudioSignal with padding applied.
+ """
+ if mode == "before":
+ self.zero_pad(max(length - self.signal_length, 0), 0)
+ elif mode == "after":
+ self.zero_pad(0, max(length - self.signal_length, 0))
+ return self
+
+ def trim(self, before: int, after: int):
+ """Trims the audio_data tensor before and after.
+
+ Parameters
+ ----------
+ before : int
+ How many samples to trim from beginning.
+ after : int
+ How many samples to trim from end.
+
+ Returns
+ -------
+ AudioSignal
+ AudioSignal with trimming applied.
+ """
+ if after == 0:
+ self.audio_data = self.audio_data[..., before:]
+ else:
+ self.audio_data = self.audio_data[..., before:-after]
+ return self
+
+ def truncate_samples(self, length_in_samples: int):
+ """Truncate signal to specified length.
+
+ Parameters
+ ----------
+ length_in_samples : int
+ Truncate to this many samples.
+
+ Returns
+ -------
+ AudioSignal
+ AudioSignal with truncation applied.
+ """
+ self.audio_data = self.audio_data[..., :length_in_samples]
+ return self
+
+ @property
+ def device(self):
+ """Get device that AudioSignal is on.
+
+ Returns
+ -------
+ torch.device
+ Device that AudioSignal is on.
+ """
+ if self.audio_data is not None:
+ device = self.audio_data.device
+ elif self.stft_data is not None:
+ device = self.stft_data.device
+ return device
+
+ # Properties
+ @property
+ def audio_data(self):
+ """Returns the audio data tensor in the object.
+
+ Audio data is always of the shape
+ (batch_size, num_channels, num_samples). If value has less
+ than 3 dims (e.g. is (num_channels, num_samples)), then it will
+ be reshaped to (1, num_channels, num_samples) - a batch size of 1.
+
+ Parameters
+ ----------
+ data : typing.Union[torch.Tensor, np.ndarray]
+ Audio data to set.
+
+ Returns
+ -------
+ torch.Tensor
+ Audio samples.
+ """
+ return self._audio_data
+
+ @audio_data.setter
+ def audio_data(self, data: typing.Union[torch.Tensor, np.ndarray]):
+ if data is not None:
+ assert torch.is_tensor(data), "audio_data should be torch.Tensor"
+ assert data.ndim == 3, "audio_data should be 3-dim (B, C, T)"
+ self._audio_data = data
+ # Old loudness value not guaranteed to be right, reset it.
+ self._loudness = None
+ return
+
+ # alias for audio_data
+ samples = audio_data
+
+ @property
+ def stft_data(self):
+ """Returns the STFT data inside the signal. Shape is
+ (batch, channels, frequencies, time).
+
+ Returns
+ -------
+ torch.Tensor
+ Complex spectrogram data.
+ """
+ return self._stft_data
+
+ @stft_data.setter
+ def stft_data(self, data: typing.Union[torch.Tensor, np.ndarray]):
+ if data is not None:
+ assert torch.is_tensor(data) and torch.is_complex(data)
+ if self.stft_data is not None and self.stft_data.shape != data.shape:
+ warnings.warn("stft_data changed shape")
+ self._stft_data = data
+ return
+
+ @property
+ def batch_size(self):
+ """Batch size of audio signal.
+
+ Returns
+ -------
+ int
+ Batch size of signal.
+ """
+ return self.audio_data.shape[0]
+
+ @property
+ def signal_length(self):
+ """Length of audio signal.
+
+ Returns
+ -------
+ int
+ Length of signal in samples.
+ """
+ return self.audio_data.shape[-1]
+
+ # alias for signal_length
+ length = signal_length
+
+ @property
+ def shape(self):
+ """Shape of audio data.
+
+ Returns
+ -------
+ tuple
+ Shape of audio data.
+ """
+ return self.audio_data.shape
+
+ @property
+ def signal_duration(self):
+ """Length of audio signal in seconds.
+
+ Returns
+ -------
+ float
+ Length of signal in seconds.
+ """
+ return self.signal_length / self.sample_rate
+
+ # alias for signal_duration
+ duration = signal_duration
+
+ @property
+ def num_channels(self):
+ """Number of audio channels.
+
+ Returns
+ -------
+ int
+ Number of audio channels.
+ """
+ return self.audio_data.shape[1]
+
+ # STFT
+ @staticmethod
+ @functools.lru_cache(None)
+ def get_window(window_type: str, window_length: int, device: str):
+ """Wrapper around scipy.signal.get_window so one can also get the
+ popular sqrt-hann window. This function caches for efficiency
+ using functools.lru\_cache.
+
+ Parameters
+ ----------
+ window_type : str
+ Type of window to get
+ window_length : int
+ Length of the window
+ device : str
+ Device to put window onto.
+
+ Returns
+ -------
+ torch.Tensor
+ Window returned by scipy.signal.get_window, as a tensor.
+ """
+ from scipy import signal
+
+ if window_type == "average":
+ window = np.ones(window_length) / window_length
+ elif window_type == "sqrt_hann":
+ window = np.sqrt(signal.get_window("hann", window_length))
+ else:
+ window = signal.get_window(window_type, window_length)
+ window = torch.from_numpy(window).to(device).float()
+ return window
+
+ @property
+ def stft_params(self):
+ """Returns STFTParams object, which can be re-used to other
+ AudioSignals.
+
+ This property can be set as well. If values are not defined in STFTParams,
+ they are inferred automatically from the signal properties. The default is to use
+ 32ms windows, with 8ms hop length, and the square root of the hann window.
+
+ Returns
+ -------
+ STFTParams
+ STFT parameters for the AudioSignal.
+
+ Examples
+ --------
+ >>> stft_params = STFTParams(128, 32)
+ >>> signal1 = AudioSignal(torch.randn(44100), 44100, stft_params=stft_params)
+ >>> signal2 = AudioSignal(torch.randn(44100), 44100, stft_params=signal1.stft_params)
+ >>> signal1.stft_params = STFTParams() # Defaults
+ """
+ return self._stft_params
+
+ @stft_params.setter
+ def stft_params(self, value: STFTParams):
+ default_win_len = int(2 ** (np.ceil(np.log2(0.032 * self.sample_rate))))
+ default_hop_len = default_win_len // 4
+ default_win_type = "hann"
+ default_match_stride = False
+ default_padding_type = "reflect"
+
+ default_stft_params = STFTParams(
+ window_length=default_win_len,
+ hop_length=default_hop_len,
+ window_type=default_win_type,
+ match_stride=default_match_stride,
+ padding_type=default_padding_type,
+ )._asdict()
+
+ value = value._asdict() if value else default_stft_params
+
+ for key in default_stft_params:
+ if value[key] is None:
+ value[key] = default_stft_params[key]
+
+ self._stft_params = STFTParams(**value)
+ self.stft_data = None
+
+ def compute_stft_padding(
+ self, window_length: int, hop_length: int, match_stride: bool
+ ):
+ """Compute how the STFT should be padded, based on match\_stride.
+
+ Parameters
+ ----------
+ window_length : int
+ Window length of STFT.
+ hop_length : int
+ Hop length of STFT.
+ match_stride : bool
+ Whether or not to match stride, making the STFT have the same alignment as
+ convolutional layers.
+
+ Returns
+ -------
+ tuple
+ Amount to pad on either side of audio.
+ """
+ length = self.signal_length
+
+ if match_stride:
+ assert (
+ hop_length == window_length // 4
+ ), "For match_stride, hop must equal n_fft // 4"
+ right_pad = math.ceil(length / hop_length) * hop_length - length
+ pad = (window_length - hop_length) // 2
+ else:
+ right_pad = 0
+ pad = 0
+
+ return right_pad, pad
+
+ def stft(
+ self,
+ window_length: int = None,
+ hop_length: int = None,
+ window_type: str = None,
+ match_stride: bool = None,
+ padding_type: str = None,
+ ):
+ """Computes the short-time Fourier transform of the audio data,
+ with specified STFT parameters.
+
+ Parameters
+ ----------
+ window_length : int, optional
+ Window length of STFT, by default ``0.032 * self.sample_rate``.
+ hop_length : int, optional
+ Hop length of STFT, by default ``window_length // 4``.
+ window_type : str, optional
+ Type of window to use, by default ``sqrt\_hann``.
+ match_stride : bool, optional
+ Whether to match the stride of convolutional layers, by default False
+ padding_type : str, optional
+ Type of padding to use, by default 'reflect'
+
+ Returns
+ -------
+ torch.Tensor
+ STFT of audio data.
+
+ Examples
+ --------
+ Compute the STFT of an AudioSignal:
+
+ >>> signal = AudioSignal(torch.randn(44100), 44100)
+ >>> signal.stft()
+
+ Vary the window and hop length:
+
+ >>> stft_params = [STFTParams(128, 32), STFTParams(512, 128)]
+ >>> for stft_param in stft_params:
+ >>> signal.stft_params = stft_params
+ >>> signal.stft()
+
+ """
+ window_length = (
+ self.stft_params.window_length
+ if window_length is None
+ else int(window_length)
+ )
+ hop_length = (
+ self.stft_params.hop_length if hop_length is None else int(hop_length)
+ )
+ window_type = (
+ self.stft_params.window_type if window_type is None else window_type
+ )
+ match_stride = (
+ self.stft_params.match_stride if match_stride is None else match_stride
+ )
+ padding_type = (
+ self.stft_params.padding_type if padding_type is None else padding_type
+ )
+
+ window = self.get_window(window_type, window_length, self.audio_data.device)
+ window = window.to(self.audio_data.device)
+
+ audio_data = self.audio_data
+ right_pad, pad = self.compute_stft_padding(
+ window_length, hop_length, match_stride
+ )
+ audio_data = torch.nn.functional.pad(
+ audio_data, (pad, pad + right_pad), padding_type
+ )
+ stft_data = torch.stft(
+ audio_data.reshape(-1, audio_data.shape[-1]),
+ n_fft=window_length,
+ hop_length=hop_length,
+ window=window,
+ return_complex=True,
+ center=True,
+ )
+ _, nf, nt = stft_data.shape
+ stft_data = stft_data.reshape(self.batch_size, self.num_channels, nf, nt)
+
+ if match_stride:
+ # Drop first two and last two frames, which are added
+ # because of padding. Now num_frames * hop_length = num_samples.
+ stft_data = stft_data[..., 2:-2]
+ self.stft_data = stft_data
+
+ return stft_data
+
+ def istft(
+ self,
+ window_length: int = None,
+ hop_length: int = None,
+ window_type: str = None,
+ match_stride: bool = None,
+ length: int = None,
+ ):
+ """Computes inverse STFT and sets it to audio\_data.
+
+ Parameters
+ ----------
+ window_length : int, optional
+ Window length of STFT, by default ``0.032 * self.sample_rate``.
+ hop_length : int, optional
+ Hop length of STFT, by default ``window_length // 4``.
+ window_type : str, optional
+ Type of window to use, by default ``sqrt\_hann``.
+ match_stride : bool, optional
+ Whether to match the stride of convolutional layers, by default False
+ length : int, optional
+ Original length of signal, by default None
+
+ Returns
+ -------
+ AudioSignal
+ AudioSignal with istft applied.
+
+ Raises
+ ------
+ RuntimeError
+ Raises an error if stft was not called prior to istft on the signal,
+ or if stft_data is not set.
+ """
+ if self.stft_data is None:
+ raise RuntimeError("Cannot do inverse STFT without self.stft_data!")
+
+ window_length = (
+ self.stft_params.window_length
+ if window_length is None
+ else int(window_length)
+ )
+ hop_length = (
+ self.stft_params.hop_length if hop_length is None else int(hop_length)
+ )
+ window_type = (
+ self.stft_params.window_type if window_type is None else window_type
+ )
+ match_stride = (
+ self.stft_params.match_stride if match_stride is None else match_stride
+ )
+
+ window = self.get_window(window_type, window_length, self.stft_data.device)
+
+ nb, nch, nf, nt = self.stft_data.shape
+ stft_data = self.stft_data.reshape(nb * nch, nf, nt)
+ right_pad, pad = self.compute_stft_padding(
+ window_length, hop_length, match_stride
+ )
+
+ if length is None:
+ length = self.original_signal_length
+ length = length + 2 * pad + right_pad
+
+ if match_stride:
+ # Zero-pad the STFT on either side, putting back the frames that were
+ # dropped in stft().
+ stft_data = torch.nn.functional.pad(stft_data, (2, 2))
+
+ audio_data = torch.istft(
+ stft_data,
+ n_fft=window_length,
+ hop_length=hop_length,
+ window=window,
+ length=length,
+ center=True,
+ )
+ audio_data = audio_data.reshape(nb, nch, -1)
+ if match_stride:
+ audio_data = audio_data[..., pad : -(pad + right_pad)]
+ self.audio_data = audio_data
+
+ return self
+
+ @staticmethod
+ @functools.lru_cache(None)
+ def get_mel_filters(
+ sr: int, n_fft: int, n_mels: int, fmin: float = 0.0, fmax: float = None
+ ):
+ """Create a Filterbank matrix to combine FFT bins into Mel-frequency bins.
+
+ Parameters
+ ----------
+ sr : int
+ Sample rate of audio
+ n_fft : int
+ Number of FFT bins
+ n_mels : int
+ Number of mels
+ fmin : float, optional
+ Lowest frequency, in Hz, by default 0.0
+ fmax : float, optional
+ Highest frequency, by default None
+
+ Returns
+ -------
+ np.ndarray [shape=(n_mels, 1 + n_fft/2)]
+ Mel transform matrix
+ """
+ from librosa.filters import mel as librosa_mel_fn
+
+ return librosa_mel_fn(
+ sr=sr,
+ n_fft=n_fft,
+ n_mels=n_mels,
+ fmin=fmin,
+ fmax=fmax,
+ )
+
+ def mel_spectrogram(
+ self, n_mels: int = 80, mel_fmin: float = 0.0, mel_fmax: float = None, **kwargs
+ ):
+ """Computes a Mel spectrogram.
+
+ Parameters
+ ----------
+ n_mels : int, optional
+ Number of mels, by default 80
+ mel_fmin : float, optional
+ Lowest frequency, in Hz, by default 0.0
+ mel_fmax : float, optional
+ Highest frequency, by default None
+ kwargs : dict, optional
+ Keyword arguments to self.stft().
+
+ Returns
+ -------
+ torch.Tensor [shape=(batch, channels, mels, time)]
+ Mel spectrogram.
+ """
+ stft = self.stft(**kwargs)
+ magnitude = torch.abs(stft)
+
+ nf = magnitude.shape[2]
+ mel_basis = self.get_mel_filters(
+ sr=self.sample_rate,
+ n_fft=2 * (nf - 1),
+ n_mels=n_mels,
+ fmin=mel_fmin,
+ fmax=mel_fmax,
+ )
+ mel_basis = torch.from_numpy(mel_basis).to(self.device)
+
+ mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T
+ mel_spectrogram = mel_spectrogram.transpose(-1, 2)
+ return mel_spectrogram
+
+ @staticmethod
+ @functools.lru_cache(None)
+ def get_dct(n_mfcc: int, n_mels: int, norm: str = "ortho", device: str = None):
+ """Create a discrete cosine transform (DCT) transformation matrix with shape (``n_mels``, ``n_mfcc``),
+ it can be normalized depending on norm. For more information about dct:
+ http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
+
+ Parameters
+ ----------
+ n_mfcc : int
+ Number of mfccs
+ n_mels : int
+ Number of mels
+ norm : str
+ Use "ortho" to get a orthogonal matrix or None, by default "ortho"
+ device : str, optional
+ Device to load the transformation matrix on, by default None
+
+ Returns
+ -------
+ torch.Tensor [shape=(n_mels, n_mfcc)] T
+ The dct transformation matrix.
+ """
+ from torchaudio.functional import create_dct
+
+ return create_dct(n_mfcc, n_mels, norm).to(device)
+
+ def mfcc(
+ self, n_mfcc: int = 40, n_mels: int = 80, log_offset: float = 1e-6, **kwargs
+ ):
+ """Computes mel-frequency cepstral coefficients (MFCCs).
+
+ Parameters
+ ----------
+ n_mfcc : int, optional
+ Number of mels, by default 40
+ n_mels : int, optional
+ Number of mels, by default 80
+ log_offset: float, optional
+ Small value to prevent numerical issues when trying to compute log(0), by default 1e-6
+ kwargs : dict, optional
+ Keyword arguments to self.mel_spectrogram(), note that some of them will be used for self.stft()
+
+ Returns
+ -------
+ torch.Tensor [shape=(batch, channels, mfccs, time)]
+ MFCCs.
+ """
+
+ mel_spectrogram = self.mel_spectrogram(n_mels, **kwargs)
+ mel_spectrogram = torch.log(mel_spectrogram + log_offset)
+ dct_mat = self.get_dct(n_mfcc, n_mels, "ortho", self.device)
+
+ mfcc = mel_spectrogram.transpose(-1, -2) @ dct_mat
+ mfcc = mfcc.transpose(-1, -2)
+ return mfcc
+
+ @property
+ def magnitude(self):
+ """Computes and returns the absolute value of the STFT, which
+ is the magnitude. This value can also be set to some tensor.
+ When set, ``self.stft_data`` is manipulated so that its magnitude
+ matches what this is set to, and modulated by the phase.
+
+ Returns
+ -------
+ torch.Tensor
+ Magnitude of STFT.
+
+ Examples
+ --------
+ >>> signal = AudioSignal(torch.randn(44100), 44100)
+ >>> magnitude = signal.magnitude # Computes stft if not computed
+ >>> magnitude[magnitude < magnitude.mean()] = 0
+ >>> signal.magnitude = magnitude
+ >>> signal.istft()
+ """
+ if self.stft_data is None:
+ self.stft()
+ return torch.abs(self.stft_data)
+
+ @magnitude.setter
+ def magnitude(self, value):
+ self.stft_data = value * torch.exp(1j * self.phase)
+ return
+
+ def log_magnitude(
+ self, ref_value: float = 1.0, amin: float = 1e-5, top_db: float = 80.0
+ ):
+ """Computes the log-magnitude of the spectrogram.
+
+ Parameters
+ ----------
+ ref_value : float, optional
+ The magnitude is scaled relative to ``ref``: ``20 * log10(S / ref)``.
+ Zeros in the output correspond to positions where ``S == ref``,
+ by default 1.0
+ amin : float, optional
+ Minimum threshold for ``S`` and ``ref``, by default 1e-5
+ top_db : float, optional
+ Threshold the output at ``top_db`` below the peak:
+ ``max(10 * log10(S/ref)) - top_db``, by default -80.0
+
+ Returns
+ -------
+ torch.Tensor
+ Log-magnitude spectrogram
+ """
+ magnitude = self.magnitude
+
+ amin = amin**2
+ log_spec = 10.0 * torch.log10(magnitude.pow(2).clamp(min=amin))
+ log_spec -= 10.0 * np.log10(np.maximum(amin, ref_value))
+
+ if top_db is not None:
+ log_spec = torch.maximum(log_spec, log_spec.max() - top_db)
+ return log_spec
+
+ @property
+ def phase(self):
+ """Computes and returns the phase of the STFT.
+ This value can also be set to some tensor.
+ When set, ``self.stft_data`` is manipulated so that its phase
+ matches what this is set to, we original magnitudeith th.
+
+ Returns
+ -------
+ torch.Tensor
+ Phase of STFT.
+
+ Examples
+ --------
+ >>> signal = AudioSignal(torch.randn(44100), 44100)
+ >>> phase = signal.phase # Computes stft if not computed
+ >>> phase[phase < phase.mean()] = 0
+ >>> signal.phase = phase
+ >>> signal.istft()
+ """
+ if self.stft_data is None:
+ self.stft()
+ return torch.angle(self.stft_data)
+
+ @phase.setter
+ def phase(self, value):
+ self.stft_data = self.magnitude * torch.exp(1j * value)
+ return
+
+ # Operator overloading
+ def __add__(self, other):
+ new_signal = self.clone()
+ new_signal.audio_data += util._get_value(other)
+ return new_signal
+
+ def __iadd__(self, other):
+ self.audio_data += util._get_value(other)
+ return self
+
+ def __radd__(self, other):
+ return self + other
+
+ def __sub__(self, other):
+ new_signal = self.clone()
+ new_signal.audio_data -= util._get_value(other)
+ return new_signal
+
+ def __isub__(self, other):
+ self.audio_data -= util._get_value(other)
+ return self
+
+ def __mul__(self, other):
+ new_signal = self.clone()
+ new_signal.audio_data *= util._get_value(other)
+ return new_signal
+
+ def __imul__(self, other):
+ self.audio_data *= util._get_value(other)
+ return self
+
+ def __rmul__(self, other):
+ return self * other
+
+ # Representation
+ def _info(self):
+ dur = f"{self.signal_duration:0.3f}" if self.signal_duration else "[unknown]"
+ info = {
+ "duration": f"{dur} seconds",
+ "batch_size": self.batch_size,
+ "path": self.path_to_file if self.path_to_file else "path unknown",
+ "sample_rate": self.sample_rate,
+ "num_channels": self.num_channels if self.num_channels else "[unknown]",
+ "audio_data.shape": self.audio_data.shape,
+ "stft_params": self.stft_params,
+ "device": self.device,
+ }
+
+ return info
+
+ def markdown(self):
+ """Produces a markdown representation of AudioSignal, in a markdown table.
+
+ Returns
+ -------
+ str
+ Markdown representation of AudioSignal.
+
+ Examples
+ --------
+ >>> signal = AudioSignal(torch.randn(44100), 44100)
+ >>> print(signal.markdown())
+ | Key | Value
+ |---|---
+ | duration | 1.000 seconds |
+ | batch_size | 1 |
+ | path | path unknown |
+ | sample_rate | 44100 |
+ | num_channels | 1 |
+ | audio_data.shape | torch.Size([1, 1, 44100]) |
+ | stft_params | STFTParams(window_length=2048, hop_length=512, window_type='sqrt_hann', match_stride=False) |
+ | device | cpu |
+ """
+ info = self._info()
+
+ FORMAT = "| Key | Value \n" "|---|--- \n"
+ for k, v in info.items():
+ row = f"| {k} | {v} |\n"
+ FORMAT += row
+ return FORMAT
+
+ def __str__(self):
+ info = self._info()
+
+ desc = ""
+ for k, v in info.items():
+ desc += f"{k}: {v}\n"
+ return desc
+
+ def __rich__(self):
+ from rich.table import Table
+
+ info = self._info()
+
+ table = Table(title=f"{self.__class__.__name__}")
+ table.add_column("Key", style="green")
+ table.add_column("Value", style="cyan")
+
+ for k, v in info.items():
+ table.add_row(k, str(v))
+ return table
+
+ # Comparison
+ def __eq__(self, other):
+ for k, v in list(self.__dict__.items()):
+ if torch.is_tensor(v):
+ if not torch.allclose(v, other.__dict__[k], atol=1e-6):
+ max_error = (v - other.__dict__[k]).abs().max()
+ print(f"Max abs error for {k}: {max_error}")
+ return False
+ return True
+
+ # Indexing
+ def __getitem__(self, key):
+ if torch.is_tensor(key) and key.ndim == 0 and key.item() is True:
+ assert self.batch_size == 1
+ audio_data = self.audio_data
+ _loudness = self._loudness
+ stft_data = self.stft_data
+
+ elif isinstance(key, (bool, int, list, slice, tuple)) or (
+ torch.is_tensor(key) and key.ndim <= 1
+ ):
+ # Indexing only on the batch dimension.
+ # Then let's copy over relevant stuff.
+ # Future work: make this work for time-indexing
+ # as well, using the hop length.
+ audio_data = self.audio_data[key]
+ _loudness = self._loudness[key] if self._loudness is not None else None
+ stft_data = self.stft_data[key] if self.stft_data is not None else None
+
+ sources = None
+
+ copy = type(self)(audio_data, self.sample_rate, stft_params=self.stft_params)
+ copy._loudness = _loudness
+ copy._stft_data = stft_data
+ copy.sources = sources
+
+ return copy
+
+ def __setitem__(self, key, value):
+ if not isinstance(value, type(self)):
+ self.audio_data[key] = value
+ return
+
+ if torch.is_tensor(key) and key.ndim == 0 and key.item() is True:
+ assert self.batch_size == 1
+ self.audio_data = value.audio_data
+ self._loudness = value._loudness
+ self.stft_data = value.stft_data
+ return
+
+ elif isinstance(key, (bool, int, list, slice, tuple)) or (
+ torch.is_tensor(key) and key.ndim <= 1
+ ):
+ if self.audio_data is not None and value.audio_data is not None:
+ self.audio_data[key] = value.audio_data
+ if self._loudness is not None and value._loudness is not None:
+ self._loudness[key] = value._loudness
+ if self.stft_data is not None and value.stft_data is not None:
+ self.stft_data[key] = value.stft_data
+ return
+
+ def __ne__(self, other):
+ return not self == other
diff --git a/audiotools/core/display.py b/audiotools/core/display.py
new file mode 100644
index 0000000000000000000000000000000000000000..66cbcf34cb2cf9fdf8d67ec4418a887eba73f184
--- /dev/null
+++ b/audiotools/core/display.py
@@ -0,0 +1,194 @@
+import inspect
+import typing
+from functools import wraps
+
+from . import util
+
+
+def format_figure(func):
+ """Decorator for formatting figures produced by the code below.
+ See :py:func:`audiotools.core.util.format_figure` for more.
+
+ Parameters
+ ----------
+ func : Callable
+ Plotting function that is decorated by this function.
+
+ """
+
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ f_keys = inspect.signature(util.format_figure).parameters.keys()
+ f_kwargs = {}
+ for k, v in list(kwargs.items()):
+ if k in f_keys:
+ kwargs.pop(k)
+ f_kwargs[k] = v
+ func(*args, **kwargs)
+ util.format_figure(**f_kwargs)
+
+ return wrapper
+
+
+class DisplayMixin:
+ @format_figure
+ def specshow(
+ self,
+ preemphasis: bool = False,
+ x_axis: str = "time",
+ y_axis: str = "linear",
+ n_mels: int = 128,
+ **kwargs,
+ ):
+ """Displays a spectrogram, using ``librosa.display.specshow``.
+
+ Parameters
+ ----------
+ preemphasis : bool, optional
+ Whether or not to apply preemphasis, which makes high
+ frequency detail easier to see, by default False
+ x_axis : str, optional
+ How to label the x axis, by default "time"
+ y_axis : str, optional
+ How to label the y axis, by default "linear"
+ n_mels : int, optional
+ If displaying a mel spectrogram with ``y_axis = "mel"``,
+ this controls the number of mels, by default 128.
+ kwargs : dict, optional
+ Keyword arguments to :py:func:`audiotools.core.util.format_figure`.
+ """
+ import librosa
+ import librosa.display
+
+ # Always re-compute the STFT data before showing it, in case
+ # it changed.
+ signal = self.clone()
+ signal.stft_data = None
+
+ if preemphasis:
+ signal.preemphasis()
+
+ ref = signal.magnitude.max()
+ log_mag = signal.log_magnitude(ref_value=ref)
+
+ if y_axis == "mel":
+ log_mag = 20 * signal.mel_spectrogram(n_mels).clamp(1e-5).log10()
+ log_mag -= log_mag.max()
+
+ librosa.display.specshow(
+ log_mag.numpy()[0].mean(axis=0),
+ x_axis=x_axis,
+ y_axis=y_axis,
+ sr=signal.sample_rate,
+ **kwargs,
+ )
+
+ @format_figure
+ def waveplot(self, x_axis: str = "time", **kwargs):
+ """Displays a waveform plot, using ``librosa.display.waveshow``.
+
+ Parameters
+ ----------
+ x_axis : str, optional
+ How to label the x axis, by default "time"
+ kwargs : dict, optional
+ Keyword arguments to :py:func:`audiotools.core.util.format_figure`.
+ """
+ import librosa
+ import librosa.display
+
+ audio_data = self.audio_data[0].mean(dim=0)
+ audio_data = audio_data.cpu().numpy()
+
+ plot_fn = "waveshow" if hasattr(librosa.display, "waveshow") else "waveplot"
+ wave_plot_fn = getattr(librosa.display, plot_fn)
+ wave_plot_fn(audio_data, x_axis=x_axis, sr=self.sample_rate, **kwargs)
+
+ @format_figure
+ def wavespec(self, x_axis: str = "time", **kwargs):
+ """Displays a waveform plot, using ``librosa.display.waveshow``.
+
+ Parameters
+ ----------
+ x_axis : str, optional
+ How to label the x axis, by default "time"
+ kwargs : dict, optional
+ Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow`.
+ """
+ import matplotlib.pyplot as plt
+ from matplotlib.gridspec import GridSpec
+
+ gs = GridSpec(6, 1)
+ plt.subplot(gs[0, :])
+ self.waveplot(x_axis=x_axis)
+ plt.subplot(gs[1:, :])
+ self.specshow(x_axis=x_axis, **kwargs)
+
+ def write_audio_to_tb(
+ self,
+ tag: str,
+ writer,
+ step: int = None,
+ plot_fn: typing.Union[typing.Callable, str] = "specshow",
+ **kwargs,
+ ):
+ """Writes a signal and its spectrogram to Tensorboard. Will show up
+ under the Audio and Images tab in Tensorboard.
+
+ Parameters
+ ----------
+ tag : str
+ Tag to write signal to (e.g. ``clean/sample_0.wav``). The image will be
+ written to the corresponding ``.png`` file (e.g. ``clean/sample_0.png``).
+ writer : SummaryWriter
+ A SummaryWriter object from PyTorch library.
+ step : int, optional
+ The step to write the signal to, by default None
+ plot_fn : typing.Union[typing.Callable, str], optional
+ How to create the image. Set to ``None`` to avoid plotting, by default "specshow"
+ kwargs : dict, optional
+ Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or
+ whatever ``plot_fn`` is set to.
+ """
+ import matplotlib.pyplot as plt
+
+ audio_data = self.audio_data[0, 0].detach().cpu()
+ sample_rate = self.sample_rate
+ writer.add_audio(tag, audio_data, step, sample_rate)
+
+ if plot_fn is not None:
+ if isinstance(plot_fn, str):
+ plot_fn = getattr(self, plot_fn)
+ fig = plt.figure()
+ plt.clf()
+ plot_fn(**kwargs)
+ writer.add_figure(tag.replace("wav", "png"), fig, step)
+
+ def save_image(
+ self,
+ image_path: str,
+ plot_fn: typing.Union[typing.Callable, str] = "specshow",
+ **kwargs,
+ ):
+ """Save AudioSignal spectrogram (or whatever ``plot_fn`` is set to) to
+ a specified file.
+
+ Parameters
+ ----------
+ image_path : str
+ Where to save the file to.
+ plot_fn : typing.Union[typing.Callable, str], optional
+ How to create the image. Set to ``None`` to avoid plotting, by default "specshow"
+ kwargs : dict, optional
+ Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or
+ whatever ``plot_fn`` is set to.
+ """
+ import matplotlib.pyplot as plt
+
+ if isinstance(plot_fn, str):
+ plot_fn = getattr(self, plot_fn)
+
+ plt.clf()
+ plot_fn(**kwargs)
+ plt.savefig(image_path, bbox_inches="tight", pad_inches=0)
+ plt.close()
diff --git a/audiotools/core/dsp.py b/audiotools/core/dsp.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9be51a119537b77e497ddc2dac126d569533d7c
--- /dev/null
+++ b/audiotools/core/dsp.py
@@ -0,0 +1,390 @@
+import typing
+
+import julius
+import numpy as np
+import torch
+
+from . import util
+
+
+class DSPMixin:
+ _original_batch_size = None
+ _original_num_channels = None
+ _padded_signal_length = None
+
+ def _preprocess_signal_for_windowing(self, window_duration, hop_duration):
+ self._original_batch_size = self.batch_size
+ self._original_num_channels = self.num_channels
+
+ window_length = int(window_duration * self.sample_rate)
+ hop_length = int(hop_duration * self.sample_rate)
+
+ if window_length % hop_length != 0:
+ factor = window_length // hop_length
+ window_length = factor * hop_length
+
+ self.zero_pad(hop_length, hop_length)
+ self._padded_signal_length = self.signal_length
+
+ return window_length, hop_length
+
+ def windows(
+ self, window_duration: float, hop_duration: float, preprocess: bool = True
+ ):
+ """Generator which yields windows of specified duration from signal with a specified
+ hop length.
+
+ Parameters
+ ----------
+ window_duration : float
+ Duration of every window in seconds.
+ hop_duration : float
+ Hop between windows in seconds.
+ preprocess : bool, optional
+ Whether to preprocess the signal, so that the first sample is in
+ the middle of the first window, by default True
+
+ Yields
+ ------
+ AudioSignal
+ Each window is returned as an AudioSignal.
+ """
+ if preprocess:
+ window_length, hop_length = self._preprocess_signal_for_windowing(
+ window_duration, hop_duration
+ )
+
+ self.audio_data = self.audio_data.reshape(-1, 1, self.signal_length)
+
+ for b in range(self.batch_size):
+ i = 0
+ start_idx = i * hop_length
+ while True:
+ start_idx = i * hop_length
+ i += 1
+ end_idx = start_idx + window_length
+ if end_idx > self.signal_length:
+ break
+ yield self[b, ..., start_idx:end_idx]
+
+ def collect_windows(
+ self, window_duration: float, hop_duration: float, preprocess: bool = True
+ ):
+ """Reshapes signal into windows of specified duration from signal with a specified
+ hop length. Window are placed along the batch dimension. Use with
+ :py:func:`audiotools.core.dsp.DSPMixin.overlap_and_add` to reconstruct the
+ original signal.
+
+ Parameters
+ ----------
+ window_duration : float
+ Duration of every window in seconds.
+ hop_duration : float
+ Hop between windows in seconds.
+ preprocess : bool, optional
+ Whether to preprocess the signal, so that the first sample is in
+ the middle of the first window, by default True
+
+ Returns
+ -------
+ AudioSignal
+ AudioSignal unfolded with shape ``(nb * nch * num_windows, 1, window_length)``
+ """
+ if preprocess:
+ window_length, hop_length = self._preprocess_signal_for_windowing(
+ window_duration, hop_duration
+ )
+
+ # self.audio_data: (nb, nch, nt).
+ unfolded = torch.nn.functional.unfold(
+ self.audio_data.reshape(-1, 1, 1, self.signal_length),
+ kernel_size=(1, window_length),
+ stride=(1, hop_length),
+ )
+ # unfolded: (nb * nch, window_length, num_windows).
+ # -> (nb * nch * num_windows, 1, window_length)
+ unfolded = unfolded.permute(0, 2, 1).reshape(-1, 1, window_length)
+ self.audio_data = unfolded
+ return self
+
+ def overlap_and_add(self, hop_duration: float):
+ """Function which takes a list of windows and overlap adds them into a
+ signal the same length as ``audio_signal``.
+
+ Parameters
+ ----------
+ hop_duration : float
+ How much to shift for each window
+ (overlap is window_duration - hop_duration) in seconds.
+
+ Returns
+ -------
+ AudioSignal
+ overlap-and-added signal.
+ """
+ hop_length = int(hop_duration * self.sample_rate)
+ window_length = self.signal_length
+
+ nb, nch = self._original_batch_size, self._original_num_channels
+
+ unfolded = self.audio_data.reshape(nb * nch, -1, window_length).permute(0, 2, 1)
+ folded = torch.nn.functional.fold(
+ unfolded,
+ output_size=(1, self._padded_signal_length),
+ kernel_size=(1, window_length),
+ stride=(1, hop_length),
+ )
+
+ norm = torch.ones_like(unfolded, device=unfolded.device)
+ norm = torch.nn.functional.fold(
+ norm,
+ output_size=(1, self._padded_signal_length),
+ kernel_size=(1, window_length),
+ stride=(1, hop_length),
+ )
+
+ folded = folded / norm
+
+ folded = folded.reshape(nb, nch, -1)
+ self.audio_data = folded
+ self.trim(hop_length, hop_length)
+ return self
+
+ def low_pass(
+ self, cutoffs: typing.Union[torch.Tensor, np.ndarray, float], zeros: int = 51
+ ):
+ """Low-passes the signal in-place. Each item in the batch
+ can have a different low-pass cutoff, if the input
+ to this signal is an array or tensor. If a float, all
+ items are given the same low-pass filter.
+
+ Parameters
+ ----------
+ cutoffs : typing.Union[torch.Tensor, np.ndarray, float]
+ Cutoff in Hz of low-pass filter.
+ zeros : int, optional
+ Number of taps to use in low-pass filter, by default 51
+
+ Returns
+ -------
+ AudioSignal
+ Low-passed AudioSignal.
+ """
+ cutoffs = util.ensure_tensor(cutoffs, 2, self.batch_size)
+ cutoffs = cutoffs / self.sample_rate
+ filtered = torch.empty_like(self.audio_data)
+
+ for i, cutoff in enumerate(cutoffs):
+ lp_filter = julius.LowPassFilter(cutoff.cpu(), zeros=zeros).to(self.device)
+ filtered[i] = lp_filter(self.audio_data[i])
+
+ self.audio_data = filtered
+ self.stft_data = None
+ return self
+
+ def high_pass(
+ self, cutoffs: typing.Union[torch.Tensor, np.ndarray, float], zeros: int = 51
+ ):
+ """High-passes the signal in-place. Each item in the batch
+ can have a different high-pass cutoff, if the input
+ to this signal is an array or tensor. If a float, all
+ items are given the same high-pass filter.
+
+ Parameters
+ ----------
+ cutoffs : typing.Union[torch.Tensor, np.ndarray, float]
+ Cutoff in Hz of high-pass filter.
+ zeros : int, optional
+ Number of taps to use in high-pass filter, by default 51
+
+ Returns
+ -------
+ AudioSignal
+ High-passed AudioSignal.
+ """
+ cutoffs = util.ensure_tensor(cutoffs, 2, self.batch_size)
+ cutoffs = cutoffs / self.sample_rate
+ filtered = torch.empty_like(self.audio_data)
+
+ for i, cutoff in enumerate(cutoffs):
+ hp_filter = julius.HighPassFilter(cutoff.cpu(), zeros=zeros).to(self.device)
+ filtered[i] = hp_filter(self.audio_data[i])
+
+ self.audio_data = filtered
+ self.stft_data = None
+ return self
+
+ def mask_frequencies(
+ self,
+ fmin_hz: typing.Union[torch.Tensor, np.ndarray, float],
+ fmax_hz: typing.Union[torch.Tensor, np.ndarray, float],
+ val: float = 0.0,
+ ):
+ """Masks frequencies between ``fmin_hz`` and ``fmax_hz``, and fills them
+ with the value specified by ``val``. Useful for implementing SpecAug.
+ The min and max can be different for every item in the batch.
+
+ Parameters
+ ----------
+ fmin_hz : typing.Union[torch.Tensor, np.ndarray, float]
+ Lower end of band to mask out.
+ fmax_hz : typing.Union[torch.Tensor, np.ndarray, float]
+ Upper end of band to mask out.
+ val : float, optional
+ Value to fill in, by default 0.0
+
+ Returns
+ -------
+ AudioSignal
+ Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
+ masked audio data.
+ """
+ # SpecAug
+ mag, phase = self.magnitude, self.phase
+ fmin_hz = util.ensure_tensor(fmin_hz, ndim=mag.ndim)
+ fmax_hz = util.ensure_tensor(fmax_hz, ndim=mag.ndim)
+ assert torch.all(fmin_hz < fmax_hz)
+
+ # build mask
+ nbins = mag.shape[-2]
+ bins_hz = torch.linspace(0, self.sample_rate / 2, nbins, device=self.device)
+ bins_hz = bins_hz[None, None, :, None].repeat(
+ self.batch_size, 1, 1, mag.shape[-1]
+ )
+ mask = (fmin_hz <= bins_hz) & (bins_hz < fmax_hz)
+ mask = mask.to(self.device)
+
+ mag = mag.masked_fill(mask, val)
+ phase = phase.masked_fill(mask, val)
+ self.stft_data = mag * torch.exp(1j * phase)
+ return self
+
+ def mask_timesteps(
+ self,
+ tmin_s: typing.Union[torch.Tensor, np.ndarray, float],
+ tmax_s: typing.Union[torch.Tensor, np.ndarray, float],
+ val: float = 0.0,
+ ):
+ """Masks timesteps between ``tmin_s`` and ``tmax_s``, and fills them
+ with the value specified by ``val``. Useful for implementing SpecAug.
+ The min and max can be different for every item in the batch.
+
+ Parameters
+ ----------
+ tmin_s : typing.Union[torch.Tensor, np.ndarray, float]
+ Lower end of timesteps to mask out.
+ tmax_s : typing.Union[torch.Tensor, np.ndarray, float]
+ Upper end of timesteps to mask out.
+ val : float, optional
+ Value to fill in, by default 0.0
+
+ Returns
+ -------
+ AudioSignal
+ Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
+ masked audio data.
+ """
+ # SpecAug
+ mag, phase = self.magnitude, self.phase
+ tmin_s = util.ensure_tensor(tmin_s, ndim=mag.ndim)
+ tmax_s = util.ensure_tensor(tmax_s, ndim=mag.ndim)
+
+ assert torch.all(tmin_s < tmax_s)
+
+ # build mask
+ nt = mag.shape[-1]
+ bins_t = torch.linspace(0, self.signal_duration, nt, device=self.device)
+ bins_t = bins_t[None, None, None, :].repeat(
+ self.batch_size, 1, mag.shape[-2], 1
+ )
+ mask = (tmin_s <= bins_t) & (bins_t < tmax_s)
+
+ mag = mag.masked_fill(mask, val)
+ phase = phase.masked_fill(mask, val)
+ self.stft_data = mag * torch.exp(1j * phase)
+ return self
+
+ def mask_low_magnitudes(
+ self, db_cutoff: typing.Union[torch.Tensor, np.ndarray, float], val: float = 0.0
+ ):
+ """Mask away magnitudes below a specified threshold, which
+ can be different for every item in the batch.
+
+ Parameters
+ ----------
+ db_cutoff : typing.Union[torch.Tensor, np.ndarray, float]
+ Decibel value for which things below it will be masked away.
+ val : float, optional
+ Value to fill in for masked portions, by default 0.0
+
+ Returns
+ -------
+ AudioSignal
+ Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
+ masked audio data.
+ """
+ mag = self.magnitude
+ log_mag = self.log_magnitude()
+
+ db_cutoff = util.ensure_tensor(db_cutoff, ndim=mag.ndim)
+ mask = log_mag < db_cutoff
+ mag = mag.masked_fill(mask, val)
+
+ self.magnitude = mag
+ return self
+
+ def shift_phase(self, shift: typing.Union[torch.Tensor, np.ndarray, float]):
+ """Shifts the phase by a constant value.
+
+ Parameters
+ ----------
+ shift : typing.Union[torch.Tensor, np.ndarray, float]
+ What to shift the phase by.
+
+ Returns
+ -------
+ AudioSignal
+ Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
+ masked audio data.
+ """
+ shift = util.ensure_tensor(shift, ndim=self.phase.ndim)
+ self.phase = self.phase + shift
+ return self
+
+ def corrupt_phase(self, scale: typing.Union[torch.Tensor, np.ndarray, float]):
+ """Corrupts the phase randomly by some scaled value.
+
+ Parameters
+ ----------
+ scale : typing.Union[torch.Tensor, np.ndarray, float]
+ Standard deviation of noise to add to the phase.
+
+ Returns
+ -------
+ AudioSignal
+ Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
+ masked audio data.
+ """
+ scale = util.ensure_tensor(scale, ndim=self.phase.ndim)
+ self.phase = self.phase + scale * torch.randn_like(self.phase)
+ return self
+
+ def preemphasis(self, coef: float = 0.85):
+ """Applies pre-emphasis to audio signal.
+
+ Parameters
+ ----------
+ coef : float, optional
+ How much pre-emphasis to apply, lower values do less. 0 does nothing.
+ by default 0.85
+
+ Returns
+ -------
+ AudioSignal
+ Pre-emphasized signal.
+ """
+ kernel = torch.tensor([1, -coef, 0]).view(1, 1, -1).to(self.device)
+ x = self.audio_data.reshape(-1, 1, self.signal_length)
+ x = torch.nn.functional.conv1d(x, kernel, padding=1)
+ self.audio_data = x.reshape(*self.audio_data.shape)
+ return self
diff --git a/audiotools/core/effects.py b/audiotools/core/effects.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb534cbcb2d457575de685fc9248d1716879145b
--- /dev/null
+++ b/audiotools/core/effects.py
@@ -0,0 +1,647 @@
+import typing
+
+import julius
+import numpy as np
+import torch
+import torchaudio
+
+from . import util
+
+
+class EffectMixin:
+ GAIN_FACTOR = np.log(10) / 20
+ """Gain factor for converting between amplitude and decibels."""
+ CODEC_PRESETS = {
+ "8-bit": {"format": "wav", "encoding": "ULAW", "bits_per_sample": 8},
+ "GSM-FR": {"format": "gsm"},
+ "MP3": {"format": "mp3", "compression": -9},
+ "Vorbis": {"format": "vorbis", "compression": -1},
+ "Ogg": {
+ "format": "ogg",
+ "compression": -1,
+ },
+ "Amr-nb": {"format": "amr-nb"},
+ }
+ """Presets for applying codecs via torchaudio."""
+
+ def mix(
+ self,
+ other,
+ snr: typing.Union[torch.Tensor, np.ndarray, float] = 10,
+ other_eq: typing.Union[torch.Tensor, np.ndarray] = None,
+ ):
+ """Mixes noise with signal at specified
+ signal-to-noise ratio. Optionally, the
+ other signal can be equalized in-place.
+
+
+ Parameters
+ ----------
+ other : AudioSignal
+ AudioSignal object to mix with.
+ snr : typing.Union[torch.Tensor, np.ndarray, float], optional
+ Signal to noise ratio, by default 10
+ other_eq : typing.Union[torch.Tensor, np.ndarray], optional
+ EQ curve to apply to other signal, if any, by default None
+
+ Returns
+ -------
+ AudioSignal
+ In-place modification of AudioSignal.
+ """
+ snr = util.ensure_tensor(snr).to(self.device)
+
+ pad_len = max(0, self.signal_length - other.signal_length)
+ other.zero_pad(0, pad_len)
+ other.truncate_samples(self.signal_length)
+ if other_eq is not None:
+ other = other.equalizer(other_eq)
+
+ tgt_loudness = self.loudness() - snr
+ other = other.normalize(tgt_loudness)
+
+ self.audio_data = self.audio_data + other.audio_data
+ return self
+
+ def convolve(self, other, start_at_max: bool = True):
+ """Convolves self with other.
+ This function uses FFTs to do the convolution.
+
+ Parameters
+ ----------
+ other : AudioSignal
+ Signal to convolve with.
+ start_at_max : bool, optional
+ Whether to start at the max value of other signal, to
+ avoid inducing delays, by default True
+
+ Returns
+ -------
+ AudioSignal
+ Convolved signal, in-place.
+ """
+ from . import AudioSignal
+
+ pad_len = self.signal_length - other.signal_length
+
+ if pad_len > 0:
+ other.zero_pad(0, pad_len)
+ else:
+ other.truncate_samples(self.signal_length)
+
+ if start_at_max:
+ # Use roll to rotate over the max for every item
+ # so that the impulse responses don't induce any
+ # delay.
+ idx = other.audio_data.abs().argmax(axis=-1)
+ irs = torch.zeros_like(other.audio_data)
+ for i in range(other.batch_size):
+ irs[i] = torch.roll(other.audio_data[i], -idx[i].item(), -1)
+ other = AudioSignal(irs, other.sample_rate)
+
+ delta = torch.zeros_like(other.audio_data)
+ delta[..., 0] = 1
+
+ length = self.signal_length
+ delta_fft = torch.fft.rfft(delta, length)
+ other_fft = torch.fft.rfft(other.audio_data, length)
+ self_fft = torch.fft.rfft(self.audio_data, length)
+
+ convolved_fft = other_fft * self_fft
+ convolved_audio = torch.fft.irfft(convolved_fft, length)
+
+ delta_convolved_fft = other_fft * delta_fft
+ delta_audio = torch.fft.irfft(delta_convolved_fft, length)
+
+ # Use the delta to rescale the audio exactly as needed.
+ delta_max = delta_audio.abs().max(dim=-1, keepdims=True)[0]
+ scale = 1 / delta_max.clamp(1e-5)
+ convolved_audio = convolved_audio * scale
+
+ self.audio_data = convolved_audio
+
+ return self
+
+ def apply_ir(
+ self,
+ ir,
+ drr: typing.Union[torch.Tensor, np.ndarray, float] = None,
+ ir_eq: typing.Union[torch.Tensor, np.ndarray] = None,
+ use_original_phase: bool = False,
+ ):
+ """Applies an impulse response to the signal. If ` is`ir_eq``
+ is specified, the impulse response is equalized before
+ it is applied, using the given curve.
+
+ Parameters
+ ----------
+ ir : AudioSignal
+ Impulse response to convolve with.
+ drr : typing.Union[torch.Tensor, np.ndarray, float], optional
+ Direct-to-reverberant ratio that impulse response will be
+ altered to, if specified, by default None
+ ir_eq : typing.Union[torch.Tensor, np.ndarray], optional
+ Equalization that will be applied to impulse response
+ if specified, by default None
+ use_original_phase : bool, optional
+ Whether to use the original phase, instead of the convolved
+ phase, by default False
+
+ Returns
+ -------
+ AudioSignal
+ Signal with impulse response applied to it
+ """
+ if ir_eq is not None:
+ ir = ir.equalizer(ir_eq)
+ if drr is not None:
+ ir = ir.alter_drr(drr)
+
+ # Save the peak before
+ max_spk = self.audio_data.abs().max(dim=-1, keepdims=True).values
+
+ # Augment the impulse response to simulate microphone effects
+ # and with varying direct-to-reverberant ratio.
+ phase = self.phase
+ self.convolve(ir)
+
+ # Use the input phase
+ if use_original_phase:
+ self.stft()
+ self.stft_data = self.magnitude * torch.exp(1j * phase)
+ self.istft()
+
+ # Rescale to the input's amplitude
+ max_transformed = self.audio_data.abs().max(dim=-1, keepdims=True).values
+ scale_factor = max_spk.clamp(1e-8) / max_transformed.clamp(1e-8)
+ self = self * scale_factor
+
+ return self
+
+ def ensure_max_of_audio(self, max: float = 1.0):
+ """Ensures that ``abs(audio_data) <= max``.
+
+ Parameters
+ ----------
+ max : float, optional
+ Max absolute value of signal, by default 1.0
+
+ Returns
+ -------
+ AudioSignal
+ Signal with values scaled between -max and max.
+ """
+ peak = self.audio_data.abs().max(dim=-1, keepdims=True)[0]
+ peak_gain = torch.ones_like(peak)
+ peak_gain[peak > max] = max / peak[peak > max]
+ self.audio_data = self.audio_data * peak_gain
+ return self
+
+ def normalize(self, db: typing.Union[torch.Tensor, np.ndarray, float] = -24.0):
+ """Normalizes the signal's volume to the specified db, in LUFS.
+ This is GPU-compatible, making for very fast loudness normalization.
+
+ Parameters
+ ----------
+ db : typing.Union[torch.Tensor, np.ndarray, float], optional
+ Loudness to normalize to, by default -24.0
+
+ Returns
+ -------
+ AudioSignal
+ Normalized audio signal.
+ """
+ db = util.ensure_tensor(db).to(self.device)
+ ref_db = self.loudness()
+ gain = db - ref_db
+ gain = torch.exp(gain * self.GAIN_FACTOR)
+
+ self.audio_data = self.audio_data * gain[:, None, None]
+ return self
+
+ def volume_change(self, db: typing.Union[torch.Tensor, np.ndarray, float]):
+ """Change volume of signal by some amount, in dB.
+
+ Parameters
+ ----------
+ db : typing.Union[torch.Tensor, np.ndarray, float]
+ Amount to change volume by.
+
+ Returns
+ -------
+ AudioSignal
+ Signal at new volume.
+ """
+ db = util.ensure_tensor(db, ndim=1).to(self.device)
+ gain = torch.exp(db * self.GAIN_FACTOR)
+ self.audio_data = self.audio_data * gain[:, None, None]
+ return self
+
+ def _to_2d(self):
+ waveform = self.audio_data.reshape(-1, self.signal_length)
+ return waveform
+
+ def _to_3d(self, waveform):
+ return waveform.reshape(self.batch_size, self.num_channels, -1)
+
+ def pitch_shift(self, n_semitones: int, quick: bool = True):
+ """Pitch shift the signal. All items in the batch
+ get the same pitch shift.
+
+ Parameters
+ ----------
+ n_semitones : int
+ How many semitones to shift the signal by.
+ quick : bool, optional
+ Using quick pitch shifting, by default True
+
+ Returns
+ -------
+ AudioSignal
+ Pitch shifted audio signal.
+ """
+ device = self.device
+ effects = [
+ ["pitch", str(n_semitones * 100)],
+ ["rate", str(self.sample_rate)],
+ ]
+ if quick:
+ effects[0].insert(1, "-q")
+
+ waveform = self._to_2d().cpu()
+ waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor(
+ waveform, self.sample_rate, effects, channels_first=True
+ )
+ self.sample_rate = sample_rate
+ self.audio_data = self._to_3d(waveform)
+ return self.to(device)
+
+ def time_stretch(self, factor: float, quick: bool = True):
+ """Time stretch the audio signal.
+
+ Parameters
+ ----------
+ factor : float
+ Factor by which to stretch the AudioSignal. Typically
+ between 0.8 and 1.2.
+ quick : bool, optional
+ Whether to use quick time stretching, by default True
+
+ Returns
+ -------
+ AudioSignal
+ Time-stretched AudioSignal.
+ """
+ device = self.device
+ effects = [
+ ["tempo", str(factor)],
+ ["rate", str(self.sample_rate)],
+ ]
+ if quick:
+ effects[0].insert(1, "-q")
+
+ waveform = self._to_2d().cpu()
+ waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor(
+ waveform, self.sample_rate, effects, channels_first=True
+ )
+ self.sample_rate = sample_rate
+ self.audio_data = self._to_3d(waveform)
+ return self.to(device)
+
+ def apply_codec(
+ self,
+ preset: str = None,
+ format: str = "wav",
+ encoding: str = None,
+ bits_per_sample: int = None,
+ compression: int = None,
+ ): # pragma: no cover
+ """Applies an audio codec to the signal.
+
+ Parameters
+ ----------
+ preset : str, optional
+ One of the keys in ``self.CODEC_PRESETS``, by default None
+ format : str, optional
+ Format for audio codec, by default "wav"
+ encoding : str, optional
+ Encoding to use, by default None
+ bits_per_sample : int, optional
+ How many bits per sample, by default None
+ compression : int, optional
+ Compression amount of codec, by default None
+
+ Returns
+ -------
+ AudioSignal
+ AudioSignal with codec applied.
+
+ Raises
+ ------
+ ValueError
+ If preset is not in ``self.CODEC_PRESETS``, an error
+ is thrown.
+ """
+ torchaudio_version_070 = "0.7" in torchaudio.__version__
+ if torchaudio_version_070:
+ return self
+
+ kwargs = {
+ "format": format,
+ "encoding": encoding,
+ "bits_per_sample": bits_per_sample,
+ "compression": compression,
+ }
+
+ if preset is not None:
+ if preset in self.CODEC_PRESETS:
+ kwargs = self.CODEC_PRESETS[preset]
+ else:
+ raise ValueError(
+ f"Unknown preset: {preset}. "
+ f"Known presets: {list(self.CODEC_PRESETS.keys())}"
+ )
+
+ waveform = self._to_2d()
+ if kwargs["format"] in ["vorbis", "mp3", "ogg", "amr-nb"]:
+ # Apply it in a for loop
+ augmented = torch.cat(
+ [
+ torchaudio.functional.apply_codec(
+ waveform[i][None, :], self.sample_rate, **kwargs
+ )
+ for i in range(waveform.shape[0])
+ ],
+ dim=0,
+ )
+ else:
+ augmented = torchaudio.functional.apply_codec(
+ waveform, self.sample_rate, **kwargs
+ )
+ augmented = self._to_3d(augmented)
+
+ self.audio_data = augmented
+ return self
+
+ def mel_filterbank(self, n_bands: int):
+ """Breaks signal into mel bands.
+
+ Parameters
+ ----------
+ n_bands : int
+ Number of mel bands to use.
+
+ Returns
+ -------
+ torch.Tensor
+ Mel-filtered bands, with last axis being the band index.
+ """
+ filterbank = (
+ julius.SplitBands(self.sample_rate, n_bands).float().to(self.device)
+ )
+ filtered = filterbank(self.audio_data)
+ return filtered.permute(1, 2, 3, 0)
+
+ def equalizer(self, db: typing.Union[torch.Tensor, np.ndarray]):
+ """Applies a mel-spaced equalizer to the audio signal.
+
+ Parameters
+ ----------
+ db : typing.Union[torch.Tensor, np.ndarray]
+ EQ curve to apply.
+
+ Returns
+ -------
+ AudioSignal
+ AudioSignal with equalization applied.
+ """
+ db = util.ensure_tensor(db)
+ n_bands = db.shape[-1]
+ fbank = self.mel_filterbank(n_bands)
+
+ # If there's a batch dimension, make sure it's the same.
+ if db.ndim == 2:
+ if db.shape[0] != 1:
+ assert db.shape[0] == fbank.shape[0]
+ else:
+ db = db.unsqueeze(0)
+
+ weights = (10**db).to(self.device).float()
+ fbank = fbank * weights[:, None, None, :]
+ eq_audio_data = fbank.sum(-1)
+ self.audio_data = eq_audio_data
+ return self
+
+ def clip_distortion(
+ self, clip_percentile: typing.Union[torch.Tensor, np.ndarray, float]
+ ):
+ """Clips the signal at a given percentile. The higher it is,
+ the lower the threshold for clipping.
+
+ Parameters
+ ----------
+ clip_percentile : typing.Union[torch.Tensor, np.ndarray, float]
+ Values are between 0.0 to 1.0. Typical values are 0.1 or below.
+
+ Returns
+ -------
+ AudioSignal
+ Audio signal with clipped audio data.
+ """
+ clip_percentile = util.ensure_tensor(clip_percentile, ndim=1)
+ min_thresh = torch.quantile(self.audio_data, clip_percentile / 2, dim=-1)
+ max_thresh = torch.quantile(self.audio_data, 1 - (clip_percentile / 2), dim=-1)
+
+ nc = self.audio_data.shape[1]
+ min_thresh = min_thresh[:, :nc, :]
+ max_thresh = max_thresh[:, :nc, :]
+
+ self.audio_data = self.audio_data.clamp(min_thresh, max_thresh)
+
+ return self
+
+ def quantization(
+ self, quantization_channels: typing.Union[torch.Tensor, np.ndarray, int]
+ ):
+ """Applies quantization to the input waveform.
+
+ Parameters
+ ----------
+ quantization_channels : typing.Union[torch.Tensor, np.ndarray, int]
+ Number of evenly spaced quantization channels to quantize
+ to.
+
+ Returns
+ -------
+ AudioSignal
+ Quantized AudioSignal.
+ """
+ quantization_channels = util.ensure_tensor(quantization_channels, ndim=3)
+
+ x = self.audio_data
+ x = (x + 1) / 2
+ x = x * quantization_channels
+ x = x.floor()
+ x = x / quantization_channels
+ x = 2 * x - 1
+
+ residual = (self.audio_data - x).detach()
+ self.audio_data = self.audio_data - residual
+ return self
+
+ def mulaw_quantization(
+ self, quantization_channels: typing.Union[torch.Tensor, np.ndarray, int]
+ ):
+ """Applies mu-law quantization to the input waveform.
+
+ Parameters
+ ----------
+ quantization_channels : typing.Union[torch.Tensor, np.ndarray, int]
+ Number of mu-law spaced quantization channels to quantize
+ to.
+
+ Returns
+ -------
+ AudioSignal
+ Quantized AudioSignal.
+ """
+ mu = quantization_channels - 1.0
+ mu = util.ensure_tensor(mu, ndim=3)
+
+ x = self.audio_data
+
+ # quantize
+ x = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu)
+ x = ((x + 1) / 2 * mu + 0.5).to(torch.int64)
+
+ # unquantize
+ x = (x / mu) * 2 - 1.0
+ x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.0) / mu
+
+ residual = (self.audio_data - x).detach()
+ self.audio_data = self.audio_data - residual
+ return self
+
+ def __matmul__(self, other):
+ return self.convolve(other)
+
+
+class ImpulseResponseMixin:
+ """These functions are generally only used with AudioSignals that are derived
+ from impulse responses, not other sources like music or speech. These methods
+ are used to replicate the data augmentation described in [1].
+
+ 1. Bryan, Nicholas J. "Impulse response data augmentation and deep
+ neural networks for blind room acoustic parameter estimation."
+ ICASSP 2020-2020 IEEE International Conference on Acoustics,
+ Speech and Signal Processing (ICASSP). IEEE, 2020.
+ """
+
+ def decompose_ir(self):
+ """Decomposes an impulse response into early and late
+ field responses.
+ """
+ # Equations 1 and 2
+ # -----------------
+ # Breaking up into early
+ # response + late field response.
+
+ td = torch.argmax(self.audio_data, dim=-1, keepdim=True)
+ t0 = int(self.sample_rate * 0.0025)
+
+ idx = torch.arange(self.audio_data.shape[-1], device=self.device)[None, None, :]
+ idx = idx.expand(self.batch_size, -1, -1)
+ early_idx = (idx >= td - t0) * (idx <= td + t0)
+
+ early_response = torch.zeros_like(self.audio_data, device=self.device)
+ early_response[early_idx] = self.audio_data[early_idx]
+
+ late_idx = ~early_idx
+ late_field = torch.zeros_like(self.audio_data, device=self.device)
+ late_field[late_idx] = self.audio_data[late_idx]
+
+ # Equation 4
+ # ----------
+ # Decompose early response into windowed
+ # direct path and windowed residual.
+
+ window = torch.zeros_like(self.audio_data, device=self.device)
+ for idx in range(self.batch_size):
+ window_idx = early_idx[idx, 0].nonzero()
+ window[idx, ..., window_idx] = self.get_window(
+ "hann", window_idx.shape[-1], self.device
+ )
+ return early_response, late_field, window
+
+ def measure_drr(self):
+ """Measures the direct-to-reverberant ratio of the impulse
+ response.
+
+ Returns
+ -------
+ float
+ Direct-to-reverberant ratio
+ """
+ early_response, late_field, _ = self.decompose_ir()
+ num = (early_response**2).sum(dim=-1)
+ den = (late_field**2).sum(dim=-1)
+ drr = 10 * torch.log10(num / den)
+ return drr
+
+ @staticmethod
+ def solve_alpha(early_response, late_field, wd, target_drr):
+ """Used to solve for the alpha value, which is used
+ to alter the drr.
+ """
+ # Equation 5
+ # ----------
+ # Apply the good ol' quadratic formula.
+
+ wd_sq = wd**2
+ wd_sq_1 = (1 - wd) ** 2
+ e_sq = early_response**2
+ l_sq = late_field**2
+ a = (wd_sq * e_sq).sum(dim=-1)
+ b = (2 * (1 - wd) * wd * e_sq).sum(dim=-1)
+ c = (wd_sq_1 * e_sq).sum(dim=-1) - torch.pow(10, target_drr / 10) * l_sq.sum(
+ dim=-1
+ )
+
+ expr = ((b**2) - 4 * a * c).sqrt()
+ alpha = torch.maximum(
+ (-b - expr) / (2 * a),
+ (-b + expr) / (2 * a),
+ )
+ return alpha
+
+ def alter_drr(self, drr: typing.Union[torch.Tensor, np.ndarray, float]):
+ """Alters the direct-to-reverberant ratio of the impulse response.
+
+ Parameters
+ ----------
+ drr : typing.Union[torch.Tensor, np.ndarray, float]
+ Direct-to-reverberant ratio that impulse response will be
+ altered to, if specified, by default None
+
+ Returns
+ -------
+ AudioSignal
+ Altered impulse response.
+ """
+ drr = util.ensure_tensor(drr, 2, self.batch_size).to(self.device)
+
+ early_response, late_field, window = self.decompose_ir()
+ alpha = self.solve_alpha(early_response, late_field, window, drr)
+ min_alpha = (
+ late_field.abs().max(dim=-1)[0] / early_response.abs().max(dim=-1)[0]
+ )
+ alpha = torch.maximum(alpha, min_alpha)[..., None]
+
+ aug_ir_data = (
+ alpha * window * early_response
+ + ((1 - window) * early_response)
+ + late_field
+ )
+ self.audio_data = aug_ir_data
+ self.ensure_max_of_audio()
+ return self
diff --git a/audiotools/core/ffmpeg.py b/audiotools/core/ffmpeg.py
new file mode 100644
index 0000000000000000000000000000000000000000..baf27ccca25ffbf9e915aa870ca8797c37187cdd
--- /dev/null
+++ b/audiotools/core/ffmpeg.py
@@ -0,0 +1,204 @@
+import json
+import shlex
+import subprocess
+import tempfile
+from pathlib import Path
+from typing import Tuple
+
+import ffmpy
+import numpy as np
+import torch
+
+
+def r128stats(filepath: str, quiet: bool):
+ """Takes a path to an audio file, returns a dict with the loudness
+ stats computed by the ffmpeg ebur128 filter.
+
+ Parameters
+ ----------
+ filepath : str
+ Path to compute loudness stats on.
+ quiet : bool
+ Whether to show FFMPEG output during computation.
+
+ Returns
+ -------
+ dict
+ Dictionary containing loudness stats.
+ """
+ ffargs = [
+ "ffmpeg",
+ "-nostats",
+ "-i",
+ filepath,
+ "-filter_complex",
+ "ebur128",
+ "-f",
+ "null",
+ "-",
+ ]
+ if quiet:
+ ffargs += ["-hide_banner"]
+ proc = subprocess.Popen(ffargs, stderr=subprocess.PIPE, universal_newlines=True)
+ stats = proc.communicate()[1]
+ summary_index = stats.rfind("Summary:")
+
+ summary_list = stats[summary_index:].split()
+ i_lufs = float(summary_list[summary_list.index("I:") + 1])
+ i_thresh = float(summary_list[summary_list.index("I:") + 4])
+ lra = float(summary_list[summary_list.index("LRA:") + 1])
+ lra_thresh = float(summary_list[summary_list.index("LRA:") + 4])
+ lra_low = float(summary_list[summary_list.index("low:") + 1])
+ lra_high = float(summary_list[summary_list.index("high:") + 1])
+ stats_dict = {
+ "I": i_lufs,
+ "I Threshold": i_thresh,
+ "LRA": lra,
+ "LRA Threshold": lra_thresh,
+ "LRA Low": lra_low,
+ "LRA High": lra_high,
+ }
+
+ return stats_dict
+
+
+def ffprobe_offset_and_codec(path: str) -> Tuple[float, str]:
+ """Given a path to a file, returns the start time offset and codec of
+ the first audio stream.
+ """
+ ff = ffmpy.FFprobe(
+ inputs={path: None},
+ global_options="-show_entries format=start_time:stream=duration,start_time,codec_type,codec_name,start_pts,time_base -of json -v quiet",
+ )
+ streams = json.loads(ff.run(stdout=subprocess.PIPE)[0])["streams"]
+ seconds_offset = 0.0
+ codec = None
+
+ # Get the offset and codec of the first audio stream we find
+ # and return its start time, if it has one.
+ for stream in streams:
+ if stream["codec_type"] == "audio":
+ seconds_offset = stream.get("start_time", 0.0)
+ codec = stream.get("codec_name")
+ break
+ return float(seconds_offset), codec
+
+
+class FFMPEGMixin:
+ _loudness = None
+
+ def ffmpeg_loudness(self, quiet: bool = True):
+ """Computes loudness of audio file using FFMPEG.
+
+ Parameters
+ ----------
+ quiet : bool, optional
+ Whether to show FFMPEG output during computation,
+ by default True
+
+ Returns
+ -------
+ torch.Tensor
+ Loudness of every item in the batch, computed via
+ FFMPEG.
+ """
+ loudness = []
+
+ with tempfile.NamedTemporaryFile(suffix=".wav") as f:
+ for i in range(self.batch_size):
+ self[i].write(f.name)
+ loudness_stats = r128stats(f.name, quiet=quiet)
+ loudness.append(loudness_stats["I"])
+
+ self._loudness = torch.from_numpy(np.array(loudness)).float()
+ return self.loudness()
+
+ def ffmpeg_resample(self, sample_rate: int, quiet: bool = True):
+ """Resamples AudioSignal using FFMPEG. More memory-efficient
+ than using julius.resample for long audio files.
+
+ Parameters
+ ----------
+ sample_rate : int
+ Sample rate to resample to.
+ quiet : bool, optional
+ Whether to show FFMPEG output during computation,
+ by default True
+
+ Returns
+ -------
+ AudioSignal
+ Resampled AudioSignal.
+ """
+ from audiotools import AudioSignal
+
+ if sample_rate == self.sample_rate:
+ return self
+
+ with tempfile.NamedTemporaryFile(suffix=".wav") as f:
+ self.write(f.name)
+ f_out = f.name.replace("wav", "rs.wav")
+ command = f"ffmpeg -i {f.name} -ar {sample_rate} {f_out}"
+ if quiet:
+ command += " -hide_banner -loglevel error"
+ subprocess.check_call(shlex.split(command))
+ resampled = AudioSignal(f_out)
+ Path.unlink(Path(f_out))
+ return resampled
+
+ @classmethod
+ def load_from_file_with_ffmpeg(cls, audio_path: str, quiet: bool = True, **kwargs):
+ """Loads AudioSignal object after decoding it to a wav file using FFMPEG.
+ Useful for loading audio that isn't covered by librosa's loading mechanism. Also
+ useful for loading mp3 files, without any offset.
+
+ Parameters
+ ----------
+ audio_path : str
+ Path to load AudioSignal from.
+ quiet : bool, optional
+ Whether to show FFMPEG output during computation,
+ by default True
+
+ Returns
+ -------
+ AudioSignal
+ AudioSignal loaded from file with FFMPEG.
+ """
+ audio_path = str(audio_path)
+ with tempfile.TemporaryDirectory() as d:
+ wav_file = str(Path(d) / "extracted.wav")
+ padded_wav = str(Path(d) / "padded.wav")
+
+ global_options = "-y"
+ if quiet:
+ global_options += " -loglevel error"
+
+ ff = ffmpy.FFmpeg(
+ inputs={audio_path: None},
+ outputs={wav_file: None},
+ global_options=global_options,
+ )
+ ff.run()
+
+ # We pad the file using the start time offset in case it's an audio
+ # stream starting at some offset in a video container.
+ pad, codec = ffprobe_offset_and_codec(audio_path)
+
+ # For mp3s, don't pad files with discrepancies less than 0.027s -
+ # it's likely due to codec latency. The amount of latency introduced
+ # by mp3 is 1152, which is 0.0261 44khz. So we set the threshold
+ # here slightly above that.
+ # Source: https://lame.sourceforge.io/tech-FAQ.txt.
+ if codec == "mp3" and pad < 0.027:
+ pad = 0.0
+ ff = ffmpy.FFmpeg(
+ inputs={wav_file: None},
+ outputs={padded_wav: f"-af 'adelay={pad*1000}:all=true'"},
+ global_options=global_options,
+ )
+ ff.run()
+
+ signal = cls(padded_wav, **kwargs)
+
+ return signal
diff --git a/audiotools/core/loudness.py b/audiotools/core/loudness.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb3ee2675d7cb71f4c00106b0c1e901b8e51b842
--- /dev/null
+++ b/audiotools/core/loudness.py
@@ -0,0 +1,320 @@
+import copy
+
+import julius
+import numpy as np
+import scipy
+import torch
+import torch.nn.functional as F
+import torchaudio
+
+
+class Meter(torch.nn.Module):
+ """Tensorized version of pyloudnorm.Meter. Works with batched audio tensors.
+
+ Parameters
+ ----------
+ rate : int
+ Sample rate of audio.
+ filter_class : str, optional
+ Class of weighting filter used.
+ K-weighting' (default), 'Fenton/Lee 1'
+ 'Fenton/Lee 2', 'Dash et al.'
+ by default "K-weighting"
+ block_size : float, optional
+ Gating block size in seconds, by default 0.400
+ zeros : int, optional
+ Number of zeros to use in FIR approximation of
+ IIR filters, by default 512
+ use_fir : bool, optional
+ Whether to use FIR approximation or exact IIR formulation.
+ If computing on GPU, ``use_fir=True`` will be used, as its
+ much faster, by default False
+ """
+
+ def __init__(
+ self,
+ rate: int,
+ filter_class: str = "K-weighting",
+ block_size: float = 0.400,
+ zeros: int = 512,
+ use_fir: bool = False,
+ ):
+ super().__init__()
+
+ self.rate = rate
+ self.filter_class = filter_class
+ self.block_size = block_size
+ self.use_fir = use_fir
+
+ G = torch.from_numpy(np.array([1.0, 1.0, 1.0, 1.41, 1.41]))
+ self.register_buffer("G", G)
+
+ # Compute impulse responses so that filtering is fast via
+ # a convolution at runtime, on GPU, unlike lfilter.
+ impulse = np.zeros((zeros,))
+ impulse[..., 0] = 1.0
+
+ firs = np.zeros((len(self._filters), 1, zeros))
+ passband_gain = torch.zeros(len(self._filters))
+
+ for i, (_, filter_stage) in enumerate(self._filters.items()):
+ firs[i] = scipy.signal.lfilter(filter_stage.b, filter_stage.a, impulse)
+ passband_gain[i] = filter_stage.passband_gain
+
+ firs = torch.from_numpy(firs[..., ::-1].copy()).float()
+
+ self.register_buffer("firs", firs)
+ self.register_buffer("passband_gain", passband_gain)
+
+ def apply_filter_gpu(self, data: torch.Tensor):
+ """Performs FIR approximation of loudness computation.
+
+ Parameters
+ ----------
+ data : torch.Tensor
+ Audio data of shape (nb, nch, nt).
+
+ Returns
+ -------
+ torch.Tensor
+ Filtered audio data.
+ """
+ # Data is of shape (nb, nch, nt)
+ # Reshape to (nb*nch, 1, nt)
+ nb, nt, nch = data.shape
+ data = data.permute(0, 2, 1)
+ data = data.reshape(nb * nch, 1, nt)
+
+ # Apply padding
+ pad_length = self.firs.shape[-1]
+
+ # Apply filtering in sequence
+ for i in range(self.firs.shape[0]):
+ data = F.pad(data, (pad_length, pad_length))
+ data = julius.fftconv.fft_conv1d(data, self.firs[i, None, ...])
+ data = self.passband_gain[i] * data
+ data = data[..., 1 : nt + 1]
+
+ data = data.permute(0, 2, 1)
+ data = data[:, :nt, :]
+ return data
+
+ def apply_filter_cpu(self, data: torch.Tensor):
+ """Performs IIR formulation of loudness computation.
+
+ Parameters
+ ----------
+ data : torch.Tensor
+ Audio data of shape (nb, nch, nt).
+
+ Returns
+ -------
+ torch.Tensor
+ Filtered audio data.
+ """
+ for _, filter_stage in self._filters.items():
+ passband_gain = filter_stage.passband_gain
+
+ a_coeffs = torch.from_numpy(filter_stage.a).float().to(data.device)
+ b_coeffs = torch.from_numpy(filter_stage.b).float().to(data.device)
+
+ _data = data.permute(0, 2, 1)
+ filtered = torchaudio.functional.lfilter(
+ _data, a_coeffs, b_coeffs, clamp=False
+ )
+ data = passband_gain * filtered.permute(0, 2, 1)
+ return data
+
+ def apply_filter(self, data: torch.Tensor):
+ """Applies filter on either CPU or GPU, depending
+ on if the audio is on GPU or is on CPU, or if
+ ``self.use_fir`` is True.
+
+ Parameters
+ ----------
+ data : torch.Tensor
+ Audio data of shape (nb, nch, nt).
+
+ Returns
+ -------
+ torch.Tensor
+ Filtered audio data.
+ """
+ if data.is_cuda or self.use_fir:
+ data = self.apply_filter_gpu(data)
+ else:
+ data = self.apply_filter_cpu(data)
+ return data
+
+ def forward(self, data: torch.Tensor):
+ """Computes integrated loudness of data.
+
+ Parameters
+ ----------
+ data : torch.Tensor
+ Audio data of shape (nb, nch, nt).
+
+ Returns
+ -------
+ torch.Tensor
+ Filtered audio data.
+ """
+ return self.integrated_loudness(data)
+
+ def _unfold(self, input_data):
+ T_g = self.block_size
+ overlap = 0.75 # overlap of 75% of the block duration
+ step = 1.0 - overlap # step size by percentage
+
+ kernel_size = int(T_g * self.rate)
+ stride = int(T_g * self.rate * step)
+ unfolded = julius.core.unfold(input_data.permute(0, 2, 1), kernel_size, stride)
+ unfolded = unfolded.transpose(-1, -2)
+
+ return unfolded
+
+ def integrated_loudness(self, data: torch.Tensor):
+ """Computes integrated loudness of data.
+
+ Parameters
+ ----------
+ data : torch.Tensor
+ Audio data of shape (nb, nch, nt).
+
+ Returns
+ -------
+ torch.Tensor
+ Filtered audio data.
+ """
+ if not torch.is_tensor(data):
+ data = torch.from_numpy(data).float()
+ else:
+ data = data.float()
+
+ input_data = copy.copy(data)
+ # Data always has a batch and channel dimension.
+ # Is of shape (nb, nt, nch)
+ if input_data.ndim < 2:
+ input_data = input_data.unsqueeze(-1)
+ if input_data.ndim < 3:
+ input_data = input_data.unsqueeze(0)
+
+ nb, nt, nch = input_data.shape
+
+ # Apply frequency weighting filters - account
+ # for the acoustic respose of the head and auditory system
+ input_data = self.apply_filter(input_data)
+
+ G = self.G # channel gains
+ T_g = self.block_size # 400 ms gating block standard
+ Gamma_a = -70.0 # -70 LKFS = absolute loudness threshold
+
+ unfolded = self._unfold(input_data)
+
+ z = (1.0 / (T_g * self.rate)) * unfolded.square().sum(2)
+ l = -0.691 + 10.0 * torch.log10((G[None, :nch, None] * z).sum(1, keepdim=True))
+ l = l.expand_as(z)
+
+ # find gating block indices above absolute threshold
+ z_avg_gated = z
+ z_avg_gated[l <= Gamma_a] = 0
+ masked = l > Gamma_a
+ z_avg_gated = z_avg_gated.sum(2) / masked.sum(2)
+
+ # calculate the relative threshold value (see eq. 6)
+ Gamma_r = (
+ -0.691 + 10.0 * torch.log10((z_avg_gated * G[None, :nch]).sum(-1)) - 10.0
+ )
+ Gamma_r = Gamma_r[:, None, None]
+ Gamma_r = Gamma_r.expand(nb, nch, l.shape[-1])
+
+ # find gating block indices above relative and absolute thresholds (end of eq. 7)
+ z_avg_gated = z
+ z_avg_gated[l <= Gamma_a] = 0
+ z_avg_gated[l <= Gamma_r] = 0
+ masked = (l > Gamma_a) * (l > Gamma_r)
+ z_avg_gated = z_avg_gated.sum(2) / masked.sum(2)
+
+ # # Cannot use nan_to_num (pytorch 1.8 does not come with GCP-supported cuda version)
+ # z_avg_gated = torch.nan_to_num(z_avg_gated)
+ z_avg_gated = torch.where(
+ z_avg_gated.isnan(), torch.zeros_like(z_avg_gated), z_avg_gated
+ )
+ z_avg_gated[z_avg_gated == float("inf")] = float(np.finfo(np.float32).max)
+ z_avg_gated[z_avg_gated == -float("inf")] = float(np.finfo(np.float32).min)
+
+ LUFS = -0.691 + 10.0 * torch.log10((G[None, :nch] * z_avg_gated).sum(1))
+ return LUFS.float()
+
+ @property
+ def filter_class(self):
+ return self._filter_class
+
+ @filter_class.setter
+ def filter_class(self, value):
+ from pyloudnorm import Meter
+
+ meter = Meter(self.rate)
+ meter.filter_class = value
+ self._filter_class = value
+ self._filters = meter._filters
+
+
+class LoudnessMixin:
+ _loudness = None
+ MIN_LOUDNESS = -70
+ """Minimum loudness possible."""
+
+ def loudness(
+ self, filter_class: str = "K-weighting", block_size: float = 0.400, **kwargs
+ ):
+ """Calculates loudness using an implementation of ITU-R BS.1770-4.
+ Allows control over gating block size and frequency weighting filters for
+ additional control. Measure the integrated gated loudness of a signal.
+
+ API is derived from PyLoudnorm, but this implementation is ported to PyTorch
+ and is tensorized across batches. When on GPU, an FIR approximation of the IIR
+ filters is used to compute loudness for speed.
+
+ Uses the weighting filters and block size defined by the meter
+ the integrated loudness is measured based upon the gating algorithm
+ defined in the ITU-R BS.1770-4 specification.
+
+ Parameters
+ ----------
+ filter_class : str, optional
+ Class of weighting filter used.
+ K-weighting' (default), 'Fenton/Lee 1'
+ 'Fenton/Lee 2', 'Dash et al.'
+ by default "K-weighting"
+ block_size : float, optional
+ Gating block size in seconds, by default 0.400
+ kwargs : dict, optional
+ Keyword arguments to :py:func:`audiotools.core.loudness.Meter`.
+
+ Returns
+ -------
+ torch.Tensor
+ Loudness of audio data.
+ """
+ if self._loudness is not None:
+ return self._loudness.to(self.device)
+ original_length = self.signal_length
+ if self.signal_duration < 0.5:
+ pad_len = int((0.5 - self.signal_duration) * self.sample_rate)
+ self.zero_pad(0, pad_len)
+
+ # create BS.1770 meter
+ meter = Meter(
+ self.sample_rate, filter_class=filter_class, block_size=block_size, **kwargs
+ )
+ meter = meter.to(self.device)
+ # measure loudness
+ loudness = meter.integrated_loudness(self.audio_data.permute(0, 2, 1))
+ self.truncate_samples(original_length)
+ min_loudness = (
+ torch.ones_like(loudness, device=loudness.device) * self.MIN_LOUDNESS
+ )
+ self._loudness = torch.maximum(loudness, min_loudness)
+
+ return self._loudness.to(self.device)
diff --git a/audiotools/core/playback.py b/audiotools/core/playback.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d0f21aaa392494f35305c0084c05b87667ea14d
--- /dev/null
+++ b/audiotools/core/playback.py
@@ -0,0 +1,252 @@
+"""
+These are utilities that allow one to embed an AudioSignal
+as a playable object in a Jupyter notebook, or to play audio from
+the terminal, etc.
+""" # fmt: skip
+import base64
+import io
+import random
+import string
+import subprocess
+from tempfile import NamedTemporaryFile
+
+import importlib_resources as pkg_resources
+
+from . import templates
+from .util import _close_temp_files
+from .util import format_figure
+
+headers = pkg_resources.files(templates).joinpath("headers.html").read_text()
+widget = pkg_resources.files(templates).joinpath("widget.html").read_text()
+
+DEFAULT_EXTENSION = ".wav"
+
+
+def _check_imports(): # pragma: no cover
+ try:
+ import ffmpy
+ except:
+ ffmpy = False
+
+ try:
+ import IPython
+ except:
+ raise ImportError("IPython must be installed in order to use this function!")
+ return ffmpy, IPython
+
+
+class PlayMixin:
+ def embed(self, ext: str = None, display: bool = True, return_html: bool = False):
+ """Embeds audio as a playable audio embed in a notebook, or HTML
+ document, etc.
+
+ Parameters
+ ----------
+ ext : str, optional
+ Extension to use when saving the audio, by default ".wav"
+ display : bool, optional
+ This controls whether or not to display the audio when called. This
+ is used when the embed is the last line in a Jupyter cell, to prevent
+ the audio from being embedded twice, by default True
+ return_html : bool, optional
+ Whether to return the data wrapped in an HTML audio element, by default False
+
+ Returns
+ -------
+ str
+ Either the element for display, or the HTML string of it.
+ """
+ if ext is None:
+ ext = DEFAULT_EXTENSION
+ ext = f".{ext}" if not ext.startswith(".") else ext
+ ffmpy, IPython = _check_imports()
+ sr = self.sample_rate
+ tmpfiles = []
+
+ with _close_temp_files(tmpfiles):
+ tmp_wav = NamedTemporaryFile(mode="w+", suffix=".wav", delete=False)
+ tmpfiles.append(tmp_wav)
+ self.write(tmp_wav.name)
+ if ext != ".wav" and ffmpy:
+ tmp_converted = NamedTemporaryFile(mode="w+", suffix=ext, delete=False)
+ tmpfiles.append(tmp_wav)
+ ff = ffmpy.FFmpeg(
+ inputs={tmp_wav.name: None},
+ outputs={
+ tmp_converted.name: "-write_xing 0 -codec:a libmp3lame -b:a 128k -y -hide_banner -loglevel error"
+ },
+ )
+ ff.run()
+ else:
+ tmp_converted = tmp_wav
+
+ audio_element = IPython.display.Audio(data=tmp_converted.name, rate=sr)
+ if display:
+ IPython.display.display(audio_element)
+
+ if return_html:
+ audio_element = (
+ f" "
+ )
+ return audio_element
+
+ def widget(
+ self,
+ title: str = None,
+ ext: str = ".wav",
+ add_headers: bool = True,
+ player_width: str = "100%",
+ margin: str = "10px",
+ plot_fn: str = "specshow",
+ return_html: bool = False,
+ **kwargs,
+ ):
+ """Creates a playable widget with spectrogram. Inspired (heavily) by
+ https://sjvasquez.github.io/blog/melnet/.
+
+ Parameters
+ ----------
+ title : str, optional
+ Title of plot, placed in upper right of top-most axis.
+ ext : str, optional
+ Extension for embedding, by default ".mp3"
+ add_headers : bool, optional
+ Whether or not to add headers (use for first embed, False for later embeds), by default True
+ player_width : str, optional
+ Width of the player, as a string in a CSS rule, by default "100%"
+ margin : str, optional
+ Margin on all sides of player, by default "10px"
+ plot_fn : function, optional
+ Plotting function to use (by default self.specshow).
+ return_html : bool, optional
+ Whether to return the data wrapped in an HTML audio element, by default False
+ kwargs : dict, optional
+ Keyword arguments to plot_fn (by default self.specshow).
+
+ Returns
+ -------
+ HTML
+ HTML object.
+ """
+ import matplotlib.pyplot as plt
+
+ def _save_fig_to_tag():
+ buffer = io.BytesIO()
+
+ plt.savefig(buffer, bbox_inches="tight", pad_inches=0)
+ plt.close()
+
+ buffer.seek(0)
+ data_uri = base64.b64encode(buffer.read()).decode("ascii")
+ tag = "data:image/png;base64,{0}".format(data_uri)
+
+ return tag
+
+ _, IPython = _check_imports()
+
+ header_html = ""
+
+ if add_headers:
+ header_html = headers.replace("PLAYER_WIDTH", str(player_width))
+ header_html = header_html.replace("MARGIN", str(margin))
+ IPython.display.display(IPython.display.HTML(header_html))
+
+ widget_html = widget
+ if isinstance(plot_fn, str):
+ plot_fn = getattr(self, plot_fn)
+ kwargs["title"] = title
+ plot_fn(**kwargs)
+
+ fig = plt.gcf()
+ pixels = fig.get_size_inches() * fig.dpi
+
+ tag = _save_fig_to_tag()
+
+ # Make the source image for the levels
+ self.specshow()
+ format_figure((12, 1.5))
+ levels_tag = _save_fig_to_tag()
+
+ player_id = "".join(random.choice(string.ascii_uppercase) for _ in range(10))
+
+ audio_elem = self.embed(ext=ext, display=False)
+ widget_html = widget_html.replace("AUDIO_SRC", audio_elem.src_attr())
+ widget_html = widget_html.replace("IMAGE_SRC", tag)
+ widget_html = widget_html.replace("LEVELS_SRC", levels_tag)
+ widget_html = widget_html.replace("PLAYER_ID", player_id)
+
+ # Calculate width/height of figure based on figure size.
+ widget_html = widget_html.replace("PADDING_AMOUNT", f"{int(pixels[1])}px")
+ widget_html = widget_html.replace("MAX_WIDTH", f"{int(pixels[0])}px")
+
+ IPython.display.display(IPython.display.HTML(widget_html))
+
+ if return_html:
+ html = header_html if add_headers else ""
+ html += widget_html
+ return html
+
+ def play(self):
+ """
+ Plays an audio signal if ffplay from the ffmpeg suite of tools is installed.
+ Otherwise, will fail. The audio signal is written to a temporary file
+ and then played with ffplay.
+ """
+ tmpfiles = []
+ with _close_temp_files(tmpfiles):
+ tmp_wav = NamedTemporaryFile(suffix=".wav", delete=False)
+ tmpfiles.append(tmp_wav)
+ self.write(tmp_wav.name)
+ print(self)
+ subprocess.call(
+ [
+ "ffplay",
+ "-nodisp",
+ "-autoexit",
+ "-hide_banner",
+ "-loglevel",
+ "error",
+ tmp_wav.name,
+ ]
+ )
+ return self
+
+
+if __name__ == "__main__": # pragma: no cover
+ from audiotools import AudioSignal
+
+ signal = AudioSignal(
+ "tests/audio/spk/f10_script4_produced.mp3", offset=5, duration=5
+ )
+
+ wave_html = signal.widget(
+ "Waveform",
+ plot_fn="waveplot",
+ return_html=True,
+ )
+
+ spec_html = signal.widget("Spectrogram", return_html=True, add_headers=False)
+
+ combined_html = signal.widget(
+ "Waveform + spectrogram",
+ plot_fn="wavespec",
+ return_html=True,
+ add_headers=False,
+ )
+
+ signal.low_pass(8000)
+ lowpass_html = signal.widget(
+ "Lowpassed audio",
+ plot_fn="wavespec",
+ return_html=True,
+ add_headers=False,
+ )
+
+ with open("/tmp/index.html", "w") as f:
+ f.write(wave_html)
+ f.write(spec_html)
+ f.write(combined_html)
+ f.write(lowpass_html)
diff --git a/audiotools/core/templates/__init__.py b/audiotools/core/templates/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/audiotools/core/templates/headers.html b/audiotools/core/templates/headers.html
new file mode 100644
index 0000000000000000000000000000000000000000..9eaef4a94d575f7826608ad63dcc77fab13b7b19
--- /dev/null
+++ b/audiotools/core/templates/headers.html
@@ -0,0 +1,322 @@
+
+
+
+
+
+
diff --git a/audiotools/core/templates/pandoc.css b/audiotools/core/templates/pandoc.css
new file mode 100644
index 0000000000000000000000000000000000000000..842be7be6d65580dab44c6a8013259644f38e6ee
--- /dev/null
+++ b/audiotools/core/templates/pandoc.css
@@ -0,0 +1,407 @@
+/*
+Copyright (c) 2017 Chris Patuzzo
+https://twitter.com/chrispatuzzo
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+*/
+
+body {
+ font-family: Helvetica, arial, sans-serif;
+ font-size: 14px;
+ line-height: 1.6;
+ padding-top: 10px;
+ padding-bottom: 10px;
+ background-color: white;
+ padding: 30px;
+ color: #333;
+}
+
+body > *:first-child {
+ margin-top: 0 !important;
+}
+
+body > *:last-child {
+ margin-bottom: 0 !important;
+}
+
+a {
+ color: #4183C4;
+ text-decoration: none;
+}
+
+a.absent {
+ color: #cc0000;
+}
+
+a.anchor {
+ display: block;
+ padding-left: 30px;
+ margin-left: -30px;
+ cursor: pointer;
+ position: absolute;
+ top: 0;
+ left: 0;
+ bottom: 0;
+}
+
+h1, h2, h3, h4, h5, h6 {
+ margin: 20px 0 10px;
+ padding: 0;
+ font-weight: bold;
+ -webkit-font-smoothing: antialiased;
+ cursor: text;
+ position: relative;
+}
+
+h2:first-child, h1:first-child, h1:first-child + h2, h3:first-child, h4:first-child, h5:first-child, h6:first-child {
+ margin-top: 0;
+ padding-top: 0;
+}
+
+h1:hover a.anchor, h2:hover a.anchor, h3:hover a.anchor, h4:hover a.anchor, h5:hover a.anchor, h6:hover a.anchor {
+ text-decoration: none;
+}
+
+h1 tt, h1 code {
+ font-size: inherit;
+}
+
+h2 tt, h2 code {
+ font-size: inherit;
+}
+
+h3 tt, h3 code {
+ font-size: inherit;
+}
+
+h4 tt, h4 code {
+ font-size: inherit;
+}
+
+h5 tt, h5 code {
+ font-size: inherit;
+}
+
+h6 tt, h6 code {
+ font-size: inherit;
+}
+
+h1 {
+ font-size: 28px;
+ color: black;
+}
+
+h2 {
+ font-size: 24px;
+ border-bottom: 1px solid #cccccc;
+ color: black;
+}
+
+h3 {
+ font-size: 18px;
+}
+
+h4 {
+ font-size: 16px;
+}
+
+h5 {
+ font-size: 14px;
+}
+
+h6 {
+ color: #777777;
+ font-size: 14px;
+}
+
+p, blockquote, ul, ol, dl, li, table, pre {
+ margin: 15px 0;
+}
+
+hr {
+ border: 0 none;
+ color: #cccccc;
+ height: 4px;
+ padding: 0;
+}
+
+body > h2:first-child {
+ margin-top: 0;
+ padding-top: 0;
+}
+
+body > h1:first-child {
+ margin-top: 0;
+ padding-top: 0;
+}
+
+body > h1:first-child + h2 {
+ margin-top: 0;
+ padding-top: 0;
+}
+
+body > h3:first-child, body > h4:first-child, body > h5:first-child, body > h6:first-child {
+ margin-top: 0;
+ padding-top: 0;
+}
+
+a:first-child h1, a:first-child h2, a:first-child h3, a:first-child h4, a:first-child h5, a:first-child h6 {
+ margin-top: 0;
+ padding-top: 0;
+}
+
+h1 p, h2 p, h3 p, h4 p, h5 p, h6 p {
+ margin-top: 0;
+}
+
+li p.first {
+ display: inline-block;
+}
+
+ul, ol {
+ padding-left: 30px;
+}
+
+ul :first-child, ol :first-child {
+ margin-top: 0;
+}
+
+ul :last-child, ol :last-child {
+ margin-bottom: 0;
+}
+
+dl {
+ padding: 0;
+}
+
+dl dt {
+ font-size: 14px;
+ font-weight: bold;
+ font-style: italic;
+ padding: 0;
+ margin: 15px 0 5px;
+}
+
+dl dt:first-child {
+ padding: 0;
+}
+
+dl dt > :first-child {
+ margin-top: 0;
+}
+
+dl dt > :last-child {
+ margin-bottom: 0;
+}
+
+dl dd {
+ margin: 0 0 15px;
+ padding: 0 15px;
+}
+
+dl dd > :first-child {
+ margin-top: 0;
+}
+
+dl dd > :last-child {
+ margin-bottom: 0;
+}
+
+blockquote {
+ border-left: 4px solid #dddddd;
+ padding: 0 15px;
+ color: #777777;
+}
+
+blockquote > :first-child {
+ margin-top: 0;
+}
+
+blockquote > :last-child {
+ margin-bottom: 0;
+}
+
+table {
+ padding: 0;
+}
+table tr {
+ border-top: 1px solid #cccccc;
+ background-color: white;
+ margin: 0;
+ padding: 0;
+}
+
+table tr:nth-child(2n) {
+ background-color: #f8f8f8;
+}
+
+table tr th {
+ font-weight: bold;
+ border: 1px solid #cccccc;
+ text-align: left;
+ margin: 0;
+ padding: 6px 13px;
+}
+
+table tr td {
+ border: 1px solid #cccccc;
+ text-align: left;
+ margin: 0;
+ padding: 6px 13px;
+}
+
+table tr th :first-child, table tr td :first-child {
+ margin-top: 0;
+}
+
+table tr th :last-child, table tr td :last-child {
+ margin-bottom: 0;
+}
+
+img {
+ max-width: 100%;
+}
+
+span.frame {
+ display: block;
+ overflow: hidden;
+}
+
+span.frame > span {
+ border: 1px solid #dddddd;
+ display: block;
+ float: left;
+ overflow: hidden;
+ margin: 13px 0 0;
+ padding: 7px;
+ width: auto;
+}
+
+span.frame span img {
+ display: block;
+ float: left;
+}
+
+span.frame span span {
+ clear: both;
+ color: #333333;
+ display: block;
+ padding: 5px 0 0;
+}
+
+span.align-center {
+ display: block;
+ overflow: hidden;
+ clear: both;
+}
+
+span.align-center > span {
+ display: block;
+ overflow: hidden;
+ margin: 13px auto 0;
+ text-align: center;
+}
+
+span.align-center span img {
+ margin: 0 auto;
+ text-align: center;
+}
+
+span.align-right {
+ display: block;
+ overflow: hidden;
+ clear: both;
+}
+
+span.align-right > span {
+ display: block;
+ overflow: hidden;
+ margin: 13px 0 0;
+ text-align: right;
+}
+
+span.align-right span img {
+ margin: 0;
+ text-align: right;
+}
+
+span.float-left {
+ display: block;
+ margin-right: 13px;
+ overflow: hidden;
+ float: left;
+}
+
+span.float-left span {
+ margin: 13px 0 0;
+}
+
+span.float-right {
+ display: block;
+ margin-left: 13px;
+ overflow: hidden;
+ float: right;
+}
+
+span.float-right > span {
+ display: block;
+ overflow: hidden;
+ margin: 13px auto 0;
+ text-align: right;
+}
+
+code, tt {
+ margin: 0 2px;
+ padding: 0 5px;
+ white-space: nowrap;
+ border-radius: 3px;
+}
+
+pre code {
+ margin: 0;
+ padding: 0;
+ white-space: pre;
+ border: none;
+ background: transparent;
+}
+
+.highlight pre {
+ font-size: 13px;
+ line-height: 19px;
+ overflow: auto;
+ padding: 6px 10px;
+ border-radius: 3px;
+}
+
+pre {
+ font-size: 13px;
+ line-height: 19px;
+ overflow: auto;
+ padding: 6px 10px;
+ border-radius: 3px;
+}
+
+pre code, pre tt {
+ background-color: transparent;
+ border: none;
+}
+
+body {
+ max-width: 600px;
+}
diff --git a/audiotools/core/templates/widget.html b/audiotools/core/templates/widget.html
new file mode 100644
index 0000000000000000000000000000000000000000..0b44e8aec64fd1db929da5fa6208dee00247c967
--- /dev/null
+++ b/audiotools/core/templates/widget.html
@@ -0,0 +1,52 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/audiotools/core/util.py b/audiotools/core/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..ece1344658d10836aa2eb693f275294ad8cdbb52
--- /dev/null
+++ b/audiotools/core/util.py
@@ -0,0 +1,671 @@
+import csv
+import glob
+import math
+import numbers
+import os
+import random
+import typing
+from contextlib import contextmanager
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Dict
+from typing import List
+
+import numpy as np
+import torch
+import torchaudio
+from flatten_dict import flatten
+from flatten_dict import unflatten
+
+
+@dataclass
+class Info:
+ """Shim for torchaudio.info API changes."""
+
+ sample_rate: float
+ num_frames: int
+
+ @property
+ def duration(self) -> float:
+ return self.num_frames / self.sample_rate
+
+
+def info(audio_path: str):
+ """Shim for torchaudio.info to make 0.7.2 API match 0.8.0.
+
+ Parameters
+ ----------
+ audio_path : str
+ Path to audio file.
+ """
+ # try default backend first, then fallback to soundfile
+ try:
+ info = torchaudio.info(str(audio_path))
+ except: # pragma: no cover
+ info = torchaudio.backend.soundfile_backend.info(str(audio_path))
+
+ if isinstance(info, tuple): # pragma: no cover
+ signal_info = info[0]
+ info = Info(sample_rate=signal_info.rate, num_frames=signal_info.length)
+ else:
+ info = Info(sample_rate=info.sample_rate, num_frames=info.num_frames)
+
+ return info
+
+
+def ensure_tensor(
+ x: typing.Union[np.ndarray, torch.Tensor, float, int],
+ ndim: int = None,
+ batch_size: int = None,
+):
+ """Ensures that the input ``x`` is a tensor of specified
+ dimensions and batch size.
+
+ Parameters
+ ----------
+ x : typing.Union[np.ndarray, torch.Tensor, float, int]
+ Data that will become a tensor on its way out.
+ ndim : int, optional
+ How many dimensions should be in the output, by default None
+ batch_size : int, optional
+ The batch size of the output, by default None
+
+ Returns
+ -------
+ torch.Tensor
+ Modified version of ``x`` as a tensor.
+ """
+ if not torch.is_tensor(x):
+ x = torch.as_tensor(x)
+ if ndim is not None:
+ assert x.ndim <= ndim
+ while x.ndim < ndim:
+ x = x.unsqueeze(-1)
+ if batch_size is not None:
+ if x.shape[0] != batch_size:
+ shape = list(x.shape)
+ shape[0] = batch_size
+ x = x.expand(*shape)
+ return x
+
+
+def _get_value(other):
+ from . import AudioSignal
+
+ if isinstance(other, AudioSignal):
+ return other.audio_data
+ return other
+
+
+def hz_to_bin(hz: torch.Tensor, n_fft: int, sample_rate: int):
+ """Closest frequency bin given a frequency, number
+ of bins, and a sampling rate.
+
+ Parameters
+ ----------
+ hz : torch.Tensor
+ Tensor of frequencies in Hz.
+ n_fft : int
+ Number of FFT bins.
+ sample_rate : int
+ Sample rate of audio.
+
+ Returns
+ -------
+ torch.Tensor
+ Closest bins to the data.
+ """
+ shape = hz.shape
+ hz = hz.flatten()
+ freqs = torch.linspace(0, sample_rate / 2, 2 + n_fft // 2)
+ hz[hz > sample_rate / 2] = sample_rate / 2
+
+ closest = (hz[None, :] - freqs[:, None]).abs()
+ closest_bins = closest.min(dim=0).indices
+
+ return closest_bins.reshape(*shape)
+
+
+def random_state(seed: typing.Union[int, np.random.RandomState]):
+ """
+ Turn seed into a np.random.RandomState instance.
+
+ Parameters
+ ----------
+ seed : typing.Union[int, np.random.RandomState] or None
+ If seed is None, return the RandomState singleton used by np.random.
+ If seed is an int, return a new RandomState instance seeded with seed.
+ If seed is already a RandomState instance, return it.
+ Otherwise raise ValueError.
+
+ Returns
+ -------
+ np.random.RandomState
+ Random state object.
+
+ Raises
+ ------
+ ValueError
+ If seed is not valid, an error is thrown.
+ """
+ if seed is None or seed is np.random:
+ return np.random.mtrand._rand
+ elif isinstance(seed, (numbers.Integral, np.integer, int)):
+ return np.random.RandomState(seed)
+ elif isinstance(seed, np.random.RandomState):
+ return seed
+ else:
+ raise ValueError(
+ "%r cannot be used to seed a numpy.random.RandomState" " instance" % seed
+ )
+
+
+def seed(random_seed, set_cudnn=False):
+ """
+ Seeds all random states with the same random seed
+ for reproducibility. Seeds ``numpy``, ``random`` and ``torch``
+ random generators.
+ For full reproducibility, two further options must be set
+ according to the torch documentation:
+ https://pytorch.org/docs/stable/notes/randomness.html
+ To do this, ``set_cudnn`` must be True. It defaults to
+ False, since setting it to True results in a performance
+ hit.
+
+ Args:
+ random_seed (int): integer corresponding to random seed to
+ use.
+ set_cudnn (bool): Whether or not to set cudnn into determinstic
+ mode and off of benchmark mode. Defaults to False.
+ """
+
+ torch.manual_seed(random_seed)
+ np.random.seed(random_seed)
+ random.seed(random_seed)
+
+ if set_cudnn:
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+
+@contextmanager
+def _close_temp_files(tmpfiles: list):
+ """Utility function for creating a context and closing all temporary files
+ once the context is exited. For correct functionality, all temporary file
+ handles created inside the context must be appended to the ```tmpfiles```
+ list.
+
+ This function is taken wholesale from Scaper.
+
+ Parameters
+ ----------
+ tmpfiles : list
+ List of temporary file handles
+ """
+
+ def _close():
+ for t in tmpfiles:
+ try:
+ t.close()
+ os.unlink(t.name)
+ except:
+ pass
+
+ try:
+ yield
+ except: # pragma: no cover
+ _close()
+ raise
+ _close()
+
+
+AUDIO_EXTENSIONS = [".wav", ".flac", ".mp3", ".mp4"]
+
+
+def find_audio(folder: str, ext: List[str] = AUDIO_EXTENSIONS):
+ """Finds all audio files in a directory recursively.
+ Returns a list.
+
+ Parameters
+ ----------
+ folder : str
+ Folder to look for audio files in, recursively.
+ ext : List[str], optional
+ Extensions to look for without the ., by default
+ ``['.wav', '.flac', '.mp3', '.mp4']``.
+ """
+ folder = Path(folder)
+ # Take care of case where user has passed in an audio file directly
+ # into one of the calling functions.
+ if str(folder).endswith(tuple(ext)):
+ # if, however, there's a glob in the path, we need to
+ # return the glob, not the file.
+ if "*" in str(folder):
+ return glob.glob(str(folder), recursive=("**" in str(folder)))
+ else:
+ return [folder]
+
+ files = []
+ for x in ext:
+ files += folder.glob(f"**/*{x}")
+ return files
+
+
+def read_sources(
+ sources: List[str],
+ remove_empty: bool = True,
+ relative_path: str = "",
+ ext: List[str] = AUDIO_EXTENSIONS,
+):
+ """Reads audio sources that can either be folders
+ full of audio files, or CSV files that contain paths
+ to audio files. CSV files that adhere to the expected
+ format can be generated by
+ :py:func:`audiotools.data.preprocess.create_csv`.
+
+ Parameters
+ ----------
+ sources : List[str]
+ List of audio sources to be converted into a
+ list of lists of audio files.
+ remove_empty : bool, optional
+ Whether or not to remove rows with an empty "path"
+ from each CSV file, by default True.
+
+ Returns
+ -------
+ list
+ List of lists of rows of CSV files.
+ """
+ files = []
+ relative_path = Path(relative_path)
+ for source in sources:
+ source = str(source)
+ _files = []
+ if source.endswith(".csv"):
+ with open(source, "r") as f:
+ reader = csv.DictReader(f)
+ for x in reader:
+ if remove_empty and x["path"] == "":
+ continue
+ if x["path"] != "":
+ x["path"] = str(relative_path / x["path"])
+ _files.append(x)
+ else:
+ for x in find_audio(source, ext=ext):
+ x = str(relative_path / x)
+ _files.append({"path": x})
+ files.append(sorted(_files, key=lambda x: x["path"]))
+ return files
+
+
+def choose_from_list_of_lists(
+ state: np.random.RandomState, list_of_lists: list, p: float = None
+):
+ """Choose a single item from a list of lists.
+
+ Parameters
+ ----------
+ state : np.random.RandomState
+ Random state to use when choosing an item.
+ list_of_lists : list
+ A list of lists from which items will be drawn.
+ p : float, optional
+ Probabilities of each list, by default None
+
+ Returns
+ -------
+ typing.Any
+ An item from the list of lists.
+ """
+ source_idx = state.choice(list(range(len(list_of_lists))), p=p)
+ item_idx = state.randint(len(list_of_lists[source_idx]))
+ return list_of_lists[source_idx][item_idx], source_idx, item_idx
+
+
+@contextmanager
+def chdir(newdir: typing.Union[Path, str]):
+ """
+ Context manager for switching directories to run a
+ function. Useful for when you want to use relative
+ paths to different runs.
+
+ Parameters
+ ----------
+ newdir : typing.Union[Path, str]
+ Directory to switch to.
+ """
+ curdir = os.getcwd()
+ try:
+ os.chdir(newdir)
+ yield
+ finally:
+ os.chdir(curdir)
+
+
+def prepare_batch(batch: typing.Union[dict, list, torch.Tensor], device: str = "cpu"):
+ """Moves items in a batch (typically generated by a DataLoader as a list
+ or a dict) to the specified device. This works even if dictionaries
+ are nested.
+
+ Parameters
+ ----------
+ batch : typing.Union[dict, list, torch.Tensor]
+ Batch, typically generated by a dataloader, that will be moved to
+ the device.
+ device : str, optional
+ Device to move batch to, by default "cpu"
+
+ Returns
+ -------
+ typing.Union[dict, list, torch.Tensor]
+ Batch with all values moved to the specified device.
+ """
+ if isinstance(batch, dict):
+ batch = flatten(batch)
+ for key, val in batch.items():
+ try:
+ batch[key] = val.to(device)
+ except:
+ pass
+ batch = unflatten(batch)
+ elif torch.is_tensor(batch):
+ batch = batch.to(device)
+ elif isinstance(batch, list):
+ for i in range(len(batch)):
+ try:
+ batch[i] = batch[i].to(device)
+ except:
+ pass
+ return batch
+
+
+def sample_from_dist(dist_tuple: tuple, state: np.random.RandomState = None):
+ """Samples from a distribution defined by a tuple. The first
+ item in the tuple is the distribution type, and the rest of the
+ items are arguments to that distribution. The distribution function
+ is gotten from the ``np.random.RandomState`` object.
+
+ Parameters
+ ----------
+ dist_tuple : tuple
+ Distribution tuple
+ state : np.random.RandomState, optional
+ Random state, or seed to use, by default None
+
+ Returns
+ -------
+ typing.Union[float, int, str]
+ Draw from the distribution.
+
+ Examples
+ --------
+ Sample from a uniform distribution:
+
+ >>> dist_tuple = ("uniform", 0, 1)
+ >>> sample_from_dist(dist_tuple)
+
+ Sample from a constant distribution:
+
+ >>> dist_tuple = ("const", 0)
+ >>> sample_from_dist(dist_tuple)
+
+ Sample from a normal distribution:
+
+ >>> dist_tuple = ("normal", 0, 0.5)
+ >>> sample_from_dist(dist_tuple)
+
+ """
+ if dist_tuple[0] == "const":
+ return dist_tuple[1]
+ state = random_state(state)
+ dist_fn = getattr(state, dist_tuple[0])
+ return dist_fn(*dist_tuple[1:])
+
+
+def collate(list_of_dicts: list, n_splits: int = None):
+ """Collates a list of dictionaries (e.g. as returned by a
+ dataloader) into a dictionary with batched values. This routine
+ uses the default torch collate function for everything
+ except AudioSignal objects, which are handled by the
+ :py:func:`audiotools.core.audio_signal.AudioSignal.batch`
+ function.
+
+ This function takes n_splits to enable splitting a batch
+ into multiple sub-batches for the purposes of gradient accumulation,
+ etc.
+
+ Parameters
+ ----------
+ list_of_dicts : list
+ List of dictionaries to be collated.
+ n_splits : int
+ Number of splits to make when creating the batches (split into
+ sub-batches). Useful for things like gradient accumulation.
+
+ Returns
+ -------
+ dict
+ Dictionary containing batched data.
+ """
+
+ from . import AudioSignal
+
+ batches = []
+ list_len = len(list_of_dicts)
+
+ return_list = False if n_splits is None else True
+ n_splits = 1 if n_splits is None else n_splits
+ n_items = int(math.ceil(list_len / n_splits))
+
+ for i in range(0, list_len, n_items):
+ # Flatten the dictionaries to avoid recursion.
+ list_of_dicts_ = [flatten(d) for d in list_of_dicts[i : i + n_items]]
+ dict_of_lists = {
+ k: [dic[k] for dic in list_of_dicts_] for k in list_of_dicts_[0]
+ }
+
+ batch = {}
+ for k, v in dict_of_lists.items():
+ if isinstance(v, list):
+ if all(isinstance(s, AudioSignal) for s in v):
+ batch[k] = AudioSignal.batch(v, pad_signals=True)
+ else:
+ # Borrow the default collate fn from torch.
+ batch[k] = torch.utils.data._utils.collate.default_collate(v)
+ batches.append(unflatten(batch))
+
+ batches = batches[0] if not return_list else batches
+ return batches
+
+
+BASE_SIZE = 864
+DEFAULT_FIG_SIZE = (9, 3)
+
+
+def format_figure(
+ fig_size: tuple = None,
+ title: str = None,
+ fig=None,
+ format_axes: bool = True,
+ format: bool = True,
+ font_color: str = "white",
+):
+ """Prettifies the spectrogram and waveform plots. A title
+ can be inset into the top right corner, and the axes can be
+ inset into the figure, allowing the data to take up the entire
+ image. Used in
+
+ - :py:func:`audiotools.core.display.DisplayMixin.specshow`
+ - :py:func:`audiotools.core.display.DisplayMixin.waveplot`
+ - :py:func:`audiotools.core.display.DisplayMixin.wavespec`
+
+ Parameters
+ ----------
+ fig_size : tuple, optional
+ Size of figure, by default (9, 3)
+ title : str, optional
+ Title to inset in top right, by default None
+ fig : matplotlib.figure.Figure, optional
+ Figure object, if None ``plt.gcf()`` will be used, by default None
+ format_axes : bool, optional
+ Format the axes to be inside the figure, by default True
+ format : bool, optional
+ This formatting can be skipped entirely by passing ``format=False``
+ to any of the plotting functions that use this formater, by default True
+ font_color : str, optional
+ Color of font of axes, by default "white"
+ """
+ import matplotlib
+ import matplotlib.pyplot as plt
+
+ if fig_size is None:
+ fig_size = DEFAULT_FIG_SIZE
+ if not format:
+ return
+ if fig is None:
+ fig = plt.gcf()
+ fig.set_size_inches(*fig_size)
+ axs = fig.axes
+
+ pixels = (fig.get_size_inches() * fig.dpi)[0]
+ font_scale = pixels / BASE_SIZE
+
+ if format_axes:
+ axs = fig.axes
+
+ for ax in axs:
+ ymin, _ = ax.get_ylim()
+ xmin, _ = ax.get_xlim()
+
+ ticks = ax.get_yticks()
+ for t in ticks[2:-1]:
+ t = axs[0].annotate(
+ f"{(t / 1000):2.1f}k",
+ xy=(xmin, t),
+ xycoords="data",
+ xytext=(5, -5),
+ textcoords="offset points",
+ ha="left",
+ va="top",
+ color=font_color,
+ fontsize=12 * font_scale,
+ alpha=0.75,
+ )
+
+ ticks = ax.get_xticks()[2:]
+ for t in ticks[:-1]:
+ t = axs[0].annotate(
+ f"{t:2.1f}s",
+ xy=(t, ymin),
+ xycoords="data",
+ xytext=(5, 5),
+ textcoords="offset points",
+ ha="center",
+ va="bottom",
+ color=font_color,
+ fontsize=12 * font_scale,
+ alpha=0.75,
+ )
+
+ ax.margins(0, 0)
+ ax.set_axis_off()
+ ax.xaxis.set_major_locator(plt.NullLocator())
+ ax.yaxis.set_major_locator(plt.NullLocator())
+
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
+
+ if title is not None:
+ t = axs[0].annotate(
+ title,
+ xy=(1, 1),
+ xycoords="axes fraction",
+ fontsize=20 * font_scale,
+ xytext=(-5, -5),
+ textcoords="offset points",
+ ha="right",
+ va="top",
+ color="white",
+ )
+ t.set_bbox(dict(facecolor="black", alpha=0.5, edgecolor="black"))
+
+
+def generate_chord_dataset(
+ max_voices: int = 8,
+ sample_rate: int = 44100,
+ num_items: int = 5,
+ duration: float = 1.0,
+ min_note: str = "C2",
+ max_note: str = "C6",
+ output_dir: Path = "chords",
+):
+ """
+ Generates a toy multitrack dataset of chords, synthesized from sine waves.
+
+
+ Parameters
+ ----------
+ max_voices : int, optional
+ Maximum number of voices in a chord, by default 8
+ sample_rate : int, optional
+ Sample rate of audio, by default 44100
+ num_items : int, optional
+ Number of items to generate, by default 5
+ duration : float, optional
+ Duration of each item, by default 1.0
+ min_note : str, optional
+ Minimum note in the dataset, by default "C2"
+ max_note : str, optional
+ Maximum note in the dataset, by default "C6"
+ output_dir : Path, optional
+ Directory to save the dataset, by default "chords"
+
+ """
+ import librosa
+ from . import AudioSignal
+ from ..data.preprocess import create_csv
+
+ min_midi = librosa.note_to_midi(min_note)
+ max_midi = librosa.note_to_midi(max_note)
+
+ tracks = []
+ for idx in range(num_items):
+ track = {}
+ # figure out how many voices to put in this track
+ num_voices = random.randint(1, max_voices)
+ for voice_idx in range(num_voices):
+ # choose some random params
+ midinote = random.randint(min_midi, max_midi)
+ dur = random.uniform(0.85 * duration, duration)
+
+ sig = AudioSignal.wave(
+ frequency=librosa.midi_to_hz(midinote),
+ duration=dur,
+ sample_rate=sample_rate,
+ shape="sine",
+ )
+ track[f"voice_{voice_idx}"] = sig
+ tracks.append(track)
+
+ # save the tracks to disk
+ output_dir = Path(output_dir)
+ output_dir.mkdir(exist_ok=True)
+ for idx, track in enumerate(tracks):
+ track_dir = output_dir / f"track_{idx}"
+ track_dir.mkdir(exist_ok=True)
+ for voice_name, sig in track.items():
+ sig.write(track_dir / f"{voice_name}.wav")
+
+ all_voices = list(set([k for track in tracks for k in track.keys()]))
+ voice_lists = {voice: [] for voice in all_voices}
+ for track in tracks:
+ for voice_name in all_voices:
+ if voice_name in track:
+ voice_lists[voice_name].append(track[voice_name].path_to_file)
+ else:
+ voice_lists[voice_name].append("")
+
+ for voice_name, paths in voice_lists.items():
+ create_csv(paths, output_dir / f"{voice_name}.csv", loudness=True)
+
+ return output_dir
diff --git a/audiotools/core/whisper.py b/audiotools/core/whisper.py
new file mode 100644
index 0000000000000000000000000000000000000000..46c071f934fc3e2be3138e7596b1c6d2ef79eade
--- /dev/null
+++ b/audiotools/core/whisper.py
@@ -0,0 +1,97 @@
+import torch
+
+
+class WhisperMixin:
+ is_initialized = False
+
+ def setup_whisper(
+ self,
+ pretrained_model_name_or_path: str = "openai/whisper-base.en",
+ device: str = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
+ ):
+ from transformers import WhisperForConditionalGeneration
+ from transformers import WhisperProcessor
+
+ self.whisper_device = device
+ self.whisper_processor = WhisperProcessor.from_pretrained(
+ pretrained_model_name_or_path
+ )
+ self.whisper_model = WhisperForConditionalGeneration.from_pretrained(
+ pretrained_model_name_or_path
+ ).to(self.whisper_device)
+ self.is_initialized = True
+
+ def get_whisper_features(self) -> torch.Tensor:
+ """Preprocess audio signal as per the whisper model's training config.
+
+ Returns
+ -------
+ torch.Tensor
+ The prepinput features of the audio signal. Shape: (1, channels, seq_len)
+ """
+ import torch
+
+ if not self.is_initialized:
+ self.setup_whisper()
+
+ signal = self.to(self.device)
+ raw_speech = list(
+ (
+ signal.clone()
+ .resample(self.whisper_processor.feature_extractor.sampling_rate)
+ .audio_data[:, 0, :]
+ .numpy()
+ )
+ )
+
+ with torch.inference_mode():
+ input_features = self.whisper_processor(
+ raw_speech,
+ sampling_rate=self.whisper_processor.feature_extractor.sampling_rate,
+ return_tensors="pt",
+ ).input_features
+
+ return input_features
+
+ def get_whisper_transcript(self) -> str:
+ """Get the transcript of the audio signal using the whisper model.
+
+ Returns
+ -------
+ str
+ The transcript of the audio signal, including special tokens such as <|startoftranscript|> and <|endoftext|>.
+ """
+
+ if not self.is_initialized:
+ self.setup_whisper()
+
+ input_features = self.get_whisper_features()
+
+ with torch.inference_mode():
+ input_features = input_features.to(self.whisper_device)
+ generated_ids = self.whisper_model.generate(inputs=input_features)
+
+ transcription = self.whisper_processor.batch_decode(generated_ids)
+ return transcription[0]
+
+ def get_whisper_embeddings(self) -> torch.Tensor:
+ """Get the last hidden state embeddings of the audio signal using the whisper model.
+
+ Returns
+ -------
+ torch.Tensor
+ The Whisper embeddings of the audio signal. Shape: (1, seq_len, hidden_size)
+ """
+ import torch
+
+ if not self.is_initialized:
+ self.setup_whisper()
+
+ input_features = self.get_whisper_features()
+ encoder = self.whisper_model.get_encoder()
+
+ with torch.inference_mode():
+ input_features = input_features.to(self.whisper_device)
+ embeddings = encoder(input_features)
+
+ return embeddings.last_hidden_state
diff --git a/audiotools/data/__init__.py b/audiotools/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..aead269f26f3782043e68418b4c87ee323cbd015
--- /dev/null
+++ b/audiotools/data/__init__.py
@@ -0,0 +1,3 @@
+from . import datasets
+from . import preprocess
+from . import transforms
diff --git a/audiotools/data/datasets.py b/audiotools/data/datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..12e7a60963399aa15ff865de2d06537818ce18ee
--- /dev/null
+++ b/audiotools/data/datasets.py
@@ -0,0 +1,517 @@
+from pathlib import Path
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Union
+
+import numpy as np
+from torch.utils.data import SequentialSampler
+from torch.utils.data.distributed import DistributedSampler
+
+from ..core import AudioSignal
+from ..core import util
+
+
+class AudioLoader:
+ """Loads audio endlessly from a list of audio sources
+ containing paths to audio files. Audio sources can be
+ folders full of audio files (which are found via file
+ extension) or by providing a CSV file which contains paths
+ to audio files.
+
+ Parameters
+ ----------
+ sources : List[str], optional
+ Sources containing folders, or CSVs with
+ paths to audio files, by default None
+ weights : List[float], optional
+ Weights to sample audio files from each source, by default None
+ relative_path : str, optional
+ Path audio should be loaded relative to, by default ""
+ transform : Callable, optional
+ Transform to instantiate alongside audio sample,
+ by default None
+ ext : List[str]
+ List of extensions to find audio within each source by. Can
+ also be a file name (e.g. "vocals.wav"). by default
+ ``['.wav', '.flac', '.mp3', '.mp4']``.
+ shuffle: bool
+ Whether to shuffle the files within the dataloader. Defaults to True.
+ shuffle_state: int
+ State to use to seed the shuffle of the files.
+ """
+
+ def __init__(
+ self,
+ sources: List[str] = None,
+ weights: List[float] = None,
+ transform: Callable = None,
+ relative_path: str = "",
+ ext: List[str] = util.AUDIO_EXTENSIONS,
+ shuffle: bool = True,
+ shuffle_state: int = 0,
+ ):
+ self.audio_lists = util.read_sources(
+ sources, relative_path=relative_path, ext=ext
+ )
+
+ self.audio_indices = [
+ (src_idx, item_idx)
+ for src_idx, src in enumerate(self.audio_lists)
+ for item_idx in range(len(src))
+ ]
+ if shuffle:
+ state = util.random_state(shuffle_state)
+ state.shuffle(self.audio_indices)
+
+ self.sources = sources
+ self.weights = weights
+ self.transform = transform
+
+ def __call__(
+ self,
+ state,
+ sample_rate: int,
+ duration: float,
+ loudness_cutoff: float = -40,
+ num_channels: int = 1,
+ offset: float = None,
+ source_idx: int = None,
+ item_idx: int = None,
+ global_idx: int = None,
+ ):
+ if source_idx is not None and item_idx is not None:
+ try:
+ audio_info = self.audio_lists[source_idx][item_idx]
+ except:
+ audio_info = {"path": "none"}
+ elif global_idx is not None:
+ source_idx, item_idx = self.audio_indices[
+ global_idx % len(self.audio_indices)
+ ]
+ audio_info = self.audio_lists[source_idx][item_idx]
+ else:
+ audio_info, source_idx, item_idx = util.choose_from_list_of_lists(
+ state, self.audio_lists, p=self.weights
+ )
+
+ path = audio_info["path"]
+ signal = AudioSignal.zeros(duration, sample_rate, num_channels)
+
+ if path != "none":
+ if offset is None:
+ signal = AudioSignal.salient_excerpt(
+ path,
+ duration=duration,
+ state=state,
+ loudness_cutoff=loudness_cutoff,
+ )
+ else:
+ signal = AudioSignal(
+ path,
+ offset=offset,
+ duration=duration,
+ )
+
+ if num_channels == 1:
+ signal = signal.to_mono()
+ signal = signal.resample(sample_rate)
+
+ if signal.duration < duration:
+ signal = signal.zero_pad_to(int(duration * sample_rate))
+
+ for k, v in audio_info.items():
+ signal.metadata[k] = v
+
+ item = {
+ "signal": signal,
+ "source_idx": source_idx,
+ "item_idx": item_idx,
+ "source": str(self.sources[source_idx]),
+ "path": str(path),
+ }
+ if self.transform is not None:
+ item["transform_args"] = self.transform.instantiate(state, signal=signal)
+ return item
+
+
+def default_matcher(x, y):
+ return Path(x).parent == Path(y).parent
+
+
+def align_lists(lists, matcher: Callable = default_matcher):
+ longest_list = lists[np.argmax([len(l) for l in lists])]
+ for i, x in enumerate(longest_list):
+ for l in lists:
+ if i >= len(l):
+ l.append({"path": "none"})
+ elif not matcher(l[i]["path"], x["path"]):
+ l.insert(i, {"path": "none"})
+ return lists
+
+
+class AudioDataset:
+ """Loads audio from multiple loaders (with associated transforms)
+ for a specified number of samples. Excerpts are drawn randomly
+ of the specified duration, above a specified loudness threshold
+ and are resampled on the fly to the desired sample rate
+ (if it is different from the audio source sample rate).
+
+ This takes either a single AudioLoader object,
+ a dictionary of AudioLoader objects, or a dictionary of AudioLoader
+ objects. Each AudioLoader is called by the dataset, and the
+ result is placed in the output dictionary. A transform can also be
+ specified for the entire dataset, rather than for each specific
+ loader. This transform can be applied to the output of all the
+ loaders if desired.
+
+ AudioLoader objects can be specified as aligned, which means the
+ loaders correspond to multitrack audio (e.g. a vocals, bass,
+ drums, and other loader for multitrack music mixtures).
+
+
+ Parameters
+ ----------
+ loaders : Union[AudioLoader, List[AudioLoader], Dict[str, AudioLoader]]
+ AudioLoaders to sample audio from.
+ sample_rate : int
+ Desired sample rate.
+ n_examples : int, optional
+ Number of examples (length of dataset), by default 1000
+ duration : float, optional
+ Duration of audio samples, by default 0.5
+ loudness_cutoff : float, optional
+ Loudness cutoff threshold for audio samples, by default -40
+ num_channels : int, optional
+ Number of channels in output audio, by default 1
+ transform : Callable, optional
+ Transform to instantiate alongside each dataset item, by default None
+ aligned : bool, optional
+ Whether the loaders should be sampled in an aligned manner (e.g. same
+ offset, duration, and matched file name), by default False
+ shuffle_loaders : bool, optional
+ Whether to shuffle the loaders before sampling from them, by default False
+ matcher : Callable
+ How to match files from adjacent audio lists (e.g. for a multitrack audio loader),
+ by default uses the parent directory of each file.
+ without_replacement : bool
+ Whether to choose files with or without replacement, by default True.
+
+
+ Examples
+ --------
+ >>> from audiotools.data.datasets import AudioLoader
+ >>> from audiotools.data.datasets import AudioDataset
+ >>> from audiotools import transforms as tfm
+ >>> import numpy as np
+ >>>
+ >>> loaders = [
+ >>> AudioLoader(
+ >>> sources=[f"tests/audio/spk"],
+ >>> transform=tfm.Equalizer(),
+ >>> ext=["wav"],
+ >>> )
+ >>> for i in range(5)
+ >>> ]
+ >>>
+ >>> dataset = AudioDataset(
+ >>> loaders = loaders,
+ >>> sample_rate = 44100,
+ >>> duration = 1.0,
+ >>> transform = tfm.RescaleAudio(),
+ >>> )
+ >>>
+ >>> item = dataset[np.random.randint(len(dataset))]
+ >>>
+ >>> for i in range(len(loaders)):
+ >>> item[i]["signal"] = loaders[i].transform(
+ >>> item[i]["signal"], **item[i]["transform_args"]
+ >>> )
+ >>> item[i]["signal"].widget(i)
+ >>>
+ >>> mix = sum([item[i]["signal"] for i in range(len(loaders))])
+ >>> mix = dataset.transform(mix, **item["transform_args"])
+ >>> mix.widget("mix")
+
+ Below is an example of how one could load MUSDB multitrack data:
+
+ >>> import audiotools as at
+ >>> from pathlib import Path
+ >>> from audiotools import transforms as tfm
+ >>> import numpy as np
+ >>> import torch
+ >>>
+ >>> def build_dataset(
+ >>> sample_rate: int = 44100,
+ >>> duration: float = 5.0,
+ >>> musdb_path: str = "~/.data/musdb/",
+ >>> ):
+ >>> musdb_path = Path(musdb_path).expanduser()
+ >>> loaders = {
+ >>> src: at.datasets.AudioLoader(
+ >>> sources=[musdb_path],
+ >>> transform=tfm.Compose(
+ >>> tfm.VolumeNorm(("uniform", -20, -10)),
+ >>> tfm.Silence(prob=0.1),
+ >>> ),
+ >>> ext=[f"{src}.wav"],
+ >>> )
+ >>> for src in ["vocals", "bass", "drums", "other"]
+ >>> }
+ >>>
+ >>> dataset = at.datasets.AudioDataset(
+ >>> loaders=loaders,
+ >>> sample_rate=sample_rate,
+ >>> duration=duration,
+ >>> num_channels=1,
+ >>> aligned=True,
+ >>> transform=tfm.RescaleAudio(),
+ >>> shuffle_loaders=True,
+ >>> )
+ >>> return dataset, list(loaders.keys())
+ >>>
+ >>> train_data, sources = build_dataset()
+ >>> dataloader = torch.utils.data.DataLoader(
+ >>> train_data,
+ >>> batch_size=16,
+ >>> num_workers=0,
+ >>> collate_fn=train_data.collate,
+ >>> )
+ >>> batch = next(iter(dataloader))
+ >>>
+ >>> for k in sources:
+ >>> src = batch[k]
+ >>> src["transformed"] = train_data.loaders[k].transform(
+ >>> src["signal"].clone(), **src["transform_args"]
+ >>> )
+ >>>
+ >>> mixture = sum(batch[k]["transformed"] for k in sources)
+ >>> mixture = train_data.transform(mixture, **batch["transform_args"])
+ >>>
+ >>> # Say a model takes the mix and gives back (n_batch, n_src, n_time).
+ >>> # Construct the targets:
+ >>> targets = at.AudioSignal.batch([batch[k]["transformed"] for k in sources], dim=1)
+
+ Similarly, here's example code for loading Slakh data:
+
+ >>> import audiotools as at
+ >>> from pathlib import Path
+ >>> from audiotools import transforms as tfm
+ >>> import numpy as np
+ >>> import torch
+ >>> import glob
+ >>>
+ >>> def build_dataset(
+ >>> sample_rate: int = 16000,
+ >>> duration: float = 10.0,
+ >>> slakh_path: str = "~/.data/slakh/",
+ >>> ):
+ >>> slakh_path = Path(slakh_path).expanduser()
+ >>>
+ >>> # Find the max number of sources in Slakh
+ >>> src_names = [x.name for x in list(slakh_path.glob("**/*.wav")) if "S" in str(x.name)]
+ >>> n_sources = len(list(set(src_names)))
+ >>>
+ >>> loaders = {
+ >>> f"S{i:02d}": at.datasets.AudioLoader(
+ >>> sources=[slakh_path],
+ >>> transform=tfm.Compose(
+ >>> tfm.VolumeNorm(("uniform", -20, -10)),
+ >>> tfm.Silence(prob=0.1),
+ >>> ),
+ >>> ext=[f"S{i:02d}.wav"],
+ >>> )
+ >>> for i in range(n_sources)
+ >>> }
+ >>> dataset = at.datasets.AudioDataset(
+ >>> loaders=loaders,
+ >>> sample_rate=sample_rate,
+ >>> duration=duration,
+ >>> num_channels=1,
+ >>> aligned=True,
+ >>> transform=tfm.RescaleAudio(),
+ >>> shuffle_loaders=False,
+ >>> )
+ >>>
+ >>> return dataset, list(loaders.keys())
+ >>>
+ >>> train_data, sources = build_dataset()
+ >>> dataloader = torch.utils.data.DataLoader(
+ >>> train_data,
+ >>> batch_size=16,
+ >>> num_workers=0,
+ >>> collate_fn=train_data.collate,
+ >>> )
+ >>> batch = next(iter(dataloader))
+ >>>
+ >>> for k in sources:
+ >>> src = batch[k]
+ >>> src["transformed"] = train_data.loaders[k].transform(
+ >>> src["signal"].clone(), **src["transform_args"]
+ >>> )
+ >>>
+ >>> mixture = sum(batch[k]["transformed"] for k in sources)
+ >>> mixture = train_data.transform(mixture, **batch["transform_args"])
+
+ """
+
+ def __init__(
+ self,
+ loaders: Union[AudioLoader, List[AudioLoader], Dict[str, AudioLoader]],
+ sample_rate: int,
+ n_examples: int = 1000,
+ duration: float = 0.5,
+ offset: float = None,
+ loudness_cutoff: float = -40,
+ num_channels: int = 1,
+ transform: Callable = None,
+ aligned: bool = False,
+ shuffle_loaders: bool = False,
+ matcher: Callable = default_matcher,
+ without_replacement: bool = True,
+ ):
+ # Internally we convert loaders to a dictionary
+ if isinstance(loaders, list):
+ loaders = {i: l for i, l in enumerate(loaders)}
+ elif isinstance(loaders, AudioLoader):
+ loaders = {0: loaders}
+
+ self.loaders = loaders
+ self.loudness_cutoff = loudness_cutoff
+ self.num_channels = num_channels
+
+ self.length = n_examples
+ self.transform = transform
+ self.sample_rate = sample_rate
+ self.duration = duration
+ self.offset = offset
+ self.aligned = aligned
+ self.shuffle_loaders = shuffle_loaders
+ self.without_replacement = without_replacement
+
+ if aligned:
+ loaders_list = list(loaders.values())
+ for i in range(len(loaders_list[0].audio_lists)):
+ input_lists = [l.audio_lists[i] for l in loaders_list]
+ # Alignment happens in-place
+ align_lists(input_lists, matcher)
+
+ def __getitem__(self, idx):
+ state = util.random_state(idx)
+ offset = None if self.offset is None else self.offset
+ item = {}
+
+ keys = list(self.loaders.keys())
+ if self.shuffle_loaders:
+ state.shuffle(keys)
+
+ loader_kwargs = {
+ "state": state,
+ "sample_rate": self.sample_rate,
+ "duration": self.duration,
+ "loudness_cutoff": self.loudness_cutoff,
+ "num_channels": self.num_channels,
+ "global_idx": idx if self.without_replacement else None,
+ }
+
+ # Draw item from first loader
+ loader = self.loaders[keys[0]]
+ item[keys[0]] = loader(**loader_kwargs)
+
+ for key in keys[1:]:
+ loader = self.loaders[key]
+ if self.aligned:
+ # Path mapper takes the current loader + everything
+ # returned by the first loader.
+ offset = item[keys[0]]["signal"].metadata["offset"]
+ loader_kwargs.update(
+ {
+ "offset": offset,
+ "source_idx": item[keys[0]]["source_idx"],
+ "item_idx": item[keys[0]]["item_idx"],
+ }
+ )
+ item[key] = loader(**loader_kwargs)
+
+ # Sort dictionary back into original order
+ keys = list(self.loaders.keys())
+ item = {k: item[k] for k in keys}
+
+ item["idx"] = idx
+ if self.transform is not None:
+ item["transform_args"] = self.transform.instantiate(
+ state=state, signal=item[keys[0]]["signal"]
+ )
+
+ # If there's only one loader, pop it up
+ # to the main dictionary, instead of keeping it
+ # nested.
+ if len(keys) == 1:
+ item.update(item.pop(keys[0]))
+
+ return item
+
+ def __len__(self):
+ return self.length
+
+ @staticmethod
+ def collate(list_of_dicts: Union[list, dict], n_splits: int = None):
+ """Collates items drawn from this dataset. Uses
+ :py:func:`audiotools.core.util.collate`.
+
+ Parameters
+ ----------
+ list_of_dicts : typing.Union[list, dict]
+ Data drawn from each item.
+ n_splits : int
+ Number of splits to make when creating the batches (split into
+ sub-batches). Useful for things like gradient accumulation.
+
+ Returns
+ -------
+ dict
+ Dictionary of batched data.
+ """
+ return util.collate(list_of_dicts, n_splits=n_splits)
+
+
+class ConcatDataset(AudioDataset):
+ def __init__(self, datasets: list):
+ self.datasets = datasets
+
+ def __len__(self):
+ return sum([len(d) for d in self.datasets])
+
+ def __getitem__(self, idx):
+ dataset = self.datasets[idx % len(self.datasets)]
+ return dataset[idx // len(self.datasets)]
+
+
+class ResumableDistributedSampler(DistributedSampler): # pragma: no cover
+ """Distributed sampler that can be resumed from a given start index."""
+
+ def __init__(self, dataset, start_idx: int = None, **kwargs):
+ super().__init__(dataset, **kwargs)
+ # Start index, allows to resume an experiment at the index it was
+ self.start_idx = start_idx // self.num_replicas if start_idx is not None else 0
+
+ def __iter__(self):
+ for i, idx in enumerate(super().__iter__()):
+ if i >= self.start_idx:
+ yield idx
+ self.start_idx = 0 # set the index back to 0 so for the next epoch
+
+
+class ResumableSequentialSampler(SequentialSampler): # pragma: no cover
+ """Sequential sampler that can be resumed from a given start index."""
+
+ def __init__(self, dataset, start_idx: int = None, **kwargs):
+ super().__init__(dataset, **kwargs)
+ # Start index, allows to resume an experiment at the index it was
+ self.start_idx = start_idx if start_idx is not None else 0
+
+ def __iter__(self):
+ for i, idx in enumerate(super().__iter__()):
+ if i >= self.start_idx:
+ yield idx
+ self.start_idx = 0 # set the index back to 0 so for the next epoch
diff --git a/audiotools/data/preprocess.py b/audiotools/data/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..d90de210115e45838bc8d69b350f7516ba730406
--- /dev/null
+++ b/audiotools/data/preprocess.py
@@ -0,0 +1,81 @@
+import csv
+import os
+from pathlib import Path
+
+from tqdm import tqdm
+
+from ..core import AudioSignal
+
+
+def create_csv(
+ audio_files: list, output_csv: Path, loudness: bool = False, data_path: str = None
+):
+ """Converts a folder of audio files to a CSV file. If ``loudness = True``,
+ the output of this function will create a CSV file that looks something
+ like:
+
+ .. csv-table::
+ :header: path,loudness
+
+ daps/produced/f1_script1_produced.wav,-16.299999237060547
+ daps/produced/f1_script2_produced.wav,-16.600000381469727
+ daps/produced/f1_script3_produced.wav,-17.299999237060547
+ daps/produced/f1_script4_produced.wav,-16.100000381469727
+ daps/produced/f1_script5_produced.wav,-16.700000762939453
+ daps/produced/f3_script1_produced.wav,-16.5
+
+ .. note::
+ The paths above are written relative to the ``data_path`` argument
+ which defaults to the environment variable ``PATH_TO_DATA`` if
+ it isn't passed to this function, and defaults to the empty string
+ if that environment variable is not set.
+
+ You can produce a CSV file from a directory of audio files via:
+
+ >>> import audiotools
+ >>> directory = ...
+ >>> audio_files = audiotools.util.find_audio(directory)
+ >>> output_path = "train.csv"
+ >>> audiotools.data.preprocess.create_csv(
+ >>> audio_files, output_csv, loudness=True
+ >>> )
+
+ Note that you can create empty rows in the CSV file by passing an empty
+ string or None in the ``audio_files`` list. This is useful if you want to
+ sync multiple CSV files in a multitrack setting. The loudness of these
+ empty rows will be set to -inf.
+
+ Parameters
+ ----------
+ audio_files : list
+ List of audio files.
+ output_csv : Path
+ Output CSV, with each row containing the relative path of every file
+ to ``data_path``, if specified (defaults to None).
+ loudness : bool
+ Compute loudness of entire file and store alongside path.
+ """
+
+ info = []
+ pbar = tqdm(audio_files)
+ for af in pbar:
+ af = Path(af)
+ pbar.set_description(f"Processing {af.name}")
+ _info = {}
+ if af.name == "":
+ _info["path"] = ""
+ if loudness:
+ _info["loudness"] = -float("inf")
+ else:
+ _info["path"] = af.relative_to(data_path) if data_path is not None else af
+ if loudness:
+ _info["loudness"] = AudioSignal(af).ffmpeg_loudness().item()
+
+ info.append(_info)
+
+ with open(output_csv, "w") as f:
+ writer = csv.DictWriter(f, fieldnames=list(info[0].keys()))
+ writer.writeheader()
+
+ for item in info:
+ writer.writerow(item)
diff --git a/audiotools/data/transforms.py b/audiotools/data/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..504e87dc61777e36ba95eb794f497bed4cdc7d2c
--- /dev/null
+++ b/audiotools/data/transforms.py
@@ -0,0 +1,1592 @@
+import copy
+from contextlib import contextmanager
+from inspect import signature
+from typing import List
+
+import numpy as np
+import torch
+from flatten_dict import flatten
+from flatten_dict import unflatten
+from numpy.random import RandomState
+
+from .. import ml
+from ..core import AudioSignal
+from ..core import util
+from .datasets import AudioLoader
+
+tt = torch.tensor
+"""Shorthand for converting things to torch.tensor."""
+
+
+class BaseTransform:
+ """This is the base class for all transforms that are implemented
+ in this library. Transforms have two main operations: ``transform``
+ and ``instantiate``.
+
+ ``instantiate`` sets the parameters randomly
+ from distribution tuples for each parameter. For example, for the
+ ``BackgroundNoise`` transform, the signal-to-noise ratio (``snr``)
+ is chosen randomly by instantiate. By default, it chosen uniformly
+ between 10.0 and 30.0 (the tuple is set to ``("uniform", 10.0, 30.0)``).
+
+ ``transform`` applies the transform using the instantiated parameters.
+ A simple example is as follows:
+
+ >>> seed = 0
+ >>> signal = ...
+ >>> transform = transforms.NoiseFloor(db = ("uniform", -50.0, -30.0))
+ >>> kwargs = transform.instantiate()
+ >>> output = transform(signal.clone(), **kwargs)
+
+ By breaking apart the instantiation of parameters from the actual audio
+ processing of the transform, we can make things more reproducible, while
+ also applying the transform on batches of data efficiently on GPU,
+ rather than on individual audio samples.
+
+ .. note::
+ We call ``signal.clone()`` for the input to the ``transform`` function
+ because signals are modified in-place! If you don't clone the signal,
+ you will lose the original data.
+
+ Parameters
+ ----------
+ keys : list, optional
+ Keys that the transform looks for when
+ calling ``self.transform``, by default []. In general this is
+ set automatically, and you won't need to manipulate this argument.
+ name : str, optional
+ Name of this transform, used to identify it in the dictionary
+ produced by ``self.instantiate``, by default None
+ prob : float, optional
+ Probability of applying this transform, by default 1.0
+
+ Examples
+ --------
+
+ >>> seed = 0
+ >>>
+ >>> audio_path = "tests/audio/spk/f10_script4_produced.wav"
+ >>> signal = AudioSignal(audio_path, offset=10, duration=2)
+ >>> transform = tfm.Compose(
+ >>> [
+ >>> tfm.RoomImpulseResponse(sources=["tests/audio/irs.csv"]),
+ >>> tfm.BackgroundNoise(sources=["tests/audio/noises.csv"]),
+ >>> ],
+ >>> )
+ >>>
+ >>> kwargs = transform.instantiate(seed, signal)
+ >>> output = transform(signal, **kwargs)
+
+ """
+
+ def __init__(self, keys: list = [], name: str = None, prob: float = 1.0):
+ # Get keys from the _transform signature.
+ tfm_keys = list(signature(self._transform).parameters.keys())
+
+ # Filter out signal and kwargs keys.
+ ignore_keys = ["signal", "kwargs"]
+ tfm_keys = [k for k in tfm_keys if k not in ignore_keys]
+
+ # Combine keys specified by the child class, the keys found in
+ # _transform signature, and the mask key.
+ self.keys = keys + tfm_keys + ["mask"]
+
+ self.prob = prob
+
+ if name is None:
+ name = self.__class__.__name__
+ self.name = name
+
+ def _prepare(self, batch: dict):
+ sub_batch = batch[self.name]
+
+ for k in self.keys:
+ assert k in sub_batch.keys(), f"{k} not in batch"
+
+ return sub_batch
+
+ def _transform(self, signal):
+ return signal
+
+ def _instantiate(self, state: RandomState, signal: AudioSignal = None):
+ return {}
+
+ @staticmethod
+ def apply_mask(batch: dict, mask: torch.Tensor):
+ """Applies a mask to the batch.
+
+ Parameters
+ ----------
+ batch : dict
+ Batch whose values will be masked in the ``transform`` pass.
+ mask : torch.Tensor
+ Mask to apply to batch.
+
+ Returns
+ -------
+ dict
+ A dictionary that contains values only where ``mask = True``.
+ """
+ masked_batch = {k: v[mask] for k, v in flatten(batch).items()}
+ return unflatten(masked_batch)
+
+ def transform(self, signal: AudioSignal, **kwargs):
+ """Apply the transform to the audio signal,
+ with given keyword arguments.
+
+ Parameters
+ ----------
+ signal : AudioSignal
+ Signal that will be modified by the transforms in-place.
+ kwargs: dict
+ Keyword arguments to the specific transforms ``self._transform``
+ function.
+
+ Returns
+ -------
+ AudioSignal
+ Transformed AudioSignal.
+
+ Examples
+ --------
+
+ >>> for seed in range(10):
+ >>> kwargs = transform.instantiate(seed, signal)
+ >>> output = transform(signal.clone(), **kwargs)
+
+ """
+ tfm_kwargs = self._prepare(kwargs)
+ mask = tfm_kwargs["mask"]
+
+ if torch.any(mask):
+ tfm_kwargs = self.apply_mask(tfm_kwargs, mask)
+ tfm_kwargs = {k: v for k, v in tfm_kwargs.items() if k != "mask"}
+ signal[mask] = self._transform(signal[mask], **tfm_kwargs)
+
+ return signal
+
+ def __call__(self, *args, **kwargs):
+ return self.transform(*args, **kwargs)
+
+ def instantiate(
+ self,
+ state: RandomState = None,
+ signal: AudioSignal = None,
+ ):
+ """Instantiates parameters for the transform.
+
+ Parameters
+ ----------
+ state : RandomState, optional
+ _description_, by default None
+ signal : AudioSignal, optional
+ _description_, by default None
+
+ Returns
+ -------
+ dict
+ Dictionary containing instantiated arguments for every keyword
+ argument to ``self._transform``.
+
+ Examples
+ --------
+
+ >>> for seed in range(10):
+ >>> kwargs = transform.instantiate(seed, signal)
+ >>> output = transform(signal.clone(), **kwargs)
+
+ """
+ state = util.random_state(state)
+
+ # Not all instantiates need the signal. Check if signal
+ # is needed before passing it in, so that the end-user
+ # doesn't need to have variables they're not using flowing
+ # into their function.
+ needs_signal = "signal" in set(signature(self._instantiate).parameters.keys())
+ kwargs = {}
+ if needs_signal:
+ kwargs = {"signal": signal}
+
+ # Instantiate the parameters for the transform.
+ params = self._instantiate(state, **kwargs)
+ for k in list(params.keys()):
+ v = params[k]
+ if isinstance(v, (AudioSignal, torch.Tensor, dict)):
+ params[k] = v
+ else:
+ params[k] = tt(v)
+ mask = state.rand() <= self.prob
+ params[f"mask"] = tt(mask)
+
+ # Put the params into a nested dictionary that will be
+ # used later when calling the transform. This is to avoid
+ # collisions in the dictionary.
+ params = {self.name: params}
+
+ return params
+
+ def batch_instantiate(
+ self,
+ states: list = None,
+ signal: AudioSignal = None,
+ ):
+ """Instantiates arguments for every item in a batch,
+ given a list of states. Each state in the list
+ corresponds to one item in the batch.
+
+ Parameters
+ ----------
+ states : list, optional
+ List of states, by default None
+ signal : AudioSignal, optional
+ AudioSignal to pass to the ``self.instantiate`` section
+ if it is needed for this transform, by default None
+
+ Returns
+ -------
+ dict
+ Collated dictionary of arguments.
+
+ Examples
+ --------
+
+ >>> batch_size = 4
+ >>> signal = AudioSignal(audio_path, offset=10, duration=2)
+ >>> signal_batch = AudioSignal.batch([signal.clone() for _ in range(batch_size)])
+ >>>
+ >>> states = [seed + idx for idx in list(range(batch_size))]
+ >>> kwargs = transform.batch_instantiate(states, signal_batch)
+ >>> batch_output = transform(signal_batch, **kwargs)
+ """
+ kwargs = []
+ for state in states:
+ kwargs.append(self.instantiate(state, signal))
+ kwargs = util.collate(kwargs)
+ return kwargs
+
+
+class Identity(BaseTransform):
+ """This transform just returns the original signal."""
+
+ pass
+
+
+class SpectralTransform(BaseTransform):
+ """Spectral transforms require STFT data to exist, since manipulations
+ of the STFT require the spectrogram. This just calls ``stft`` before
+ the transform is called, and calls ``istft`` after the transform is
+ called so that the audio data is written to after the spectral
+ manipulation.
+ """
+
+ def transform(self, signal, **kwargs):
+ signal.stft()
+ super().transform(signal, **kwargs)
+ signal.istft()
+ return signal
+
+
+class Compose(BaseTransform):
+ """Compose applies transforms in sequence, one after the other. The
+ transforms are passed in as positional arguments or as a list like so:
+
+ >>> transform = tfm.Compose(
+ >>> [
+ >>> tfm.RoomImpulseResponse(sources=["tests/audio/irs.csv"]),
+ >>> tfm.BackgroundNoise(sources=["tests/audio/noises.csv"]),
+ >>> ],
+ >>> )
+
+ This will convolve the signal with a room impulse response, and then
+ add background noise to the signal. Instantiate instantiates
+ all the parameters for every transform in the transform list so the
+ interface for using the Compose transform is the same as everything
+ else:
+
+ >>> kwargs = transform.instantiate()
+ >>> output = transform(signal.clone(), **kwargs)
+
+ Under the hood, the transform maps each transform to a unique name
+ under the hood of the form ``{position}.{name}``, where ``position``
+ is the index of the transform in the list. ``Compose`` can nest
+ within other ``Compose`` transforms, like so:
+
+ >>> preprocess = transforms.Compose(
+ >>> tfm.GlobalVolumeNorm(),
+ >>> tfm.CrossTalk(),
+ >>> name="preprocess",
+ >>> )
+ >>> augment = transforms.Compose(
+ >>> tfm.RoomImpulseResponse(),
+ >>> tfm.BackgroundNoise(),
+ >>> name="augment",
+ >>> )
+ >>> postprocess = transforms.Compose(
+ >>> tfm.VolumeChange(),
+ >>> tfm.RescaleAudio(),
+ >>> tfm.ShiftPhase(),
+ >>> name="postprocess",
+ >>> )
+ >>> transform = transforms.Compose(preprocess, augment, postprocess),
+
+ This defines 3 composed transforms, and then composes them in sequence
+ with one another.
+
+ Parameters
+ ----------
+ *transforms : list
+ List of transforms to apply
+ name : str, optional
+ Name of this transform, used to identify it in the dictionary
+ produced by ``self.instantiate``, by default None
+ prob : float, optional
+ Probability of applying this transform, by default 1.0
+ """
+
+ def __init__(self, *transforms: list, name: str = None, prob: float = 1.0):
+ if isinstance(transforms[0], list):
+ transforms = transforms[0]
+
+ for i, tfm in enumerate(transforms):
+ tfm.name = f"{i}.{tfm.name}"
+
+ keys = [tfm.name for tfm in transforms]
+ super().__init__(keys=keys, name=name, prob=prob)
+
+ self.transforms = transforms
+ self.transforms_to_apply = keys
+
+ @contextmanager
+ def filter(self, *names: list):
+ """This can be used to skip transforms entirely when applying
+ the sequence of transforms to a signal. For example, take
+ the following transforms with the names ``preprocess, augment, postprocess``.
+
+ >>> preprocess = transforms.Compose(
+ >>> tfm.GlobalVolumeNorm(),
+ >>> tfm.CrossTalk(),
+ >>> name="preprocess",
+ >>> )
+ >>> augment = transforms.Compose(
+ >>> tfm.RoomImpulseResponse(),
+ >>> tfm.BackgroundNoise(),
+ >>> name="augment",
+ >>> )
+ >>> postprocess = transforms.Compose(
+ >>> tfm.VolumeChange(),
+ >>> tfm.RescaleAudio(),
+ >>> tfm.ShiftPhase(),
+ >>> name="postprocess",
+ >>> )
+ >>> transform = transforms.Compose(preprocess, augment, postprocess)
+
+ If we wanted to apply all 3 to a signal, we do:
+
+ >>> kwargs = transform.instantiate()
+ >>> output = transform(signal.clone(), **kwargs)
+
+ But if we only wanted to apply the ``preprocess`` and ``postprocess``
+ transforms to the signal, we do:
+
+ >>> with transform_fn.filter("preprocess", "postprocess"):
+ >>> output = transform(signal.clone(), **kwargs)
+
+ Parameters
+ ----------
+ *names : list
+ List of transforms, identified by name, to apply to signal.
+ """
+ old_transforms = self.transforms_to_apply
+ self.transforms_to_apply = names
+ yield
+ self.transforms_to_apply = old_transforms
+
+ def _transform(self, signal, **kwargs):
+ for transform in self.transforms:
+ if any([x in transform.name for x in self.transforms_to_apply]):
+ signal = transform(signal, **kwargs)
+ return signal
+
+ def _instantiate(self, state: RandomState, signal: AudioSignal = None):
+ parameters = {}
+ for transform in self.transforms:
+ parameters.update(transform.instantiate(state, signal=signal))
+ return parameters
+
+ def __getitem__(self, idx):
+ return self.transforms[idx]
+
+ def __len__(self):
+ return len(self.transforms)
+
+ def __iter__(self):
+ for transform in self.transforms:
+ yield transform
+
+
+class Choose(Compose):
+ """Choose logic is the same as :py:func:`audiotools.data.transforms.Compose`,
+ but instead of applying all the transforms in sequence, it applies just a single transform,
+ which is chosen for each item in the batch.
+
+ Parameters
+ ----------
+ *transforms : list
+ List of transforms to apply
+ weights : list
+ Probability of choosing any specific transform.
+ name : str, optional
+ Name of this transform, used to identify it in the dictionary
+ produced by ``self.instantiate``, by default None
+ prob : float, optional
+ Probability of applying this transform, by default 1.0
+
+ Examples
+ --------
+
+ >>> transforms.Choose(tfm.LowPass(), tfm.HighPass())
+ """
+
+ def __init__(
+ self,
+ *transforms: list,
+ weights: list = None,
+ name: str = None,
+ prob: float = 1.0,
+ ):
+ super().__init__(*transforms, name=name, prob=prob)
+
+ if weights is None:
+ _len = len(self.transforms)
+ weights = [1 / _len for _ in range(_len)]
+ self.weights = np.array(weights)
+
+ def _instantiate(self, state: RandomState, signal: AudioSignal = None):
+ kwargs = super()._instantiate(state, signal)
+ tfm_idx = list(range(len(self.transforms)))
+ tfm_idx = state.choice(tfm_idx, p=self.weights)
+ one_hot = []
+ for i, t in enumerate(self.transforms):
+ mask = kwargs[t.name]["mask"]
+ if mask.item():
+ kwargs[t.name]["mask"] = tt(i == tfm_idx)
+ one_hot.append(kwargs[t.name]["mask"])
+ kwargs["one_hot"] = one_hot
+ return kwargs
+
+
+class Repeat(Compose):
+ """Repeatedly applies a given transform ``n_repeat`` times."
+
+ Parameters
+ ----------
+ transform : BaseTransform
+ Transform to repeat.
+ n_repeat : int, optional
+ Number of times to repeat transform, by default 1
+ """
+
+ def __init__(
+ self,
+ transform,
+ n_repeat: int = 1,
+ name: str = None,
+ prob: float = 1.0,
+ ):
+ transforms = [copy.copy(transform) for _ in range(n_repeat)]
+ super().__init__(transforms, name=name, prob=prob)
+
+ self.n_repeat = n_repeat
+
+
+class RepeatUpTo(Choose):
+ """Repeatedly applies a given transform up to ``max_repeat`` times."
+
+ Parameters
+ ----------
+ transform : BaseTransform
+ Transform to repeat.
+ max_repeat : int, optional
+ Max number of times to repeat transform, by default 1
+ weights : list
+ Probability of choosing any specific number up to ``max_repeat``.
+ """
+
+ def __init__(
+ self,
+ transform,
+ max_repeat: int = 5,
+ weights: list = None,
+ name: str = None,
+ prob: float = 1.0,
+ ):
+ transforms = []
+ for n in range(1, max_repeat):
+ transforms.append(Repeat(transform, n_repeat=n))
+ super().__init__(transforms, name=name, prob=prob, weights=weights)
+
+ self.max_repeat = max_repeat
+
+
+class ClippingDistortion(BaseTransform):
+ """Adds clipping distortion to signal. Corresponds
+ to :py:func:`audiotools.core.effects.EffectMixin.clip_distortion`.
+
+ Parameters
+ ----------
+ perc : tuple, optional
+ Clipping percentile. Values are between 0.0 to 1.0.
+ Typical values are 0.1 or below, by default ("uniform", 0.0, 0.1)
+ name : str, optional
+ Name of this transform, used to identify it in the dictionary
+ produced by ``self.instantiate``, by default None
+ prob : float, optional
+ Probability of applying this transform, by default 1.0
+ """
+
+ def __init__(
+ self,
+ perc: tuple = ("uniform", 0.0, 0.1),
+ name: str = None,
+ prob: float = 1.0,
+ ):
+ super().__init__(name=name, prob=prob)
+
+ self.perc = perc
+
+ def _instantiate(self, state: RandomState):
+ return {"perc": util.sample_from_dist(self.perc, state)}
+
+ def _transform(self, signal, perc):
+ return signal.clip_distortion(perc)
+
+
+class Equalizer(BaseTransform):
+ """Applies an equalization curve to the audio signal. Corresponds
+ to :py:func:`audiotools.core.effects.EffectMixin.equalizer`.
+
+ Parameters
+ ----------
+ eq_amount : tuple, optional
+ The maximum dB cut to apply to the audio in any band,
+ by default ("const", 1.0 dB)
+ n_bands : int, optional
+ Number of bands in EQ, by default 6
+ name : str, optional
+ Name of this transform, used to identify it in the dictionary
+ produced by ``self.instantiate``, by default None
+ prob : float, optional
+ Probability of applying this transform, by default 1.0
+ """
+
+ def __init__(
+ self,
+ eq_amount: tuple = ("const", 1.0),
+ n_bands: int = 6,
+ name: str = None,
+ prob: float = 1.0,
+ ):
+ super().__init__(name=name, prob=prob)
+
+ self.eq_amount = eq_amount
+ self.n_bands = n_bands
+
+ def _instantiate(self, state: RandomState):
+ eq_amount = util.sample_from_dist(self.eq_amount, state)
+ eq = -eq_amount * state.rand(self.n_bands)
+ return {"eq": eq}
+
+ def _transform(self, signal, eq):
+ return signal.equalizer(eq)
+
+
+class Quantization(BaseTransform):
+ """Applies quantization to the input waveform. Corresponds
+ to :py:func:`audiotools.core.effects.EffectMixin.quantization`.
+
+ Parameters
+ ----------
+ channels : tuple, optional
+ Number of evenly spaced quantization channels to quantize
+ to, by default ("choice", [8, 32, 128, 256, 1024])
+ name : str, optional
+ Name of this transform, used to identify it in the dictionary
+ produced by ``self.instantiate``, by default None
+ prob : float, optional
+ Probability of applying this transform, by default 1.0
+ """
+
+ def __init__(
+ self,
+ channels: tuple = ("choice", [8, 32, 128, 256, 1024]),
+ name: str = None,
+ prob: float = 1.0,
+ ):
+ super().__init__(name=name, prob=prob)
+
+ self.channels = channels
+
+ def _instantiate(self, state: RandomState):
+ return {"channels": util.sample_from_dist(self.channels, state)}
+
+ def _transform(self, signal, channels):
+ return signal.quantization(channels)
+
+
+class MuLawQuantization(BaseTransform):
+ """Applies mu-law quantization to the input waveform. Corresponds
+ to :py:func:`audiotools.core.effects.EffectMixin.mulaw_quantization`.
+
+ Parameters
+ ----------
+ channels : tuple, optional
+ Number of mu-law spaced quantization channels to quantize
+ to, by default ("choice", [8, 32, 128, 256, 1024])
+ name : str, optional
+ Name of this transform, used to identify it in the dictionary
+ produced by ``self.instantiate``, by default None
+ prob : float, optional
+ Probability of applying this transform, by default 1.0
+ """
+
+ def __init__(
+ self,
+ channels: tuple = ("choice", [8, 32, 128, 256, 1024]),
+ name: str = None,
+ prob: float = 1.0,
+ ):
+ super().__init__(name=name, prob=prob)
+
+ self.channels = channels
+
+ def _instantiate(self, state: RandomState):
+ return {"channels": util.sample_from_dist(self.channels, state)}
+
+ def _transform(self, signal, channels):
+ return signal.mulaw_quantization(channels)
+
+
+class NoiseFloor(BaseTransform):
+ """Adds a noise floor of Gaussian noise to the signal at a specified
+ dB.
+
+ Parameters
+ ----------
+ db : tuple, optional
+ Level of noise to add to signal, by default ("const", -50.0)
+ name : str, optional
+ Name of this transform, used to identify it in the dictionary
+ produced by ``self.instantiate``, by default None
+ prob : float, optional
+ Probability of applying this transform, by default 1.0
+ """
+
+ def __init__(
+ self,
+ db: tuple = ("const", -50.0),
+ name: str = None,
+ prob: float = 1.0,
+ ):
+ super().__init__(name=name, prob=prob)
+
+ self.db = db
+
+ def _instantiate(self, state: RandomState, signal: AudioSignal):
+ db = util.sample_from_dist(self.db, state)
+ audio_data = state.randn(signal.num_channels, signal.signal_length)
+ nz_signal = AudioSignal(audio_data, signal.sample_rate)
+ nz_signal.normalize(db)
+ return {"nz_signal": nz_signal}
+
+ def _transform(self, signal, nz_signal):
+ # Clone bg_signal so that transform can be repeatedly applied
+ # to different signals with the same effect.
+ return signal + nz_signal
+
+
+class BackgroundNoise(BaseTransform):
+ """Adds background noise from audio specified by a set of CSV files.
+ A valid CSV file looks like, and is typically generated by
+ :py:func:`audiotools.data.preprocess.create_csv`:
+
+ .. csv-table::
+ :header: path
+
+ room_tone/m6_script2_clean.wav
+ room_tone/m6_script2_cleanraw.wav
+ room_tone/m6_script2_ipad_balcony1.wav
+ room_tone/m6_script2_ipad_bedroom1.wav
+ room_tone/m6_script2_ipad_confroom1.wav
+ room_tone/m6_script2_ipad_confroom2.wav
+ room_tone/m6_script2_ipad_livingroom1.wav
+ room_tone/m6_script2_ipad_office1.wav
+
+ .. note::
+ All paths are relative to an environment variable called ``PATH_TO_DATA``,
+ so that CSV files are portable across machines where data may be
+ located in different places.
+
+ This transform calls :py:func:`audiotools.core.effects.EffectMixin.mix`
+ and :py:func:`audiotools.core.effects.EffectMixin.equalizer` under the
+ hood.
+
+ Parameters
+ ----------
+ snr : tuple, optional
+ Signal-to-noise ratio, by default ("uniform", 10.0, 30.0)
+ sources : List[str], optional
+ Sources containing folders, or CSVs with paths to audio files,
+ by default None
+ weights : List[float], optional
+ Weights to sample audio files from each source, by default None
+ eq_amount : tuple, optional
+ Amount of equalization to apply, by default ("const", 1.0)
+ n_bands : int, optional
+ Number of bands in equalizer, by default 3
+ name : str, optional
+ Name of this transform, used to identify it in the dictionary
+ produced by ``self.instantiate``, by default None
+ prob : float, optional
+ Probability of applying this transform, by default 1.0
+ loudness_cutoff : float, optional
+ Loudness cutoff when loading from audio files, by default None
+ """
+
+ def __init__(
+ self,
+ snr: tuple = ("uniform", 10.0, 30.0),
+ sources: List[str] = None,
+ weights: List[float] = None,
+ eq_amount: tuple = ("const", 1.0),
+ n_bands: int = 3,
+ name: str = None,
+ prob: float = 1.0,
+ loudness_cutoff: float = None,
+ ):
+ super().__init__(name=name, prob=prob)
+
+ self.snr = snr
+ self.eq_amount = eq_amount
+ self.n_bands = n_bands
+ self.loader = AudioLoader(sources, weights)
+ self.loudness_cutoff = loudness_cutoff
+
+ def _instantiate(self, state: RandomState, signal: AudioSignal):
+ eq_amount = util.sample_from_dist(self.eq_amount, state)
+ eq = -eq_amount * state.rand(self.n_bands)
+ snr = util.sample_from_dist(self.snr, state)
+
+ bg_signal = self.loader(
+ state,
+ signal.sample_rate,
+ duration=signal.signal_duration,
+ loudness_cutoff=self.loudness_cutoff,
+ num_channels=signal.num_channels,
+ )["signal"]
+
+ return {"eq": eq, "bg_signal": bg_signal, "snr": snr}
+
+ def _transform(self, signal, bg_signal, snr, eq):
+ # Clone bg_signal so that transform can be repeatedly applied
+ # to different signals with the same effect.
+ return signal.mix(bg_signal.clone(), snr, eq)
+
+
+class CrossTalk(BaseTransform):
+ """Adds crosstalk between speakers, whose audio is drawn from a CSV file
+ that was produced via :py:func:`audiotools.data.preprocess.create_csv`.
+
+ This transform calls :py:func:`audiotools.core.effects.EffectMixin.mix`
+ under the hood.
+
+ Parameters
+ ----------
+ snr : tuple, optional
+ How loud cross-talk speaker is relative to original signal in dB,
+ by default ("uniform", 0.0, 10.0)
+ sources : List[str], optional
+ Sources containing folders, or CSVs with paths to audio files,
+ by default None
+ weights : List[float], optional
+ Weights to sample audio files from each source, by default None
+ name : str, optional
+ Name of this transform, used to identify it in the dictionary
+ produced by ``self.instantiate``, by default None
+ prob : float, optional
+ Probability of applying this transform, by default 1.0
+ loudness_cutoff : float, optional
+ Loudness cutoff when loading from audio files, by default -40
+ """
+
+ def __init__(
+ self,
+ snr: tuple = ("uniform", 0.0, 10.0),
+ sources: List[str] = None,
+ weights: List[float] = None,
+ name: str = None,
+ prob: float = 1.0,
+ loudness_cutoff: float = -40,
+ ):
+ super().__init__(name=name, prob=prob)
+
+ self.snr = snr
+ self.loader = AudioLoader(sources, weights)
+ self.loudness_cutoff = loudness_cutoff
+
+ def _instantiate(self, state: RandomState, signal: AudioSignal):
+ snr = util.sample_from_dist(self.snr, state)
+ crosstalk_signal = self.loader(
+ state,
+ signal.sample_rate,
+ duration=signal.signal_duration,
+ loudness_cutoff=self.loudness_cutoff,
+ num_channels=signal.num_channels,
+ )["signal"]
+
+ return {"crosstalk_signal": crosstalk_signal, "snr": snr}
+
+ def _transform(self, signal, crosstalk_signal, snr):
+ # Clone bg_signal so that transform can be repeatedly applied
+ # to different signals with the same effect.
+ loudness = signal.loudness()
+ mix = signal.mix(crosstalk_signal.clone(), snr)
+ mix.normalize(loudness)
+ return mix
+
+
+class RoomImpulseResponse(BaseTransform):
+ """Convolves signal with a room impulse response, at a specified
+ direct-to-reverberant ratio, with equalization applied. Room impulse
+ response data is drawn from a CSV file that was produced via
+ :py:func:`audiotools.data.preprocess.create_csv`.
+
+ This transform calls :py:func:`audiotools.core.effects.EffectMixin.apply_ir`
+ under the hood.
+
+ Parameters
+ ----------
+ drr : tuple, optional
+ _description_, by default ("uniform", 0.0, 30.0)
+ sources : List[str], optional
+ Sources containing folders, or CSVs with paths to audio files,
+ by default None
+ weights : List[float], optional
+ Weights to sample audio files from each source, by default None
+ eq_amount : tuple, optional
+ Amount of equalization to apply, by default ("const", 1.0)
+ n_bands : int, optional
+ Number of bands in equalizer, by default 6
+ name : str, optional
+ Name of this transform, used to identify it in the dictionary
+ produced by ``self.instantiate``, by default None
+ prob : float, optional
+ Probability of applying this transform, by default 1.0
+ use_original_phase : bool, optional
+ Whether or not to use the original phase, by default False
+ offset : float, optional
+ Offset from each impulse response file to use, by default 0.0
+ duration : float, optional
+ Duration of each impulse response, by default 1.0
+ """
+
+ def __init__(
+ self,
+ drr: tuple = ("uniform", 0.0, 30.0),
+ sources: List[str] = None,
+ weights: List[float] = None,
+ eq_amount: tuple = ("const", 1.0),
+ n_bands: int = 6,
+ name: str = None,
+ prob: float = 1.0,
+ use_original_phase: bool = False,
+ offset: float = 0.0,
+ duration: float = 1.0,
+ ):
+ super().__init__(name=name, prob=prob)
+
+ self.drr = drr
+ self.eq_amount = eq_amount
+ self.n_bands = n_bands
+ self.use_original_phase = use_original_phase
+
+ self.loader = AudioLoader(sources, weights)
+ self.offset = offset
+ self.duration = duration
+
+ def _instantiate(self, state: RandomState, signal: AudioSignal = None):
+ eq_amount = util.sample_from_dist(self.eq_amount, state)
+ eq = -eq_amount * state.rand(self.n_bands)
+ drr = util.sample_from_dist(self.drr, state)
+
+ ir_signal = self.loader(
+ state,
+ signal.sample_rate,
+ offset=self.offset,
+ duration=self.duration,
+ loudness_cutoff=None,
+ num_channels=signal.num_channels,
+ )["signal"]
+ ir_signal.zero_pad_to(signal.sample_rate)
+
+ return {"eq": eq, "ir_signal": ir_signal, "drr": drr}
+
+ def _transform(self, signal, ir_signal, drr, eq):
+ # Clone ir_signal so that transform can be repeatedly applied
+ # to different signals with the same effect.
+ return signal.apply_ir(
+ ir_signal.clone(), drr, eq, use_original_phase=self.use_original_phase
+ )
+
+
+class VolumeChange(BaseTransform):
+ """Changes the volume of the input signal.
+
+ Uses :py:func:`audiotools.core.effects.EffectMixin.volume_change`.
+
+ Parameters
+ ----------
+ db : tuple, optional
+ Change in volume in decibels, by default ("uniform", -12.0, 0.0)
+ name : str, optional
+ Name of this transform, used to identify it in the dictionary
+ produced by ``self.instantiate``, by default None
+ prob : float, optional
+ Probability of applying this transform, by default 1.0
+ """
+
+ def __init__(
+ self,
+ db: tuple = ("uniform", -12.0, 0.0),
+ name: str = None,
+ prob: float = 1.0,
+ ):
+ super().__init__(name=name, prob=prob)
+ self.db = db
+
+ def _instantiate(self, state: RandomState):
+ return {"db": util.sample_from_dist(self.db, state)}
+
+ def _transform(self, signal, db):
+ return signal.volume_change(db)
+
+
+class VolumeNorm(BaseTransform):
+ """Normalizes the volume of the excerpt to a specified decibel.
+
+ Uses :py:func:`audiotools.core.effects.EffectMixin.normalize`.
+
+ Parameters
+ ----------
+ db : tuple, optional
+ dB to normalize signal to, by default ("const", -24)
+ name : str, optional
+ Name of this transform, used to identify it in the dictionary
+ produced by ``self.instantiate``, by default None
+ prob : float, optional
+ Probability of applying this transform, by default 1.0
+ """
+
+ def __init__(
+ self,
+ db: tuple = ("const", -24),
+ name: str = None,
+ prob: float = 1.0,
+ ):
+ super().__init__(name=name, prob=prob)
+
+ self.db = db
+
+ def _instantiate(self, state: RandomState):
+ return {"db": util.sample_from_dist(self.db, state)}
+
+ def _transform(self, signal, db):
+ return signal.normalize(db)
+
+
+class GlobalVolumeNorm(BaseTransform):
+ """Similar to :py:func:`audiotools.data.transforms.VolumeNorm`, this
+ transform also normalizes the volume of a signal, but it uses
+ the volume of the entire audio file the loaded excerpt comes from,
+ rather than the volume of just the excerpt. The volume of the
+ entire audio file is expected in ``signal.metadata["loudness"]``.
+ If loading audio from a CSV generated by :py:func:`audiotools.data.preprocess.create_csv`
+ with ``loudness = True``, like the following:
+
+ .. csv-table::
+ :header: path,loudness
+
+ daps/produced/f1_script1_produced.wav,-16.299999237060547
+ daps/produced/f1_script2_produced.wav,-16.600000381469727
+ daps/produced/f1_script3_produced.wav,-17.299999237060547
+ daps/produced/f1_script4_produced.wav,-16.100000381469727
+ daps/produced/f1_script5_produced.wav,-16.700000762939453
+ daps/produced/f3_script1_produced.wav,-16.5
+
+ The ``AudioLoader`` will automatically load the loudness column into
+ the metadata of the signal.
+
+ Uses :py:func:`audiotools.core.effects.EffectMixin.volume_change`.
+
+ Parameters
+ ----------
+ db : tuple, optional
+ dB to normalize signal to, by default ("const", -24)
+ name : str, optional
+ Name of this transform, used to identify it in the dictionary
+ produced by ``self.instantiate``, by default None
+ prob : float, optional
+ Probability of applying this transform, by default 1.0
+ """
+
+ def __init__(
+ self,
+ db: tuple = ("const", -24),
+ name: str = None,
+ prob: float = 1.0,
+ ):
+ super().__init__(name=name, prob=prob)
+
+ self.db = db
+
+ def _instantiate(self, state: RandomState, signal: AudioSignal):
+ if "loudness" not in signal.metadata:
+ db_change = 0.0
+ elif float(signal.metadata["loudness"]) == float("-inf"):
+ db_change = 0.0
+ else:
+ db = util.sample_from_dist(self.db, state)
+ db_change = db - float(signal.metadata["loudness"])
+
+ return {"db": db_change}
+
+ def _transform(self, signal, db):
+ return signal.volume_change(db)
+
+
+class Silence(BaseTransform):
+ """Zeros out the signal with some probability.
+
+ Parameters
+ ----------
+ name : str, optional
+ Name of this transform, used to identify it in the dictionary
+ produced by ``self.instantiate``, by default None
+ prob : float, optional
+ Probability of applying this transform, by default 0.1
+ """
+
+ def __init__(self, name: str = None, prob: float = 0.1):
+ super().__init__(name=name, prob=prob)
+
+ def _transform(self, signal):
+ _loudness = signal._loudness
+ signal = AudioSignal(
+ torch.zeros_like(signal.audio_data),
+ sample_rate=signal.sample_rate,
+ stft_params=signal.stft_params,
+ )
+ # So that the amound of noise added is as if it wasn't silenced.
+ # TODO: improve this hack
+ signal._loudness = _loudness
+
+ return signal
+
+
+class LowPass(BaseTransform):
+ """Applies a LowPass filter.
+
+ Uses :py:func:`audiotools.core.dsp.DSPMixin.low_pass`.
+
+ Parameters
+ ----------
+ cutoff : tuple, optional
+ Cutoff frequency distribution,
+ by default ``("choice", [4000, 8000, 16000])``
+ zeros : int, optional
+ Number of zero-crossings in filter, argument to
+ ``julius.LowPassFilters``, by default 51
+ name : str, optional
+ Name of this transform, used to identify it in the dictionary
+ produced by ``self.instantiate``, by default None
+ prob : float, optional
+ Probability of applying this transform, by default 1.0
+ """
+
+ def __init__(
+ self,
+ cutoff: tuple = ("choice", [4000, 8000, 16000]),
+ zeros: int = 51,
+ name: str = None,
+ prob: float = 1,
+ ):
+ super().__init__(name=name, prob=prob)
+
+ self.cutoff = cutoff
+ self.zeros = zeros
+
+ def _instantiate(self, state: RandomState):
+ return {"cutoff": util.sample_from_dist(self.cutoff, state)}
+
+ def _transform(self, signal, cutoff):
+ return signal.low_pass(cutoff, zeros=self.zeros)
+
+
+class HighPass(BaseTransform):
+ """Applies a HighPass filter.
+
+ Uses :py:func:`audiotools.core.dsp.DSPMixin.high_pass`.
+
+ Parameters
+ ----------
+ cutoff : tuple, optional
+ Cutoff frequency distribution,
+ by default ``("choice", [50, 100, 250, 500, 1000])``
+ zeros : int, optional
+ Number of zero-crossings in filter, argument to
+ ``julius.LowPassFilters``, by default 51
+ name : str, optional
+ Name of this transform, used to identify it in the dictionary
+ produced by ``self.instantiate``, by default None
+ prob : float, optional
+ Probability of applying this transform, by default 1.0
+ """
+
+ def __init__(
+ self,
+ cutoff: tuple = ("choice", [50, 100, 250, 500, 1000]),
+ zeros: int = 51,
+ name: str = None,
+ prob: float = 1,
+ ):
+ super().__init__(name=name, prob=prob)
+
+ self.cutoff = cutoff
+ self.zeros = zeros
+
+ def _instantiate(self, state: RandomState):
+ return {"cutoff": util.sample_from_dist(self.cutoff, state)}
+
+ def _transform(self, signal, cutoff):
+ return signal.high_pass(cutoff, zeros=self.zeros)
+
+
+class RescaleAudio(BaseTransform):
+ """Rescales the audio so it is in between ``-val`` and ``val``
+ only if the original audio exceeds those bounds. Useful if
+ transforms have caused the audio to clip.
+
+ Uses :py:func:`audiotools.core.effects.EffectMixin.ensure_max_of_audio`.
+
+ Parameters
+ ----------
+ val : float, optional
+ Max absolute value of signal, by default 1.0
+ name : str, optional
+ Name of this transform, used to identify it in the dictionary
+ produced by ``self.instantiate``, by default None
+ prob : float, optional
+ Probability of applying this transform, by default 1.0
+ """
+
+ def __init__(self, val: float = 1.0, name: str = None, prob: float = 1):
+ super().__init__(name=name, prob=prob)
+
+ self.val = val
+
+ def _transform(self, signal):
+ return signal.ensure_max_of_audio(self.val)
+
+
+class ShiftPhase(SpectralTransform):
+ """Shifts the phase of the audio.
+
+ Uses :py:func:`audiotools.core.dsp.DSPMixin.shift)phase`.
+
+ Parameters
+ ----------
+ shift : tuple, optional
+ How much to shift phase by, by default ("uniform", -np.pi, np.pi)
+ name : str, optional
+ Name of this transform, used to identify it in the dictionary
+ produced by ``self.instantiate``, by default None
+ prob : float, optional
+ Probability of applying this transform, by default 1.0
+ """
+
+ def __init__(
+ self,
+ shift: tuple = ("uniform", -np.pi, np.pi),
+ name: str = None,
+ prob: float = 1,
+ ):
+ super().__init__(name=name, prob=prob)
+ self.shift = shift
+
+ def _instantiate(self, state: RandomState):
+ return {"shift": util.sample_from_dist(self.shift, state)}
+
+ def _transform(self, signal, shift):
+ return signal.shift_phase(shift)
+
+
+class InvertPhase(ShiftPhase):
+ """Inverts the phase of the audio.
+
+ Uses :py:func:`audiotools.core.dsp.DSPMixin.shift_phase`.
+
+ Parameters
+ ----------
+ name : str, optional
+ Name of this transform, used to identify it in the dictionary
+ produced by ``self.instantiate``, by default None
+ prob : float, optional
+ Probability of applying this transform, by default 1.0
+ """
+
+ def __init__(self, name: str = None, prob: float = 1):
+ super().__init__(shift=("const", np.pi), name=name, prob=prob)
+
+
+class CorruptPhase(SpectralTransform):
+ """Corrupts the phase of the audio.
+
+ Uses :py:func:`audiotools.core.dsp.DSPMixin.corrupt_phase`.
+
+ Parameters
+ ----------
+ scale : tuple, optional
+ How much to corrupt phase by, by default ("uniform", 0, np.pi)
+ name : str, optional
+ Name of this transform, used to identify it in the dictionary
+ produced by ``self.instantiate``, by default None
+ prob : float, optional
+ Probability of applying this transform, by default 1.0
+ """
+
+ def __init__(
+ self, scale: tuple = ("uniform", 0, np.pi), name: str = None, prob: float = 1
+ ):
+ super().__init__(name=name, prob=prob)
+ self.scale = scale
+
+ def _instantiate(self, state: RandomState, signal: AudioSignal = None):
+ scale = util.sample_from_dist(self.scale, state)
+ corruption = state.normal(scale=scale, size=signal.phase.shape[1:])
+ return {"corruption": corruption.astype("float32")}
+
+ def _transform(self, signal, corruption):
+ return signal.shift_phase(shift=corruption)
+
+
+class FrequencyMask(SpectralTransform):
+ """Masks a band of frequencies at a center frequency
+ from the audio.
+
+ Uses :py:func:`audiotools.core.dsp.DSPMixin.mask_frequencies`.
+
+ Parameters
+ ----------
+ f_center : tuple, optional
+ Center frequency between 0.0 and 1.0 (Nyquist), by default ("uniform", 0.0, 1.0)
+ f_width : tuple, optional
+ Width of zero'd out band, by default ("const", 0.1)
+ name : str, optional
+ Name of this transform, used to identify it in the dictionary
+ produced by ``self.instantiate``, by default None
+ prob : float, optional
+ Probability of applying this transform, by default 1.0
+ """
+
+ def __init__(
+ self,
+ f_center: tuple = ("uniform", 0.0, 1.0),
+ f_width: tuple = ("const", 0.1),
+ name: str = None,
+ prob: float = 1,
+ ):
+ super().__init__(name=name, prob=prob)
+ self.f_center = f_center
+ self.f_width = f_width
+
+ def _instantiate(self, state: RandomState, signal: AudioSignal):
+ f_center = util.sample_from_dist(self.f_center, state)
+ f_width = util.sample_from_dist(self.f_width, state)
+
+ fmin = max(f_center - (f_width / 2), 0.0)
+ fmax = min(f_center + (f_width / 2), 1.0)
+
+ fmin_hz = (signal.sample_rate / 2) * fmin
+ fmax_hz = (signal.sample_rate / 2) * fmax
+
+ return {"fmin_hz": fmin_hz, "fmax_hz": fmax_hz}
+
+ def _transform(self, signal, fmin_hz: float, fmax_hz: float):
+ return signal.mask_frequencies(fmin_hz=fmin_hz, fmax_hz=fmax_hz)
+
+
+class TimeMask(SpectralTransform):
+ """Masks out contiguous time-steps from signal.
+
+ Uses :py:func:`audiotools.core.dsp.DSPMixin.mask_timesteps`.
+
+ Parameters
+ ----------
+ t_center : tuple, optional
+ Center time in terms of 0.0 and 1.0 (duration of signal),
+ by default ("uniform", 0.0, 1.0)
+ t_width : tuple, optional
+ Width of dropped out portion, by default ("const", 0.025)
+ name : str, optional
+ Name of this transform, used to identify it in the dictionary
+ produced by ``self.instantiate``, by default None
+ prob : float, optional
+ Probability of applying this transform, by default 1.0
+ """
+
+ def __init__(
+ self,
+ t_center: tuple = ("uniform", 0.0, 1.0),
+ t_width: tuple = ("const", 0.025),
+ name: str = None,
+ prob: float = 1,
+ ):
+ super().__init__(name=name, prob=prob)
+ self.t_center = t_center
+ self.t_width = t_width
+
+ def _instantiate(self, state: RandomState, signal: AudioSignal):
+ t_center = util.sample_from_dist(self.t_center, state)
+ t_width = util.sample_from_dist(self.t_width, state)
+
+ tmin = max(t_center - (t_width / 2), 0.0)
+ tmax = min(t_center + (t_width / 2), 1.0)
+
+ tmin_s = signal.signal_duration * tmin
+ tmax_s = signal.signal_duration * tmax
+ return {"tmin_s": tmin_s, "tmax_s": tmax_s}
+
+ def _transform(self, signal, tmin_s: float, tmax_s: float):
+ return signal.mask_timesteps(tmin_s=tmin_s, tmax_s=tmax_s)
+
+
+class MaskLowMagnitudes(SpectralTransform):
+ """Masks low magnitude regions out of signal.
+
+ Uses :py:func:`audiotools.core.dsp.DSPMixin.mask_low_magnitudes`.
+
+ Parameters
+ ----------
+ db_cutoff : tuple, optional
+ Decibel value for which things below it will be masked away,
+ by default ("uniform", -10, 10)
+ name : str, optional
+ Name of this transform, used to identify it in the dictionary
+ produced by ``self.instantiate``, by default None
+ prob : float, optional
+ Probability of applying this transform, by default 1.0
+ """
+
+ def __init__(
+ self,
+ db_cutoff: tuple = ("uniform", -10, 10),
+ name: str = None,
+ prob: float = 1,
+ ):
+ super().__init__(name=name, prob=prob)
+ self.db_cutoff = db_cutoff
+
+ def _instantiate(self, state: RandomState, signal: AudioSignal = None):
+ return {"db_cutoff": util.sample_from_dist(self.db_cutoff, state)}
+
+ def _transform(self, signal, db_cutoff: float):
+ return signal.mask_low_magnitudes(db_cutoff)
+
+
+class Smoothing(BaseTransform):
+ """Convolves the signal with a smoothing window.
+
+ Uses :py:func:`audiotools.core.effects.EffectMixin.convolve`.
+
+ Parameters
+ ----------
+ window_type : tuple, optional
+ Type of window to use, by default ("const", "average")
+ window_length : tuple, optional
+ Length of smoothing window, by
+ default ("choice", [8, 16, 32, 64, 128, 256, 512])
+ name : str, optional
+ Name of this transform, used to identify it in the dictionary
+ produced by ``self.instantiate``, by default None
+ prob : float, optional
+ Probability of applying this transform, by default 1.0
+ """
+
+ def __init__(
+ self,
+ window_type: tuple = ("const", "average"),
+ window_length: tuple = ("choice", [8, 16, 32, 64, 128, 256, 512]),
+ name: str = None,
+ prob: float = 1,
+ ):
+ super().__init__(name=name, prob=prob)
+ self.window_type = window_type
+ self.window_length = window_length
+
+ def _instantiate(self, state: RandomState, signal: AudioSignal = None):
+ window_type = util.sample_from_dist(self.window_type, state)
+ window_length = util.sample_from_dist(self.window_length, state)
+ window = signal.get_window(
+ window_type=window_type, window_length=window_length, device="cpu"
+ )
+ return {"window": AudioSignal(window, signal.sample_rate)}
+
+ def _transform(self, signal, window):
+ sscale = signal.audio_data.abs().max(dim=-1, keepdim=True).values
+ sscale[sscale == 0.0] = 1.0
+
+ out = signal.convolve(window)
+
+ oscale = out.audio_data.abs().max(dim=-1, keepdim=True).values
+ oscale[oscale == 0.0] = 1.0
+
+ out = out * (sscale / oscale)
+ return out
+
+
+class TimeNoise(TimeMask):
+ """Similar to :py:func:`audiotools.data.transforms.TimeMask`, but
+ replaces with noise instead of zeros.
+
+ Parameters
+ ----------
+ t_center : tuple, optional
+ Center time in terms of 0.0 and 1.0 (duration of signal),
+ by default ("uniform", 0.0, 1.0)
+ t_width : tuple, optional
+ Width of dropped out portion, by default ("const", 0.025)
+ name : str, optional
+ Name of this transform, used to identify it in the dictionary
+ produced by ``self.instantiate``, by default None
+ prob : float, optional
+ Probability of applying this transform, by default 1.0
+ """
+
+ def __init__(
+ self,
+ t_center: tuple = ("uniform", 0.0, 1.0),
+ t_width: tuple = ("const", 0.025),
+ name: str = None,
+ prob: float = 1,
+ ):
+ super().__init__(t_center=t_center, t_width=t_width, name=name, prob=prob)
+
+ def _transform(self, signal, tmin_s: float, tmax_s: float):
+ signal = signal.mask_timesteps(tmin_s=tmin_s, tmax_s=tmax_s, val=0.0)
+ mag, phase = signal.magnitude, signal.phase
+
+ mag_r, phase_r = torch.randn_like(mag), torch.randn_like(phase)
+ mask = (mag == 0.0) * (phase == 0.0)
+
+ mag[mask] = mag_r[mask]
+ phase[mask] = phase_r[mask]
+
+ signal.magnitude = mag
+ signal.phase = phase
+ return signal
+
+
+class FrequencyNoise(FrequencyMask):
+ """Similar to :py:func:`audiotools.data.transforms.FrequencyMask`, but
+ replaces with noise instead of zeros.
+
+ Parameters
+ ----------
+ f_center : tuple, optional
+ Center frequency between 0.0 and 1.0 (Nyquist), by default ("uniform", 0.0, 1.0)
+ f_width : tuple, optional
+ Width of zero'd out band, by default ("const", 0.1)
+ name : str, optional
+ Name of this transform, used to identify it in the dictionary
+ produced by ``self.instantiate``, by default None
+ prob : float, optional
+ Probability of applying this transform, by default 1.0
+ """
+
+ def __init__(
+ self,
+ f_center: tuple = ("uniform", 0.0, 1.0),
+ f_width: tuple = ("const", 0.1),
+ name: str = None,
+ prob: float = 1,
+ ):
+ super().__init__(f_center=f_center, f_width=f_width, name=name, prob=prob)
+
+ def _transform(self, signal, fmin_hz: float, fmax_hz: float):
+ signal = signal.mask_frequencies(fmin_hz=fmin_hz, fmax_hz=fmax_hz)
+ mag, phase = signal.magnitude, signal.phase
+
+ mag_r, phase_r = torch.randn_like(mag), torch.randn_like(phase)
+ mask = (mag == 0.0) * (phase == 0.0)
+
+ mag[mask] = mag_r[mask]
+ phase[mask] = phase_r[mask]
+
+ signal.magnitude = mag
+ signal.phase = phase
+ return signal
+
+
+class SpectralDenoising(Equalizer):
+ """Applies denoising algorithm detailed in
+ :py:func:`audiotools.ml.layers.spectral_gate.SpectralGate`,
+ using a randomly generated noise signal for denoising.
+
+ Parameters
+ ----------
+ eq_amount : tuple, optional
+ Amount of eq to apply to noise signal, by default ("const", 1.0)
+ denoise_amount : tuple, optional
+ Amount to denoise by, by default ("uniform", 0.8, 1.0)
+ nz_volume : float, optional
+ Volume of noise to denoise with, by default -40
+ n_bands : int, optional
+ Number of bands in equalizer, by default 6
+ n_freq : int, optional
+ Number of frequency bins to smooth by, by default 3
+ n_time : int, optional
+ Number of time bins to smooth by, by default 5
+ name : str, optional
+ Name of this transform, used to identify it in the dictionary
+ produced by ``self.instantiate``, by default None
+ prob : float, optional
+ Probability of applying this transform, by default 1.0
+ """
+
+ def __init__(
+ self,
+ eq_amount: tuple = ("const", 1.0),
+ denoise_amount: tuple = ("uniform", 0.8, 1.0),
+ nz_volume: float = -40,
+ n_bands: int = 6,
+ n_freq: int = 3,
+ n_time: int = 5,
+ name: str = None,
+ prob: float = 1,
+ ):
+ super().__init__(eq_amount=eq_amount, n_bands=n_bands, name=name, prob=prob)
+
+ self.nz_volume = nz_volume
+ self.denoise_amount = denoise_amount
+ self.spectral_gate = ml.layers.SpectralGate(n_freq, n_time)
+
+ def _transform(self, signal, nz, eq, denoise_amount):
+ nz = nz.normalize(self.nz_volume).equalizer(eq)
+ self.spectral_gate = self.spectral_gate.to(signal.device)
+ signal = self.spectral_gate(signal, nz, denoise_amount)
+ return signal
+
+ def _instantiate(self, state: RandomState):
+ kwargs = super()._instantiate(state)
+ kwargs["denoise_amount"] = util.sample_from_dist(self.denoise_amount, state)
+ kwargs["nz"] = AudioSignal(state.randn(22050), 44100)
+ return kwargs
diff --git a/audiotools/metrics/__init__.py b/audiotools/metrics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9c8d2df61f94afae8e39e57abf156e8e4059a9e
--- /dev/null
+++ b/audiotools/metrics/__init__.py
@@ -0,0 +1,6 @@
+"""
+Functions for comparing AudioSignal objects to one another.
+""" # fmt: skip
+from . import distance
+from . import quality
+from . import spectral
diff --git a/audiotools/metrics/distance.py b/audiotools/metrics/distance.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce78739bfc29f9ddc39b23063b4243ddac10adaf
--- /dev/null
+++ b/audiotools/metrics/distance.py
@@ -0,0 +1,131 @@
+import torch
+from torch import nn
+
+from .. import AudioSignal
+
+
+class L1Loss(nn.L1Loss):
+ """L1 Loss between AudioSignals. Defaults
+ to comparing ``audio_data``, but any
+ attribute of an AudioSignal can be used.
+
+ Parameters
+ ----------
+ attribute : str, optional
+ Attribute of signal to compare, defaults to ``audio_data``.
+ weight : float, optional
+ Weight of this loss, defaults to 1.0.
+ """
+
+ def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
+ self.attribute = attribute
+ self.weight = weight
+ super().__init__(**kwargs)
+
+ def forward(self, x: AudioSignal, y: AudioSignal):
+ """
+ Parameters
+ ----------
+ x : AudioSignal
+ Estimate AudioSignal
+ y : AudioSignal
+ Reference AudioSignal
+
+ Returns
+ -------
+ torch.Tensor
+ L1 loss between AudioSignal attributes.
+ """
+ if isinstance(x, AudioSignal):
+ x = getattr(x, self.attribute)
+ y = getattr(y, self.attribute)
+ return super().forward(x, y)
+
+
+class SISDRLoss(nn.Module):
+ """
+ Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
+ of estimated and reference audio signals or aligned features.
+
+ Parameters
+ ----------
+ scaling : int, optional
+ Whether to use scale-invariant (True) or
+ signal-to-noise ratio (False), by default True
+ reduction : str, optional
+ How to reduce across the batch (either 'mean',
+ 'sum', or none).], by default ' mean'
+ zero_mean : int, optional
+ Zero mean the references and estimates before
+ computing the loss, by default True
+ clip_min : int, optional
+ The minimum possible loss value. Helps network
+ to not focus on making already good examples better, by default None
+ weight : float, optional
+ Weight of this loss, defaults to 1.0.
+ """
+
+ def __init__(
+ self,
+ scaling: int = True,
+ reduction: str = "mean",
+ zero_mean: int = True,
+ clip_min: int = None,
+ weight: float = 1.0,
+ ):
+ self.scaling = scaling
+ self.reduction = reduction
+ self.zero_mean = zero_mean
+ self.clip_min = clip_min
+ self.weight = weight
+ super().__init__()
+
+ def forward(self, x: AudioSignal, y: AudioSignal):
+ eps = 1e-8
+ # nb, nc, nt
+ if isinstance(x, AudioSignal):
+ references = x.audio_data
+ estimates = y.audio_data
+ else:
+ references = x
+ estimates = y
+
+ nb = references.shape[0]
+ references = references.reshape(nb, 1, -1).permute(0, 2, 1)
+ estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
+
+ # samples now on axis 1
+ if self.zero_mean:
+ mean_reference = references.mean(dim=1, keepdim=True)
+ mean_estimate = estimates.mean(dim=1, keepdim=True)
+ else:
+ mean_reference = 0
+ mean_estimate = 0
+
+ _references = references - mean_reference
+ _estimates = estimates - mean_estimate
+
+ references_projection = (_references**2).sum(dim=-2) + eps
+ references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
+
+ scale = (
+ (references_on_estimates / references_projection).unsqueeze(1)
+ if self.scaling
+ else 1
+ )
+
+ e_true = scale * _references
+ e_res = _estimates - e_true
+
+ signal = (e_true**2).sum(dim=1)
+ noise = (e_res**2).sum(dim=1)
+ sdr = -10 * torch.log10(signal / noise + eps)
+
+ if self.clip_min is not None:
+ sdr = torch.clamp(sdr, min=self.clip_min)
+
+ if self.reduction == "mean":
+ sdr = sdr.mean()
+ elif self.reduction == "sum":
+ sdr = sdr.sum()
+ return sdr
diff --git a/audiotools/metrics/quality.py b/audiotools/metrics/quality.py
new file mode 100644
index 0000000000000000000000000000000000000000..1608f25507082b49ccbf49289025a5a94a422808
--- /dev/null
+++ b/audiotools/metrics/quality.py
@@ -0,0 +1,159 @@
+import os
+
+import numpy as np
+import torch
+
+from .. import AudioSignal
+
+
+def stoi(
+ estimates: AudioSignal,
+ references: AudioSignal,
+ extended: int = False,
+):
+ """Short term objective intelligibility
+ Computes the STOI (See [1][2]) of a denoised signal compared to a clean
+ signal, The output is expected to have a monotonic relation with the
+ subjective speech-intelligibility, where a higher score denotes better
+ speech intelligibility. Uses pystoi under the hood.
+
+ Parameters
+ ----------
+ estimates : AudioSignal
+ Denoised speech
+ references : AudioSignal
+ Clean original speech
+ extended : int, optional
+ Boolean, whether to use the extended STOI described in [3], by default False
+
+ Returns
+ -------
+ Tensor[float]
+ Short time objective intelligibility measure between clean and
+ denoised speech
+
+ References
+ ----------
+ 1. C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'A Short-Time
+ Objective Intelligibility Measure for Time-Frequency Weighted Noisy
+ Speech', ICASSP 2010, Texas, Dallas.
+ 2. C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'An Algorithm for
+ Intelligibility Prediction of Time-Frequency Weighted Noisy Speech',
+ IEEE Transactions on Audio, Speech, and Language Processing, 2011.
+ 3. Jesper Jensen and Cees H. Taal, 'An Algorithm for Predicting the
+ Intelligibility of Speech Masked by Modulated Noise Maskers',
+ IEEE Transactions on Audio, Speech and Language Processing, 2016.
+ """
+ import pystoi
+
+ estimates = estimates.clone().to_mono()
+ references = references.clone().to_mono()
+
+ stois = []
+ for i in range(estimates.batch_size):
+ _stoi = pystoi.stoi(
+ references.audio_data[i, 0].detach().cpu().numpy(),
+ estimates.audio_data[i, 0].detach().cpu().numpy(),
+ references.sample_rate,
+ extended=extended,
+ )
+ stois.append(_stoi)
+ return torch.from_numpy(np.array(stois))
+
+
+def pesq(
+ estimates: AudioSignal,
+ references: AudioSignal,
+ mode: str = "wb",
+ target_sr: float = 16000,
+):
+ """_summary_
+
+ Parameters
+ ----------
+ estimates : AudioSignal
+ Degraded AudioSignal
+ references : AudioSignal
+ Reference AudioSignal
+ mode : str, optional
+ 'wb' (wide-band) or 'nb' (narrow-band), by default "wb"
+ target_sr : float, optional
+ Target sample rate, by default 16000
+
+ Returns
+ -------
+ Tensor[float]
+ PESQ score: P.862.2 Prediction (MOS-LQO)
+ """
+ from pesq import pesq as pesq_fn
+
+ estimates = estimates.clone().to_mono().resample(target_sr)
+ references = references.clone().to_mono().resample(target_sr)
+
+ pesqs = []
+ for i in range(estimates.batch_size):
+ _pesq = pesq_fn(
+ estimates.sample_rate,
+ references.audio_data[i, 0].detach().cpu().numpy(),
+ estimates.audio_data[i, 0].detach().cpu().numpy(),
+ mode,
+ )
+ pesqs.append(_pesq)
+ return torch.from_numpy(np.array(pesqs))
+
+
+def visqol(
+ estimates: AudioSignal,
+ references: AudioSignal,
+ mode: str = "audio",
+): # pragma: no cover
+ """ViSQOL score.
+
+ Parameters
+ ----------
+ estimates : AudioSignal
+ Degraded AudioSignal
+ references : AudioSignal
+ Reference AudioSignal
+ mode : str, optional
+ 'audio' or 'speech', by default 'audio'
+
+ Returns
+ -------
+ Tensor[float]
+ ViSQOL score (MOS-LQO)
+ """
+ from visqol import visqol_lib_py
+ from visqol.pb2 import visqol_config_pb2
+ from visqol.pb2 import similarity_result_pb2
+
+ config = visqol_config_pb2.VisqolConfig()
+ if mode == "audio":
+ target_sr = 48000
+ config.options.use_speech_scoring = False
+ svr_model_path = "libsvm_nu_svr_model.txt"
+ elif mode == "speech":
+ target_sr = 16000
+ config.options.use_speech_scoring = True
+ svr_model_path = "lattice_tcditugenmeetpackhref_ls2_nl60_lr12_bs2048_learn.005_ep2400_train1_7_raw.tflite"
+ else:
+ raise ValueError(f"Unrecognized mode: {mode}")
+ config.audio.sample_rate = target_sr
+ config.options.svr_model_path = os.path.join(
+ os.path.dirname(visqol_lib_py.__file__), "model", svr_model_path
+ )
+
+ api = visqol_lib_py.VisqolApi()
+ api.Create(config)
+
+ estimates = estimates.clone().to_mono().resample(target_sr)
+ references = references.clone().to_mono().resample(target_sr)
+
+ visqols = []
+ for i in range(estimates.batch_size):
+ _visqol = api.Measure(
+ references.audio_data[i, 0].detach().cpu().numpy().astype(float),
+ estimates.audio_data[i, 0].detach().cpu().numpy().astype(float),
+ )
+ visqols.append(_visqol.moslqo)
+ return torch.from_numpy(np.array(visqols))
diff --git a/audiotools/metrics/spectral.py b/audiotools/metrics/spectral.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ce953882efa4e5b777a0348bee6c1be39279a6c
--- /dev/null
+++ b/audiotools/metrics/spectral.py
@@ -0,0 +1,247 @@
+import typing
+from typing import List
+
+import numpy as np
+from torch import nn
+
+from .. import AudioSignal
+from .. import STFTParams
+
+
+class MultiScaleSTFTLoss(nn.Module):
+ """Computes the multi-scale STFT loss from [1].
+
+ Parameters
+ ----------
+ window_lengths : List[int], optional
+ Length of each window of each STFT, by default [2048, 512]
+ loss_fn : typing.Callable, optional
+ How to compare each loss, by default nn.L1Loss()
+ clamp_eps : float, optional
+ Clamp on the log magnitude, below, by default 1e-5
+ mag_weight : float, optional
+ Weight of raw magnitude portion of loss, by default 1.0
+ log_weight : float, optional
+ Weight of log magnitude portion of loss, by default 1.0
+ pow : float, optional
+ Power to raise magnitude to before taking log, by default 2.0
+ weight : float, optional
+ Weight of this loss, by default 1.0
+ match_stride : bool, optional
+ Whether to match the stride of convolutional layers, by default False
+
+ References
+ ----------
+
+ 1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
+ "DDSP: Differentiable Digital Signal Processing."
+ International Conference on Learning Representations. 2019.
+ """
+
+ def __init__(
+ self,
+ window_lengths: List[int] = [2048, 512],
+ loss_fn: typing.Callable = nn.L1Loss(),
+ clamp_eps: float = 1e-5,
+ mag_weight: float = 1.0,
+ log_weight: float = 1.0,
+ pow: float = 2.0,
+ weight: float = 1.0,
+ match_stride: bool = False,
+ window_type: str = None,
+ ):
+ super().__init__()
+ self.stft_params = [
+ STFTParams(
+ window_length=w,
+ hop_length=w // 4,
+ match_stride=match_stride,
+ window_type=window_type,
+ )
+ for w in window_lengths
+ ]
+ self.loss_fn = loss_fn
+ self.log_weight = log_weight
+ self.mag_weight = mag_weight
+ self.clamp_eps = clamp_eps
+ self.weight = weight
+ self.pow = pow
+
+ def forward(self, x: AudioSignal, y: AudioSignal):
+ """Computes multi-scale STFT between an estimate and a reference
+ signal.
+
+ Parameters
+ ----------
+ x : AudioSignal
+ Estimate signal
+ y : AudioSignal
+ Reference signal
+
+ Returns
+ -------
+ torch.Tensor
+ Multi-scale STFT loss.
+ """
+ loss = 0.0
+ for s in self.stft_params:
+ x.stft(s.window_length, s.hop_length, s.window_type)
+ y.stft(s.window_length, s.hop_length, s.window_type)
+ loss += self.log_weight * self.loss_fn(
+ x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
+ y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
+ )
+ loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
+ return loss
+
+
+class MelSpectrogramLoss(nn.Module):
+ """Compute distance between mel spectrograms. Can be used
+ in a multi-scale way.
+
+ Parameters
+ ----------
+ n_mels : List[int]
+ Number of mels per STFT, by default [150, 80],
+ window_lengths : List[int], optional
+ Length of each window of each STFT, by default [2048, 512]
+ loss_fn : typing.Callable, optional
+ How to compare each loss, by default nn.L1Loss()
+ clamp_eps : float, optional
+ Clamp on the log magnitude, below, by default 1e-5
+ mag_weight : float, optional
+ Weight of raw magnitude portion of loss, by default 1.0
+ log_weight : float, optional
+ Weight of log magnitude portion of loss, by default 1.0
+ pow : float, optional
+ Power to raise magnitude to before taking log, by default 2.0
+ weight : float, optional
+ Weight of this loss, by default 1.0
+ match_stride : bool, optional
+ Whether to match the stride of convolutional layers, by default False
+ """
+
+ def __init__(
+ self,
+ n_mels: List[int] = [150, 80],
+ window_lengths: List[int] = [2048, 512],
+ loss_fn: typing.Callable = nn.L1Loss(),
+ clamp_eps: float = 1e-5,
+ mag_weight: float = 1.0,
+ log_weight: float = 1.0,
+ pow: float = 2.0,
+ weight: float = 1.0,
+ match_stride: bool = False,
+ mel_fmin: List[float] = [0.0, 0.0],
+ mel_fmax: List[float] = [None, None],
+ window_type: str = None,
+ ):
+ super().__init__()
+ self.stft_params = [
+ STFTParams(
+ window_length=w,
+ hop_length=w // 4,
+ match_stride=match_stride,
+ window_type=window_type,
+ )
+ for w in window_lengths
+ ]
+ self.n_mels = n_mels
+ self.loss_fn = loss_fn
+ self.clamp_eps = clamp_eps
+ self.log_weight = log_weight
+ self.mag_weight = mag_weight
+ self.weight = weight
+ self.mel_fmin = mel_fmin
+ self.mel_fmax = mel_fmax
+ self.pow = pow
+
+ def forward(self, x: AudioSignal, y: AudioSignal):
+ """Computes mel loss between an estimate and a reference
+ signal.
+
+ Parameters
+ ----------
+ x : AudioSignal
+ Estimate signal
+ y : AudioSignal
+ Reference signal
+
+ Returns
+ -------
+ torch.Tensor
+ Mel loss.
+ """
+ loss = 0.0
+ for n_mels, fmin, fmax, s in zip(
+ self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
+ ):
+ kwargs = {
+ "window_length": s.window_length,
+ "hop_length": s.hop_length,
+ "window_type": s.window_type,
+ }
+ x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
+ y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
+
+ loss += self.log_weight * self.loss_fn(
+ x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
+ y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
+ )
+ loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
+ return loss
+
+
+class PhaseLoss(nn.Module):
+ """Difference between phase spectrograms.
+
+ Parameters
+ ----------
+ window_length : int, optional
+ Length of STFT window, by default 2048
+ hop_length : int, optional
+ Hop length of STFT window, by default 512
+ weight : float, optional
+ Weight of loss, by default 1.0
+ """
+
+ def __init__(
+ self, window_length: int = 2048, hop_length: int = 512, weight: float = 1.0
+ ):
+ super().__init__()
+
+ self.weight = weight
+ self.stft_params = STFTParams(window_length, hop_length)
+
+ def forward(self, x: AudioSignal, y: AudioSignal):
+ """Computes phase loss between an estimate and a reference
+ signal.
+
+ Parameters
+ ----------
+ x : AudioSignal
+ Estimate signal
+ y : AudioSignal
+ Reference signal
+
+ Returns
+ -------
+ torch.Tensor
+ Phase loss.
+ """
+ s = self.stft_params
+ x.stft(s.window_length, s.hop_length, s.window_type)
+ y.stft(s.window_length, s.hop_length, s.window_type)
+
+ # Take circular difference
+ diff = x.phase - y.phase
+ diff[diff < -np.pi] += 2 * np.pi
+ diff[diff > np.pi] -= -2 * np.pi
+
+ # Scale true magnitude to weights in [0, 1]
+ x_min, x_max = x.magnitude.min(), x.magnitude.max()
+ weights = (x.magnitude - x_min) / (x_max - x_min)
+
+ # Take weighted mean of all phase errors
+ loss = ((weights * diff) ** 2).mean()
+ return loss
diff --git a/audiotools/ml/__init__.py b/audiotools/ml/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9ca69977bad57e1a92b7551d601d9224ee854ab
--- /dev/null
+++ b/audiotools/ml/__init__.py
@@ -0,0 +1,5 @@
+from . import decorators
+from . import layers
+from .accelerator import Accelerator
+from .experiment import Experiment
+from .layers import BaseModel
diff --git a/audiotools/ml/accelerator.py b/audiotools/ml/accelerator.py
new file mode 100644
index 0000000000000000000000000000000000000000..37c6e8d954f112b8b0aff257894e62add8874e30
--- /dev/null
+++ b/audiotools/ml/accelerator.py
@@ -0,0 +1,184 @@
+import os
+import typing
+
+import torch
+import torch.distributed as dist
+from torch.nn.parallel import DataParallel
+from torch.nn.parallel import DistributedDataParallel
+
+from ..data.datasets import ResumableDistributedSampler as DistributedSampler
+from ..data.datasets import ResumableSequentialSampler as SequentialSampler
+
+
+class Accelerator: # pragma: no cover
+ """This class is used to prepare models and dataloaders for
+ usage with DDP or DP. Use the functions prepare_model, prepare_dataloader to
+ prepare the respective objects. In the case of models, they are moved to
+ the appropriate GPU and SyncBatchNorm is applied to them. In the case of
+ dataloaders, a sampler is created and the dataloader is initialized with
+ that sampler.
+
+ If the world size is 1, prepare_model and prepare_dataloader are
+ no-ops. If the environment variable ``LOCAL_RANK`` is not set, then the
+ script was launched without ``torchrun``, and ``DataParallel``
+ will be used instead of ``DistributedDataParallel`` (not recommended), if
+ the world size (number of GPUs) is greater than 1.
+
+ Parameters
+ ----------
+ amp : bool, optional
+ Whether or not to enable automatic mixed precision, by default False
+ """
+
+ def __init__(self, amp: bool = False):
+ local_rank = os.getenv("LOCAL_RANK", None)
+ self.world_size = torch.cuda.device_count()
+
+ self.use_ddp = self.world_size > 1 and local_rank is not None
+ self.use_dp = self.world_size > 1 and local_rank is None
+ self.device = "cpu" if self.world_size == 0 else "cuda"
+
+ if self.use_ddp:
+ local_rank = int(local_rank)
+ dist.init_process_group(
+ "nccl",
+ init_method="env://",
+ world_size=self.world_size,
+ rank=local_rank,
+ )
+
+ self.local_rank = 0 if local_rank is None else local_rank
+ self.amp = amp
+
+ class DummyScaler:
+ def __init__(self):
+ pass
+
+ def step(self, optimizer):
+ optimizer.step()
+
+ def scale(self, loss):
+ return loss
+
+ def unscale_(self, optimizer):
+ return optimizer
+
+ def update(self):
+ pass
+
+ self.scaler = torch.cuda.amp.GradScaler() if amp else DummyScaler()
+ self.device_ctx = (
+ torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
+ )
+
+ def __enter__(self):
+ if self.device_ctx is not None:
+ self.device_ctx.__enter__()
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ if self.device_ctx is not None:
+ self.device_ctx.__exit__(exc_type, exc_value, traceback)
+
+ def prepare_model(self, model: torch.nn.Module, **kwargs):
+ """Prepares model for DDP or DP. The model is moved to
+ the device of the correct rank.
+
+ Parameters
+ ----------
+ model : torch.nn.Module
+ Model that is converted for DDP or DP.
+
+ Returns
+ -------
+ torch.nn.Module
+ Wrapped model, or original model if DDP and DP are turned off.
+ """
+ model = model.to(self.device)
+ if self.use_ddp:
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+ model = DistributedDataParallel(
+ model, device_ids=[self.local_rank], **kwargs
+ )
+ elif self.use_dp:
+ model = DataParallel(model, **kwargs)
+ return model
+
+ # Automatic mixed-precision utilities
+ def autocast(self, *args, **kwargs):
+ """Context manager for autocasting. Arguments
+ go to ``torch.cuda.amp.autocast``.
+ """
+ return torch.cuda.amp.autocast(self.amp, *args, **kwargs)
+
+ def backward(self, loss: torch.Tensor):
+ """Backwards pass, after scaling the loss if ``amp`` is
+ enabled.
+
+ Parameters
+ ----------
+ loss : torch.Tensor
+ Loss value.
+ """
+ self.scaler.scale(loss).backward()
+
+ def step(self, optimizer: torch.optim.Optimizer):
+ """Steps the optimizer, using a ``scaler`` if ``amp`` is
+ enabled.
+
+ Parameters
+ ----------
+ optimizer : torch.optim.Optimizer
+ Optimizer to step forward.
+ """
+ self.scaler.step(optimizer)
+
+ def update(self):
+ """Updates the scale factor."""
+ self.scaler.update()
+
+ def prepare_dataloader(
+ self, dataset: typing.Iterable, start_idx: int = None, **kwargs
+ ):
+ """Wraps a dataset with a DataLoader, using the correct sampler if DDP is
+ enabled.
+
+ Parameters
+ ----------
+ dataset : typing.Iterable
+ Dataset to build Dataloader around.
+ start_idx : int, optional
+ Start index of sampler, useful if resuming from some epoch,
+ by default None
+
+ Returns
+ -------
+ _type_
+ _description_
+ """
+
+ if self.use_ddp:
+ sampler = DistributedSampler(
+ dataset,
+ start_idx,
+ num_replicas=self.world_size,
+ rank=self.local_rank,
+ )
+ if "num_workers" in kwargs:
+ kwargs["num_workers"] = max(kwargs["num_workers"] // self.world_size, 1)
+ kwargs["batch_size"] = max(kwargs["batch_size"] // self.world_size, 1)
+ else:
+ sampler = SequentialSampler(dataset, start_idx)
+
+ dataloader = torch.utils.data.DataLoader(dataset, sampler=sampler, **kwargs)
+ return dataloader
+
+ @staticmethod
+ def unwrap(model):
+ """Unwraps the model if it was wrapped in DDP or DP, otherwise
+ just returns the model. Use this to unwrap the model returned by
+ :py:func:`audiotools.ml.accelerator.Accelerator.prepare_model`.
+ """
+ if hasattr(model, "module"):
+ return model.module
+ return model
diff --git a/audiotools/ml/decorators.py b/audiotools/ml/decorators.py
new file mode 100644
index 0000000000000000000000000000000000000000..834ec10270ff9e8e84a5fa99e13a686516a4af41
--- /dev/null
+++ b/audiotools/ml/decorators.py
@@ -0,0 +1,440 @@
+import math
+import os
+import time
+from collections import defaultdict
+from functools import wraps
+
+import torch
+import torch.distributed as dist
+from rich import box
+from rich.console import Console
+from rich.console import Group
+from rich.live import Live
+from rich.markdown import Markdown
+from rich.padding import Padding
+from rich.panel import Panel
+from rich.progress import BarColumn
+from rich.progress import Progress
+from rich.progress import SpinnerColumn
+from rich.progress import TimeElapsedColumn
+from rich.progress import TimeRemainingColumn
+from rich.rule import Rule
+from rich.table import Table
+from torch.utils.tensorboard import SummaryWriter
+
+
+# This is here so that the history can be pickled.
+def default_list():
+ return []
+
+
+class Mean:
+ """Keeps track of the running mean, along with the latest
+ value.
+ """
+
+ def __init__(self):
+ self.reset()
+
+ def __call__(self):
+ mean = self.total / max(self.count, 1)
+ return mean
+
+ def reset(self):
+ self.count = 0
+ self.total = 0
+
+ def update(self, val):
+ if math.isfinite(val):
+ self.count += 1
+ self.total += val
+
+
+def when(condition):
+ """Runs a function only when the condition is met. The condition is
+ a function that is run.
+
+ Parameters
+ ----------
+ condition : Callable
+ Function to run to check whether or not to run the decorated
+ function.
+
+ Example
+ -------
+ Checkpoint only runs every 100 iterations, and only if the
+ local rank is 0.
+
+ >>> i = 0
+ >>> rank = 0
+ >>>
+ >>> @when(lambda: i % 100 == 0 and rank == 0)
+ >>> def checkpoint():
+ >>> print("Saving to /runs/exp1")
+ >>>
+ >>> for i in range(1000):
+ >>> checkpoint()
+
+ """
+
+ def decorator(fn):
+ @wraps(fn)
+ def decorated(*args, **kwargs):
+ if condition():
+ return fn(*args, **kwargs)
+
+ return decorated
+
+ return decorator
+
+
+def timer(prefix: str = "time"):
+ """Adds execution time to the output dictionary of the decorated
+ function. The function decorated by this must output a dictionary.
+ The key added will follow the form "[prefix]/[name_of_function]"
+
+ Parameters
+ ----------
+ prefix : str, optional
+ The key added will follow the form "[prefix]/[name_of_function]",
+ by default "time".
+ """
+
+ def decorator(fn):
+ @wraps(fn)
+ def decorated(*args, **kwargs):
+ s = time.perf_counter()
+ output = fn(*args, **kwargs)
+ assert isinstance(output, dict)
+ e = time.perf_counter()
+ output[f"{prefix}/{fn.__name__}"] = e - s
+ return output
+
+ return decorated
+
+ return decorator
+
+
+class Tracker:
+ """
+ A tracker class that helps to monitor the progress of training and logging the metrics.
+
+ Attributes
+ ----------
+ metrics : dict
+ A dictionary containing the metrics for each label.
+ history : dict
+ A dictionary containing the history of metrics for each label.
+ writer : SummaryWriter
+ A SummaryWriter object for logging the metrics.
+ rank : int
+ The rank of the current process.
+ step : int
+ The current step of the training.
+ tasks : dict
+ A dictionary containing the progress bars and tables for each label.
+ pbar : Progress
+ A progress bar object for displaying the progress.
+ consoles : list
+ A list of console objects for logging.
+ live : Live
+ A Live object for updating the display live.
+
+ Methods
+ -------
+ print(msg: str)
+ Prints the given message to all consoles.
+ update(label: str, fn_name: str)
+ Updates the progress bar and table for the given label.
+ done(label: str, title: str)
+ Resets the progress bar and table for the given label and prints the final result.
+ track(label: str, length: int, completed: int = 0, op: dist.ReduceOp = dist.ReduceOp.AVG, ddp_active: bool = "LOCAL_RANK" in os.environ)
+ A decorator for tracking the progress and metrics of a function.
+ log(label: str, value_type: str = "value", history: bool = True)
+ A decorator for logging the metrics of a function.
+ is_best(label: str, key: str) -> bool
+ Checks if the latest value of the given key in the label is the best so far.
+ state_dict() -> dict
+ Returns a dictionary containing the state of the tracker.
+ load_state_dict(state_dict: dict) -> Tracker
+ Loads the state of the tracker from the given state dictionary.
+ """
+
+ def __init__(
+ self,
+ writer: SummaryWriter = None,
+ log_file: str = None,
+ rank: int = 0,
+ console_width: int = 100,
+ step: int = 0,
+ ):
+ """
+ Initializes the Tracker object.
+
+ Parameters
+ ----------
+ writer : SummaryWriter, optional
+ A SummaryWriter object for logging the metrics, by default None.
+ log_file : str, optional
+ The path to the log file, by default None.
+ rank : int, optional
+ The rank of the current process, by default 0.
+ console_width : int, optional
+ The width of the console, by default 100.
+ step : int, optional
+ The current step of the training, by default 0.
+ """
+ self.metrics = {}
+ self.history = {}
+ self.writer = writer
+ self.rank = rank
+ self.step = step
+
+ # Create progress bars etc.
+ self.tasks = {}
+ self.pbar = Progress(
+ SpinnerColumn(),
+ "[progress.description]{task.description}",
+ "{task.completed}/{task.total}",
+ BarColumn(),
+ TimeElapsedColumn(),
+ "/",
+ TimeRemainingColumn(),
+ )
+ self.consoles = [Console(width=console_width)]
+ self.live = Live(console=self.consoles[0], refresh_per_second=10)
+ if log_file is not None:
+ self.consoles.append(Console(width=console_width, file=open(log_file, "a")))
+
+ def print(self, msg):
+ """
+ Prints the given message to all consoles.
+
+ Parameters
+ ----------
+ msg : str
+ The message to be printed.
+ """
+ if self.rank == 0:
+ for c in self.consoles:
+ c.log(msg)
+
+ def update(self, label, fn_name):
+ """
+ Updates the progress bar and table for the given label.
+
+ Parameters
+ ----------
+ label : str
+ The label of the progress bar and table to be updated.
+ fn_name : str
+ The name of the function associated with the label.
+ """
+ if self.rank == 0:
+ self.pbar.advance(self.tasks[label]["pbar"])
+
+ # Create table
+ table = Table(title=label, expand=True, box=box.MINIMAL)
+ table.add_column("key", style="cyan")
+ table.add_column("value", style="bright_blue")
+ table.add_column("mean", style="bright_green")
+
+ keys = self.metrics[label]["value"].keys()
+ for k in keys:
+ value = self.metrics[label]["value"][k]
+ mean = self.metrics[label]["mean"][k]()
+ table.add_row(k, f"{value:10.6f}", f"{mean:10.6f}")
+
+ self.tasks[label]["table"] = table
+ tables = [t["table"] for t in self.tasks.values()]
+ group = Group(*tables, self.pbar)
+ self.live.update(
+ Group(
+ Padding("", (0, 0)),
+ Rule(f"[italic]{fn_name}()", style="white"),
+ Padding("", (0, 0)),
+ Panel.fit(
+ group, padding=(0, 5), title="[b]Progress", border_style="blue"
+ ),
+ )
+ )
+
+ def done(self, label: str, title: str):
+ """
+ Resets the progress bar and table for the given label and prints the final result.
+
+ Parameters
+ ----------
+ label : str
+ The label of the progress bar and table to be reset.
+ title : str
+ The title to be displayed when printing the final result.
+ """
+ for label in self.metrics:
+ for v in self.metrics[label]["mean"].values():
+ v.reset()
+
+ if self.rank == 0:
+ self.pbar.reset(self.tasks[label]["pbar"])
+ tables = [t["table"] for t in self.tasks.values()]
+ group = Group(Markdown(f"# {title}"), *tables, self.pbar)
+ self.print(group)
+
+ def track(
+ self,
+ label: str,
+ length: int,
+ completed: int = 0,
+ op: dist.ReduceOp = dist.ReduceOp.AVG,
+ ddp_active: bool = "LOCAL_RANK" in os.environ,
+ ):
+ """
+ A decorator for tracking the progress and metrics of a function.
+
+ Parameters
+ ----------
+ label : str
+ The label to be associated with the progress and metrics.
+ length : int
+ The total number of iterations to be completed.
+ completed : int, optional
+ The number of iterations already completed, by default 0.
+ op : dist.ReduceOp, optional
+ The reduce operation to be used, by default dist.ReduceOp.AVG.
+ ddp_active : bool, optional
+ Whether the DistributedDataParallel is active, by default "LOCAL_RANK" in os.environ.
+ """
+ self.tasks[label] = {
+ "pbar": self.pbar.add_task(
+ f"[white]Iteration ({label})", total=length, completed=completed
+ ),
+ "table": Table(),
+ }
+ self.metrics[label] = {
+ "value": defaultdict(),
+ "mean": defaultdict(lambda: Mean()),
+ }
+
+ def decorator(fn):
+ @wraps(fn)
+ def decorated(*args, **kwargs):
+ output = fn(*args, **kwargs)
+ if not isinstance(output, dict):
+ self.update(label, fn.__name__)
+ return output
+ # Collect across all DDP processes
+ scalar_keys = []
+ for k, v in output.items():
+ if isinstance(v, (int, float)):
+ v = torch.tensor([v])
+ if not torch.is_tensor(v):
+ continue
+ if ddp_active and v.is_cuda: # pragma: no cover
+ dist.all_reduce(v, op=op)
+ output[k] = v.detach()
+ if torch.numel(v) == 1:
+ scalar_keys.append(k)
+ output[k] = v.item()
+
+ # Save the outputs to tracker
+ for k, v in output.items():
+ if k not in scalar_keys:
+ continue
+ self.metrics[label]["value"][k] = v
+ # Update the running mean
+ self.metrics[label]["mean"][k].update(v)
+
+ self.update(label, fn.__name__)
+ return output
+
+ return decorated
+
+ return decorator
+
+ def log(self, label: str, value_type: str = "value", history: bool = True):
+ """
+ A decorator for logging the metrics of a function.
+
+ Parameters
+ ----------
+ label : str
+ The label to be associated with the logging.
+ value_type : str, optional
+ The type of value to be logged, by default "value".
+ history : bool, optional
+ Whether to save the history of the metrics, by default True.
+ """
+ assert value_type in ["mean", "value"]
+ if history:
+ if label not in self.history:
+ self.history[label] = defaultdict(default_list)
+
+ def decorator(fn):
+ @wraps(fn)
+ def decorated(*args, **kwargs):
+ output = fn(*args, **kwargs)
+ if self.rank == 0:
+ nonlocal value_type, label
+ metrics = self.metrics[label][value_type]
+ for k, v in metrics.items():
+ v = v() if isinstance(v, Mean) else v
+ if self.writer is not None:
+ self.writer.add_scalar(f"{k}/{label}", v, self.step)
+ if label in self.history:
+ self.history[label][k].append(v)
+
+ if label in self.history:
+ self.history[label]["step"].append(self.step)
+
+ return output
+
+ return decorated
+
+ return decorator
+
+ def is_best(self, label, key):
+ """
+ Checks if the latest value of the given key in the label is the best so far.
+
+ Parameters
+ ----------
+ label : str
+ The label of the metrics to be checked.
+ key : str
+ The key of the metric to be checked.
+
+ Returns
+ -------
+ bool
+ True if the latest value is the best so far, otherwise False.
+ """
+ return self.history[label][key][-1] == min(self.history[label][key])
+
+ def state_dict(self):
+ """
+ Returns a dictionary containing the state of the tracker.
+
+ Returns
+ -------
+ dict
+ A dictionary containing the history and step of the tracker.
+ """
+ return {"history": self.history, "step": self.step}
+
+ def load_state_dict(self, state_dict):
+ """
+ Loads the state of the tracker from the given state dictionary.
+
+ Parameters
+ ----------
+ state_dict : dict
+ A dictionary containing the history and step of the tracker.
+
+ Returns
+ -------
+ Tracker
+ The tracker object with the loaded state.
+ """
+ self.history = state_dict["history"]
+ self.step = state_dict["step"]
+ return self
diff --git a/audiotools/ml/experiment.py b/audiotools/ml/experiment.py
new file mode 100644
index 0000000000000000000000000000000000000000..62833d0f8f80dcdf496a1a5d2785ef666e0a15b6
--- /dev/null
+++ b/audiotools/ml/experiment.py
@@ -0,0 +1,90 @@
+"""
+Useful class for Experiment tracking, and ensuring code is
+saved alongside files.
+""" # fmt: skip
+import datetime
+import os
+import shlex
+import shutil
+import subprocess
+import typing
+from pathlib import Path
+
+import randomname
+
+
+class Experiment:
+ """This class contains utilities for managing experiments.
+ It is a context manager, that when you enter it, changes
+ your directory to a specified experiment folder (which
+ optionally can have an automatically generated experiment
+ name, or a specified one), and changes the CUDA device used
+ to the specified device (or devices).
+
+ Parameters
+ ----------
+ exp_directory : str
+ Folder where all experiments are saved, by default "runs/".
+ exp_name : str, optional
+ Name of the experiment, by default uses the current time, date, and
+ hostname to save.
+ """
+
+ def __init__(
+ self,
+ exp_directory: str = "runs/",
+ exp_name: str = None,
+ ):
+ if exp_name is None:
+ exp_name = self.generate_exp_name()
+ exp_dir = Path(exp_directory) / exp_name
+ exp_dir.mkdir(parents=True, exist_ok=True)
+
+ self.exp_dir = exp_dir
+ self.exp_name = exp_name
+ self.git_tracked_files = (
+ subprocess.check_output(
+ shlex.split("git ls-tree --full-tree --name-only -r HEAD")
+ )
+ .decode("utf-8")
+ .splitlines()
+ )
+ self.parent_directory = Path(".").absolute()
+
+ def __enter__(self):
+ self.prev_dir = os.getcwd()
+ os.chdir(self.exp_dir)
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ os.chdir(self.prev_dir)
+
+ @staticmethod
+ def generate_exp_name():
+ """Generates a random experiment name based on the date
+ and a randomly generated adjective-noun tuple.
+
+ Returns
+ -------
+ str
+ Randomly generated experiment name.
+ """
+ date = datetime.datetime.now().strftime("%y%m%d")
+ name = f"{date}-{randomname.get_name()}"
+ return name
+
+ def snapshot(self, filter_fn: typing.Callable = lambda f: True):
+ """Captures a full snapshot of all the files tracked by git at the time
+ the experiment is run. It also captures the diff against the committed
+ code as a separate file.
+
+ Parameters
+ ----------
+ filter_fn : typing.Callable, optional
+ Function that can be used to exclude some files
+ from the snapshot, by default accepts all files
+ """
+ for f in self.git_tracked_files:
+ if filter_fn(f):
+ Path(f).parent.mkdir(parents=True, exist_ok=True)
+ shutil.copyfile(self.parent_directory / f, f)
diff --git a/audiotools/ml/layers/__init__.py b/audiotools/ml/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..92a016cab2ddf06bf5dadfae241b7e5d9def4878
--- /dev/null
+++ b/audiotools/ml/layers/__init__.py
@@ -0,0 +1,2 @@
+from .base import BaseModel
+from .spectral_gate import SpectralGate
diff --git a/audiotools/ml/layers/base.py b/audiotools/ml/layers/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..b82c96cdd7336ca6b8ed6fc7f0192d69a8e998dd
--- /dev/null
+++ b/audiotools/ml/layers/base.py
@@ -0,0 +1,328 @@
+import inspect
+import shutil
+import tempfile
+import typing
+from pathlib import Path
+
+import torch
+from torch import nn
+
+
+class BaseModel(nn.Module):
+ """This is a class that adds useful save/load functionality to a
+ ``torch.nn.Module`` object. ``BaseModel`` objects can be saved
+ as ``torch.package`` easily, making them super easy to port between
+ machines without requiring a ton of dependencies. Files can also be
+ saved as just weights, in the standard way.
+
+ >>> class Model(ml.BaseModel):
+ >>> def __init__(self, arg1: float = 1.0):
+ >>> super().__init__()
+ >>> self.arg1 = arg1
+ >>> self.linear = nn.Linear(1, 1)
+ >>>
+ >>> def forward(self, x):
+ >>> return self.linear(x)
+ >>>
+ >>> model1 = Model()
+ >>>
+ >>> with tempfile.NamedTemporaryFile(suffix=".pth") as f:
+ >>> model1.save(
+ >>> f.name,
+ >>> )
+ >>> model2 = Model.load(f.name)
+ >>> out2 = seed_and_run(model2, x)
+ >>> assert torch.allclose(out1, out2)
+ >>>
+ >>> model1.save(f.name, package=True)
+ >>> model2 = Model.load(f.name)
+ >>> model2.save(f.name, package=False)
+ >>> model3 = Model.load(f.name)
+ >>> out3 = seed_and_run(model3, x)
+ >>>
+ >>> with tempfile.TemporaryDirectory() as d:
+ >>> model1.save_to_folder(d, {"data": 1.0})
+ >>> Model.load_from_folder(d)
+
+ """
+
+ EXTERN = [
+ "audiotools.**",
+ "tqdm",
+ "__main__",
+ "numpy.**",
+ "julius.**",
+ "torchaudio.**",
+ "scipy.**",
+ "einops",
+ ]
+ """Names of libraries that are external to the torch.package saving mechanism.
+ Source code from these libraries will not be packaged into the model. This can
+ be edited by the user of this class by editing ``model.EXTERN``."""
+ INTERN = []
+ """Names of libraries that are internal to the torch.package saving mechanism.
+ Source code from these libraries will be saved alongside the model."""
+
+ def save(
+ self,
+ path: str,
+ metadata: dict = None,
+ package: bool = True,
+ intern: list = [],
+ extern: list = [],
+ mock: list = [],
+ ):
+ """Saves the model, either as a torch package, or just as
+ weights, alongside some specified metadata.
+
+ Parameters
+ ----------
+ path : str
+ Path to save model to.
+ metadata : dict, optional
+ Any metadata to save alongside the model,
+ by default None
+ package : bool, optional
+ Whether to use ``torch.package`` to save the model in
+ a format that is portable, by default True
+ intern : list, optional
+ List of additional libraries that are internal
+ to the model, used with torch.package, by default []
+ extern : list, optional
+ List of additional libraries that are external to
+ the model, used with torch.package, by default []
+ mock : list, optional
+ List of libraries to mock, used with torch.package,
+ by default []
+
+ Returns
+ -------
+ str
+ Path to saved model.
+ """
+ sig = inspect.signature(self.__class__)
+ args = {}
+
+ for key, val in sig.parameters.items():
+ arg_val = val.default
+ if arg_val is not inspect.Parameter.empty:
+ args[key] = arg_val
+
+ # Look up attibutes in self, and if any of them are in args,
+ # overwrite them in args.
+ for attribute in dir(self):
+ if attribute in args:
+ args[attribute] = getattr(self, attribute)
+
+ metadata = {} if metadata is None else metadata
+ metadata["kwargs"] = args
+ if not hasattr(self, "metadata"):
+ self.metadata = {}
+ self.metadata.update(metadata)
+
+ if not package:
+ state_dict = {"state_dict": self.state_dict(), "metadata": metadata}
+ torch.save(state_dict, path)
+ else:
+ self._save_package(path, intern=intern, extern=extern, mock=mock)
+
+ return path
+
+ @property
+ def device(self):
+ """Gets the device the model is on by looking at the device of
+ the first parameter. May not be valid if model is split across
+ multiple devices.
+ """
+ return list(self.parameters())[0].device
+
+ @classmethod
+ def load(
+ cls,
+ location: str,
+ *args,
+ package_name: str = None,
+ strict: bool = False,
+ **kwargs,
+ ):
+ """Load model from a path. Tries first to load as a package, and if
+ that fails, tries to load as weights. The arguments to the class are
+ specified inside the model weights file.
+
+ Parameters
+ ----------
+ location : str
+ Path to file.
+ package_name : str, optional
+ Name of package, by default ``cls.__name__``.
+ strict : bool, optional
+ Ignore unmatched keys, by default False
+ kwargs : dict
+ Additional keyword arguments to the model instantiation, if
+ not loading from package.
+
+ Returns
+ -------
+ BaseModel
+ A model that inherits from BaseModel.
+ """
+ try:
+ model = cls._load_package(location, package_name=package_name)
+ except:
+ model_dict = torch.load(location, "cpu")
+ metadata = model_dict["metadata"]
+ metadata["kwargs"].update(kwargs)
+
+ sig = inspect.signature(cls)
+ class_keys = list(sig.parameters.keys())
+ for k in list(metadata["kwargs"].keys()):
+ if k not in class_keys:
+ metadata["kwargs"].pop(k)
+
+ model = cls(*args, **metadata["kwargs"])
+ model.load_state_dict(model_dict["state_dict"], strict=strict)
+ model.metadata = metadata
+
+ return model
+
+ def _save_package(self, path, intern=[], extern=[], mock=[], **kwargs):
+ package_name = type(self).__name__
+ resource_name = f"{type(self).__name__}.pth"
+
+ # Below is for loading and re-saving a package.
+ if hasattr(self, "importer"):
+ kwargs["importer"] = (self.importer, torch.package.sys_importer)
+ del self.importer
+
+ # Why do we use a tempfile, you ask?
+ # It's so we can load a packaged model and then re-save
+ # it to the same location. torch.package throws an
+ # error if it's loading and writing to the same
+ # file (this is undocumented).
+ with tempfile.NamedTemporaryFile(suffix=".pth") as f:
+ with torch.package.PackageExporter(f.name, **kwargs) as exp:
+ exp.intern(self.INTERN + intern)
+ exp.mock(mock)
+ exp.extern(self.EXTERN + extern)
+ exp.save_pickle(package_name, resource_name, self)
+
+ if hasattr(self, "metadata"):
+ exp.save_pickle(
+ package_name, f"{package_name}.metadata", self.metadata
+ )
+
+ shutil.copyfile(f.name, path)
+
+ # Must reset the importer back to `self` if it existed
+ # so that you can save the model again!
+ if "importer" in kwargs:
+ self.importer = kwargs["importer"][0]
+ return path
+
+ @classmethod
+ def _load_package(cls, path, package_name=None):
+ package_name = cls.__name__ if package_name is None else package_name
+ resource_name = f"{package_name}.pth"
+
+ imp = torch.package.PackageImporter(path)
+ model = imp.load_pickle(package_name, resource_name, "cpu")
+ try:
+ model.metadata = imp.load_pickle(package_name, f"{package_name}.metadata")
+ except: # pragma: no cover
+ pass
+ model.importer = imp
+
+ return model
+
+ def save_to_folder(
+ self,
+ folder: typing.Union[str, Path],
+ extra_data: dict = None,
+ package: bool = True,
+ ):
+ """Dumps a model into a folder, as both a package
+ and as weights, as well as anything specified in
+ ``extra_data``. ``extra_data`` is a dictionary of other
+ pickleable files, with the keys being the paths
+ to save them in. The model is saved under a subfolder
+ specified by the name of the class (e.g. ``folder/generator/[package, weights].pth``
+ if the model name was ``Generator``).
+
+ >>> with tempfile.TemporaryDirectory() as d:
+ >>> extra_data = {
+ >>> "optimizer.pth": optimizer.state_dict()
+ >>> }
+ >>> model.save_to_folder(d, extra_data)
+ >>> Model.load_from_folder(d)
+
+ Parameters
+ ----------
+ folder : typing.Union[str, Path]
+ _description_
+ extra_data : dict, optional
+ _description_, by default None
+
+ Returns
+ -------
+ str
+ Path to folder
+ """
+ extra_data = {} if extra_data is None else extra_data
+ model_name = type(self).__name__.lower()
+ target_base = Path(f"{folder}/{model_name}/")
+ target_base.mkdir(exist_ok=True, parents=True)
+
+ if package:
+ package_path = target_base / f"package.pth"
+ self.save(package_path)
+
+ weights_path = target_base / f"weights.pth"
+ self.save(weights_path, package=False)
+
+ for path, obj in extra_data.items():
+ torch.save(obj, target_base / path)
+
+ return target_base
+
+ @classmethod
+ def load_from_folder(
+ cls,
+ folder: typing.Union[str, Path],
+ package: bool = True,
+ strict: bool = False,
+ **kwargs,
+ ):
+ """Loads the model from a folder generated by
+ :py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`.
+ Like that function, this one looks for a subfolder that has
+ the name of the class (e.g. ``folder/generator/[package, weights].pth`` if the
+ model name was ``Generator``).
+
+ Parameters
+ ----------
+ folder : typing.Union[str, Path]
+ _description_
+ package : bool, optional
+ Whether to use ``torch.package`` to load the model,
+ loading the model from ``package.pth``.
+ strict : bool, optional
+ Ignore unmatched keys, by default False
+
+ Returns
+ -------
+ tuple
+ tuple of model and extra data as saved by
+ :py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`.
+ """
+ folder = Path(folder) / cls.__name__.lower()
+ model_pth = "package.pth" if package else "weights.pth"
+ model_pth = folder / model_pth
+
+ model = cls.load(model_pth, strict=strict)
+ extra_data = {}
+ excluded = ["package.pth", "weights.pth"]
+ files = [x for x in folder.glob("*") if x.is_file() and x.name not in excluded]
+ for f in files:
+ extra_data[f.name] = torch.load(f, **kwargs)
+
+ return model, extra_data
diff --git a/audiotools/ml/layers/spectral_gate.py b/audiotools/ml/layers/spectral_gate.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4ae8b5eab2e56ce13541695f52a11a454759dae
--- /dev/null
+++ b/audiotools/ml/layers/spectral_gate.py
@@ -0,0 +1,127 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from ...core import AudioSignal
+from ...core import STFTParams
+from ...core import util
+
+
+class SpectralGate(nn.Module):
+ """Spectral gating algorithm for noise reduction,
+ as in Audacity/Ocenaudio. The steps are as follows:
+
+ 1. An FFT is calculated over the noise audio clip
+ 2. Statistics are calculated over FFT of the the noise
+ (in frequency)
+ 3. A threshold is calculated based upon the statistics
+ of the noise (and the desired sensitivity of the algorithm)
+ 4. An FFT is calculated over the signal
+ 5. A mask is determined by comparing the signal FFT to the
+ threshold
+ 6. The mask is smoothed with a filter over frequency and time
+ 7. The mask is appled to the FFT of the signal, and is inverted
+
+ Implementation inspired by Tim Sainburg's noisereduce:
+
+ https://timsainburg.com/noise-reduction-python.html
+
+ Parameters
+ ----------
+ n_freq : int, optional
+ Number of frequency bins to smooth by, by default 3
+ n_time : int, optional
+ Number of time bins to smooth by, by default 5
+ """
+
+ def __init__(self, n_freq: int = 3, n_time: int = 5):
+ super().__init__()
+
+ smoothing_filter = torch.outer(
+ torch.cat(
+ [
+ torch.linspace(0, 1, n_freq + 2)[:-1],
+ torch.linspace(1, 0, n_freq + 2),
+ ]
+ )[..., 1:-1],
+ torch.cat(
+ [
+ torch.linspace(0, 1, n_time + 2)[:-1],
+ torch.linspace(1, 0, n_time + 2),
+ ]
+ )[..., 1:-1],
+ )
+ smoothing_filter = smoothing_filter / smoothing_filter.sum()
+ smoothing_filter = smoothing_filter.unsqueeze(0).unsqueeze(0)
+ self.register_buffer("smoothing_filter", smoothing_filter)
+
+ def forward(
+ self,
+ audio_signal: AudioSignal,
+ nz_signal: AudioSignal,
+ denoise_amount: float = 1.0,
+ n_std: float = 3.0,
+ win_length: int = 2048,
+ hop_length: int = 512,
+ ):
+ """Perform noise reduction.
+
+ Parameters
+ ----------
+ audio_signal : AudioSignal
+ Audio signal that noise will be removed from.
+ nz_signal : AudioSignal, optional
+ Noise signal to compute noise statistics from.
+ denoise_amount : float, optional
+ Amount to denoise by, by default 1.0
+ n_std : float, optional
+ Number of standard deviations above which to consider
+ noise, by default 3.0
+ win_length : int, optional
+ Length of window for STFT, by default 2048
+ hop_length : int, optional
+ Hop length for STFT, by default 512
+
+ Returns
+ -------
+ AudioSignal
+ Denoised audio signal.
+ """
+ stft_params = STFTParams(win_length, hop_length, "sqrt_hann")
+
+ audio_signal = audio_signal.clone()
+ audio_signal.stft_data = None
+ audio_signal.stft_params = stft_params
+
+ nz_signal = nz_signal.clone()
+ nz_signal.stft_params = stft_params
+
+ nz_stft_db = 20 * nz_signal.magnitude.clamp(1e-4).log10()
+ nz_freq_mean = nz_stft_db.mean(keepdim=True, dim=-1)
+ nz_freq_std = nz_stft_db.std(keepdim=True, dim=-1)
+
+ nz_thresh = nz_freq_mean + nz_freq_std * n_std
+
+ stft_db = 20 * audio_signal.magnitude.clamp(1e-4).log10()
+ nb, nac, nf, nt = stft_db.shape
+ db_thresh = nz_thresh.expand(nb, nac, -1, nt)
+
+ stft_mask = (stft_db < db_thresh).float()
+ shape = stft_mask.shape
+
+ stft_mask = stft_mask.reshape(nb * nac, 1, nf, nt)
+ pad_tuple = (
+ self.smoothing_filter.shape[-2] // 2,
+ self.smoothing_filter.shape[-1] // 2,
+ )
+ stft_mask = F.conv2d(stft_mask, self.smoothing_filter, padding=pad_tuple)
+ stft_mask = stft_mask.reshape(*shape)
+ stft_mask *= util.ensure_tensor(denoise_amount, ndim=stft_mask.ndim).to(
+ audio_signal.device
+ )
+ stft_mask = 1 - stft_mask
+
+ audio_signal.stft_data *= stft_mask
+ audio_signal.istft()
+
+ return audio_signal
diff --git a/audiotools/post.py b/audiotools/post.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ced2d1e66a4ffda3269685bd45593b01038739f
--- /dev/null
+++ b/audiotools/post.py
@@ -0,0 +1,140 @@
+import tempfile
+import typing
+import zipfile
+from pathlib import Path
+
+import markdown2 as md
+import matplotlib.pyplot as plt
+import torch
+from IPython.display import HTML
+
+
+def audio_table(
+ audio_dict: dict,
+ first_column: str = None,
+ format_fn: typing.Callable = None,
+ **kwargs,
+): # pragma: no cover
+ """Embeds an audio table into HTML, or as the output cell
+ in a notebook.
+
+ Parameters
+ ----------
+ audio_dict : dict
+ Dictionary of data to embed.
+ first_column : str, optional
+ The label for the first column of the table, by default None
+ format_fn : typing.Callable, optional
+ How to format the data, by default None
+
+ Returns
+ -------
+ str
+ Table as a string
+
+ Examples
+ --------
+
+ >>> audio_dict = {}
+ >>> for i in range(signal_batch.batch_size):
+ >>> audio_dict[i] = {
+ >>> "input": signal_batch[i],
+ >>> "output": output_batch[i]
+ >>> }
+ >>> audiotools.post.audio_zip(audio_dict)
+
+ """
+ from audiotools import AudioSignal
+
+ output = []
+ columns = None
+
+ def _default_format_fn(label, x, **kwargs):
+ if torch.is_tensor(x):
+ x = x.tolist()
+
+ if x is None:
+ return "."
+ elif isinstance(x, AudioSignal):
+ return x.embed(display=False, return_html=True, **kwargs)
+ else:
+ return str(x)
+
+ if format_fn is None:
+ format_fn = _default_format_fn
+
+ if first_column is None:
+ first_column = "."
+
+ for k, v in audio_dict.items():
+ if not isinstance(v, dict):
+ v = {"Audio": v}
+
+ v_keys = list(v.keys())
+ if columns is None:
+ columns = [first_column] + v_keys
+ output.append(" | ".join(columns))
+
+ layout = "|---" + len(v_keys) * "|:-:"
+ output.append(layout)
+
+ formatted_audio = []
+ for col in columns[1:]:
+ formatted_audio.append(format_fn(col, v[col], **kwargs))
+
+ row = f"| {k} | "
+ row += " | ".join(formatted_audio)
+ output.append(row)
+
+ output = "\n" + "\n".join(output)
+ return output
+
+
+def in_notebook(): # pragma: no cover
+ """Determines if code is running in a notebook.
+
+ Returns
+ -------
+ bool
+ Whether or not this is running in a notebook.
+ """
+ try:
+ from IPython import get_ipython
+
+ if "IPKernelApp" not in get_ipython().config: # pragma: no cover
+ return False
+ except ImportError:
+ return False
+ except AttributeError:
+ return False
+ return True
+
+
+def disp(obj, **kwargs): # pragma: no cover
+ """Displays an object, depending on if its in a notebook
+ or not.
+
+ Parameters
+ ----------
+ obj : typing.Any
+ Any object to display.
+
+ """
+ from audiotools import AudioSignal
+
+ IN_NOTEBOOK = in_notebook()
+
+ if isinstance(obj, AudioSignal):
+ audio_elem = obj.embed(display=False, return_html=True)
+ if IN_NOTEBOOK:
+ return HTML(audio_elem)
+ else:
+ print(audio_elem)
+ if isinstance(obj, dict):
+ table = audio_table(obj, **kwargs)
+ if IN_NOTEBOOK:
+ return HTML(md.markdown(table, extras=["tables"]))
+ else:
+ print(table)
+ if isinstance(obj, plt.Figure):
+ plt.show()
diff --git a/audiotools/preference.py b/audiotools/preference.py
new file mode 100644
index 0000000000000000000000000000000000000000..800a852e8119dd18ea65784cf95182de2470fbc4
--- /dev/null
+++ b/audiotools/preference.py
@@ -0,0 +1,600 @@
+##############################################################
+### Tools for creating preference tests (MUSHRA, ABX, etc) ###
+##############################################################
+import copy
+import csv
+import random
+import sys
+import traceback
+from collections import defaultdict
+from pathlib import Path
+from typing import List
+
+import gradio as gr
+
+from audiotools.core.util import find_audio
+
+################################################################
+### Logic for audio player, and adding audio / play buttons. ###
+################################################################
+
+WAVESURFER = """"""
+
+CUSTOM_CSS = """
+.gradio-container {
+ max-width: 840px !important;
+}
+region.wavesurfer-region:before {
+ content: attr(data-region-label);
+}
+
+block {
+ min-width: 0 !important;
+}
+
+#wave-timeline {
+ background-color: rgba(0, 0, 0, 0.8);
+}
+
+.head.svelte-1cl284s {
+ display: none;
+}
+"""
+
+load_wavesurfer_js = """
+function load_wavesurfer() {
+ function load_script(url) {
+ const script = document.createElement('script');
+ script.src = url;
+ document.body.appendChild(script);
+
+ return new Promise((res, rej) => {
+ script.onload = function() {
+ res();
+ }
+ script.onerror = function () {
+ rej();
+ }
+ });
+ }
+
+ function create_wavesurfer() {
+ var options = {
+ container: '#waveform',
+ waveColor: '#F2F2F2', // Set a darker wave color
+ progressColor: 'white', // Set a slightly lighter progress color
+ loaderColor: 'white', // Set a slightly lighter loader color
+ cursorColor: 'black', // Set a slightly lighter cursor color
+ backgroundColor: '#00AAFF', // Set a black background color
+ barWidth: 4,
+ barRadius: 3,
+ barHeight: 1, // the height of the wave
+ plugins: [
+ WaveSurfer.regions.create({
+ regionsMinLength: 0.0,
+ dragSelection: {
+ slop: 5
+ },
+ color: 'hsla(200, 50%, 70%, 0.4)',
+ }),
+ WaveSurfer.timeline.create({
+ container: "#wave-timeline",
+ primaryLabelInterval: 5.0,
+ secondaryLabelInterval: 1.0,
+ primaryFontColor: '#F2F2F2',
+ secondaryFontColor: '#F2F2F2',
+ }),
+ ]
+ };
+ wavesurfer = WaveSurfer.create(options);
+ wavesurfer.on('region-created', region => {
+ wavesurfer.regions.clear();
+ });
+ wavesurfer.on('finish', function () {
+ var loop = document.getElementById("loop-button").textContent.includes("ON");
+ if (loop) {
+ wavesurfer.play();
+ }
+ else {
+ var button_elements = document.getElementsByClassName('playpause')
+ var buttons = Array.from(button_elements);
+
+ for (let j = 0; j < buttons.length; j++) {
+ buttons[j].classList.remove("primary");
+ buttons[j].classList.add("secondary");
+ buttons[j].textContent = buttons[j].textContent.replace("Stop", "Play")
+ }
+ }
+ });
+
+ wavesurfer.on('region-out', function () {
+ var loop = document.getElementById("loop-button").textContent.includes("ON");
+ if (!loop) {
+ var button_elements = document.getElementsByClassName('playpause')
+ var buttons = Array.from(button_elements);
+
+ for (let j = 0; j < buttons.length; j++) {
+ buttons[j].classList.remove("primary");
+ buttons[j].classList.add("secondary");
+ buttons[j].textContent = buttons[j].textContent.replace("Stop", "Play")
+ }
+ wavesurfer.pause();
+ }
+ });
+
+ console.log("Created WaveSurfer object.")
+ }
+
+ load_script('https://unpkg.com/wavesurfer.js@6.6.4')
+ .then(() => {
+ load_script("https://unpkg.com/wavesurfer.js@6.6.4/dist/plugin/wavesurfer.timeline.min.js")
+ .then(() => {
+ load_script('https://unpkg.com/wavesurfer.js@6.6.4/dist/plugin/wavesurfer.regions.min.js')
+ .then(() => {
+ console.log("Loaded regions");
+ create_wavesurfer();
+ document.getElementById("start-survey").click();
+ })
+ })
+ });
+}
+"""
+
+play = lambda i: """
+function play() {
+ var audio_elements = document.getElementsByTagName('audio');
+ var button_elements = document.getElementsByClassName('playpause')
+
+ var audio_array = Array.from(audio_elements);
+ var buttons = Array.from(button_elements);
+
+ var src_link = audio_array[{i}].getAttribute("src");
+ console.log(src_link);
+
+ var loop = document.getElementById("loop-button").textContent.includes("ON");
+ var playing = buttons[{i}].textContent.includes("Stop");
+
+ for (let j = 0; j < buttons.length; j++) {
+ if (j != {i} || playing) {
+ buttons[j].classList.remove("primary");
+ buttons[j].classList.add("secondary");
+ buttons[j].textContent = buttons[j].textContent.replace("Stop", "Play")
+ }
+ else {
+ buttons[j].classList.remove("secondary");
+ buttons[j].classList.add("primary");
+ buttons[j].textContent = buttons[j].textContent.replace("Play", "Stop")
+ }
+ }
+
+ if (playing) {
+ wavesurfer.pause();
+ wavesurfer.seekTo(0.0);
+ }
+ else {
+ wavesurfer.load(src_link);
+ wavesurfer.on('ready', function () {
+ var region = Object.values(wavesurfer.regions.list)[0];
+
+ if (region != null) {
+ region.loop = loop;
+ region.play();
+ } else {
+ wavesurfer.play();
+ }
+ });
+ }
+}
+""".replace(
+ "{i}", str(i)
+)
+
+clear_regions = """
+function clear_regions() {
+ wavesurfer.clearRegions();
+}
+"""
+
+reset_player = """
+function reset_player() {
+ wavesurfer.clearRegions();
+ wavesurfer.pause();
+ wavesurfer.seekTo(0.0);
+
+ var button_elements = document.getElementsByClassName('playpause')
+ var buttons = Array.from(button_elements);
+
+ for (let j = 0; j < buttons.length; j++) {
+ buttons[j].classList.remove("primary");
+ buttons[j].classList.add("secondary");
+ buttons[j].textContent = buttons[j].textContent.replace("Stop", "Play")
+ }
+}
+"""
+
+loop_region = """
+function loop_region() {
+ var element = document.getElementById("loop-button");
+ var loop = element.textContent.includes("OFF");
+ console.log(loop);
+
+ try {
+ var region = Object.values(wavesurfer.regions.list)[0];
+ region.loop = loop;
+ } catch {}
+
+ if (loop) {
+ element.classList.remove("secondary");
+ element.classList.add("primary");
+ element.textContent = "Looping ON";
+ } else {
+ element.classList.remove("primary");
+ element.classList.add("secondary");
+ element.textContent = "Looping OFF";
+ }
+}
+"""
+
+
+class Player:
+ def __init__(self, app):
+ self.app = app
+
+ self.app.load(_js=load_wavesurfer_js)
+ self.app.css = CUSTOM_CSS
+
+ self.wavs = []
+ self.position = 0
+
+ def create(self):
+ gr.HTML(WAVESURFER)
+ gr.Markdown(
+ "Click and drag on the waveform above to select a region for playback. "
+ "Once created, the region can be moved around and resized. "
+ "Clear the regions using the button below. Hit play on one of the buttons below to start!"
+ )
+
+ with gr.Row():
+ clear = gr.Button("Clear region")
+ loop = gr.Button("Looping OFF", elem_id="loop-button")
+
+ loop.click(None, _js=loop_region)
+ clear.click(None, _js=clear_regions)
+
+ gr.HTML("")
+
+ def add(self, name: str = "Play"):
+ i = self.position
+ self.wavs.append(
+ {
+ "audio": gr.Audio(visible=False),
+ "button": gr.Button(name, elem_classes=["playpause"]),
+ "position": i,
+ }
+ )
+ self.wavs[-1]["button"].click(None, _js=play(i))
+ self.position += 1
+ return self.wavs[-1]
+
+ def to_list(self):
+ return [x["audio"] for x in self.wavs]
+
+
+############################################################
+### Keeping track of users, and CSS for the progress bar ###
+############################################################
+
+load_tracker = lambda name: """
+function load_name() {
+ function setCookie(name, value, exp_days) {
+ var d = new Date();
+ d.setTime(d.getTime() + (exp_days*24*60*60*1000));
+ var expires = "expires=" + d.toGMTString();
+ document.cookie = name + "=" + value + ";" + expires + ";path=/";
+ }
+
+ function getCookie(name) {
+ var cname = name + "=";
+ var decodedCookie = decodeURIComponent(document.cookie);
+ var ca = decodedCookie.split(';');
+ for(var i = 0; i < ca.length; i++){
+ var c = ca[i];
+ while(c.charAt(0) == ' '){
+ c = c.substring(1);
+ }
+ if(c.indexOf(cname) == 0){
+ return c.substring(cname.length, c.length);
+ }
+ }
+ return "";
+ }
+
+ name = getCookie("{name}");
+ if (name == "") {
+ name = Math.random().toString(36).slice(2);
+ console.log(name);
+ setCookie("name", name, 30);
+ }
+ name = getCookie("{name}");
+ return name;
+}
+""".replace(
+ "{name}", name
+)
+
+# Progress bar
+
+progress_template = """
+
+
+
+ Progress Bar
+
+
+
+
+
+
{TEXT}
+
+
+
+"""
+
+
+def create_tracker(app, cookie_name="name"):
+ user = gr.Text(label="user", interactive=True, visible=False, elem_id="user")
+ app.load(_js=load_tracker(cookie_name), outputs=user)
+ return user
+
+
+#################################################################
+### CSS and HTML for labeling sliders for both ABX and MUSHRA ###
+#################################################################
+
+slider_abx = """
+
+
+
+
+ Labels Example
+
+
+
+