Spaces:
Running
on
T4
Running
on
T4
| import torch | |
| import typing as tp | |
| from itertools import chain | |
| from pathlib import Path | |
| from torch import nn | |
| from .conditioners import (ConditioningAttributes, BaseConditioner, ConditionType, | |
| ConditioningProvider, JascoCondConst, | |
| WaveformConditioner, WavCondition, SymbolicCondition) | |
| from ..data.audio import audio_read | |
| from ..data.audio_utils import convert_audio | |
| from ..utils.autocast import TorchAutocast | |
| from ..utils.cache import EmbeddingCache | |
| class MelodyConditioner(BaseConditioner): | |
| """ | |
| A conditioner that handles melody conditioning from pre-computed salience matrix. | |
| Attributes: | |
| card (int): The cardinality of the melody matrix. | |
| out_dim (int): The dimensionality of the output projection. | |
| device (Union[torch.device, str]): The device on which the embeddings are stored. | |
| """ | |
| def __init__(self, card: int, out_dim: int, device: tp.Union[torch.device, str] = 'cpu', **kwargs): | |
| super().__init__(dim=card, output_dim=out_dim) | |
| self.device = device | |
| def tokenize(self, x: SymbolicCondition) -> SymbolicCondition: | |
| return SymbolicCondition(melody=x.melody.to(self.device)) # type: ignore | |
| def forward(self, x: SymbolicCondition) -> ConditionType: | |
| embeds = self.output_proj(x.melody.permute(0, 2, 1)) # type: ignore | |
| mask = torch.ones_like(embeds[..., 0]) | |
| return embeds, mask | |
| class ChordsEmbConditioner(BaseConditioner): | |
| """ | |
| A conditioner that embeds chord symbols into a continuous vector space. | |
| Attributes: | |
| card (int): The cardinality of the chord vocabulary. | |
| out_dim (int): The dimensionality of the output embeddings. | |
| device (Union[torch.device, str]): The device on which the embeddings are stored. | |
| """ | |
| def __init__(self, card: int, out_dim: int, device: tp.Union[torch.device, str] = 'cpu', **kwargs): | |
| vocab_size = card + 1 # card + 1 - for null chord used during dropout | |
| super().__init__(dim=vocab_size, output_dim=-1) # out_dim=-1 to avoid another projection | |
| self.emb = nn.Embedding(vocab_size, out_dim, device=device) | |
| self.device = device | |
| def tokenize(self, x: SymbolicCondition) -> SymbolicCondition: | |
| return SymbolicCondition(frame_chords=x.frame_chords.to(self.device)) # type: ignore | |
| def forward(self, x: SymbolicCondition) -> ConditionType: | |
| embeds = self.emb(x.frame_chords) | |
| mask = torch.ones_like(embeds[..., 0]) | |
| return embeds, mask | |
| class DrumsConditioner(WaveformConditioner): | |
| def __init__(self, out_dim: int, sample_rate: int, blurring_factor: int = 3, | |
| cache_path: tp.Optional[tp.Union[str, Path]] = None, | |
| compression_model_latent_dim: int = 128, | |
| compression_model_framerate: float = 50, | |
| segment_duration: float = 10.0, | |
| device: tp.Union[torch.device, str] = 'cpu', | |
| **kwargs): | |
| """Drum condition conditioner | |
| Args: | |
| out_dim (int): _description_ | |
| sample_rate (int): _description_ | |
| blurring_factor (int, optional): _description_. Defaults to 3. | |
| cache_path (tp.Optional[tp.Union[str, Path]], optional): path to precomputed cache. Defaults to None. | |
| compression_model_latent_dim (int, optional): latent dimensino. Defaults to 128. | |
| compression_model_framerate (float, optional): frame rate of the representation model. Defaults to 50. | |
| segment_duration (float, optional): duration in sec for each audio segment. Defaults to 10.0. | |
| device (tp.Union[torch.device, str], optional): device. Defaults to 'cpu'. | |
| """ | |
| from demucs import pretrained | |
| self.sample_rate = sample_rate | |
| self.__dict__['demucs'] = pretrained.get_model('htdemucs').to(device) | |
| stem_sources: list = self.demucs.sources # type: ignore | |
| self.stem_idx = stem_sources.index('drums') | |
| self.compression_model = None | |
| self.latent_dim = compression_model_latent_dim | |
| super().__init__(dim=self.latent_dim, output_dim=out_dim, device=device) | |
| self.autocast = TorchAutocast(enabled=device != 'cpu', device_type=self.device, dtype=torch.float32) | |
| self._use_masking = False | |
| self.blurring_factor = blurring_factor | |
| self.seq_len = int(segment_duration * compression_model_framerate) | |
| self.cache = None # If you wish to train with EmbeddingCache, call self.create_embedding_cache(cache_path) | |
| def create_embedding_cache(self, cache_path): | |
| if cache_path is not None: | |
| self.cache = EmbeddingCache(Path(cache_path) / 'wav', self.device, | |
| compute_embed_fn=self._calc_coarse_drum_codes_for_cache, | |
| extract_embed_fn=self._load_drum_codes_chunk) | |
| def _get_drums_stem(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor: | |
| """Get parts of the wav that holds the drums, extracting the main stems from the wav.""" | |
| from demucs.apply import apply_model | |
| from demucs.audio import convert_audio | |
| with self.autocast: | |
| wav = convert_audio( | |
| wav, sample_rate, self.demucs.samplerate, self.demucs.audio_channels) # type: ignore | |
| stems = apply_model(self.demucs, wav, device=self.device) | |
| drum_stem = stems[:, self.stem_idx] # extract relevant stems for drums conditioning | |
| return convert_audio(drum_stem, self.demucs.samplerate, self.sample_rate, 1) # type: ignore | |
| def _temporal_blur(self, z: torch.Tensor): | |
| # z: (B, T, C) | |
| B, T, C = z.shape | |
| if T % self.blurring_factor != 0: | |
| # pad with reflect for T % self.temporal_blurring on the right in dim=1 | |
| pad_val = self.blurring_factor - T % self.blurring_factor | |
| z = torch.nn.functional.pad(z, (0, 0, 0, pad_val), mode='reflect') | |
| z = z.reshape(B, -1, self.blurring_factor, C).sum(dim=2) / self.blurring_factor | |
| z = z.unsqueeze(2).repeat(1, 1, self.blurring_factor, 1).reshape(B, -1, C) | |
| z = z[:, :T] | |
| assert z.shape == (B, T, C) | |
| return z | |
| def _extract_coarse_drum_codes(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor: | |
| assert self.compression_model is not None | |
| # stem separation of drums | |
| drums = self._get_drums_stem(wav, sample_rate) | |
| # continuous encoding with compression model | |
| latents = self.compression_model.model.encoder(drums) | |
| # quantization to coarsest codebook | |
| coarsest_quantizer = self.compression_model.model.quantizer.layers[0] | |
| drums = coarsest_quantizer.encode(latents).to(torch.int16) | |
| return drums | |
| def _calc_coarse_drum_codes_for_cache(self, path: tp.Union[str, Path], | |
| x: WavCondition, idx: int, | |
| max_duration_to_process: float = 600) -> torch.Tensor: | |
| """Extract blurred drum latents from the whole audio waveform at the given path.""" | |
| wav, sr = audio_read(path) | |
| wav = wav[None].to(self.device) | |
| wav = convert_audio(wav, sr, self.sample_rate, to_channels=1) | |
| max_frames_to_process = int(max_duration_to_process * self.sample_rate) | |
| if wav.shape[-1] > max_frames_to_process: | |
| # process very long tracks in chunks | |
| start = 0 | |
| codes = [] | |
| while start < wav.shape[-1] - 1: | |
| wav_chunk = wav[..., start: start + max_frames_to_process] | |
| codes.append(self._extract_coarse_drum_codes(wav_chunk, self.sample_rate)[0]) | |
| start += max_frames_to_process | |
| return torch.cat(codes) | |
| return self._extract_coarse_drum_codes(wav, self.sample_rate)[0] | |
| def _load_drum_codes_chunk(self, full_coarse_drum_codes: torch.Tensor, x: WavCondition, idx: int) -> torch.Tensor: | |
| """Extract a chunk of coarse drum codes from the full coarse drum codes derived from the full waveform.""" | |
| wav_length = x.wav.shape[-1] | |
| seek_time = x.seek_time[idx] | |
| assert seek_time is not None, ( | |
| "WavCondition seek_time is required " | |
| "when extracting chunks from pre-computed drum codes.") | |
| assert self.compression_model is not None | |
| frame_rate = self.compression_model.frame_rate | |
| target_length = int(frame_rate * wav_length / self.sample_rate) | |
| target_length = max(target_length, self.seq_len) | |
| index = int(frame_rate * seek_time) | |
| out = full_coarse_drum_codes[index: index + target_length] | |
| # pad | |
| out = torch.cat((out, torch.zeros(target_length - out.shape[0], dtype=out.dtype, device=out.device))) | |
| return out.to(self.device) | |
| def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor: | |
| bs = x.wav.shape[0] | |
| if x.wav.shape[-1] <= 1: | |
| # null condition | |
| return torch.zeros((bs, self.seq_len, self.latent_dim), device=x.wav.device, dtype=x.wav.dtype) | |
| # extract coarse drum codes | |
| no_undefined_paths = all(p is not None for p in x.path) | |
| no_nullified_cond = x.wav.shape[-1] > 1 | |
| if self.cache is not None and no_undefined_paths and no_nullified_cond: | |
| paths = [Path(p) for p in x.path if p is not None] | |
| codes = self.cache.get_embed_from_cache(paths, x) | |
| else: | |
| assert all(sr == x.sample_rate[0] for sr in x.sample_rate), "All sample rates in batch should be equal." | |
| codes = self._extract_coarse_drum_codes(x.wav, x.sample_rate[0]) | |
| assert self.compression_model is not None | |
| # decode back to the continuous representation of compression model | |
| codes = codes.unsqueeze(1).permute(1, 0, 2) # (B, T) -> (1, B, T) | |
| codes = codes.to(torch.int64) | |
| latents = self.compression_model.model.quantizer.decode(codes) | |
| latents = latents.permute(0, 2, 1) # [B, C, T] -> [B, T, C] | |
| # temporal blurring | |
| return self._temporal_blur(latents) | |
| def tokenize(self, x: WavCondition) -> WavCondition: | |
| """Apply WavConditioner tokenization and populate cache if needed.""" | |
| x = super().tokenize(x) | |
| no_undefined_paths = all(p is not None for p in x.path) | |
| if self.cache is not None and no_undefined_paths: | |
| paths = [Path(p) for p in x.path if p is not None] | |
| self.cache.populate_embed_cache(paths, x) | |
| return x | |
| class JascoConditioningProvider(ConditioningProvider): | |
| """ | |
| A cond-provider that manages and tokenizes various types of conditioning attributes for Jasco models. | |
| Attributes: | |
| chords_card (int): The cardinality of the chord vocabulary. | |
| sequence_length (int): The length of the sequence for padding purposes. | |
| melody_dim (int): The dimensionality of the melody matrix. | |
| """ | |
| def __init__(self, *args, | |
| chords_card: int = 194, | |
| sequence_length: int = 500, | |
| melody_dim: int = 53, **kwargs): | |
| self.null_chord = chords_card | |
| self.sequence_len = sequence_length | |
| self.melody_dim = melody_dim | |
| super().__init__(*args, **kwargs) | |
| def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]: | |
| """Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly. | |
| This should be called before starting any real GPU work to avoid synchronization points. | |
| This will return a dict matching conditioner names to their arbitrary tokenized representations. | |
| Args: | |
| inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing | |
| text and wav conditions. | |
| """ | |
| assert all([isinstance(x, ConditioningAttributes) for x in inputs]), ( | |
| "Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]", | |
| f" but types were {set([type(x) for x in inputs])}" | |
| ) | |
| output = {} | |
| text = self._collate_text(inputs) | |
| wavs = self._collate_wavs(inputs) | |
| symbolic = self._collate_symbolic(inputs, self.conditioners.keys()) | |
| assert set(text.keys() | wavs.keys() | symbolic.keys()).issubset(set(self.conditioners.keys())), ( | |
| f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ", | |
| f"got {text.keys(), wavs.keys(), symbolic.keys()}" | |
| ) | |
| for attribute, batch in chain(text.items(), wavs.items(), symbolic.items()): | |
| output[attribute] = self.conditioners[attribute].tokenize(batch) | |
| return output | |
| def _collate_symbolic(self, samples: tp.List[ConditioningAttributes], | |
| conditioner_keys: tp.Set) -> tp.Dict[str, SymbolicCondition]: | |
| output = {} | |
| # collate if symbolic cond exists | |
| if any(x in conditioner_keys for x in JascoCondConst.SYM.value): | |
| for s in samples: | |
| # hydrate with null chord if chords not exist - for inference support | |
| if (s.symbolic == {} or | |
| s.symbolic[JascoCondConst.CRD.value].frame_chords is None or | |
| s.symbolic[JascoCondConst.CRD.value].frame_chords.shape[-1] <= 1): # type: ignore | |
| # no chords conditioning - fill with null chord token | |
| s.symbolic[JascoCondConst.CRD.value] = SymbolicCondition( | |
| frame_chords=torch.ones(self.sequence_len, dtype=torch.int32) * self.null_chord) | |
| if (s.symbolic == {} or | |
| s.symbolic[JascoCondConst.MLD.value].melody is None or | |
| s.symbolic[JascoCondConst.MLD.value].melody.shape[-1] <= 1): # type: ignore | |
| # no chords conditioning - fill with null chord token | |
| s.symbolic[JascoCondConst.MLD.value] = SymbolicCondition( | |
| melody=torch.zeros((self.melody_dim, self.sequence_len))) | |
| if JascoCondConst.CRD.value in conditioner_keys: | |
| # pad to max | |
| max_seq_len = max( | |
| [s.symbolic[JascoCondConst.CRD.value].frame_chords.shape[-1] for s in samples]) # type: ignore | |
| padded_chords = [ | |
| torch.cat((x.symbolic[JascoCondConst.CRD.value].frame_chords, # type: ignore | |
| torch.ones(max_seq_len - | |
| x.symbolic[JascoCondConst.CRD.value].frame_chords.shape[-1], # type: ignore | |
| dtype=torch.int32) * self.null_chord)) | |
| for x in samples | |
| ] | |
| output[JascoCondConst.CRD.value] = SymbolicCondition(frame_chords=torch.stack(padded_chords)) | |
| if JascoCondConst.MLD.value in conditioner_keys: | |
| melodies = torch.stack([x.symbolic[JascoCondConst.MLD.value].melody for x in samples]) # type: ignore | |
| output[JascoCondConst.MLD.value] = SymbolicCondition(melody=melodies) | |
| return output | |