Myna-Base
From "Myna: Masking-Based Contrastive Learning of Musical Representations"
Model Overview
Myna is a self-supervised contrastive model designed for musical representation learning. It employs a Vision Transformer (ViT) backbone on mel-spectrograms and introduces token masking as its primary augmentation method. Unlike traditional contrastive learning frameworks that rely on augmentations such as pitch shifts, Myna retains pitch sensitivity, leading to improvements in key detection tasks.
Abstract
In this paper, we present Myna, a simple yet effective approach for self-supervised musical representation learning. Built on a contrastive learning framework, Myna introduces two key innovations:
- The use of a Vision Transformer (ViT) on mel-spectrograms as the backbone, replacing SampleCNN on raw audio.
- A novel token masking strategy that masks 90% of spectrogram tokens (e.g., 16x16 patches).
These innovations deliver both effectiveness and efficiency:
- Token masking enables a significant increase in per-GPU batch size, from 48 or 120 in traditional contrastive methods (e.g., CLMR, MULE) to 4096.
- Avoiding traditional augmentations (e.g., pitch shifts) retains pitch sensitivity, enhancing performance in tasks like key detection.
- The use of vertical patches (128x2 instead of 16x16) allows the model to better capture critical features for key detection.
Our hybrid model, Myna-22M-Hybrid, processes both 16x16 and 128x2 patches, achieving state-of-the-art results. Trained on a single GPU, it outperforms MULE (62M) and rivals MERT-95M, which was trained on 16 and 64 GPUs, respectively. Additionally, it surpasses MERT-95M-public, establishing itself as the best-performing model trained on publicly available data.
Installation
To use Myna, install the necessary dependencies:
pip3 install -q nnAudio transformers torch
Usage
import torch
from transformers import AutoModel
model = AutoModel.from_pretrained('oriyonay/myna-85m')
# Myna supports unbatched (2D) and batched (3D or 4D) inputs:
output = model(torch.randn(128, 96)) # shape (1, 1536)
output = model(torch.randn(2, 128, 96)) # shape (2, 1536)
output = model(torch.randn(2, 1, 128, 96)) # shape (2, 1536)
# Additionally, you can load audio directly from a file:
output = model.from_file('your_file.wav') # shape (n_chunks, 1536)
- Downloads last month
- 31