oriyonay commited on
Commit
feb8c25
·
verified ·
1 Parent(s): 9a41a79

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. config.json +90 -0
  2. model.safetensors +3 -0
  3. myna.py +340 -0
config.json ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "return_dict": true,
3
+ "output_hidden_states": false,
4
+ "output_attentions": false,
5
+ "torchscript": false,
6
+ "torch_dtype": "float32",
7
+ "use_bfloat16": false,
8
+ "tf_legacy_loss": false,
9
+ "pruned_heads": {},
10
+ "tie_word_embeddings": true,
11
+ "chunk_size_feed_forward": 0,
12
+ "is_encoder_decoder": false,
13
+ "is_decoder": false,
14
+ "cross_attention_hidden_size": null,
15
+ "add_cross_attention": false,
16
+ "tie_encoder_decoder": false,
17
+ "max_length": 20,
18
+ "min_length": 0,
19
+ "do_sample": false,
20
+ "early_stopping": false,
21
+ "num_beams": 1,
22
+ "num_beam_groups": 1,
23
+ "diversity_penalty": 0.0,
24
+ "temperature": 1.0,
25
+ "top_k": 50,
26
+ "top_p": 1.0,
27
+ "typical_p": 1.0,
28
+ "repetition_penalty": 1.0,
29
+ "length_penalty": 1.0,
30
+ "no_repeat_ngram_size": 0,
31
+ "encoder_no_repeat_ngram_size": 0,
32
+ "bad_words_ids": null,
33
+ "num_return_sequences": 1,
34
+ "output_scores": false,
35
+ "return_dict_in_generate": false,
36
+ "forced_bos_token_id": null,
37
+ "forced_eos_token_id": null,
38
+ "remove_invalid_values": false,
39
+ "exponential_decay_length_penalty": null,
40
+ "suppress_tokens": null,
41
+ "begin_suppress_tokens": null,
42
+ "architectures": [
43
+ "Myna"
44
+ ],
45
+ "finetuning_task": null,
46
+ "id2label": {
47
+ "0": "LABEL_0",
48
+ "1": "LABEL_1"
49
+ },
50
+ "label2id": {
51
+ "LABEL_0": 0,
52
+ "LABEL_1": 1
53
+ },
54
+ "tokenizer_class": null,
55
+ "prefix": null,
56
+ "bos_token_id": null,
57
+ "pad_token_id": null,
58
+ "eos_token_id": null,
59
+ "sep_token_id": null,
60
+ "decoder_start_token_id": null,
61
+ "task_specific_params": null,
62
+ "problem_type": null,
63
+ "_name_or_path": "oriyonay/myna-85m",
64
+ "_attn_implementation_autoset": false,
65
+ "transformers_version": "4.48.0",
66
+ "spec_size": [
67
+ 128,
68
+ 4096
69
+ ],
70
+ "patch_size": 16,
71
+ "dim": 768,
72
+ "depth": 12,
73
+ "heads": 12,
74
+ "mlp_dim": 3072,
75
+ "dim_head": 64,
76
+ "arch": "vit-b-16",
77
+ "additional_patch_size": [
78
+ 128,
79
+ 2
80
+ ],
81
+ "hybrid_mode": true,
82
+ "n_samples": 50000,
83
+ "sr": 16000,
84
+ "n_frames": 96,
85
+ "model_type": "myna",
86
+ "auto_map": {
87
+ "AutoConfig": "myna.MynaConfig",
88
+ "AutoModel": "myna.Myna"
89
+ }
90
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd4b05fe43c9234e7637101ba007a2525cc14f504c5682192c7c4e1e866e4127
3
+ size 341685936
myna.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Modified from the vit_pytorch library: https://github.com/lucidrains/vit-pytorch
3
+ '''
4
+
5
+ from einops import rearrange
6
+ from einops.layers.torch import Rearrange
7
+ import json
8
+ import math
9
+ from nnAudio.features.mel import MelSpectrogram
10
+ import os
11
+ import torch
12
+ from torch import nn
13
+ import torchaudio
14
+ import torchaudio.transforms as T
15
+
16
+ # for uploading to huggingface hub
17
+ from huggingface_hub import HfApi, PyTorchModelHubMixin
18
+ from transformers import PretrainedConfig, PreTrainedModel
19
+ import shutil
20
+
21
+
22
+ def pair(t):
23
+ return t if isinstance(t, (tuple, list)) else (t, t)
24
+
25
+
26
+ def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
27
+ y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
28
+ assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
29
+ omega = torch.arange(dim // 4) / (dim // 4 - 1)
30
+ omega = 1.0 / (temperature ** omega)
31
+
32
+ y = y.flatten()[:, None] * omega[None, :]
33
+ x = x.flatten()[:, None] * omega[None, :]
34
+ pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
35
+ return pe.type(dtype)
36
+
37
+
38
+ def load_model(model: nn.Module, checkpoint_path: str, device: str = 'cpu', ignore_layers: list = ['linear_head'], verbose: bool = False):
39
+ checkpoint = torch.load(checkpoint_path, map_location=device)
40
+
41
+ filtered_state_dict = {
42
+ k: v for k, v in checkpoint.items()
43
+ if not any(k.startswith(layer) for layer in ignore_layers)
44
+ }
45
+
46
+ model.load_state_dict(filtered_state_dict, strict=False)
47
+
48
+ if ignore_layers and verbose:
49
+ print(f'==> Loaded model from {checkpoint_path}, ignoring layers: {", ".join(ignore_layers)}')
50
+
51
+
52
+ class FeedForward(nn.Module):
53
+ def __init__(self, dim, hidden_dim):
54
+ super().__init__()
55
+ self.net = nn.Sequential(
56
+ nn.LayerNorm(dim),
57
+ nn.Linear(dim, hidden_dim),
58
+ nn.GELU(),
59
+ nn.Linear(hidden_dim, dim),
60
+ )
61
+ def forward(self, x):
62
+ return self.net(x)
63
+
64
+
65
+ class Attention(nn.Module):
66
+ def __init__(self, dim, heads = 8, dim_head = 64):
67
+ super().__init__()
68
+ inner_dim = dim_head * heads
69
+ self.heads = heads
70
+ self.scale = dim_head ** -0.5
71
+ self.norm = nn.LayerNorm(dim)
72
+
73
+ self.attend = nn.Softmax(dim = -1)
74
+
75
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
76
+ self.to_out = nn.Linear(inner_dim, dim, bias = False)
77
+
78
+ def forward(self, x):
79
+ x = self.norm(x)
80
+
81
+ qkv = self.to_qkv(x).chunk(3, dim = -1)
82
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
83
+
84
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
85
+
86
+ attn = self.attend(dots)
87
+
88
+ out = torch.matmul(attn, v)
89
+ out = rearrange(out, 'b h n d -> b n (h d)')
90
+ return self.to_out(out)
91
+
92
+
93
+ class Transformer(nn.Module):
94
+ def __init__(self, dim, depth, heads, dim_head, mlp_dim):
95
+ super().__init__()
96
+ self.norm = nn.LayerNorm(dim)
97
+ self.layers = nn.ModuleList([])
98
+ for _ in range(depth):
99
+ self.layers.append(nn.ModuleList([
100
+ Attention(dim, heads = heads, dim_head = dim_head),
101
+ FeedForward(dim, mlp_dim)
102
+ ]))
103
+ def forward(self, x):
104
+ for attn, ff in self.layers:
105
+ x = attn(x) + x
106
+ x = ff(x) + x
107
+ return self.norm(x)
108
+
109
+
110
+ class MynaPreprocessor:
111
+ def __init__(self, target_sr: int = 16000, n_mels: int = 128):
112
+ self.target_sr = target_sr
113
+ self.n_mels = n_mels
114
+ self.mel_spec = MelSpectrogram(sr=target_sr, n_mels=n_mels, verbose=False)
115
+
116
+ def __call__(self, filename: str, n_frames: int = None):
117
+ # loads audio from file and returns a 3D tensor (B, n_mels, n_frames)
118
+ signal, sr = torchaudio.load(filename)
119
+ if signal.shape[0] > 1:
120
+ signal = signal.mean(dim=0, keepdim=True)
121
+ if sr != self.target_sr:
122
+ resampler = T.Resample(orig_freq=sr, new_freq=self.target_sr)
123
+ signal = resampler(signal)
124
+ ms = self.mel_spec(signal)
125
+
126
+ if n_frames:
127
+ ms = self._batch_spectrogram(ms, n_frames)
128
+
129
+ return ms
130
+
131
+ def _batch_spectrogram(self, ms: torch.Tensor, n_frames: int):
132
+ # sanity check
133
+ assert ms.dim() == 3 and ms.shape[0] == 1
134
+
135
+ # discard excess frames
136
+ num_chunks = ms.shape[-1] // n_frames
137
+ ms = ms[:, :, :num_chunks * n_frames]
138
+
139
+ # split the tensor into chunks and stack them
140
+ chunks = torch.chunk(ms, num_chunks, dim=2)
141
+ batch = torch.stack(chunks)
142
+
143
+ return batch
144
+
145
+
146
+ class MynaConfig(PretrainedConfig):
147
+ model_type = 'myna'
148
+ def __init__(
149
+ self, spec_size=(128, 4096), patch_size=16, dim=384, depth=12,
150
+ heads=6, mlp_dim=1536, dim_head = 64, arch=None, additional_patch_size = None,
151
+ hybrid_mode: bool = False, n_samples = 50000, sr = 16000, **kwargs
152
+ ):
153
+ super().__init__(**kwargs)
154
+ self.spec_size = spec_size
155
+ self.patch_size = patch_size
156
+ self.dim = dim
157
+ self.depth = depth
158
+ self.heads = heads
159
+ self.mlp_dim = mlp_dim
160
+ self.dim_head = dim_head
161
+ self.arch = arch
162
+ self.additional_patch_size = additional_patch_size
163
+ self.hybrid_mode = hybrid_mode
164
+
165
+ self.n_samples = n_samples # number of samples for inference
166
+ self.sr = sr # for preprocessing
167
+ self.n_frames = self._get_n_frames(n_samples)
168
+
169
+ # load architecture if provided
170
+ if arch:
171
+ arch = self._get_arch(arch)
172
+ self.dim = arch['dim']
173
+ self.depth = arch['depth']
174
+ self.heads = arch['heads']
175
+ self.mlp_dim = arch['mlp_dim']
176
+
177
+ def _get_arch(self, arch: str):
178
+ if arch.lower() in ['vit-s-16', 'vit-s-32']:
179
+ # dim 384, depth 12, MLP 1536, 6 heads, 22M parameters
180
+ return {'dim': 384, 'depth': 12, 'mlp_dim': 1536, 'heads': 6}
181
+ if arch.lower() == 'vit-b-16':
182
+ # dim 768, depth 12, MLP 3072, 12 heads, 87M parameters
183
+ return {'dim': 768, 'depth': 12, 'mlp_dim': 3072, 'heads': 12}
184
+ if arch.lower() == 'vit-l-16':
185
+ # dim 1024, depth 24, MLP 4096, 16 heads, 303M parameters
186
+ return {'dim': 1024, 'depth': 24, 'mlp_dim': 4096, 'heads': 16}
187
+
188
+ raise ValueError(f'Architecture {arch} not implemented')
189
+
190
+ def _get_n_frames(self, n_samples: int):
191
+ ''' How many frames is n_samples samples? '''
192
+ mel_spectrogram = MelSpectrogram(sr=self.sr, n_mels=self.spec_size[0], verbose=False)
193
+ patch_size_time = self.patch_size if isinstance(self.patch_size, int) else self.patch_size[1]
194
+ mel_frames = mel_spectrogram(torch.randn(1, 1, n_samples)).shape[-1]
195
+ mel_frames = math.floor(mel_frames / patch_size_time) * patch_size_time
196
+ return mel_frames
197
+
198
+
199
+ class Myna(PreTrainedModel, PyTorchModelHubMixin):
200
+ config_class = MynaConfig
201
+ def __init__(self, config: MynaConfig):
202
+ super().__init__(config)
203
+
204
+ self.preprocessor = MynaPreprocessor()
205
+ self.hybrid_mode = config.hybrid_mode
206
+ spec_height, spec_width = pair(config.spec_size)
207
+ patch_height, patch_width = pair(config.patch_size)
208
+
209
+ assert spec_height % patch_height == 0 and spec_width % patch_width == 0, 'Spectrogram dimensions must be divisible by the patch size.'
210
+
211
+ self.additional_patch_size = config.additional_patch_size
212
+ if config.additional_patch_size:
213
+ patch_height_b, patch_width_b = pair(config.additional_patch_size)
214
+ patch_dim_b = patch_height_b * patch_width_b
215
+
216
+ self.to_patch_embedding_b, self.pos_embedding_b = self._make_embeddings(
217
+ patch_height_b, patch_width_b, patch_dim_b, config.dim, spec_height, spec_width
218
+ )
219
+
220
+ patch_dim = patch_height * patch_width
221
+
222
+ self.to_patch_embedding, self.pos_embedding = self._make_embeddings(
223
+ patch_height, patch_width, patch_dim, config.dim, spec_height, spec_width
224
+ )
225
+
226
+ self.transformer = Transformer(config.dim, config.depth, config.heads, config.dim_head, config.mlp_dim)
227
+
228
+ self.pool = 'mean'
229
+ self.to_latent = nn.Identity()
230
+
231
+ self.linear_head = nn.Identity()
232
+
233
+ def forward(self, spec, recurse=True):
234
+ if self.hybrid_mode and recurse:
235
+ a = self(spec, recurse=False)
236
+ self.toggle_embeddings()
237
+ b = self(spec, recurse=False)
238
+ self.toggle_embeddings()
239
+ return torch.cat((a, b), dim=-1)
240
+
241
+ # if input shape is not 4d, make it 4d:
242
+ if spec.dim() == 2:
243
+ # unbatched: n_mels, n_frames
244
+ spec = spec.unsqueeze(0).unsqueeze(0)
245
+ elif spec.dim() == 3:
246
+ # batched but without channels: B, n_mels, n_frames
247
+ spec = spec.unsqueeze(1)
248
+ assert spec.dim() == 4
249
+
250
+ device = spec.device
251
+
252
+ x = self.to_patch_embedding(spec)
253
+ n_patches = x.shape[1] # x is of shape (B, n_patches, dim)
254
+ x += self.pos_embedding[:n_patches].to(device, dtype=x.dtype)
255
+
256
+ x = self.transformer(x)
257
+ x = x.mean(dim = 1)
258
+
259
+ x = self.to_latent(x)
260
+ return self.linear_head(x)
261
+
262
+ def toggle_embeddings(self):
263
+ if not self.additional_patch_size:
264
+ print('toggle_embeddings() called but no additional patch size provided! Ignoring call.')
265
+ return
266
+ self.to_patch_embedding, self.to_patch_embedding_b = self.to_patch_embedding_b, self.to_patch_embedding
267
+ self.pos_embedding, self.pos_embedding_b = self.pos_embedding_b, self.pos_embedding
268
+
269
+ def _make_embeddings(self, patch_height, patch_width, patch_dim, dim, image_height, image_width):
270
+ to_patch_embedding = nn.Sequential(
271
+ Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
272
+ nn.LayerNorm(patch_dim),
273
+ nn.Linear(patch_dim, dim),
274
+ nn.LayerNorm(dim),
275
+ )
276
+
277
+ pos_embedding = posemb_sincos_2d(
278
+ h = image_height // patch_height,
279
+ w = image_width // patch_width,
280
+ dim = dim,
281
+ )
282
+
283
+ return to_patch_embedding, pos_embedding
284
+
285
+ def from_file(self, filename: str, n_samples: int = None):
286
+ n_frames = self.config.n_frames
287
+ if n_samples and n_samples != self.config.n_samples:
288
+ n_frames = self.config._get_n_frames(n_samples)
289
+ spec = self.preprocessor(filename, n_frames).to(self.device)
290
+ return self(spec)
291
+
292
+ @property
293
+ def n_params(self):
294
+ return sum(p.numel() for p in self.parameters())
295
+
296
+
297
+ def save_model_and_push(model, repo_name, save_dir='myna-temp', to_hub=False):
298
+ model.save_pretrained(save_dir)
299
+ shutil.copy('myna.py', save_dir)
300
+
301
+ config = model.config.to_dict()
302
+ config.update({
303
+ '_name_or_path': repo_name,
304
+ 'architectures': ['Myna'],
305
+ 'auto_map': {
306
+ 'AutoConfig': 'myna.MynaConfig',
307
+ 'AutoModel': 'myna.Myna'
308
+ },
309
+ 'model_type': 'myna'
310
+ })
311
+
312
+ with open(os.path.join(save_dir, 'config.json'), 'w') as f:
313
+ json.dump(config, f, indent=4)
314
+
315
+ print(f'Model saved locally to {save_dir}')
316
+
317
+ if to_hub:
318
+ api = HfApi()
319
+ api.create_repo(repo_name, exist_ok=True)
320
+ api.upload_folder(folder_path=save_dir, repo_id=repo_name)
321
+ print(f"Model pushed to: https://huggingface.co/{repo_name}")
322
+
323
+
324
+ if __name__ == '__main__':
325
+ config = MynaConfig(
326
+ arch='vit-b-16', # arch='vit-s-16',
327
+ patch_size=16,
328
+ additional_patch_size=(128, 2),
329
+ hybrid_mode=True
330
+ )
331
+ model = Myna(config)
332
+ load_model(model, 'checkpoints/myna-85m.pth', verbose=True)
333
+ print(f'Model contains {model.n_params:,} parameters')
334
+
335
+ save_model_and_push(
336
+ model,
337
+ repo_name='oriyonay/myna-85m',
338
+ save_dir='myna-85m-hybrid',
339
+ to_hub=True
340
+ )