''' Modified from the vit_pytorch library: https://github.com/lucidrains/vit-pytorch ''' from einops import rearrange from einops.layers.torch import Rearrange import json import math from nnAudio.features.mel import MelSpectrogram import os import torch from torch import nn import torchaudio import torchaudio.transforms as T # for uploading to huggingface hub from huggingface_hub import HfApi, PyTorchModelHubMixin from transformers import PretrainedConfig, PreTrainedModel import shutil def pair(t): return t if isinstance(t, (tuple, list)) else (t, t) def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32): y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" omega = torch.arange(dim // 4) / (dim // 4 - 1) omega = 1.0 / (temperature ** omega) y = y.flatten()[:, None] * omega[None, :] x = x.flatten()[:, None] * omega[None, :] pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) return pe.type(dtype) def load_model(model: nn.Module, checkpoint_path: str, device: str = 'cpu', ignore_layers: list = ['linear_head'], verbose: bool = False): checkpoint = torch.load(checkpoint_path, map_location=device) filtered_state_dict = { k: v for k, v in checkpoint.items() if not any(k.startswith(layer) for layer in ignore_layers) } model.load_state_dict(filtered_state_dict, strict=False) if ignore_layers and verbose: print(f'==> Loaded model from {checkpoint_path}, ignoring layers: {", ".join(ignore_layers)}') class FeedForward(nn.Module): def __init__(self, dim, hidden_dim): super().__init__() self.net = nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, dim), ) def forward(self, x): return self.net(x) class Attention(nn.Module): def __init__(self, dim, heads = 8, dim_head = 64): super().__init__() inner_dim = dim_head * heads self.heads = heads self.scale = dim_head ** -0.5 self.norm = nn.LayerNorm(dim) self.attend = nn.Softmax(dim = -1) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) self.to_out = nn.Linear(inner_dim, dim, bias = False) def forward(self, x): x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale attn = self.attend(dots) out = torch.matmul(attn, v) out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) class Transformer(nn.Module): def __init__(self, dim, depth, heads, dim_head, mlp_dim): super().__init__() self.norm = nn.LayerNorm(dim) self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ Attention(dim, heads = heads, dim_head = dim_head), FeedForward(dim, mlp_dim) ])) def forward(self, x): for attn, ff in self.layers: x = attn(x) + x x = ff(x) + x return self.norm(x) class MynaPreprocessor: def __init__(self, target_sr: int = 16000, n_mels: int = 128): self.target_sr = target_sr self.n_mels = n_mels self.mel_spec = MelSpectrogram(sr=target_sr, n_mels=n_mels, verbose=False) def __call__(self, filename: str, n_frames: int = None): # loads audio from file and returns a 3D tensor (B, n_mels, n_frames) signal, sr = torchaudio.load(filename) if signal.shape[0] > 1: signal = signal.mean(dim=0, keepdim=True) if sr != self.target_sr: resampler = T.Resample(orig_freq=sr, new_freq=self.target_sr) signal = resampler(signal) ms = self.mel_spec(signal) if n_frames: ms = self._batch_spectrogram(ms, n_frames) return ms def _batch_spectrogram(self, ms: torch.Tensor, n_frames: int): # sanity check assert ms.dim() == 3 and ms.shape[0] == 1 # discard excess frames num_chunks = ms.shape[-1] // n_frames ms = ms[:, :, :num_chunks * n_frames] # split the tensor into chunks and stack them chunks = torch.chunk(ms, num_chunks, dim=2) batch = torch.stack(chunks) return batch class MynaConfig(PretrainedConfig): model_type = 'myna' def __init__( self, spec_size=(128, 4096), patch_size=16, dim=384, depth=12, heads=6, mlp_dim=1536, dim_head = 64, arch=None, additional_patch_size = None, hybrid_mode: bool = False, n_samples = 50000, sr = 16000, **kwargs ): super().__init__(**kwargs) self.spec_size = spec_size self.patch_size = patch_size self.dim = dim self.depth = depth self.heads = heads self.mlp_dim = mlp_dim self.dim_head = dim_head self.arch = arch self.additional_patch_size = additional_patch_size self.hybrid_mode = hybrid_mode self.n_samples = n_samples # number of samples for inference self.sr = sr # for preprocessing self.n_frames = self._get_n_frames(n_samples) # load architecture if provided if arch: arch = self._get_arch(arch) self.dim = arch['dim'] self.depth = arch['depth'] self.heads = arch['heads'] self.mlp_dim = arch['mlp_dim'] def _get_arch(self, arch: str): if arch.lower() in ['vit-s-16', 'vit-s-32']: # dim 384, depth 12, MLP 1536, 6 heads, 22M parameters return {'dim': 384, 'depth': 12, 'mlp_dim': 1536, 'heads': 6} if arch.lower() == 'vit-b-16': # dim 768, depth 12, MLP 3072, 12 heads, 87M parameters return {'dim': 768, 'depth': 12, 'mlp_dim': 3072, 'heads': 12} if arch.lower() == 'vit-l-16': # dim 1024, depth 24, MLP 4096, 16 heads, 303M parameters return {'dim': 1024, 'depth': 24, 'mlp_dim': 4096, 'heads': 16} raise ValueError(f'Architecture {arch} not implemented') def _get_n_frames(self, n_samples: int): ''' How many frames is n_samples samples? ''' mel_spectrogram = MelSpectrogram(sr=self.sr, n_mels=self.spec_size[0], verbose=False) patch_size_time = self.patch_size if isinstance(self.patch_size, int) else self.patch_size[1] mel_frames = mel_spectrogram(torch.randn(1, 1, n_samples)).shape[-1] mel_frames = math.floor(mel_frames / patch_size_time) * patch_size_time return mel_frames class Myna(PreTrainedModel, PyTorchModelHubMixin): config_class = MynaConfig def __init__(self, config: MynaConfig): super().__init__(config) self.preprocessor = MynaPreprocessor() self.hybrid_mode = config.hybrid_mode spec_height, spec_width = pair(config.spec_size) patch_height, patch_width = pair(config.patch_size) assert spec_height % patch_height == 0 and spec_width % patch_width == 0, 'Spectrogram dimensions must be divisible by the patch size.' self.additional_patch_size = config.additional_patch_size if config.additional_patch_size: patch_height_b, patch_width_b = pair(config.additional_patch_size) patch_dim_b = patch_height_b * patch_width_b self.to_patch_embedding_b, self.pos_embedding_b = self._make_embeddings( patch_height_b, patch_width_b, patch_dim_b, config.dim, spec_height, spec_width ) patch_dim = patch_height * patch_width self.to_patch_embedding, self.pos_embedding = self._make_embeddings( patch_height, patch_width, patch_dim, config.dim, spec_height, spec_width ) self.transformer = Transformer(config.dim, config.depth, config.heads, config.dim_head, config.mlp_dim) self.pool = 'mean' self.to_latent = nn.Identity() self.linear_head = nn.Identity() def forward(self, spec, recurse=True): if self.hybrid_mode and recurse: a = self(spec, recurse=False) self.toggle_embeddings() b = self(spec, recurse=False) self.toggle_embeddings() return torch.cat((a, b), dim=-1) # if input shape is not 4d, make it 4d: if spec.dim() == 2: # unbatched: n_mels, n_frames spec = spec.unsqueeze(0).unsqueeze(0) elif spec.dim() == 3: # batched but without channels: B, n_mels, n_frames spec = spec.unsqueeze(1) assert spec.dim() == 4 device = spec.device x = self.to_patch_embedding(spec) n_patches = x.shape[1] # x is of shape (B, n_patches, dim) x += self.pos_embedding[:n_patches].to(device, dtype=x.dtype) x = self.transformer(x) x = x.mean(dim = 1) x = self.to_latent(x) return self.linear_head(x) def toggle_embeddings(self): if not self.additional_patch_size: print('toggle_embeddings() called but no additional patch size provided! Ignoring call.') return self.to_patch_embedding, self.to_patch_embedding_b = self.to_patch_embedding_b, self.to_patch_embedding self.pos_embedding, self.pos_embedding_b = self.pos_embedding_b, self.pos_embedding def _make_embeddings(self, patch_height, patch_width, patch_dim, dim, image_height, image_width): to_patch_embedding = nn.Sequential( Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), nn.LayerNorm(patch_dim), nn.Linear(patch_dim, dim), nn.LayerNorm(dim), ) pos_embedding = posemb_sincos_2d( h = image_height // patch_height, w = image_width // patch_width, dim = dim, ) return to_patch_embedding, pos_embedding def from_file(self, filename: str, n_samples: int = None): n_frames = self.config.n_frames if n_samples and n_samples != self.config.n_samples: n_frames = self.config._get_n_frames(n_samples) spec = self.preprocessor(filename, n_frames).to(self.device) return self(spec) @property def n_params(self): return sum(p.numel() for p in self.parameters()) def save_model_and_push(model, repo_name, save_dir='myna-temp', to_hub=False): model.save_pretrained(save_dir) shutil.copy('myna.py', save_dir) config = model.config.to_dict() config.update({ '_name_or_path': repo_name, 'architectures': ['Myna'], 'auto_map': { 'AutoConfig': 'myna.MynaConfig', 'AutoModel': 'myna.Myna' }, 'model_type': 'myna' }) with open(os.path.join(save_dir, 'config.json'), 'w') as f: json.dump(config, f, indent=4) print(f'Model saved locally to {save_dir}') if to_hub: api = HfApi() api.create_repo(repo_name, exist_ok=True) api.upload_folder(folder_path=save_dir, repo_id=repo_name) print(f"Model pushed to: https://huggingface.co/{repo_name}") if __name__ == '__main__': config = MynaConfig( arch='vit-b-16', # arch='vit-s-16', patch_size=16, additional_patch_size=(128, 2), hybrid_mode=True ) model = Myna(config) load_model(model, 'checkpoints/myna-85m.pth', verbose=True) print(f'Model contains {model.n_params:,} parameters') save_model_and_push( model, repo_name='oriyonay/myna-85m', save_dir='myna-85m-hybrid', to_hub=True )